diff --git a/custom_components/auth_oidc/oidc_client.py b/custom_components/auth_oidc/oidc_client.py index ab10b85..f2228fc 100644 --- a/custom_components/auth_oidc/oidc_client.py +++ b/custom_components/auth_oidc/oidc_client.py @@ -47,6 +47,10 @@ class OIDCStateInvalid(OIDCClientException): "Raised when the state for your request cannot be matched against a stored state." +class OIDCUserinfoInvalid(OIDCClientException): + "Raised when the user info is invalid or cannot be obtained." + + class OIDCIdTokenSigningAlgorithmInvalid(OIDCTokenResponseInvalid): "Raised when the id_token is signed with the wrong algorithm, adjust your config accordingly." @@ -220,6 +224,19 @@ class OIDCClient: raise OIDCTokenResponseInvalid from e + async def _get_userinfo(self, userinfo_uri, access_token): + """Fetches userinfo from the given URL.""" + try: + session = await self._get_http_session() + headers = {"Authorization": "Bearer " + access_token} + + async with session.get(userinfo_uri, headers=headers) as response: + await self.http_raise_for_status(response) + return await response.json() + except HTTPClientError as e: + _LOGGER.warning("Error fetching userinfo: %s", e) + raise OIDCUserinfoInvalid from e + async def _parse_id_token( self, id_token: str, access_token: str | None ) -> Optional[dict]: @@ -395,6 +412,57 @@ class OIDCClient: _LOGGER.warning("Error generating authorization URL: %s", e) return None + async def parse_user_details(self, id_token: str, access_token: str) -> UserDetails: + """Parses the ID token and/or userinfo into user details.""" + + # Fetch userinfo if there is an userinfo_endpoint available + # and use the data to supply the missing values in id_token + if "userinfo_endpoint" in self.discovery_document: + userinfo_endpoint = self.discovery_document["userinfo_endpoint"] + userinfo = await self._get_userinfo(userinfo_endpoint, access_token) + + # Replace missing claims in the id_token with their userinfo version + for claim in ( + self.groups_claim, + self.display_name_claim, + self.username_claim, + ): + if claim not in id_token and claim in userinfo: + id_token[claim] = userinfo[claim] + + # Get and parse groups (to check if it's an array) + groups = id_token.get(self.groups_claim, []) + if not isinstance(groups, list): + _LOGGER.warning("Groups claim is not a list, using empty list instead.") + groups = [] + + # Assign role if user has the required groups + role = "invalid" + if self.user_role in groups or self.user_role is None: + role = "system-users" + + if self.admin_role in groups: + role = "system-admin" + + # Create a user details dict based on the contents of the id_token & userinfo + return { + # Subject Identifier. A locally unique and never reassigned identifier within the + # Issuer for the End-User, which is intended to be consumed by the Client + # Only unique per issuer, so we combine it with the issuer and hash it. + # This might allow multiple OIDC providers to be used with this integration. + "sub": hashlib.sha256( + f"{self.discovery_document['issuer']}.{id_token.get('sub')}".encode( + "utf-8" + ) + ).hexdigest(), + # Display name, configurable + "display_name": id_token.get(self.display_name_claim), + # Username, configurable + "username": id_token.get(self.username_claim), + # Role + "role": role, + } + async def async_complete_token_flow( self, redirect_uri: str, code: str, state: str ) -> Optional[UserDetails]: @@ -451,40 +519,7 @@ class OIDCClient: _LOGGER.warning("Nonce mismatch!") return None - # TODO: If the configured claims are not present in id_token, we should fetch userinfo - - # Get and parse groups (to check if it's an array) - groups = id_token.get(self.groups_claim, []) - if not isinstance(groups, list): - _LOGGER.warning("Groups claim is not a list, using empty list instead.") - groups = [] - - # Assign role if user has the required groups - role = "invalid" - if self.user_role in groups or self.user_role is None: - role = "system-users" - - if self.admin_role in groups: - role = "system-admin" - - # Create a user details dict based on the contents of the id_token & userinfo - data: UserDetails = { - # Subject Identifier. A locally unique and never reassigned identifier within the - # Issuer for the End-User, which is intended to be consumed by the Client - # Only unique per issuer, so we combine it with the issuer and hash it. - # This might allow multiple OIDC providers to be used with this integration. - "sub": hashlib.sha256( - f"{self.discovery_document['issuer']}.{id_token.get('sub')}".encode( - "utf-8" - ) - ).hexdigest(), - # Display name, configurable - "display_name": id_token.get(self.display_name_claim), - # Username, configurable - "username": id_token.get(self.username_claim), - # Role - "role": role, - } + data = await self.parse_user_details(id_token, access_token) # Log which details were obtained for debugging # Also log the original subject identifier such that you can look it up in your provider