* Try a different method to set ?storeToken * Formatting * Only insert storeToken on web client & fix tests
819 lines
25 KiB
Python
819 lines
25 KiB
Python
"""Unit tests for OIDC client token and security behavior."""
|
|
|
|
# pylint: disable=protected-access
|
|
|
|
import hashlib
|
|
import json
|
|
import base64
|
|
import time
|
|
from urllib.parse import parse_qs, urlparse
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from homeassistant.core import HomeAssistant
|
|
from joserfc import errors as joserfc_errors, jwt, jwk
|
|
|
|
from custom_components.auth_oidc.tools.oidc_client import (
|
|
HTTPClientError,
|
|
OIDCClient,
|
|
OIDCDiscoveryInvalid,
|
|
OIDCIdTokenSigningAlgorithmInvalid,
|
|
OIDCTokenResponseInvalid,
|
|
OIDCUserinfoInvalid,
|
|
http_raise_for_status,
|
|
)
|
|
|
|
# List from https://jose.authlib.org/en/guide/algorithms/#json-web-signature
|
|
ALL_ID_TOKEN_SIGNING_ALGORITHMS = (
|
|
"HS256",
|
|
"HS384",
|
|
"HS512",
|
|
"RS256",
|
|
"RS384",
|
|
"RS512",
|
|
"ES256",
|
|
"ES384",
|
|
"ES512",
|
|
"PS256",
|
|
"PS384",
|
|
"PS512",
|
|
"ES256K",
|
|
"Ed25519",
|
|
"Ed448",
|
|
)
|
|
|
|
|
|
def make_client(hass: HomeAssistant, **kwargs) -> OIDCClient:
|
|
"""Build an OIDC client with explicit defaults for unit testing."""
|
|
return OIDCClient(
|
|
hass=hass,
|
|
discovery_url="https://issuer/.well-known/openid-configuration",
|
|
client_id="test-client",
|
|
scope="openid profile",
|
|
features=kwargs.pop("features", {}),
|
|
claims=kwargs.pop("claims", {}),
|
|
roles=kwargs.pop("roles", {}),
|
|
network=kwargs.pop("network", {}),
|
|
**kwargs,
|
|
)
|
|
|
|
|
|
def make_jwt(
|
|
header: dict | None,
|
|
payload: dict | None = None,
|
|
signature: str = "sig",
|
|
) -> str:
|
|
"""Build a compact JWT string for parser-focused tests."""
|
|
|
|
def _b64url_json(data: dict) -> str:
|
|
encoded = json.dumps(data, separators=(",", ":")).encode("utf-8")
|
|
return base64.urlsafe_b64encode(encoded).rstrip(b"=").decode("utf-8")
|
|
|
|
protected = _b64url_json(header) if header is not None else ""
|
|
claims = _b64url_json(payload or {"sub": "subject"})
|
|
return f"{protected}.{claims}.{signature}"
|
|
|
|
|
|
def make_signed_hs256_jwt(secret: str, claims: dict) -> str:
|
|
"""Build a real HS256 signed JWT for parser validation tests."""
|
|
jwk_obj = jwk.import_key(
|
|
{
|
|
"kty": "oct",
|
|
"k": base64.urlsafe_b64encode(secret.encode()).decode().rstrip("="),
|
|
"alg": "HS256",
|
|
}
|
|
)
|
|
return jwt.encode({"alg": "HS256"}, claims, jwk_obj)
|
|
|
|
|
|
def build_real_signed_token(
|
|
algorithm: str, claims: dict, secret: str
|
|
) -> tuple[str, dict]:
|
|
"""Build a real signed token and matching JWKS payload for a given algorithm."""
|
|
if algorithm.startswith("HS"):
|
|
signing_key = jwk.import_key(
|
|
{
|
|
"kty": "oct",
|
|
"k": base64.urlsafe_b64encode(secret.encode()).decode().rstrip("="),
|
|
"alg": algorithm,
|
|
}
|
|
)
|
|
token = jwt.encode(
|
|
{"alg": algorithm}, claims, signing_key, algorithms=[algorithm]
|
|
)
|
|
return token, {"keys": []}
|
|
|
|
if algorithm in ("RS256", "RS384", "RS512", "PS256", "PS384", "PS512"):
|
|
key = jwk.generate_key(
|
|
"RSA", 2048, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True
|
|
)
|
|
elif algorithm in ("ES256", "ES384", "ES512", "ES256K"):
|
|
curve = {
|
|
"ES256": "P-256",
|
|
"ES384": "P-384",
|
|
"ES512": "P-521",
|
|
"ES256K": "secp256k1",
|
|
}[algorithm]
|
|
key = jwk.generate_key(
|
|
"EC", curve, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True
|
|
)
|
|
elif algorithm in ("Ed25519", "Ed448"):
|
|
key = jwk.generate_key(
|
|
"OKP",
|
|
algorithm,
|
|
{"alg": algorithm, "use": "sig"},
|
|
private=True,
|
|
auto_kid=True,
|
|
)
|
|
else:
|
|
raise ValueError(f"Unsupported test algorithm: {algorithm}")
|
|
|
|
kid = key.kid
|
|
token = jwt.encode(
|
|
{"alg": algorithm, "kid": kid},
|
|
claims,
|
|
key,
|
|
algorithms=[algorithm],
|
|
)
|
|
public_key = key.as_dict(private=False)
|
|
return token, {"keys": [public_key]}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_token_flow_rejects_missing_state(hass: HomeAssistant):
|
|
"""Flow state must exist; missing state should fail closed."""
|
|
client = make_client(hass)
|
|
|
|
result = await client.async_complete_token_flow(
|
|
"https://example.com/callback", "code", "missing-state"
|
|
)
|
|
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_token_flow_rejects_nonce_mismatch(hass: HomeAssistant):
|
|
"""Nonce mismatch should reject the token flow."""
|
|
client = make_client(hass)
|
|
client.flows["state-1"] = {"code_verifier": "verifier", "nonce": "expected"}
|
|
|
|
with (
|
|
patch.object(
|
|
client,
|
|
"_fetch_discovery_document",
|
|
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
|
|
),
|
|
patch.object(
|
|
client,
|
|
"_make_token_request",
|
|
new=AsyncMock(return_value={"id_token": "id", "access_token": "access"}),
|
|
),
|
|
patch.object(
|
|
client,
|
|
"_parse_id_token",
|
|
new=AsyncMock(return_value={"sub": "abc", "nonce": "wrong"}),
|
|
),
|
|
):
|
|
result = await client.async_complete_token_flow(
|
|
"https://example.com/callback", "code", "state-1"
|
|
)
|
|
|
|
assert result is None
|
|
assert "state-1" not in client.flows
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_token_flow_handles_token_request_failure(hass: HomeAssistant):
|
|
"""Token endpoint failures should return None to caller."""
|
|
client = make_client(hass)
|
|
client.flows["state-2"] = {"code_verifier": "verifier", "nonce": "nonce"}
|
|
|
|
with (
|
|
patch.object(
|
|
client,
|
|
"_fetch_discovery_document",
|
|
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
|
|
),
|
|
patch.object(
|
|
client,
|
|
"_make_token_request",
|
|
new=AsyncMock(side_effect=OIDCTokenResponseInvalid()),
|
|
),
|
|
):
|
|
result = await client.async_complete_token_flow(
|
|
"https://example.com/callback", "code", "state-2"
|
|
)
|
|
|
|
assert result is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_user_details_handles_non_list_groups(hass: HomeAssistant):
|
|
"""Non-list groups should not accidentally grant roles."""
|
|
client = make_client(hass, roles={"user": "users", "admin": "admins"})
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_discovery_document",
|
|
new=AsyncMock(return_value={"issuer": "https://issuer"}),
|
|
):
|
|
details = await client.parse_user_details(
|
|
{
|
|
"sub": "subject",
|
|
"name": "Display Name",
|
|
"preferred_username": "username",
|
|
"groups": "admins",
|
|
},
|
|
"access-token",
|
|
)
|
|
|
|
assert details["role"] == "invalid"
|
|
assert details["display_name"] == "Display Name"
|
|
assert details["username"] == "username"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_user_details_uses_userinfo_for_missing_claims(
|
|
hass: HomeAssistant,
|
|
):
|
|
"""Missing claims in id_token should be filled from userinfo when available."""
|
|
client = make_client(hass)
|
|
|
|
with (
|
|
patch.object(
|
|
client,
|
|
"_fetch_discovery_document",
|
|
new=AsyncMock(
|
|
return_value={
|
|
"issuer": "https://issuer",
|
|
"userinfo_endpoint": "https://issuer/userinfo",
|
|
}
|
|
),
|
|
),
|
|
patch.object(
|
|
client,
|
|
"_get_userinfo",
|
|
new=AsyncMock(
|
|
return_value={
|
|
"name": "From UserInfo",
|
|
"preferred_username": "userinfo-user",
|
|
"groups": ["admins"],
|
|
}
|
|
),
|
|
),
|
|
):
|
|
details = await client.parse_user_details({"sub": "subject"}, "access-token")
|
|
|
|
expected_sub = hashlib.sha256("https://issuer.subject".encode("utf-8")).hexdigest()
|
|
assert details["sub"] == expected_sub
|
|
assert details["display_name"] == "From UserInfo"
|
|
assert details["username"] == "userinfo-user"
|
|
assert details["role"] == "system-admin"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_user_details_assigns_system_users_role(hass: HomeAssistant):
|
|
"""Configured user role should map to system-users when group is present."""
|
|
client = make_client(hass, roles={"user": "users", "admin": "admins"})
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_discovery_document",
|
|
new=AsyncMock(return_value={"issuer": "https://issuer"}),
|
|
):
|
|
details = await client.parse_user_details(
|
|
{
|
|
"sub": "subject",
|
|
"name": "Display Name",
|
|
"preferred_username": "username",
|
|
"groups": ["users"],
|
|
},
|
|
"access-token",
|
|
)
|
|
|
|
assert details["role"] == "system-users"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_user_details_admin_role_overrides_user_role(
|
|
hass: HomeAssistant,
|
|
):
|
|
"""Admin group should take precedence when both user and admin groups are present."""
|
|
client = make_client(hass, roles={"user": "users", "admin": "admins"})
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_discovery_document",
|
|
new=AsyncMock(return_value={"issuer": "https://issuer"}),
|
|
):
|
|
details = await client.parse_user_details(
|
|
{
|
|
"sub": "subject",
|
|
"name": "Display Name",
|
|
"preferred_username": "username",
|
|
"groups": ["users", "admins"],
|
|
},
|
|
"access-token",
|
|
)
|
|
|
|
assert details["role"] == "system-admin"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_authorization_url_omits_pkce_when_disabled(
|
|
hass: HomeAssistant,
|
|
):
|
|
"""Authorization URL should omit PKCE params when compatibility mode disables PKCE."""
|
|
client = make_client(hass, features={"disable_rfc7636": True})
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_discovery_document",
|
|
new=AsyncMock(
|
|
return_value={"authorization_endpoint": "https://issuer/authorize"}
|
|
),
|
|
):
|
|
url = await client.async_get_authorization_url(
|
|
"https://example.com/callback", "state-xyz"
|
|
)
|
|
|
|
assert url is not None
|
|
parsed = urlparse(url)
|
|
query = parse_qs(parsed.query)
|
|
|
|
assert query["state"] == ["state-xyz"]
|
|
assert "nonce" in query
|
|
assert "code_challenge" not in query
|
|
assert "code_challenge_method" not in query
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_id_token_returns_none_when_kid_missing(hass: HomeAssistant):
|
|
"""ID token without kid should be rejected."""
|
|
client = make_client(hass)
|
|
client.discovery_document = {
|
|
"issuer": "https://issuer",
|
|
"jwks_uri": "https://issuer/jwks",
|
|
}
|
|
|
|
token = make_jwt({"alg": "RS256"})
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_jwks",
|
|
new=AsyncMock(return_value={"keys": []}),
|
|
):
|
|
parsed = await client._parse_id_token(token)
|
|
|
|
assert parsed is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_id_token_returns_none_when_kid_not_found(hass: HomeAssistant):
|
|
"""ID token with unknown kid should be rejected."""
|
|
client = make_client(hass)
|
|
client.discovery_document = {
|
|
"issuer": "https://issuer",
|
|
"jwks_uri": "https://issuer/jwks",
|
|
}
|
|
|
|
token = make_jwt({"alg": "RS256", "kid": "missing"})
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_jwks",
|
|
new=AsyncMock(return_value={"keys": [{"kid": "other"}]}),
|
|
):
|
|
parsed = await client._parse_id_token(token)
|
|
|
|
assert parsed is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_id_token_rejects_hs_without_client_secret(hass: HomeAssistant):
|
|
"""HMAC-signed id_token requires client_secret and must fail otherwise."""
|
|
client = make_client(hass, id_token_signing_alg="HS256")
|
|
client.discovery_document = {
|
|
"issuer": "https://issuer",
|
|
"jwks_uri": "https://issuer/jwks",
|
|
}
|
|
|
|
token = make_jwt({"alg": "HS256"})
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_jwks",
|
|
new=AsyncMock(return_value={"keys": []}),
|
|
):
|
|
with pytest.raises(OIDCIdTokenSigningAlgorithmInvalid):
|
|
await client._parse_id_token(token)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_id_token_returns_none_when_decode_fails_jose(hass: HomeAssistant):
|
|
"""Jose decode/verification failures should be handled without raising to callers."""
|
|
client = make_client(hass)
|
|
client.discovery_document = {
|
|
"issuer": "https://issuer",
|
|
"jwks_uri": "https://issuer/jwks",
|
|
}
|
|
|
|
token = make_jwt({"alg": "RS256", "kid": "kid1"})
|
|
|
|
with (
|
|
patch.object(
|
|
client,
|
|
"_fetch_jwks",
|
|
new=AsyncMock(return_value={"keys": [{"kid": "kid1", "kty": "RSA"}]}),
|
|
),
|
|
patch(
|
|
"custom_components.auth_oidc.tools.oidc_client.jwk.import_key",
|
|
return_value=object(),
|
|
),
|
|
patch(
|
|
"custom_components.auth_oidc.tools.oidc_client.jwt.decode",
|
|
side_effect=joserfc_errors.JoseError("bad token"),
|
|
),
|
|
):
|
|
parsed = await client._parse_id_token(token)
|
|
|
|
assert parsed is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_id_token_rejects_wrong_signing_algorithm(hass: HomeAssistant):
|
|
"""ID token signed with unexpected alg should be rejected."""
|
|
client = make_client(hass, id_token_signing_alg="RS256")
|
|
client.discovery_document = {
|
|
"issuer": "https://issuer",
|
|
"jwks_uri": "https://issuer/jwks",
|
|
}
|
|
|
|
token = make_jwt({"alg": "HS256"})
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_jwks",
|
|
new=AsyncMock(return_value={"keys": []}),
|
|
):
|
|
with pytest.raises(OIDCIdTokenSigningAlgorithmInvalid):
|
|
await client._parse_id_token(token)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_id_token_rejects_missing_header(hass: HomeAssistant):
|
|
"""ID token without protected header should be rejected."""
|
|
client = make_client(hass)
|
|
client.discovery_document = {
|
|
"issuer": "https://issuer",
|
|
"jwks_uri": "https://issuer/jwks",
|
|
}
|
|
|
|
token = make_jwt(None)
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_jwks",
|
|
new=AsyncMock(return_value={"keys": []}),
|
|
):
|
|
parsed = await client._parse_id_token(token)
|
|
|
|
assert parsed is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_parse_id_token_rejects_invalid_registered_claims(hass: HomeAssistant):
|
|
"""Invalid aud/iss/sub style claim validation should fail closed."""
|
|
hs_secret = "top-secret-value"
|
|
|
|
client = make_client(
|
|
hass,
|
|
id_token_signing_alg="HS256",
|
|
client_secret=hs_secret,
|
|
)
|
|
client.discovery_document = {
|
|
"issuer": "https://issuer",
|
|
"jwks_uri": "https://issuer/jwks",
|
|
}
|
|
|
|
now = int(time.time())
|
|
token = make_signed_hs256_jwt(
|
|
hs_secret,
|
|
{
|
|
"sub": "abc",
|
|
"aud": "wrong-audience",
|
|
"iss": "https://wrong-issuer",
|
|
"nbf": now,
|
|
"iat": now,
|
|
"exp": now + 3600,
|
|
},
|
|
)
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_jwks",
|
|
new=AsyncMock(return_value={"keys": []}),
|
|
):
|
|
parsed = await client._parse_id_token(token)
|
|
|
|
assert parsed is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
@pytest.mark.parametrize("algorithm", ALL_ID_TOKEN_SIGNING_ALGORITHMS)
|
|
async def test_parse_id_token_validates_real_signed_tokens_and_decode_inputs(
|
|
hass: HomeAssistant, algorithm: str
|
|
):
|
|
"""Use real signatures and verify token/key/algorithm passed into joserfc."""
|
|
secret = "top-secret-value"
|
|
client_kwargs = {"id_token_signing_alg": algorithm}
|
|
if algorithm.startswith("HS"):
|
|
client_kwargs["client_secret"] = secret
|
|
|
|
client = make_client(hass, **client_kwargs)
|
|
client.discovery_document = {
|
|
"issuer": "https://issuer",
|
|
"jwks_uri": "https://issuer/jwks",
|
|
}
|
|
|
|
now = int(time.time())
|
|
claims = {
|
|
"sub": "subject-1",
|
|
"aud": "test-client",
|
|
"iss": "https://issuer",
|
|
"nbf": now,
|
|
"iat": now,
|
|
"exp": now + 3600,
|
|
}
|
|
|
|
token, jwks_payload = build_real_signed_token(algorithm, claims, secret)
|
|
|
|
with (
|
|
patch.object(client, "_fetch_jwks", new=AsyncMock(return_value=jwks_payload)),
|
|
patch(
|
|
"custom_components.auth_oidc.tools.oidc_client.jwt.decode",
|
|
wraps=jwt.decode,
|
|
) as decode_spy,
|
|
patch(
|
|
"custom_components.auth_oidc.tools.oidc_client.jwk.import_key",
|
|
wraps=jwk.import_key,
|
|
) as import_key_spy,
|
|
):
|
|
parsed = await client._parse_id_token(token)
|
|
|
|
assert parsed == claims
|
|
decode_spy.assert_called_once()
|
|
assert decode_spy.call_args.args[0] == token
|
|
assert decode_spy.call_args.kwargs["algorithms"] == [algorithm]
|
|
|
|
import_key_spy.assert_called()
|
|
imported_key_payload = import_key_spy.call_args.args[0]
|
|
assert imported_key_payload["alg"] == algorithm
|
|
if algorithm.startswith("HS"):
|
|
assert imported_key_payload["kty"] == "oct"
|
|
else:
|
|
assert imported_key_payload["kid"] is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_authorization_url_returns_none_when_discovery_fails(
|
|
hass: HomeAssistant,
|
|
):
|
|
"""Discovery failures should return None from authorization URL generation."""
|
|
client = make_client(hass)
|
|
|
|
with patch.object(
|
|
client,
|
|
"_fetch_discovery_document",
|
|
new=AsyncMock(side_effect=OIDCDiscoveryInvalid()),
|
|
):
|
|
url = await client.async_get_authorization_url(
|
|
"https://example.com/callback", "state-1"
|
|
)
|
|
|
|
assert url is None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_token_flow_omits_code_verifier_when_pkce_disabled(
|
|
hass: HomeAssistant,
|
|
):
|
|
"""When PKCE is disabled, token request should omit code_verifier."""
|
|
client = make_client(hass, features={"disable_rfc7636": True})
|
|
client.flows["state-3"] = {"code_verifier": "verifier", "nonce": "nonce"}
|
|
|
|
with (
|
|
patch.object(
|
|
client,
|
|
"_fetch_discovery_document",
|
|
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
|
|
),
|
|
patch.object(
|
|
client,
|
|
"_make_token_request",
|
|
new=AsyncMock(return_value={"id_token": "id", "access_token": "access"}),
|
|
) as make_token_request,
|
|
patch.object(
|
|
client,
|
|
"_parse_id_token",
|
|
new=AsyncMock(return_value={"sub": "abc", "nonce": "nonce"}),
|
|
),
|
|
patch.object(
|
|
client,
|
|
"parse_user_details",
|
|
new=AsyncMock(
|
|
return_value={
|
|
"sub": "abc",
|
|
"display_name": "n",
|
|
"username": "u",
|
|
"role": "system-users",
|
|
}
|
|
),
|
|
),
|
|
):
|
|
result = await client.async_complete_token_flow(
|
|
"https://example.com/callback", "code", "state-3"
|
|
)
|
|
|
|
assert result is not None
|
|
token_params = make_token_request.await_args.args[1]
|
|
assert "code_verifier" not in token_params
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_raise_for_status_noop_on_ok_response():
|
|
"""Status helper should not raise for successful responses."""
|
|
response = MagicMock()
|
|
response.ok = True
|
|
|
|
await http_raise_for_status(response)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_http_raise_for_status_raises_http_client_error_with_body():
|
|
"""Status helper should include response body in raised exception."""
|
|
response = MagicMock()
|
|
response.ok = False
|
|
response.reason = "Bad Request"
|
|
response.status = 400
|
|
response.request_info = MagicMock()
|
|
response.history = ()
|
|
response.headers = {}
|
|
response.text = AsyncMock(return_value="problem details")
|
|
|
|
with pytest.raises(HTTPClientError) as exc_info:
|
|
await http_raise_for_status(response)
|
|
|
|
assert "400 (Bad Request)" in str(exc_info.value)
|
|
assert "problem details" in str(exc_info.value)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_http_session_reuses_existing_session(hass: HomeAssistant):
|
|
"""Session helper should return existing session when already created."""
|
|
client = make_client(hass)
|
|
existing_session = MagicMock()
|
|
client.http_session = existing_session
|
|
|
|
session = await client._get_http_session()
|
|
|
|
assert session is existing_session
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_http_session_applies_tls_verify_flag(hass: HomeAssistant):
|
|
"""Session helper should pass tls_verify setting into TCP connector."""
|
|
client = make_client(hass, network={"tls_verify": False})
|
|
|
|
with (
|
|
patch(
|
|
"custom_components.auth_oidc.tools.oidc_client.aiohttp.TCPConnector",
|
|
return_value=MagicMock(),
|
|
) as tcp_connector,
|
|
patch(
|
|
"custom_components.auth_oidc.tools.oidc_client.aiohttp.ClientSession",
|
|
return_value=MagicMock(),
|
|
),
|
|
):
|
|
await client._get_http_session()
|
|
|
|
tcp_connector.assert_called_once_with(verify_ssl=False)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_http_session_uses_custom_ca_path(hass: HomeAssistant):
|
|
"""Session helper should create SSL context when custom CA path is configured."""
|
|
client = make_client(
|
|
hass,
|
|
network={"tls_verify": True, "tls_ca_path": "/tmp/test-ca.pem"},
|
|
)
|
|
fake_ssl_context = object()
|
|
|
|
with (
|
|
patch.object(
|
|
hass.loop,
|
|
"run_in_executor",
|
|
new=AsyncMock(return_value=fake_ssl_context),
|
|
) as run_in_executor,
|
|
patch(
|
|
"custom_components.auth_oidc.tools.oidc_client.aiohttp.TCPConnector",
|
|
return_value=MagicMock(),
|
|
) as tcp_connector,
|
|
patch(
|
|
"custom_components.auth_oidc.tools.oidc_client.aiohttp.ClientSession",
|
|
return_value=MagicMock(),
|
|
),
|
|
):
|
|
await client._get_http_session()
|
|
|
|
run_in_executor.assert_awaited_once()
|
|
tcp_connector.assert_called_once_with(verify_ssl=True, ssl=fake_ssl_context)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_make_token_request_returns_json_on_success(hass: HomeAssistant):
|
|
"""Token request helper should return JSON payload for successful responses."""
|
|
client = make_client(hass)
|
|
response = MagicMock()
|
|
response.ok = True
|
|
response.json = AsyncMock(return_value={"access_token": "token"})
|
|
|
|
context_manager = AsyncMock()
|
|
context_manager.__aenter__.return_value = response
|
|
session = MagicMock()
|
|
session.post.return_value = context_manager
|
|
|
|
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
|
|
payload = await client._make_token_request(
|
|
"https://issuer/token", {"code": "abc"}
|
|
)
|
|
|
|
assert payload == {"access_token": "token"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_make_token_request_raises_invalid_on_non_400_http_error(
|
|
hass: HomeAssistant,
|
|
):
|
|
"""Token request helper should map upstream HTTP errors to OIDCTokenResponseInvalid."""
|
|
client = make_client(hass)
|
|
response = MagicMock()
|
|
response.ok = False
|
|
response.reason = "Server Error"
|
|
response.status = 500
|
|
response.request_info = MagicMock()
|
|
response.history = ()
|
|
response.headers = {}
|
|
response.text = AsyncMock(return_value="boom")
|
|
|
|
context_manager = AsyncMock()
|
|
context_manager.__aenter__.return_value = response
|
|
session = MagicMock()
|
|
session.post.return_value = context_manager
|
|
|
|
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
|
|
with pytest.raises(OIDCTokenResponseInvalid):
|
|
await client._make_token_request("https://issuer/token", {"code": "abc"})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_userinfo_returns_json_on_success(hass: HomeAssistant):
|
|
"""Userinfo helper should return JSON payload for successful responses."""
|
|
client = make_client(hass)
|
|
response = MagicMock()
|
|
response.ok = True
|
|
response.json = AsyncMock(return_value={"sub": "abc"})
|
|
|
|
context_manager = AsyncMock()
|
|
context_manager.__aenter__.return_value = response
|
|
session = MagicMock()
|
|
session.get.return_value = context_manager
|
|
|
|
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
|
|
payload = await client._get_userinfo("https://issuer/userinfo", "access")
|
|
|
|
assert payload == {"sub": "abc"}
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_userinfo_raises_invalid_on_http_error(hass: HomeAssistant):
|
|
"""Userinfo helper should map upstream HTTP errors to OIDCUserinfoInvalid."""
|
|
client = make_client(hass)
|
|
response = MagicMock()
|
|
response.ok = False
|
|
response.reason = "Unavailable"
|
|
response.status = 503
|
|
response.request_info = MagicMock()
|
|
response.history = ()
|
|
response.headers = {}
|
|
response.text = AsyncMock(return_value="oops")
|
|
|
|
context_manager = AsyncMock()
|
|
context_manager.__aenter__.return_value = response
|
|
session = MagicMock()
|
|
session.get.return_value = context_manager
|
|
|
|
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
|
|
with pytest.raises(OIDCUserinfoInvalid):
|
|
await client._get_userinfo("https://issuer/userinfo", "access")
|