Migrate to joserfc, remove python-jose (#150)

This commit is contained in:
Christiaan Goossens
2025-10-31 10:16:45 +01:00
committed by GitHub
parent a8e0162d25
commit 674c342a81
3 changed files with 37 additions and 106 deletions

View File

@@ -9,7 +9,7 @@ import ssl
from typing import Optional
from functools import partial
import aiohttp
from jose import jwt, jwk
from joserfc import jwt, jwk, jws, errors as joserfc_errors
from homeassistant.core import HomeAssistant
from .types import UserDetails
@@ -433,9 +433,7 @@ class OIDCClient:
"""Fetches JWKS."""
return await self.discovery_class.fetch_jwks(jwks_uri)
async def _parse_id_token(
self, id_token: str, access_token: str | None
) -> Optional[dict]:
async def _parse_id_token(self, id_token: str) -> Optional[dict]:
"""Parses the ID token into a dict containing token contents."""
if self.discovery_document is None:
self.discovery_document = await self._fetch_discovery_document()
@@ -445,7 +443,8 @@ class OIDCClient:
try:
# Obtain the id_token header
unverified_header = jwt.get_unverified_header(id_token)
token_obj = jws.extract_compact(id_token.encode())
unverified_header = token_obj.protected
if not unverified_header:
_LOGGER.warning("Could not get header from received id_token.")
return None
@@ -474,7 +473,7 @@ class OIDCClient:
)
raise OIDCIdTokenSigningAlgorithmInvalid()
jwk_obj = jwk.construct(
jwk_obj = jwk.import_key(
{
"kty": "oct",
"k": base64.urlsafe_b64encode(
@@ -507,9 +506,9 @@ class OIDCClient:
signing_key["alg"] = alg
# Construct the JWK from the RSA key
jwk_obj = jwk.construct(signing_key)
jwk_obj = jwk.import_key(signing_key)
# Verify the token
# Decode the token, decode does not verify it
decoded_token = jwt.decode(
id_token,
jwk_obj,
@@ -518,48 +517,31 @@ class OIDCClient:
# according to JWS [JWS] using the algorithm specified in the JWT
# alg Header Parameter.
algorithms=[self.id_token_signing_alg],
)
# Create Claims Registry for validation
id_token_validator = jwt.JWTClaimsRegistry(
leeway=5,
# OpenID Connect Core 1.0 Section 3.1.3.7.3
# The Client MUST validate that the aud (audience) Claim contains
# its client_id value registered at the Issuer identified by the
# iss (issuer) Claim as an audience.
audience=self.client_id,
aud={"essential": True, "value": self.client_id},
# OpenID Connect Core 1.0 Section 3.1.3.7.2
# The Issuer Identifier for the OpenID Provider MUST exactly
# match the value of the iss (issuer) Claim.
issuer=self.discovery_document["issuer"],
access_token=access_token,
options={
# Verify everything if present
"verify_signature": True,
"verify_aud": True,
"verify_iat": True,
"verify_exp": True,
"verify_nbf": True,
"verify_iss": True,
"verify_sub": True,
"verify_jti": True,
"verify_at_hash": True,
# OpenID Connect Core 1.0 Section 3.1.3.7.3
"require_aud": True,
# OpenID Connect Core 1.0 Section 3.1.3.7.10
"require_iat": True,
# OpenID Connect Core 1.0 Section 3.1.3.7.9
"require_exp": True,
# OpenID Connect Core 1.0 Section 3.1.3.7.2
"require_iss": True,
# We need the sub as it's used to identify the user
"require_sub": True,
# Other values, not required.
"require_nbf": False,
"require_jti": False,
"require_at_hash": False,
"leeway": 5,
},
iss={"essential": True, "value": self.discovery_document["issuer"]},
# OpenID Connect Core 1.0 Section 3.1.3.7.9
# OpenID Connect Core 1.0 Section 3.1.3.7.10
# No need to specify exp, nbf, iat, they are in here by default
sub={"essential": True},
)
return decoded_token
except jwt.JWTError as e:
_LOGGER.warning("JWT Verification failed: %s", e)
id_token_validator.validate(decoded_token.claims)
return decoded_token.claims
except joserfc_errors.JoseError as e:
_LOGGER.warning("JWT verification failed: %s", e)
return None
async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]:
@@ -692,11 +674,9 @@ class OIDCClient:
)
id_token = token_response.get("id_token")
access_token = token_response.get("access_token")
# Parse the id token to obtain the relevant details
# Access token is supplied to check at_hash if present
id_token = await self._parse_id_token(id_token, access_token)
id_token = await self._parse_id_token(id_token)
if id_token is None:
_LOGGER.warning("ID token could not be parsed!")
@@ -710,6 +690,7 @@ class OIDCClient:
_LOGGER.warning("Nonce mismatch!")
return None
access_token = token_response.get("access_token")
data = await self.parse_user_details(id_token, access_token)
# Log which details were obtained for debugging