From 0ca300c385f8450bf0cd2b4de1d182fa7177a628 Mon Sep 17 00:00:00 2001 From: Christiaan Goossens <9487666+christiaangoossens@users.noreply.github.com> Date: Tue, 14 Apr 2026 15:29:06 +0200 Subject: [PATCH] Add tests for other signing methods (#246) * Add tests for other signing methods #151 * Add doc for list source --- tests/test_hass_oidc_client_unit.py | 120 ++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/tests/test_hass_oidc_client_unit.py b/tests/test_hass_oidc_client_unit.py index 4592ca3..a486dfa 100644 --- a/tests/test_hass_oidc_client_unit.py +++ b/tests/test_hass_oidc_client_unit.py @@ -23,6 +23,25 @@ from custom_components.auth_oidc.tools.oidc_client import ( http_raise_for_status, ) +# List from https://jose.authlib.org/en/guide/algorithms/#json-web-signature +ALL_ID_TOKEN_SIGNING_ALGORITHMS = ( + "HS256", + "HS384", + "HS512", + "RS256", + "RS384", + "RS512", + "ES256", + "ES384", + "ES512", + "PS256", + "PS384", + "PS512", + "ES256K", + "Ed25519", + "Ed448", +) + def make_client(hass: HomeAssistant, **kwargs) -> OIDCClient: """Build an OIDC client with explicit defaults for unit testing.""" @@ -67,6 +86,51 @@ def make_signed_hs256_jwt(secret: str, claims: dict) -> str: return jwt.encode({"alg": "HS256"}, claims, jwk_obj) +def build_real_signed_token(algorithm: str, claims: dict, secret: str) -> tuple[str, dict]: + """Build a real signed token and matching JWKS payload for a given algorithm.""" + if algorithm.startswith("HS"): + signing_key = jwk.import_key( + { + "kty": "oct", + "k": base64.urlsafe_b64encode(secret.encode()).decode().rstrip("="), + "alg": algorithm, + } + ) + token = jwt.encode({"alg": algorithm}, claims, signing_key, algorithms=[algorithm]) + return token, {"keys": []} + + if algorithm in ("RS256", "RS384", "RS512", "PS256", "PS384", "PS512"): + key = jwk.generate_key( + "RSA", 2048, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True + ) + elif algorithm in ("ES256", "ES384", "ES512", "ES256K"): + curve = { + "ES256": "P-256", + "ES384": "P-384", + "ES512": "P-521", + "ES256K": "secp256k1", + }[algorithm] + key = jwk.generate_key( + "EC", curve, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True + ) + elif algorithm in ("Ed25519", "Ed448"): + key = jwk.generate_key( + "OKP", algorithm, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True + ) + else: + raise ValueError(f"Unsupported test algorithm: {algorithm}") + + kid = key.kid + token = jwt.encode( + {"alg": algorithm, "kid": kid}, + claims, + key, + algorithms=[algorithm], + ) + public_key = key.as_dict(private=False) + return token, {"keys": [public_key]} + + @pytest.mark.asyncio async def test_complete_token_flow_rejects_missing_state(hass: HomeAssistant): """Flow state must exist; missing state should fail closed.""" @@ -447,6 +511,62 @@ async def test_parse_id_token_rejects_invalid_registered_claims(hass: HomeAssist assert parsed is None +@pytest.mark.asyncio +@pytest.mark.parametrize("algorithm", ALL_ID_TOKEN_SIGNING_ALGORITHMS) +async def test_parse_id_token_validates_real_signed_tokens_and_decode_inputs( + hass: HomeAssistant, algorithm: str +): + """Use real signatures and verify token/key/algorithm passed into joserfc.""" + secret = "top-secret-value" + client_kwargs = {"id_token_signing_alg": algorithm} + if algorithm.startswith("HS"): + client_kwargs["client_secret"] = secret + + client = make_client(hass, **client_kwargs) + client.discovery_document = { + "issuer": "https://issuer", + "jwks_uri": "https://issuer/jwks", + } + + now = int(time.time()) + claims = { + "sub": "subject-1", + "aud": "test-client", + "iss": "https://issuer", + "nbf": now, + "iat": now, + "exp": now + 3600, + } + + token, jwks_payload = build_real_signed_token(algorithm, claims, secret) + + with ( + patch.object(client, "_fetch_jwks", new=AsyncMock(return_value=jwks_payload)), + patch( + "custom_components.auth_oidc.tools.oidc_client.jwt.decode", + wraps=jwt.decode, + ) as decode_spy, + patch( + "custom_components.auth_oidc.tools.oidc_client.jwk.import_key", + wraps=jwk.import_key, + ) as import_key_spy, + ): + parsed = await client._parse_id_token(token) + + assert parsed == claims + decode_spy.assert_called_once() + assert decode_spy.call_args.args[0] == token + assert decode_spy.call_args.kwargs["algorithms"] == [algorithm] + + import_key_spy.assert_called() + imported_key_payload = import_key_spy.call_args.args[0] + assert imported_key_payload["alg"] == algorithm + if algorithm.startswith("HS"): + assert imported_key_payload["kty"] == "oct" + else: + assert imported_key_payload["kid"] is not None + + @pytest.mark.asyncio async def test_get_authorization_url_returns_none_when_discovery_fails( hass: HomeAssistant,