Fix compatibility with Microsoft Entra ID (#48)
* Fixes necessary for Entra ID * Better error * Bump 0.6.1 * Also bump manifest * Linting
This commit is contained in:
committed by
GitHub
parent
f24519787b
commit
6e56311176
@@ -51,6 +51,19 @@ class OIDCIdTokenSigningAlgorithmInvalid(OIDCTokenResponseInvalid):
|
||||
"Raised when the id_token is signed with the wrong algorithm, adjust your config accordingly."
|
||||
|
||||
|
||||
class HTTPClientError(aiohttp.ClientResponseError):
|
||||
"Raised when the HTTP client encounters not OK (200) status code."
|
||||
|
||||
body: str
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.body = kwargs.pop("body")
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.status} ({self.message}) with response body: {self.body}"
|
||||
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
class OIDCClient:
|
||||
"""OIDC Client implementation for Python, including PKCE."""
|
||||
@@ -105,6 +118,23 @@ class OIDCClient:
|
||||
_LOGGER.debug("Closing HTTP session")
|
||||
self.http_session.close()
|
||||
|
||||
async def http_raise_for_status(self, response: aiohttp.ClientResponse) -> None:
|
||||
"""Raises an exception if the response is not OK."""
|
||||
if not response.ok:
|
||||
# reason should always be not None for a started response
|
||||
assert response.reason is not None
|
||||
|
||||
body = await response.text()
|
||||
|
||||
raise HTTPClientError(
|
||||
response.request_info,
|
||||
response.history,
|
||||
status=response.status,
|
||||
message=response.reason,
|
||||
headers=response.headers,
|
||||
body=body,
|
||||
)
|
||||
|
||||
def _base64url_encode(self, value: str) -> str:
|
||||
"""Uses base64url encoding on a given string"""
|
||||
return base64.urlsafe_b64encode(value).rstrip(b"=").decode("utf-8")
|
||||
@@ -145,15 +175,15 @@ class OIDCClient:
|
||||
session = await self._get_http_session()
|
||||
|
||||
async with session.get(self.discovery_url) as response:
|
||||
response.raise_for_status()
|
||||
await self.http_raise_for_status(response)
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
except HTTPClientError as e:
|
||||
if e.status == 404:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document not found at %s", self.discovery_url
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning("Error: %s - %s", e.status, e.message)
|
||||
_LOGGER.warning("Error fetching discovery: %s", e)
|
||||
raise OIDCDiscoveryInvalid from e
|
||||
|
||||
async def _get_jwks(self, jwks_uri):
|
||||
@@ -162,10 +192,10 @@ class OIDCClient:
|
||||
session = await self._get_http_session()
|
||||
|
||||
async with session.get(jwks_uri) as response:
|
||||
response.raise_for_status()
|
||||
await self.http_raise_for_status(response)
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
_LOGGER.warning("Error fetching JWKS: %s - %s", e.status, e.message)
|
||||
except HTTPClientError as e:
|
||||
_LOGGER.warning("Error fetching JWKS: %s", e)
|
||||
raise OIDCJWKSInvalid from e
|
||||
|
||||
async def _make_token_request(self, token_endpoint, query_params):
|
||||
@@ -174,18 +204,20 @@ class OIDCClient:
|
||||
session = await self._get_http_session()
|
||||
|
||||
async with session.post(token_endpoint, data=query_params) as response:
|
||||
response.raise_for_status()
|
||||
await self.http_raise_for_status(response)
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
except HTTPClientError as e:
|
||||
if e.status == 400:
|
||||
_LOGGER.warning(
|
||||
"Error: Token could not be obtained (Bad Request), "
|
||||
+ "did you forget the client_secret?"
|
||||
"Error: Token could not be obtained (%s, %s), "
|
||||
+ "did you forget the client_secret? Server returned: %s",
|
||||
e.status,
|
||||
e.message,
|
||||
e.body,
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Unexpected error exchanging token: %s - %s", e.status, e.message
|
||||
)
|
||||
_LOGGER.warning("Unexpected error exchanging token: %s", e)
|
||||
|
||||
raise OIDCTokenResponseInvalid from e
|
||||
|
||||
async def _parse_id_token(
|
||||
@@ -257,6 +289,10 @@ class OIDCClient:
|
||||
_LOGGER.warning("Could not find matching key with kid: %s", kid)
|
||||
return None
|
||||
|
||||
# If signing_key does not have alg, set it to the one passed in the token
|
||||
if "alg" not in signing_key:
|
||||
signing_key["alg"] = alg
|
||||
|
||||
# Construct the JWK from the RSA key
|
||||
jwk_obj = jwk.construct(signing_key)
|
||||
|
||||
@@ -459,5 +495,5 @@ class OIDCClient:
|
||||
)
|
||||
return data
|
||||
except OIDCClientException as e:
|
||||
_LOGGER.warning("Error completing token flow: %s", e)
|
||||
_LOGGER.warning("Failed to complete token flow, returning None. (%s)", e)
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user