Files
hass-oidc-auth/tests/test_hass_oidc_client_unit.py
Christiaan Goossens 07c1e3a4c4 Fix regression of storeToken parameter (#248)
* Try a different method to set ?storeToken

* Formatting

* Only insert storeToken on web client & fix tests
2026-04-15 12:07:19 +02:00

819 lines
25 KiB
Python

"""Unit tests for OIDC client token and security behavior."""
# pylint: disable=protected-access
import hashlib
import json
import base64
import time
from urllib.parse import parse_qs, urlparse
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from homeassistant.core import HomeAssistant
from joserfc import errors as joserfc_errors, jwt, jwk
from custom_components.auth_oidc.tools.oidc_client import (
HTTPClientError,
OIDCClient,
OIDCDiscoveryInvalid,
OIDCIdTokenSigningAlgorithmInvalid,
OIDCTokenResponseInvalid,
OIDCUserinfoInvalid,
http_raise_for_status,
)
# List from https://jose.authlib.org/en/guide/algorithms/#json-web-signature
ALL_ID_TOKEN_SIGNING_ALGORITHMS = (
"HS256",
"HS384",
"HS512",
"RS256",
"RS384",
"RS512",
"ES256",
"ES384",
"ES512",
"PS256",
"PS384",
"PS512",
"ES256K",
"Ed25519",
"Ed448",
)
def make_client(hass: HomeAssistant, **kwargs) -> OIDCClient:
"""Build an OIDC client with explicit defaults for unit testing."""
return OIDCClient(
hass=hass,
discovery_url="https://issuer/.well-known/openid-configuration",
client_id="test-client",
scope="openid profile",
features=kwargs.pop("features", {}),
claims=kwargs.pop("claims", {}),
roles=kwargs.pop("roles", {}),
network=kwargs.pop("network", {}),
**kwargs,
)
def make_jwt(
header: dict | None,
payload: dict | None = None,
signature: str = "sig",
) -> str:
"""Build a compact JWT string for parser-focused tests."""
def _b64url_json(data: dict) -> str:
encoded = json.dumps(data, separators=(",", ":")).encode("utf-8")
return base64.urlsafe_b64encode(encoded).rstrip(b"=").decode("utf-8")
protected = _b64url_json(header) if header is not None else ""
claims = _b64url_json(payload or {"sub": "subject"})
return f"{protected}.{claims}.{signature}"
def make_signed_hs256_jwt(secret: str, claims: dict) -> str:
"""Build a real HS256 signed JWT for parser validation tests."""
jwk_obj = jwk.import_key(
{
"kty": "oct",
"k": base64.urlsafe_b64encode(secret.encode()).decode().rstrip("="),
"alg": "HS256",
}
)
return jwt.encode({"alg": "HS256"}, claims, jwk_obj)
def build_real_signed_token(
algorithm: str, claims: dict, secret: str
) -> tuple[str, dict]:
"""Build a real signed token and matching JWKS payload for a given algorithm."""
if algorithm.startswith("HS"):
signing_key = jwk.import_key(
{
"kty": "oct",
"k": base64.urlsafe_b64encode(secret.encode()).decode().rstrip("="),
"alg": algorithm,
}
)
token = jwt.encode(
{"alg": algorithm}, claims, signing_key, algorithms=[algorithm]
)
return token, {"keys": []}
if algorithm in ("RS256", "RS384", "RS512", "PS256", "PS384", "PS512"):
key = jwk.generate_key(
"RSA", 2048, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True
)
elif algorithm in ("ES256", "ES384", "ES512", "ES256K"):
curve = {
"ES256": "P-256",
"ES384": "P-384",
"ES512": "P-521",
"ES256K": "secp256k1",
}[algorithm]
key = jwk.generate_key(
"EC", curve, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True
)
elif algorithm in ("Ed25519", "Ed448"):
key = jwk.generate_key(
"OKP",
algorithm,
{"alg": algorithm, "use": "sig"},
private=True,
auto_kid=True,
)
else:
raise ValueError(f"Unsupported test algorithm: {algorithm}")
kid = key.kid
token = jwt.encode(
{"alg": algorithm, "kid": kid},
claims,
key,
algorithms=[algorithm],
)
public_key = key.as_dict(private=False)
return token, {"keys": [public_key]}
@pytest.mark.asyncio
async def test_complete_token_flow_rejects_missing_state(hass: HomeAssistant):
"""Flow state must exist; missing state should fail closed."""
client = make_client(hass)
result = await client.async_complete_token_flow(
"https://example.com/callback", "code", "missing-state"
)
assert result is None
@pytest.mark.asyncio
async def test_complete_token_flow_rejects_nonce_mismatch(hass: HomeAssistant):
"""Nonce mismatch should reject the token flow."""
client = make_client(hass)
client.flows["state-1"] = {"code_verifier": "verifier", "nonce": "expected"}
with (
patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
),
patch.object(
client,
"_make_token_request",
new=AsyncMock(return_value={"id_token": "id", "access_token": "access"}),
),
patch.object(
client,
"_parse_id_token",
new=AsyncMock(return_value={"sub": "abc", "nonce": "wrong"}),
),
):
result = await client.async_complete_token_flow(
"https://example.com/callback", "code", "state-1"
)
assert result is None
assert "state-1" not in client.flows
@pytest.mark.asyncio
async def test_complete_token_flow_handles_token_request_failure(hass: HomeAssistant):
"""Token endpoint failures should return None to caller."""
client = make_client(hass)
client.flows["state-2"] = {"code_verifier": "verifier", "nonce": "nonce"}
with (
patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
),
patch.object(
client,
"_make_token_request",
new=AsyncMock(side_effect=OIDCTokenResponseInvalid()),
),
):
result = await client.async_complete_token_flow(
"https://example.com/callback", "code", "state-2"
)
assert result is None
@pytest.mark.asyncio
async def test_parse_user_details_handles_non_list_groups(hass: HomeAssistant):
"""Non-list groups should not accidentally grant roles."""
client = make_client(hass, roles={"user": "users", "admin": "admins"})
with patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"issuer": "https://issuer"}),
):
details = await client.parse_user_details(
{
"sub": "subject",
"name": "Display Name",
"preferred_username": "username",
"groups": "admins",
},
"access-token",
)
assert details["role"] == "invalid"
assert details["display_name"] == "Display Name"
assert details["username"] == "username"
@pytest.mark.asyncio
async def test_parse_user_details_uses_userinfo_for_missing_claims(
hass: HomeAssistant,
):
"""Missing claims in id_token should be filled from userinfo when available."""
client = make_client(hass)
with (
patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(
return_value={
"issuer": "https://issuer",
"userinfo_endpoint": "https://issuer/userinfo",
}
),
),
patch.object(
client,
"_get_userinfo",
new=AsyncMock(
return_value={
"name": "From UserInfo",
"preferred_username": "userinfo-user",
"groups": ["admins"],
}
),
),
):
details = await client.parse_user_details({"sub": "subject"}, "access-token")
expected_sub = hashlib.sha256("https://issuer.subject".encode("utf-8")).hexdigest()
assert details["sub"] == expected_sub
assert details["display_name"] == "From UserInfo"
assert details["username"] == "userinfo-user"
assert details["role"] == "system-admin"
@pytest.mark.asyncio
async def test_parse_user_details_assigns_system_users_role(hass: HomeAssistant):
"""Configured user role should map to system-users when group is present."""
client = make_client(hass, roles={"user": "users", "admin": "admins"})
with patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"issuer": "https://issuer"}),
):
details = await client.parse_user_details(
{
"sub": "subject",
"name": "Display Name",
"preferred_username": "username",
"groups": ["users"],
},
"access-token",
)
assert details["role"] == "system-users"
@pytest.mark.asyncio
async def test_parse_user_details_admin_role_overrides_user_role(
hass: HomeAssistant,
):
"""Admin group should take precedence when both user and admin groups are present."""
client = make_client(hass, roles={"user": "users", "admin": "admins"})
with patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"issuer": "https://issuer"}),
):
details = await client.parse_user_details(
{
"sub": "subject",
"name": "Display Name",
"preferred_username": "username",
"groups": ["users", "admins"],
},
"access-token",
)
assert details["role"] == "system-admin"
@pytest.mark.asyncio
async def test_get_authorization_url_omits_pkce_when_disabled(
hass: HomeAssistant,
):
"""Authorization URL should omit PKCE params when compatibility mode disables PKCE."""
client = make_client(hass, features={"disable_rfc7636": True})
with patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(
return_value={"authorization_endpoint": "https://issuer/authorize"}
),
):
url = await client.async_get_authorization_url(
"https://example.com/callback", "state-xyz"
)
assert url is not None
parsed = urlparse(url)
query = parse_qs(parsed.query)
assert query["state"] == ["state-xyz"]
assert "nonce" in query
assert "code_challenge" not in query
assert "code_challenge_method" not in query
@pytest.mark.asyncio
async def test_parse_id_token_returns_none_when_kid_missing(hass: HomeAssistant):
"""ID token without kid should be rejected."""
client = make_client(hass)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt({"alg": "RS256"})
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": []}),
):
parsed = await client._parse_id_token(token)
assert parsed is None
@pytest.mark.asyncio
async def test_parse_id_token_returns_none_when_kid_not_found(hass: HomeAssistant):
"""ID token with unknown kid should be rejected."""
client = make_client(hass)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt({"alg": "RS256", "kid": "missing"})
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": [{"kid": "other"}]}),
):
parsed = await client._parse_id_token(token)
assert parsed is None
@pytest.mark.asyncio
async def test_parse_id_token_rejects_hs_without_client_secret(hass: HomeAssistant):
"""HMAC-signed id_token requires client_secret and must fail otherwise."""
client = make_client(hass, id_token_signing_alg="HS256")
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt({"alg": "HS256"})
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": []}),
):
with pytest.raises(OIDCIdTokenSigningAlgorithmInvalid):
await client._parse_id_token(token)
@pytest.mark.asyncio
async def test_parse_id_token_returns_none_when_decode_fails_jose(hass: HomeAssistant):
"""Jose decode/verification failures should be handled without raising to callers."""
client = make_client(hass)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt({"alg": "RS256", "kid": "kid1"})
with (
patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": [{"kid": "kid1", "kty": "RSA"}]}),
),
patch(
"custom_components.auth_oidc.tools.oidc_client.jwk.import_key",
return_value=object(),
),
patch(
"custom_components.auth_oidc.tools.oidc_client.jwt.decode",
side_effect=joserfc_errors.JoseError("bad token"),
),
):
parsed = await client._parse_id_token(token)
assert parsed is None
@pytest.mark.asyncio
async def test_parse_id_token_rejects_wrong_signing_algorithm(hass: HomeAssistant):
"""ID token signed with unexpected alg should be rejected."""
client = make_client(hass, id_token_signing_alg="RS256")
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt({"alg": "HS256"})
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": []}),
):
with pytest.raises(OIDCIdTokenSigningAlgorithmInvalid):
await client._parse_id_token(token)
@pytest.mark.asyncio
async def test_parse_id_token_rejects_missing_header(hass: HomeAssistant):
"""ID token without protected header should be rejected."""
client = make_client(hass)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt(None)
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": []}),
):
parsed = await client._parse_id_token(token)
assert parsed is None
@pytest.mark.asyncio
async def test_parse_id_token_rejects_invalid_registered_claims(hass: HomeAssistant):
"""Invalid aud/iss/sub style claim validation should fail closed."""
hs_secret = "top-secret-value"
client = make_client(
hass,
id_token_signing_alg="HS256",
client_secret=hs_secret,
)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
now = int(time.time())
token = make_signed_hs256_jwt(
hs_secret,
{
"sub": "abc",
"aud": "wrong-audience",
"iss": "https://wrong-issuer",
"nbf": now,
"iat": now,
"exp": now + 3600,
},
)
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": []}),
):
parsed = await client._parse_id_token(token)
assert parsed is None
@pytest.mark.asyncio
@pytest.mark.parametrize("algorithm", ALL_ID_TOKEN_SIGNING_ALGORITHMS)
async def test_parse_id_token_validates_real_signed_tokens_and_decode_inputs(
hass: HomeAssistant, algorithm: str
):
"""Use real signatures and verify token/key/algorithm passed into joserfc."""
secret = "top-secret-value"
client_kwargs = {"id_token_signing_alg": algorithm}
if algorithm.startswith("HS"):
client_kwargs["client_secret"] = secret
client = make_client(hass, **client_kwargs)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
now = int(time.time())
claims = {
"sub": "subject-1",
"aud": "test-client",
"iss": "https://issuer",
"nbf": now,
"iat": now,
"exp": now + 3600,
}
token, jwks_payload = build_real_signed_token(algorithm, claims, secret)
with (
patch.object(client, "_fetch_jwks", new=AsyncMock(return_value=jwks_payload)),
patch(
"custom_components.auth_oidc.tools.oidc_client.jwt.decode",
wraps=jwt.decode,
) as decode_spy,
patch(
"custom_components.auth_oidc.tools.oidc_client.jwk.import_key",
wraps=jwk.import_key,
) as import_key_spy,
):
parsed = await client._parse_id_token(token)
assert parsed == claims
decode_spy.assert_called_once()
assert decode_spy.call_args.args[0] == token
assert decode_spy.call_args.kwargs["algorithms"] == [algorithm]
import_key_spy.assert_called()
imported_key_payload = import_key_spy.call_args.args[0]
assert imported_key_payload["alg"] == algorithm
if algorithm.startswith("HS"):
assert imported_key_payload["kty"] == "oct"
else:
assert imported_key_payload["kid"] is not None
@pytest.mark.asyncio
async def test_get_authorization_url_returns_none_when_discovery_fails(
hass: HomeAssistant,
):
"""Discovery failures should return None from authorization URL generation."""
client = make_client(hass)
with patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(side_effect=OIDCDiscoveryInvalid()),
):
url = await client.async_get_authorization_url(
"https://example.com/callback", "state-1"
)
assert url is None
@pytest.mark.asyncio
async def test_complete_token_flow_omits_code_verifier_when_pkce_disabled(
hass: HomeAssistant,
):
"""When PKCE is disabled, token request should omit code_verifier."""
client = make_client(hass, features={"disable_rfc7636": True})
client.flows["state-3"] = {"code_verifier": "verifier", "nonce": "nonce"}
with (
patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
),
patch.object(
client,
"_make_token_request",
new=AsyncMock(return_value={"id_token": "id", "access_token": "access"}),
) as make_token_request,
patch.object(
client,
"_parse_id_token",
new=AsyncMock(return_value={"sub": "abc", "nonce": "nonce"}),
),
patch.object(
client,
"parse_user_details",
new=AsyncMock(
return_value={
"sub": "abc",
"display_name": "n",
"username": "u",
"role": "system-users",
}
),
),
):
result = await client.async_complete_token_flow(
"https://example.com/callback", "code", "state-3"
)
assert result is not None
token_params = make_token_request.await_args.args[1]
assert "code_verifier" not in token_params
@pytest.mark.asyncio
async def test_http_raise_for_status_noop_on_ok_response():
"""Status helper should not raise for successful responses."""
response = MagicMock()
response.ok = True
await http_raise_for_status(response)
@pytest.mark.asyncio
async def test_http_raise_for_status_raises_http_client_error_with_body():
"""Status helper should include response body in raised exception."""
response = MagicMock()
response.ok = False
response.reason = "Bad Request"
response.status = 400
response.request_info = MagicMock()
response.history = ()
response.headers = {}
response.text = AsyncMock(return_value="problem details")
with pytest.raises(HTTPClientError) as exc_info:
await http_raise_for_status(response)
assert "400 (Bad Request)" in str(exc_info.value)
assert "problem details" in str(exc_info.value)
@pytest.mark.asyncio
async def test_get_http_session_reuses_existing_session(hass: HomeAssistant):
"""Session helper should return existing session when already created."""
client = make_client(hass)
existing_session = MagicMock()
client.http_session = existing_session
session = await client._get_http_session()
assert session is existing_session
@pytest.mark.asyncio
async def test_get_http_session_applies_tls_verify_flag(hass: HomeAssistant):
"""Session helper should pass tls_verify setting into TCP connector."""
client = make_client(hass, network={"tls_verify": False})
with (
patch(
"custom_components.auth_oidc.tools.oidc_client.aiohttp.TCPConnector",
return_value=MagicMock(),
) as tcp_connector,
patch(
"custom_components.auth_oidc.tools.oidc_client.aiohttp.ClientSession",
return_value=MagicMock(),
),
):
await client._get_http_session()
tcp_connector.assert_called_once_with(verify_ssl=False)
@pytest.mark.asyncio
async def test_get_http_session_uses_custom_ca_path(hass: HomeAssistant):
"""Session helper should create SSL context when custom CA path is configured."""
client = make_client(
hass,
network={"tls_verify": True, "tls_ca_path": "/tmp/test-ca.pem"},
)
fake_ssl_context = object()
with (
patch.object(
hass.loop,
"run_in_executor",
new=AsyncMock(return_value=fake_ssl_context),
) as run_in_executor,
patch(
"custom_components.auth_oidc.tools.oidc_client.aiohttp.TCPConnector",
return_value=MagicMock(),
) as tcp_connector,
patch(
"custom_components.auth_oidc.tools.oidc_client.aiohttp.ClientSession",
return_value=MagicMock(),
),
):
await client._get_http_session()
run_in_executor.assert_awaited_once()
tcp_connector.assert_called_once_with(verify_ssl=True, ssl=fake_ssl_context)
@pytest.mark.asyncio
async def test_make_token_request_returns_json_on_success(hass: HomeAssistant):
"""Token request helper should return JSON payload for successful responses."""
client = make_client(hass)
response = MagicMock()
response.ok = True
response.json = AsyncMock(return_value={"access_token": "token"})
context_manager = AsyncMock()
context_manager.__aenter__.return_value = response
session = MagicMock()
session.post.return_value = context_manager
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
payload = await client._make_token_request(
"https://issuer/token", {"code": "abc"}
)
assert payload == {"access_token": "token"}
@pytest.mark.asyncio
async def test_make_token_request_raises_invalid_on_non_400_http_error(
hass: HomeAssistant,
):
"""Token request helper should map upstream HTTP errors to OIDCTokenResponseInvalid."""
client = make_client(hass)
response = MagicMock()
response.ok = False
response.reason = "Server Error"
response.status = 500
response.request_info = MagicMock()
response.history = ()
response.headers = {}
response.text = AsyncMock(return_value="boom")
context_manager = AsyncMock()
context_manager.__aenter__.return_value = response
session = MagicMock()
session.post.return_value = context_manager
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
with pytest.raises(OIDCTokenResponseInvalid):
await client._make_token_request("https://issuer/token", {"code": "abc"})
@pytest.mark.asyncio
async def test_get_userinfo_returns_json_on_success(hass: HomeAssistant):
"""Userinfo helper should return JSON payload for successful responses."""
client = make_client(hass)
response = MagicMock()
response.ok = True
response.json = AsyncMock(return_value={"sub": "abc"})
context_manager = AsyncMock()
context_manager.__aenter__.return_value = response
session = MagicMock()
session.get.return_value = context_manager
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
payload = await client._get_userinfo("https://issuer/userinfo", "access")
assert payload == {"sub": "abc"}
@pytest.mark.asyncio
async def test_get_userinfo_raises_invalid_on_http_error(hass: HomeAssistant):
"""Userinfo helper should map upstream HTTP errors to OIDCUserinfoInvalid."""
client = make_client(hass)
response = MagicMock()
response.ok = False
response.reason = "Unavailable"
response.status = 503
response.request_info = MagicMock()
response.history = ()
response.headers = {}
response.text = AsyncMock(return_value="oops")
context_manager = AsyncMock()
context_manager.__aenter__.return_value = response
session = MagicMock()
session.get.return_value = context_manager
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
with pytest.raises(OIDCUserinfoInvalid):
await client._get_userinfo("https://issuer/userinfo", "access")