Reimplement UI injection (#236)

This commit is contained in:
Christiaan Goossens
2026-04-13 22:51:31 +02:00
committed by GitHub
parent fdc93e2719
commit fd3643685d
36 changed files with 3772 additions and 1114 deletions

View File

@@ -1,90 +0,0 @@
"""Tests for the code store"""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
from homeassistant.core import HomeAssistant
import pytest
from auth_oidc.stores.code_store import CodeStore
@pytest.mark.asyncio
async def test_code_store_generate_and_receive_code(hass: HomeAssistant):
"""Test generating and receiving a code."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
code_store = CodeStore(hass)
# Simulate loading with empty data
store_mock.async_load.return_value = {}
await code_store.async_load()
assert code_store.get_data() == {}
user_info = {"sub": "user1", "name": "Test User"}
code = await code_store.async_generate_code_for_userinfo(user_info)
assert code in code_store.get_data()
# Should return user_info and remove the code
with patch("custom_components.auth_oidc.stores.code_store.datetime") as dt_mock:
dt_mock.utcnow.return_value = datetime.now(timezone.utc)
dt_mock.fromisoformat.side_effect = datetime.fromisoformat
result = await code_store.receive_userinfo_for_code(code)
assert result == user_info
assert code not in code_store.get_data()
@pytest.mark.asyncio
async def test_code_store_expired_code(hass):
"""Test that expired codes return None."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
code_store = CodeStore(hass)
store_mock.async_load.return_value = {}
await code_store.async_load()
assert code_store.get_data() == {}
user_info = {"sub": "user2", "name": "Expired User"}
code = await code_store.async_generate_code_for_userinfo(user_info)
# Patch expiration to be in the past
code_store.get_data()[code]["expiration"] = (
datetime.now(timezone.utc) - timedelta(minutes=10)
).isoformat()
with patch("custom_components.auth_oidc.stores.code_store.datetime") as dt_mock:
dt_mock.utcnow.return_value = datetime.now(timezone.utc)
dt_mock.fromisoformat.side_effect = datetime.fromisoformat
result = await code_store.receive_userinfo_for_code(code)
assert result is None
assert code not in code_store.get_data()
@pytest.mark.asyncio
async def test_code_store_data_not_loaded(hass):
"""Test that using the store before loading raises RuntimeError."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
code_store = CodeStore(hass)
# Data is not loaded yet, should result in RuntimeError
with pytest.raises(RuntimeError):
await code_store.async_generate_code_for_userinfo({"sub": "user3"})
with pytest.raises(RuntimeError):
await code_store.receive_userinfo_for_code("123456")
@pytest.mark.asyncio
async def test_code_store_generate_code_length(hass):
"""Test that generated codes are 6 digits."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
code_store = CodeStore(hass)
store_mock.async_load.return_value = {}
await code_store.async_load()
assert code_store.get_data() == {}
user_info = {"sub": "user4"}
code = await code_store.async_generate_code_for_userinfo(user_info)
assert len(code) == 6
assert code.isdigit()

View File

@@ -1,6 +1,10 @@
"""Tests for the Auth Provider registration in HA"""
from urllib.parse import urlparse, parse_qs
import base64
import re
from types import SimpleNamespace
from urllib.parse import parse_qs, unquote, urlparse
from unittest.mock import patch
import pytest
from homeassistant.core import HomeAssistant
@@ -19,6 +23,8 @@ from custom_components.auth_oidc.config.const import (
)
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
FAKE_REDIR_URL = "http://example.com/auth/authorize?response_type=code&redirect_uri=http%3A%2F%2Fexample.com%3A8123%2F%3Fauth_callback%3D1&client_id=http%3A%2F%2Fexample.com%3A8123%2F&state=example"
async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool:
"""Set up the auth_oidc component."""
@@ -45,23 +51,63 @@ async def test_setup_success_auth_provider_registration(hass: HomeAssistant):
auth_providers = hass.auth.get_auth_providers(DOMAIN)
assert len(auth_providers) == 1
# Public auth-provider contract: OIDC provider does not support HA MFA
assert auth_providers[0].support_mfa is False
async def login_user(hass: HomeAssistant, code: str):
"""Helper to login a user."""
@pytest.mark.asyncio
async def test_provider_ip_fallback_fails_closed_without_request_context(
hass: HomeAssistant,
):
"""Provider should not invent a shared IP when request context is missing."""
await setup(
hass,
{
CLIENT_ID: "dummy",
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
},
True,
)
provider = hass.auth.get_auth_providers(DOMAIN)[0]
flow = await provider.async_login_flow({})
result = await flow.async_step_init({"code": code})
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["data"] is not None
with patch(
"custom_components.auth_oidc.provider.http.current_request"
) as current_request:
current_request.get.return_value = None
assert provider._resolve_ip() is None
data = result["data"]
sub = data["sub"]
@pytest.mark.asyncio
async def test_provider_cookie_header_sets_secure_when_requested(hass: HomeAssistant):
"""Cookie header should include Secure when HTTPS is in use."""
await setup(
hass,
{
CLIENT_ID: "dummy",
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
},
True,
)
provider = hass.auth.get_auth_providers(DOMAIN)[0]
cookie_header = provider.get_cookie_header("state-id", secure=True)["set-cookie"]
assert "SameSite=Strict" in cookie_header
assert "HttpOnly" in cookie_header
assert "Secure" in cookie_header
async def login_user(hass: HomeAssistant, state_id: str):
"""Helper to login a user from the stored OIDC state."""
provider = hass.auth.get_auth_providers(DOMAIN)[0]
# This helper runs outside an HTTP request, so pass the known local test IP.
sub = await provider.async_get_subject(state_id, "127.0.0.1")
assert sub == MockOIDCServer.get_final_subject()
# Get credentials
credentials = await provider.async_get_or_create_credentials(data)
credentials = await provider.async_get_or_create_credentials({"sub": sub})
assert credentials is not None
assert credentials.data["sub"] == sub
@@ -70,36 +116,49 @@ async def login_user(hass: HomeAssistant, code: str):
return user
async def get_login_code(hass: HomeAssistant, hass_client):
"""Helper to get a login code."""
async def get_login_state(hass: HomeAssistant, hass_client):
"""Helper to complete the browser login flow and return the OIDC state id."""
client = await hass_client()
redirect_uri = FAKE_REDIR_URL
encoded_redirect_uri = base64.b64encode(redirect_uri.encode("utf-8")).decode(
"utf-8"
)
resp = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded_redirect_uri}",
allow_redirects=False,
)
assert resp.status == 200
state_id = resp.cookies["auth_oidc_state"].value
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
assert resp.status == 302
location = resp.headers["Location"]
parsed_url = urlparse(location)
assert resp.status == 200
html = await resp.text()
match = re.search(r'decodeURIComponent\("([^"]+)"\)', html)
assert match is not None
auth_url = unquote(match.group(1))
parsed_url = urlparse(auth_url)
query_params = parse_qs(parsed_url.query)
state = query_params["state"][0]
assert query_params["state"][0] == state_id
session = async_get_clientsession(hass)
resp = session.get(location, allow_redirects=False)
resp = session.get(auth_url, allow_redirects=False)
assert resp.status == 200
# Mock OIDC returns JSON
json_parsed = await resp.json()
assert "code" in json_parsed and json_parsed["code"]
code = json_parsed["code"]
client = await hass_client()
resp = await client.get(
f"/auth/oidc/callback?code={code}&state={state}", allow_redirects=False
f"/auth/oidc/callback?code={code}&state={state_id}", allow_redirects=False
)
assert resp.status == 302
location = resp.headers["Location"]
assert "/auth/oidc/finish?code=" in location
assert resp.headers["Location"].endswith("/auth/oidc/finish")
# Get the code from the finish URL
code = location.split("code=")[1]
return code
return state_id
@pytest.mark.asyncio
@@ -120,16 +179,16 @@ async def test_full_login(hass: HomeAssistant, hass_client):
with mock_oidc_responses():
# Actually start the login and get a code
code = await get_login_code(hass, hass_client)
state_id = await get_login_state(hass, hass_client)
# Use the code to login directly with the registered auth provider
# Use the stored state to login directly with the registered auth provider
# Inspired by tests for the built-in providers
user = await login_user(hass, code)
user = await login_user(hass, state_id)
assert user.name == "Test Name"
# Login again to see if we trigger the re-use path
code2 = await get_login_code(hass, hass_client)
user2 = await login_user(hass, code2)
state_id2 = await get_login_state(hass, hass_client)
user2 = await login_user(hass, state_id2)
assert user2.id == user.id
@@ -161,10 +220,10 @@ async def test_login_with_linking(hass: HomeAssistant, hass_client):
await hass.auth.async_link_user(user, credential)
# Actually start the login and get a code
code = await get_login_code(hass, hass_client)
state_id = await get_login_state(hass, hass_client)
# Use the code to login directly with the registered auth provider
user2 = await login_user(hass, code)
# Use the stored state to login directly with the registered auth provider
user2 = await login_user(hass, state_id)
assert user2.id == user.id # Assert that the user was linked
@@ -187,8 +246,8 @@ async def test_login_with_person_create(hass: HomeAssistant, hass_client):
await async_setup_component(hass, PERSON_DOMAIN, {})
with mock_oidc_responses():
code = await get_login_code(hass, hass_client)
user = await login_user(hass, code)
state_id = await get_login_state(hass, hass_client)
user = await login_user(hass, state_id)
assert user.is_active
# Find the person associated to this user using the PersonRegistry API
@@ -200,6 +259,36 @@ async def test_login_with_person_create(hass: HomeAssistant, hass_client):
assert person["user_id"] == user.id
@pytest.mark.asyncio
async def test_login_without_person_create_does_not_create_person(
hass: HomeAssistant, hass_client
):
"""Test that person creation can be disabled."""
await setup(
hass,
{
CLIENT_ID: "dummy",
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
FEATURES: {
FEATURES_AUTOMATIC_PERSON_CREATION: False,
FEATURES_AUTOMATIC_USER_LINKING: False,
},
},
True,
)
await async_setup_component(hass, PERSON_DOMAIN, {})
with mock_oidc_responses():
state_id = await get_login_state(hass, hass_client)
user = await login_user(hass, state_id)
assert user.is_active
person_store = hass.data[PERSON_DOMAIN][1]
persons = person_store.async_items()
assert len(persons) == 0
@pytest.mark.asyncio
async def test_login_shows_form(hass: HomeAssistant):
"""Test a login"""
@@ -220,10 +309,38 @@ async def test_login_shows_form(hass: HomeAssistant):
flow = await provider.async_login_flow({})
result = await flow.async_step_init({})
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "mfa"
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "no_oidc_cookie_found"
# Attempt an invalid code
result = await flow.async_step_init({"code": "invalid"})
assert result["type"] == FlowResultType.FORM
assert result["errors"] == {"base": "invalid_auth"}
@pytest.mark.asyncio
async def test_login_with_invalid_cookie_aborts(hass: HomeAssistant):
"""A cookie that does not map to a valid state should fail closed."""
await setup(
hass,
{
CLIENT_ID: "dummy",
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
FEATURES: {
FEATURES_AUTOMATIC_PERSON_CREATION: False,
FEATURES_AUTOMATIC_USER_LINKING: False,
},
},
True,
)
provider = hass.auth.get_auth_providers(DOMAIN)[0]
flow = await provider.async_login_flow({})
fake_request = SimpleNamespace(
cookies={"auth_oidc_state": "missing-state"}, remote="127.0.0.1"
)
with patch(
"custom_components.auth_oidc.provider.http.current_request"
) as current_request:
current_request.get.return_value = fake_request
result = await flow.async_step_init({})
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "no_oidc_cookie_found"

View File

@@ -1,287 +0,0 @@
"""Tests for the OIDC client"""
from urllib.parse import urlparse, parse_qs
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from auth_oidc import DOMAIN
from auth_oidc.tools.oidc_client import OIDCDiscoveryClient, OIDCDiscoveryInvalid
from auth_oidc.config.const import (
DISCOVERY_URL,
CLIENT_ID,
)
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
EXAMPLE_CLIENT_ID = "dummyclient"
async def setup(hass: HomeAssistant):
"""Set up the integration within Home Assistant"""
mock_config = {
DOMAIN: {
CLIENT_ID: EXAMPLE_CLIENT_ID,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
}
}
result = await async_setup_component(hass, DOMAIN, mock_config)
assert result
@pytest.mark.asyncio
async def test_full_oidc_flow(hass: HomeAssistant, hass_client):
"""Test that one full OIDC flow works if OIDC is mocked."""
await setup(hass)
with mock_oidc_responses():
# Start by going to /auth/oidc/redirect
client = await hass_client()
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
assert resp.status == 302
assert resp.headers["Location"].startswith(MockOIDCServer.get_authorize_url())
# Parse the location header and test the query params for correctness
location = resp.headers["Location"]
parsed_url = urlparse(location)
query_params = parse_qs(parsed_url.query)
assert "response_type" in query_params and query_params.get(
"response_type"
) == ["code"]
assert "client_id" in query_params and query_params.get("client_id") == [
EXAMPLE_CLIENT_ID
]
assert "scope" in query_params and query_params.get("scope") == [
"openid profile groups"
]
assert "state" in query_params and query_params["state"]
state = query_params["state"][0]
assert len(state) >= 16 # Ensure state is sufficiently long
assert (
"redirect_uri" in query_params
and query_params["redirect_uri"]
and query_params["redirect_uri"][0].endswith("/auth/oidc/callback")
) # TODO: Also test that the URL itself is correct
assert "nonce" in query_params and query_params["nonce"]
assert "code_challenge_method" in query_params and query_params.get(
"code_challenge_method"
) == ["S256"]
assert "code_challenge" in query_params and query_params["code_challenge"]
session = async_get_clientsession(hass)
resp = session.get(location, allow_redirects=False)
assert resp.status == 200
json_parsed = await resp.json()
assert "code" in json_parsed and json_parsed["code"]
# Now go back to the callback with a sample code
code = json_parsed["code"]
client = await hass_client()
resp = await client.get(
f"/auth/oidc/callback?code={code}&state={state}", allow_redirects=False
)
# TODO: Test if logged text contains our login
# TODO: Test if the code actually works
assert resp.status == 302
assert "/auth/oidc/finish?code=" in resp.headers["Location"]
async def discovery_test_through_redirect(
hass_client, caplog, scenario: str, match_log_line: str
):
"""Test that discovery document retrieval fails gracefully through redirect endpoint."""
with mock_oidc_responses(scenario):
# Start by going to /auth/oidc/redirect
client = await hass_client()
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
# Find matching log line
assert match_log_line in caplog.text
# Assert that we get a 200 response with an error message
assert resp.status == 200
text = await resp.text()
assert "Integration is misconfigured, discovery could not be obtained." in text
async def direct_discovery_test(
hass: HomeAssistant,
scenario: str,
match_type: str,
match_log_line: str | None = None,
):
"""Test that discovery document retrieval fails with nice error directly."""
with mock_oidc_responses(scenario):
session = async_get_clientsession(hass)
client = OIDCDiscoveryClient(
MockOIDCServer.get_discovery_url(),
session,
{
"id_token_signing_alg": "RS256",
},
)
with pytest.raises(OIDCDiscoveryInvalid) as exc_info:
await client.fetch_discovery_document()
assert exc_info.value.type == match_type
assert exc_info.value.get_detail_string().startswith("type: " + match_type)
if match_log_line:
assert match_log_line in exc_info.value.get_detail_string()
@pytest.mark.asyncio
async def test_discovery_failures(hass: HomeAssistant, hass_client, caplog):
"""Test that discovery document retrieval fails gracefully."""
await setup(hass)
# Empty scenario
await discovery_test_through_redirect(
hass_client, caplog, "empty", "is missing required endpoint: issuer"
)
await direct_discovery_test(hass, "empty", "missing_endpoint", "endpoint: issuer")
# Missing authorization_endpoint
await discovery_test_through_redirect(
hass_client,
caplog,
"only_issuer",
"is missing required endpoint: authorization_endpoint",
)
await direct_discovery_test(
hass, "only_issuer", "missing_endpoint", "endpoint: authorization_endpoint"
)
# Missing token_endpoint
await discovery_test_through_redirect(
hass_client,
caplog,
"missing_token",
"is missing required endpoint: token_endpoint",
)
await direct_discovery_test(
hass, "missing_token", "missing_endpoint", "endpoint: token_endpoint"
)
# Missing jwks_uri
await discovery_test_through_redirect(
hass_client,
caplog,
"missing_jwks",
"is missing required endpoint: jwks_uri",
)
await direct_discovery_test(
hass, "missing_jwks", "missing_endpoint", "endpoint: jwks_uri"
)
# Invalid response_modes_supported
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_response_modes",
"does not support required 'query' response mode, only supports: ['post']",
)
await direct_discovery_test(
hass, "invalid_response_modes", "does_not_support_response_mode", "post"
)
# Invalid grant_types supported
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_grant_types",
"does not support required 'authorization_code' grant type, only supports: ['refresh_token']",
)
await direct_discovery_test(
hass, "invalid_grant_types", "does_not_support_grant_type", "refresh_token"
)
# Invalid response types
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_response_types",
"does not support required 'code' response type, only supports: ['token']",
)
await direct_discovery_test(
hass, "invalid_response_types", "does_not_support_response_type", "token"
)
# Invalid code_challenge types
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_code_challenge_types",
"does not support required 'S256' code challenge method, only supports: ['plain']",
)
await direct_discovery_test(
hass,
"invalid_code_challenge_types",
"does_not_support_required_code_challenge_method",
"plain",
)
# Invalid id_token_signing alg
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_id_token_signing_alg",
"does not have 'id_token_signing_alg_values_supported' field",
)
await direct_discovery_test(
hass, "invalid_id_token_signing_alg", "missing_id_token_signing_alg_values"
)
# Not matching id_token_signing alg
await discovery_test_through_redirect(
hass_client,
caplog,
"wrong_id_token_signing_alg",
"does not support requested id_token_signing_alg 'RS256', only supports: ['HS256']",
)
await direct_discovery_test(
hass,
"wrong_id_token_signing_alg",
"does_not_support_id_token_signing_alg",
"requested: RS256, supported: ['HS256']",
)
# Invalid URL
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_url",
"has invalid URL in endpoint: jwks_uri (/jwks)",
)
await direct_discovery_test(
hass,
"invalid_url",
"invalid_endpoint",
"endpoint: jwks_uri, url: /jwks",
)
@pytest.mark.asyncio
async def test_direct_jwks_fetch(hass: HomeAssistant):
"""Test direct fetch of JWKS."""
with mock_oidc_responses():
session = async_get_clientsession(hass)
client = OIDCDiscoveryClient(
MockOIDCServer.get_discovery_url(),
session,
{
"id_token_signing_alg": "RS256",
},
)
await client.fetch_discovery_document()
jwks = await client.fetch_jwks()
assert "keys" in jwks

View File

@@ -0,0 +1,690 @@
"""Tests for the OIDC client"""
import base64
import asyncio
import re
from unittest.mock import AsyncMock, patch
from urllib.parse import parse_qs, unquote, urlparse, urlencode
import pytest
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from custom_components.auth_oidc import DOMAIN
from custom_components.auth_oidc.tools.oidc_client import (
OIDCDiscoveryClient,
OIDCDiscoveryInvalid,
)
from custom_components.auth_oidc.config.const import (
DISCOVERY_URL,
CLIENT_ID,
)
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
EXAMPLE_CLIENT_ID = "http://example.com/"
WEB_CLIENT_ID = "https://example.com"
MOBILE_CLIENT_ID = "https://home-assistant.io/Android"
# Helper functions
def encode_redirect_uri(redirect_uri: str) -> str:
"""Helper to encode redirect URI for welcome page."""
return base64.b64encode(redirect_uri.encode("utf-8")).decode("utf-8")
def create_redirect_uri(client_id: str) -> str:
"""Create a redirect URI for Home Assistant Android app."""
params = {
"response_type": "code",
"redirect_uri": client_id,
"client_id": client_id,
"state": "example",
}
return f"http://example.com/auth/authorize?{urlencode(params)}"
async def get_welcome_for_client(client, redirect_uri: str) -> tuple[str, str, int]:
"""Go to welcome page and return state cookie, HTML content, and status.
Returns:
Tuple of (state_id, html_content, status_code)
"""
encoded_uri = encode_redirect_uri(redirect_uri)
resp = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded_uri}",
allow_redirects=False,
)
state = resp.cookies["auth_oidc_state"].value
html = await resp.text() if resp.status == 200 else ""
return state, html, resp.status
async def get_redirect_auth_url(client) -> str:
"""Go to redirect page and extract the authorization URL.
Returns:
The full authorization URL to send to the OIDC provider
"""
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
assert resp.status == 200
html = await resp.text()
match = re.search(r'decodeURIComponent\("([^"]+)"\)', html)
assert match is not None, "Authorization URL not found in redirect page"
return unquote(match.group(1))
async def complete_callback_and_finish(client, code: str, state: str):
"""Complete the callback and finish flow.
Returns:
The state_id cookie value after completion
"""
resp = await client.get(
f"/auth/oidc/callback?code={code}&state={state}",
allow_redirects=False,
)
assert resp.status == 302
assert resp.headers["Location"].endswith("/auth/oidc/finish")
resp_finish = await client.get("/auth/oidc/finish", allow_redirects=False)
assert resp_finish.status == 200
finish_html = await resp_finish.text()
assert 'id="continue-on-this-device"' in finish_html
assert 'id="device-code-input"' in finish_html
assert 'id="approve-login-button"' in finish_html
async def verify_back_redirect(client, expected_redirect_uri: str):
"""Verify that POST to finish without body redirects back to the original redirect_uri."""
resp_finish_post = await client.post("/auth/oidc/finish", allow_redirects=False)
assert resp_finish_post.status == 302
assert (
resp_finish_post.headers["Location"]
== unquote(expected_redirect_uri) + "&storeToken=true&skip_oidc_redirect=true"
)
async def listen_for_sse_events(
resp_sse,
expected_event: str,
timeout_seconds: int = 5,
) -> list[str]:
"""Listen for SSE events and return once the expected event is received.
Args:
resp_sse: The SSE response stream
expected_event: The event type to listen for (e.g., "waiting" or "ready")
timeout_seconds: Maximum time to wait for the event
Returns:
List of received event lines
"""
if resp_sse is None:
raise ValueError("resp_sse cannot be None")
received_events = []
async def stream_reader():
try:
async for line in resp_sse.content:
decoded_line = line.decode("utf-8").strip()
if not decoded_line:
continue
received_events.append(decoded_line)
# Check if this is an event line
if decoded_line.startswith("event:"):
event_type = decoded_line.split(":", 1)[1].strip()
if event_type == expected_event:
# Found the expected event, return successfully.
return True
# Device SSE may emit multiple waiting events before ready.
if expected_event == "ready" and event_type == "waiting":
continue
raise AssertionError(
f"Unexpected event type '{event_type}'. Expected: {expected_event}"
)
except asyncio.CancelledError:
pass
return False
try:
result = await asyncio.wait_for(stream_reader(), timeout=timeout_seconds)
if result:
return received_events
except asyncio.TimeoutError as exc:
raise AssertionError(
f"Timeout after {timeout_seconds}s waiting for '{expected_event}' event"
) from exc
raise AssertionError(f"Failed to receive '{expected_event}' event")
async def setup(hass: HomeAssistant):
"""Set up the integration within Home Assistant"""
mock_config = {
DOMAIN: {
CLIENT_ID: EXAMPLE_CLIENT_ID,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
}
}
result = await async_setup_component(hass, DOMAIN, mock_config)
assert result
# Actual tests
@pytest.mark.asyncio
async def test_full_oidc_flow(hass: HomeAssistant, hass_client):
"""Test that one full OIDC flow works if OIDC is mocked."""
await setup(hass)
with mock_oidc_responses():
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
# Go to welcome and get state cookie
state, _, status = await get_welcome_for_client(client, redirect_uri)
assert status == 200
assert state is not None
# Get authorization URL from redirect page
authorization_url = await get_redirect_auth_url(client)
assert authorization_url.startswith(MockOIDCServer.get_authorize_url())
# Parse the rendered redirect URL and test the query params for correctness
parsed_url = urlparse(authorization_url)
query_params = parse_qs(parsed_url.query)
assert "response_type" in query_params and query_params.get(
"response_type"
) == ["code"]
assert "client_id" in query_params and query_params.get("client_id") == [
EXAMPLE_CLIENT_ID
]
assert "scope" in query_params and query_params.get("scope") == [
"openid profile groups"
]
assert "state" in query_params and query_params["state"]
assert query_params["state"][0] == state
assert len(query_params["state"][0]) >= 16 # Ensure state is sufficiently long
assert (
"redirect_uri" in query_params
and query_params["redirect_uri"]
and query_params["redirect_uri"][0].endswith("/auth/oidc/callback")
)
assert "nonce" in query_params and query_params["nonce"]
assert "code_challenge_method" in query_params and query_params.get(
"code_challenge_method"
) == ["S256"]
assert "code_challenge" in query_params and query_params["code_challenge"]
session = async_get_clientsession(hass)
resp = session.get(authorization_url, allow_redirects=False)
assert resp.status == 200
# JSON response from mock server, normally would be interactive
json_parsed = await resp.json()
assert "code" in json_parsed and json_parsed["code"]
# Now go back to the callback with a sample code
code = json_parsed["code"]
await complete_callback_and_finish(client, code, state)
# POST to finish without any POST body should result in 302 back to the original redirect_uri
await verify_back_redirect(client, redirect_uri)
async def discovery_test_through_redirect(
hass_client, caplog, scenario: str, match_log_line: str
):
"""Test that discovery document retrieval fails gracefully through redirect endpoint."""
with mock_oidc_responses(scenario):
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
await client.get(
f"/auth/oidc/welcome?redirect_uri={encode_redirect_uri(redirect_uri)}",
allow_redirects=False,
)
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
# Find matching log line
assert match_log_line in caplog.text
# Assert that we get an error response with an error message
assert resp.status == 500
text = await resp.text()
assert "Integration is misconfigured, discovery could not be obtained." in text
async def direct_discovery_test(
hass: HomeAssistant,
scenario: str,
match_type: str,
match_log_line: str | None = None,
):
"""Test that discovery document retrieval fails with nice error directly."""
with mock_oidc_responses(scenario):
session = async_get_clientsession(hass)
client = OIDCDiscoveryClient(
MockOIDCServer.get_discovery_url(),
session,
{
"id_token_signing_alg": "RS256",
},
)
with pytest.raises(OIDCDiscoveryInvalid) as exc_info:
await client.fetch_discovery_document()
assert exc_info.value.type == match_type
assert exc_info.value.get_detail_string().startswith("type: " + match_type)
if match_log_line:
assert match_log_line in exc_info.value.get_detail_string()
@pytest.mark.asyncio
async def test_discovery_failures(hass: HomeAssistant, hass_client, caplog):
"""Test that discovery document retrieval fails gracefully."""
await setup(hass)
# Empty scenario
await discovery_test_through_redirect(
hass_client, caplog, "empty", "is missing required endpoint: issuer"
)
await direct_discovery_test(hass, "empty", "missing_endpoint", "endpoint: issuer")
# Missing authorization_endpoint
await discovery_test_through_redirect(
hass_client,
caplog,
"only_issuer",
"is missing required endpoint: authorization_endpoint",
)
await direct_discovery_test(
hass, "only_issuer", "missing_endpoint", "endpoint: authorization_endpoint"
)
# Missing token_endpoint
await discovery_test_through_redirect(
hass_client,
caplog,
"missing_token",
"is missing required endpoint: token_endpoint",
)
await direct_discovery_test(
hass, "missing_token", "missing_endpoint", "endpoint: token_endpoint"
)
# Missing jwks_uri
await discovery_test_through_redirect(
hass_client,
caplog,
"missing_jwks",
"is missing required endpoint: jwks_uri",
)
await direct_discovery_test(
hass, "missing_jwks", "missing_endpoint", "endpoint: jwks_uri"
)
# Invalid response_modes_supported
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_response_modes",
"does not support required 'query' response mode, only supports: ['post']",
)
await direct_discovery_test(
hass, "invalid_response_modes", "does_not_support_response_mode", "post"
)
# Invalid grant_types supported
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_grant_types",
"does not support required 'authorization_code' grant type, only supports: ['refresh_token']",
)
await direct_discovery_test(
hass, "invalid_grant_types", "does_not_support_grant_type", "refresh_token"
)
# Invalid response types
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_response_types",
"does not support required 'code' response type, only supports: ['token']",
)
await direct_discovery_test(
hass, "invalid_response_types", "does_not_support_response_type", "token"
)
# Invalid code_challenge types
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_code_challenge_types",
"does not support required 'S256' code challenge method, only supports: ['plain']",
)
await direct_discovery_test(
hass,
"invalid_code_challenge_types",
"does_not_support_required_code_challenge_method",
"plain",
)
# Invalid id_token_signing alg
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_id_token_signing_alg",
"does not have 'id_token_signing_alg_values_supported' field",
)
await direct_discovery_test(
hass, "invalid_id_token_signing_alg", "missing_id_token_signing_alg_values"
)
# Not matching id_token_signing alg
await discovery_test_through_redirect(
hass_client,
caplog,
"wrong_id_token_signing_alg",
"does not support requested id_token_signing_alg 'RS256', only supports: ['HS256']",
)
await direct_discovery_test(
hass,
"wrong_id_token_signing_alg",
"does_not_support_id_token_signing_alg",
"requested: RS256, supported: ['HS256']",
)
# Invalid URL
await discovery_test_through_redirect(
hass_client,
caplog,
"invalid_url",
"has invalid URL in endpoint: jwks_uri (/jwks)",
)
await direct_discovery_test(
hass,
"invalid_url",
"invalid_endpoint",
"endpoint: jwks_uri, url: /jwks",
)
@pytest.mark.asyncio
async def test_direct_jwks_fetch(hass: HomeAssistant):
"""Test direct fetch of JWKS."""
with mock_oidc_responses():
session = async_get_clientsession(hass)
client = OIDCDiscoveryClient(
MockOIDCServer.get_discovery_url(),
session,
{
"id_token_signing_alg": "RS256",
},
)
await client.fetch_discovery_document()
jwks = await client.fetch_jwks()
assert "keys" in jwks
@pytest.mark.asyncio
async def test_device_login_flow_two_browsers(hass: HomeAssistant, hass_client):
"""Test device login flow with two separate browser sessions.
This simulates:
- Mobile device (Device 1) generating a device code and waiting via SSE
- Desktop browser (Device 2) completing full OAuth flow and linking the code
- Mobile device receiving ready event after code is linked
"""
await setup(hass)
with mock_oidc_responses():
# ==================== DEVICE 1: Mobile ====================
# Mobile client starts the login flow
mobile_client = await hass_client()
mobile_redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
mobile_state, mobile_html, status = await get_welcome_for_client(
mobile_client, mobile_redirect_uri
)
assert status == 200
assert mobile_state is not None
assert 'id="device-instructions"' in mobile_html
assert 'id="device-code"' in mobile_html
# Extract device code from the welcome page.
# The code is rendered in a div with id="device-code".
device_code_match = re.search(
r'id=["\']device-code["\'][^>]*>\s*([^<\s]+)\s*<',
mobile_html,
)
assert device_code_match is not None, (
"Device code should be generated for mobile client"
)
mobile_device_code = device_code_match.group(1)
assert len(mobile_device_code) > 0
# ==================== DEVICE 2: Desktop ====================
# Desktop client in a separate session
desktop_client = await hass_client()
desktop_redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
desktop_state, _, status = await get_welcome_for_client(
desktop_client, desktop_redirect_uri
)
assert status in [200, 302]
assert desktop_state is not None
# Desktop goes through redirect to get the authorization URL
authorization_url = await get_redirect_auth_url(desktop_client)
assert authorization_url.startswith(MockOIDCServer.get_authorize_url())
# Desktop gets the authorization code from OIDC provider
session = async_get_clientsession(hass)
resp_auth = session.get(authorization_url, allow_redirects=False)
assert resp_auth.status == 200
json_auth = await resp_auth.json()
assert "code" in json_auth
desktop_code = json_auth["code"]
await complete_callback_and_finish(desktop_client, desktop_code, desktop_state)
# ==================== Mobile Device Finalizes Flow ====================
# Mobile device polls SSE and keeps the connection open throughout
resp_sse = await mobile_client.get(
"/auth/oidc/device-sse", allow_redirects=False
)
assert resp_sse.status == 200
# Listen for waiting events for up to 5 seconds
await listen_for_sse_events(resp_sse, "waiting", timeout_seconds=5)
# Actually submit the mobile code using POST
resp_code = await desktop_client.post(
"/auth/oidc/finish",
data={"device_code": mobile_device_code},
allow_redirects=False,
)
assert resp_code.status == 200
assert resp_code.headers.get("Content-Type", "").startswith("text/html")
html_code = await resp_code.text()
assert 'id="mobile-success-message"' in html_code
assert 'id="restart-login-button"' in html_code
# ==================== Mobile Device Receives Ready Event ====================
# After desktop flow is completed, mobile SSE should receive a ready event on same connection
await listen_for_sse_events(resp_sse, "ready", timeout_seconds=5)
# POST to finish without any POST body should result in 302 back to the original redirect_uri
await verify_back_redirect(mobile_client, mobile_redirect_uri)
@pytest.mark.asyncio
async def test_finish_rejects_device_code_when_state_not_ready(
hass: HomeAssistant, hass_client
):
"""Submitting a device code must fail if callback did not complete for this browser."""
await setup(hass)
with mock_oidc_responses():
# Device session that owns the device code.
mobile_client = await hass_client()
mobile_redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
_, mobile_html, status = await get_welcome_for_client(
mobile_client, mobile_redirect_uri
)
assert status == 200
device_code_match = re.search(
r'id=["\']device-code["\'][^>]*>\s*([^<\s]+)\s*<',
mobile_html,
)
assert device_code_match is not None
mobile_device_code = device_code_match.group(1)
# Separate browser starts but does not complete callback flow.
desktop_client = await hass_client()
desktop_redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
_, _, desktop_status = await get_welcome_for_client(
desktop_client, desktop_redirect_uri
)
assert desktop_status in [200, 302]
# Negative branch: try to finalize before desktop state has user info.
resp = await desktop_client.post(
"/auth/oidc/finish",
data={"device_code": mobile_device_code},
allow_redirects=False,
)
assert resp.status == 400
text = await resp.text()
assert "Failed to link state to device code" in text
@pytest.mark.asyncio
async def test_callback_shows_error_if_userinfo_save_fails(
hass: HomeAssistant, hass_client
):
"""Callback should return error page when state save fails after successful token flow."""
await setup(hass)
with (
mock_oidc_responses(),
patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_save_user_info",
new=AsyncMock(return_value=False),
),
):
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
state, _, status = await get_welcome_for_client(client, redirect_uri)
assert status == 200
authorization_url = await get_redirect_auth_url(client)
session = async_get_clientsession(hass)
resp_auth = session.get(authorization_url, allow_redirects=False)
json_auth = await resp_auth.json()
resp = await client.get(
f"/auth/oidc/callback?code={json_auth['code']}&state={state}",
allow_redirects=False,
)
assert resp.status == 500
text = await resp.text()
assert "Failed to save user information, session probably expired." in text
@pytest.mark.asyncio
async def test_callback_rejects_nonce_mismatch(hass: HomeAssistant, hass_client):
"""Callback should fail closed when the returned nonce does not match the stored flow nonce."""
await setup(hass)
with (
mock_oidc_responses(),
patch(
"custom_components.auth_oidc.tools.oidc_client.OIDCClient._parse_id_token",
new=AsyncMock(
return_value={
"sub": "test-user",
"nonce": "mismatched-nonce",
"name": "Test Name",
"preferred_username": "testuser",
"groups": [],
}
),
),
):
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
state, _, status = await get_welcome_for_client(client, redirect_uri)
assert status == 200
authorization_url = await get_redirect_auth_url(client)
session = async_get_clientsession(hass)
resp_auth = session.get(authorization_url, allow_redirects=False)
json_auth = await resp_auth.json()
resp = await client.get(
f"/auth/oidc/callback?code={json_auth['code']}&state={state}",
allow_redirects=False,
)
assert resp.status == 500
text = await resp.text()
assert "Failed to get user details" in text
@pytest.mark.asyncio
async def test_callback_replay_is_rejected(hass: HomeAssistant, hass_client):
"""A callback replay with the same state should be rejected after first successful use."""
await setup(hass)
with mock_oidc_responses():
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
state, _, status = await get_welcome_for_client(client, redirect_uri)
assert status == 200
authorization_url = await get_redirect_auth_url(client)
session = async_get_clientsession(hass)
resp_auth = session.get(authorization_url, allow_redirects=False)
json_auth = await resp_auth.json()
code = json_auth["code"]
# First callback should succeed.
first = await client.get(
f"/auth/oidc/callback?code={code}&state={state}",
allow_redirects=False,
)
assert first.status == 302
# Replay should fail because the state flow has already been consumed.
replay = await client.get(
f"/auth/oidc/callback?code={code}&state={state}",
allow_redirects=False,
)
assert replay.status == 500
replay_text = await replay.text()
assert "Failed to get user details" in replay_text

View File

@@ -0,0 +1,690 @@
"""Unit tests for OIDC client token and security behavior."""
# pylint: disable=protected-access
import hashlib
import json
import base64
import time
from urllib.parse import parse_qs, urlparse
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from homeassistant.core import HomeAssistant
from joserfc import errors as joserfc_errors, jwt, jwk
from custom_components.auth_oidc.tools.oidc_client import (
HTTPClientError,
OIDCClient,
OIDCDiscoveryInvalid,
OIDCIdTokenSigningAlgorithmInvalid,
OIDCTokenResponseInvalid,
OIDCUserinfoInvalid,
http_raise_for_status,
)
def make_client(hass: HomeAssistant, **kwargs) -> OIDCClient:
"""Build an OIDC client with explicit defaults for unit testing."""
return OIDCClient(
hass=hass,
discovery_url="https://issuer/.well-known/openid-configuration",
client_id="test-client",
scope="openid profile",
features=kwargs.pop("features", {}),
claims=kwargs.pop("claims", {}),
roles=kwargs.pop("roles", {}),
network=kwargs.pop("network", {}),
**kwargs,
)
def make_jwt(
header: dict | None,
payload: dict | None = None,
signature: str = "sig",
) -> str:
"""Build a compact JWT string for parser-focused tests."""
def _b64url_json(data: dict) -> str:
encoded = json.dumps(data, separators=(",", ":")).encode("utf-8")
return base64.urlsafe_b64encode(encoded).rstrip(b"=").decode("utf-8")
protected = _b64url_json(header) if header is not None else ""
claims = _b64url_json(payload or {"sub": "subject"})
return f"{protected}.{claims}.{signature}"
def make_signed_hs256_jwt(secret: str, claims: dict) -> str:
"""Build a real HS256 signed JWT for parser validation tests."""
jwk_obj = jwk.import_key(
{
"kty": "oct",
"k": base64.urlsafe_b64encode(secret.encode()).decode().rstrip("="),
"alg": "HS256",
}
)
return jwt.encode({"alg": "HS256"}, claims, jwk_obj)
@pytest.mark.asyncio
async def test_complete_token_flow_rejects_missing_state(hass: HomeAssistant):
"""Flow state must exist; missing state should fail closed."""
client = make_client(hass)
result = await client.async_complete_token_flow(
"https://example.com/callback", "code", "missing-state"
)
assert result is None
@pytest.mark.asyncio
async def test_complete_token_flow_rejects_nonce_mismatch(hass: HomeAssistant):
"""Nonce mismatch should reject the token flow."""
client = make_client(hass)
client.flows["state-1"] = {"code_verifier": "verifier", "nonce": "expected"}
with (
patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
),
patch.object(
client,
"_make_token_request",
new=AsyncMock(return_value={"id_token": "id", "access_token": "access"}),
),
patch.object(
client,
"_parse_id_token",
new=AsyncMock(return_value={"sub": "abc", "nonce": "wrong"}),
),
):
result = await client.async_complete_token_flow(
"https://example.com/callback", "code", "state-1"
)
assert result is None
assert "state-1" not in client.flows
@pytest.mark.asyncio
async def test_complete_token_flow_handles_token_request_failure(hass: HomeAssistant):
"""Token endpoint failures should return None to caller."""
client = make_client(hass)
client.flows["state-2"] = {"code_verifier": "verifier", "nonce": "nonce"}
with (
patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
),
patch.object(
client,
"_make_token_request",
new=AsyncMock(side_effect=OIDCTokenResponseInvalid()),
),
):
result = await client.async_complete_token_flow(
"https://example.com/callback", "code", "state-2"
)
assert result is None
@pytest.mark.asyncio
async def test_parse_user_details_handles_non_list_groups(hass: HomeAssistant):
"""Non-list groups should not accidentally grant roles."""
client = make_client(hass, roles={"user": "users", "admin": "admins"})
with patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"issuer": "https://issuer"}),
):
details = await client.parse_user_details(
{
"sub": "subject",
"name": "Display Name",
"preferred_username": "username",
"groups": "admins",
},
"access-token",
)
assert details["role"] == "invalid"
assert details["display_name"] == "Display Name"
assert details["username"] == "username"
@pytest.mark.asyncio
async def test_parse_user_details_uses_userinfo_for_missing_claims(
hass: HomeAssistant,
):
"""Missing claims in id_token should be filled from userinfo when available."""
client = make_client(hass)
with (
patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(
return_value={
"issuer": "https://issuer",
"userinfo_endpoint": "https://issuer/userinfo",
}
),
),
patch.object(
client,
"_get_userinfo",
new=AsyncMock(
return_value={
"name": "From UserInfo",
"preferred_username": "userinfo-user",
"groups": ["admins"],
}
),
),
):
details = await client.parse_user_details({"sub": "subject"}, "access-token")
expected_sub = hashlib.sha256("https://issuer.subject".encode("utf-8")).hexdigest()
assert details["sub"] == expected_sub
assert details["display_name"] == "From UserInfo"
assert details["username"] == "userinfo-user"
assert details["role"] == "system-admin"
@pytest.mark.asyncio
async def test_parse_user_details_assigns_system_users_role(hass: HomeAssistant):
"""Configured user role should map to system-users when group is present."""
client = make_client(hass, roles={"user": "users", "admin": "admins"})
with patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"issuer": "https://issuer"}),
):
details = await client.parse_user_details(
{
"sub": "subject",
"name": "Display Name",
"preferred_username": "username",
"groups": ["users"],
},
"access-token",
)
assert details["role"] == "system-users"
@pytest.mark.asyncio
async def test_parse_user_details_admin_role_overrides_user_role(
hass: HomeAssistant,
):
"""Admin group should take precedence when both user and admin groups are present."""
client = make_client(hass, roles={"user": "users", "admin": "admins"})
with patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"issuer": "https://issuer"}),
):
details = await client.parse_user_details(
{
"sub": "subject",
"name": "Display Name",
"preferred_username": "username",
"groups": ["users", "admins"],
},
"access-token",
)
assert details["role"] == "system-admin"
@pytest.mark.asyncio
async def test_get_authorization_url_omits_pkce_when_disabled(
hass: HomeAssistant,
):
"""Authorization URL should omit PKCE params when compatibility mode disables PKCE."""
client = make_client(hass, features={"disable_rfc7636": True})
with patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(
return_value={"authorization_endpoint": "https://issuer/authorize"}
),
):
url = await client.async_get_authorization_url(
"https://example.com/callback", "state-xyz"
)
assert url is not None
parsed = urlparse(url)
query = parse_qs(parsed.query)
assert query["state"] == ["state-xyz"]
assert "nonce" in query
assert "code_challenge" not in query
assert "code_challenge_method" not in query
@pytest.mark.asyncio
async def test_parse_id_token_returns_none_when_kid_missing(hass: HomeAssistant):
"""ID token without kid should be rejected."""
client = make_client(hass)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt({"alg": "RS256"})
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": []}),
):
parsed = await client._parse_id_token(token)
assert parsed is None
@pytest.mark.asyncio
async def test_parse_id_token_returns_none_when_kid_not_found(hass: HomeAssistant):
"""ID token with unknown kid should be rejected."""
client = make_client(hass)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt({"alg": "RS256", "kid": "missing"})
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": [{"kid": "other"}]}),
):
parsed = await client._parse_id_token(token)
assert parsed is None
@pytest.mark.asyncio
async def test_parse_id_token_rejects_hs_without_client_secret(hass: HomeAssistant):
"""HMAC-signed id_token requires client_secret and must fail otherwise."""
client = make_client(hass, id_token_signing_alg="HS256")
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt({"alg": "HS256"})
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": []}),
):
with pytest.raises(OIDCIdTokenSigningAlgorithmInvalid):
await client._parse_id_token(token)
@pytest.mark.asyncio
async def test_parse_id_token_returns_none_when_decode_fails_jose(hass: HomeAssistant):
"""Jose decode/verification failures should be handled without raising to callers."""
client = make_client(hass)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt({"alg": "RS256", "kid": "kid1"})
with (
patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": [{"kid": "kid1", "kty": "RSA"}]}),
),
patch(
"custom_components.auth_oidc.tools.oidc_client.jwk.import_key",
return_value=object(),
),
patch(
"custom_components.auth_oidc.tools.oidc_client.jwt.decode",
side_effect=joserfc_errors.JoseError("bad token"),
),
):
parsed = await client._parse_id_token(token)
assert parsed is None
@pytest.mark.asyncio
async def test_parse_id_token_rejects_wrong_signing_algorithm(hass: HomeAssistant):
"""ID token signed with unexpected alg should be rejected."""
client = make_client(hass, id_token_signing_alg="RS256")
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt({"alg": "HS256"})
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": []}),
):
with pytest.raises(OIDCIdTokenSigningAlgorithmInvalid):
await client._parse_id_token(token)
@pytest.mark.asyncio
async def test_parse_id_token_rejects_missing_header(hass: HomeAssistant):
"""ID token without protected header should be rejected."""
client = make_client(hass)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
token = make_jwt(None)
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": []}),
):
parsed = await client._parse_id_token(token)
assert parsed is None
@pytest.mark.asyncio
async def test_parse_id_token_rejects_invalid_registered_claims(hass: HomeAssistant):
"""Invalid aud/iss/sub style claim validation should fail closed."""
hs_secret = "top-secret-value"
client = make_client(
hass,
id_token_signing_alg="HS256",
client_secret=hs_secret,
)
client.discovery_document = {
"issuer": "https://issuer",
"jwks_uri": "https://issuer/jwks",
}
now = int(time.time())
token = make_signed_hs256_jwt(
hs_secret,
{
"sub": "abc",
"aud": "wrong-audience",
"iss": "https://wrong-issuer",
"nbf": now,
"iat": now,
"exp": now + 3600,
},
)
with patch.object(
client,
"_fetch_jwks",
new=AsyncMock(return_value={"keys": []}),
):
parsed = await client._parse_id_token(token)
assert parsed is None
@pytest.mark.asyncio
async def test_get_authorization_url_returns_none_when_discovery_fails(
hass: HomeAssistant,
):
"""Discovery failures should return None from authorization URL generation."""
client = make_client(hass)
with patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(side_effect=OIDCDiscoveryInvalid()),
):
url = await client.async_get_authorization_url(
"https://example.com/callback", "state-1"
)
assert url is None
@pytest.mark.asyncio
async def test_complete_token_flow_omits_code_verifier_when_pkce_disabled(
hass: HomeAssistant,
):
"""When PKCE is disabled, token request should omit code_verifier."""
client = make_client(hass, features={"disable_rfc7636": True})
client.flows["state-3"] = {"code_verifier": "verifier", "nonce": "nonce"}
with (
patch.object(
client,
"_fetch_discovery_document",
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
),
patch.object(
client,
"_make_token_request",
new=AsyncMock(return_value={"id_token": "id", "access_token": "access"}),
) as make_token_request,
patch.object(
client,
"_parse_id_token",
new=AsyncMock(return_value={"sub": "abc", "nonce": "nonce"}),
),
patch.object(
client,
"parse_user_details",
new=AsyncMock(
return_value={
"sub": "abc",
"display_name": "n",
"username": "u",
"role": "system-users",
}
),
),
):
result = await client.async_complete_token_flow(
"https://example.com/callback", "code", "state-3"
)
assert result is not None
token_params = make_token_request.await_args.args[1]
assert "code_verifier" not in token_params
@pytest.mark.asyncio
async def test_http_raise_for_status_noop_on_ok_response():
"""Status helper should not raise for successful responses."""
response = MagicMock()
response.ok = True
await http_raise_for_status(response)
@pytest.mark.asyncio
async def test_http_raise_for_status_raises_http_client_error_with_body():
"""Status helper should include response body in raised exception."""
response = MagicMock()
response.ok = False
response.reason = "Bad Request"
response.status = 400
response.request_info = MagicMock()
response.history = ()
response.headers = {}
response.text = AsyncMock(return_value="problem details")
with pytest.raises(HTTPClientError) as exc_info:
await http_raise_for_status(response)
assert "400 (Bad Request)" in str(exc_info.value)
assert "problem details" in str(exc_info.value)
@pytest.mark.asyncio
async def test_get_http_session_reuses_existing_session(hass: HomeAssistant):
"""Session helper should return existing session when already created."""
client = make_client(hass)
existing_session = MagicMock()
client.http_session = existing_session
session = await client._get_http_session()
assert session is existing_session
@pytest.mark.asyncio
async def test_get_http_session_applies_tls_verify_flag(hass: HomeAssistant):
"""Session helper should pass tls_verify setting into TCP connector."""
client = make_client(hass, network={"tls_verify": False})
with (
patch(
"custom_components.auth_oidc.tools.oidc_client.aiohttp.TCPConnector",
return_value=MagicMock(),
) as tcp_connector,
patch(
"custom_components.auth_oidc.tools.oidc_client.aiohttp.ClientSession",
return_value=MagicMock(),
),
):
await client._get_http_session()
tcp_connector.assert_called_once_with(verify_ssl=False)
@pytest.mark.asyncio
async def test_get_http_session_uses_custom_ca_path(hass: HomeAssistant):
"""Session helper should create SSL context when custom CA path is configured."""
client = make_client(
hass,
network={"tls_verify": True, "tls_ca_path": "/tmp/test-ca.pem"},
)
fake_ssl_context = object()
with (
patch.object(
hass.loop,
"run_in_executor",
new=AsyncMock(return_value=fake_ssl_context),
) as run_in_executor,
patch(
"custom_components.auth_oidc.tools.oidc_client.aiohttp.TCPConnector",
return_value=MagicMock(),
) as tcp_connector,
patch(
"custom_components.auth_oidc.tools.oidc_client.aiohttp.ClientSession",
return_value=MagicMock(),
),
):
await client._get_http_session()
run_in_executor.assert_awaited_once()
tcp_connector.assert_called_once_with(verify_ssl=True, ssl=fake_ssl_context)
@pytest.mark.asyncio
async def test_make_token_request_returns_json_on_success(hass: HomeAssistant):
"""Token request helper should return JSON payload for successful responses."""
client = make_client(hass)
response = MagicMock()
response.ok = True
response.json = AsyncMock(return_value={"access_token": "token"})
context_manager = AsyncMock()
context_manager.__aenter__.return_value = response
session = MagicMock()
session.post.return_value = context_manager
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
payload = await client._make_token_request(
"https://issuer/token", {"code": "abc"}
)
assert payload == {"access_token": "token"}
@pytest.mark.asyncio
async def test_make_token_request_raises_invalid_on_non_400_http_error(
hass: HomeAssistant,
):
"""Token request helper should map upstream HTTP errors to OIDCTokenResponseInvalid."""
client = make_client(hass)
response = MagicMock()
response.ok = False
response.reason = "Server Error"
response.status = 500
response.request_info = MagicMock()
response.history = ()
response.headers = {}
response.text = AsyncMock(return_value="boom")
context_manager = AsyncMock()
context_manager.__aenter__.return_value = response
session = MagicMock()
session.post.return_value = context_manager
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
with pytest.raises(OIDCTokenResponseInvalid):
await client._make_token_request("https://issuer/token", {"code": "abc"})
@pytest.mark.asyncio
async def test_get_userinfo_returns_json_on_success(hass: HomeAssistant):
"""Userinfo helper should return JSON payload for successful responses."""
client = make_client(hass)
response = MagicMock()
response.ok = True
response.json = AsyncMock(return_value={"sub": "abc"})
context_manager = AsyncMock()
context_manager.__aenter__.return_value = response
session = MagicMock()
session.get.return_value = context_manager
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
payload = await client._get_userinfo("https://issuer/userinfo", "access")
assert payload == {"sub": "abc"}
@pytest.mark.asyncio
async def test_get_userinfo_raises_invalid_on_http_error(hass: HomeAssistant):
"""Userinfo helper should map upstream HTTP errors to OIDCUserinfoInvalid."""
client = make_client(hass)
response = MagicMock()
response.ok = False
response.reason = "Unavailable"
response.status = 503
response.request_info = MagicMock()
response.history = ()
response.headers = {}
response.text = AsyncMock(return_value="oops")
context_manager = AsyncMock()
context_manager.__aenter__.return_value = response
session = MagicMock()
session.get.return_value = context_manager
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
with pytest.raises(OIDCUserinfoInvalid):
await client._get_userinfo("https://issuer/userinfo", "access")

View File

@@ -7,7 +7,7 @@ from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from custom_components.auth_oidc import DOMAIN, async_setup_entry
from custom_components.auth_oidc import DOMAIN
from custom_components.auth_oidc.config.const import (
OIDC_PROVIDERS,
CLIENT_ID,
@@ -170,12 +170,6 @@ async def test_full_config_flow_success(hass: HomeAssistant):
assert len(entries) == 1
assert entries[0].data == expected_data
# You can also assert that `async_setup_entry` was called for this entry
# (assuming it's mocked or you let it run if it's simple)
# The PHCC `hass` fixture automatically mocks `async_setup_entry`
# and `async_unload_entry` for you, making it easy to test that they're called.
assert await async_setup_entry(hass, entries[0]) is True
@pytest.mark.asyncio
async def test_options_flow_success(hass: HomeAssistant):
@@ -362,3 +356,294 @@ async def test_reconfigure_flow_success(hass: HomeAssistant):
assert len(entries) == 1
assert entries[0].data[CLIENT_ID] == new_client_id
assert entries[0].data[CLIENT_SECRET] == new_client_secret
@pytest.mark.asyncio
async def test_reconfigure_flow_rejects_invalid_client_id(hass: HomeAssistant):
"""Reconfigure should keep the form open when the client ID is invalid."""
initial_data = {
"provider": "authentik",
CLIENT_ID: DEMO_CLIENT_ID,
CLIENT_SECRET: DEMO_CLIENT_SECRET,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"],
FEATURES: {
FEATURES_AUTOMATIC_USER_LINKING: False,
FEATURES_AUTOMATIC_PERSON_CREATION: True,
FEATURES_INCLUDE_GROUPS_SCOPE: True,
},
CLAIMS: {
CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"]["display_name"],
CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"],
CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"],
},
ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE},
}
entry = config_entries.ConfigEntry(
version=1,
minor_version=0,
domain=DOMAIN,
title=OIDC_PROVIDERS["authentik"]["name"],
data=initial_data,
source=config_entries.SOURCE_USER,
entry_id="1",
unique_id="test_unique_id",
options={},
pref_disable_new_entities=False,
pref_disable_polling=False,
discovery_keys=[],
subentries_data=None,
)
await hass.config_entries.async_add(entry)
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={
"source": config_entries.SOURCE_RECONFIGURE,
"entry_id": entry.entry_id,
},
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"client_id": " ", "client_secret": DEMO_CLIENT_SECRET},
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "reconfigure"
assert result["errors"]["client_id"] == "invalid_client_id"
@pytest.mark.asyncio
async def test_reconfigure_flow_keeps_client_secret_when_blank(hass: HomeAssistant):
"""Submitting a blank secret should keep the existing client secret."""
initial_data = {
"provider": "authentik",
CLIENT_ID: DEMO_CLIENT_ID,
CLIENT_SECRET: DEMO_CLIENT_SECRET,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"],
FEATURES: {
FEATURES_AUTOMATIC_USER_LINKING: False,
FEATURES_AUTOMATIC_PERSON_CREATION: True,
FEATURES_INCLUDE_GROUPS_SCOPE: True,
},
CLAIMS: {
CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"]["display_name"],
CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"],
CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"],
},
ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE},
}
entry = config_entries.ConfigEntry(
version=1,
minor_version=0,
domain=DOMAIN,
title=OIDC_PROVIDERS["authentik"]["name"],
data=initial_data,
source=config_entries.SOURCE_USER,
entry_id="1",
unique_id="test_unique_id",
options={},
pref_disable_new_entities=False,
pref_disable_polling=False,
discovery_keys=[],
subentries_data=None,
)
await hass.config_entries.async_add(entry)
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={
"source": config_entries.SOURCE_RECONFIGURE,
"entry_id": entry.entry_id,
},
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"client_id": DEMO_CLIENT_ID, "client_secret": ""},
)
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "reconfigure_successful"
updated_entry = hass.config_entries.async_get_entry(entry.entry_id)
assert updated_entry is not None
assert updated_entry.data[CLIENT_SECRET] == DEMO_CLIENT_SECRET
@pytest.mark.asyncio
async def test_validation_actions_route_to_other_steps(hass: HomeAssistant):
"""Validation actions should route to the requested flow step."""
with mock_oidc_responses():
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"provider": "authentik"}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"discovery_url": MockOIDCServer.get_discovery_url()},
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "validate_connection"
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"action": "fix_discovery"}
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "discovery_url"
with mock_oidc_responses():
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"provider": "authentik"}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"discovery_url": MockOIDCServer.get_discovery_url()},
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "validate_connection"
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"action": "change_provider"}
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "user"
@pytest.mark.asyncio
async def test_user_flow_aborts_when_yaml_configured(hass: HomeAssistant):
"""The user flow should abort when YAML config already owns the provider."""
hass.data[DOMAIN] = {"yaml_config": {"client_id": DEMO_CLIENT_ID}}
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "yaml_configured"
@pytest.mark.asyncio
async def test_user_flow_aborts_when_entry_already_exists(hass: HomeAssistant):
"""The flow should not create a second OIDC config entry."""
entry = config_entries.ConfigEntry(
version=1,
minor_version=0,
domain=DOMAIN,
title=OIDC_PROVIDERS["authentik"]["name"],
data={"provider": "authentik"},
source=config_entries.SOURCE_USER,
entry_id="1",
unique_id="test_unique_id",
options={},
pref_disable_new_entities=False,
pref_disable_polling=False,
discovery_keys=[],
subentries_data=None,
)
await hass.config_entries.async_add(entry)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "single_instance_allowed"
@pytest.mark.asyncio
async def test_discovery_url_validation_rejects_invalid_url(hass: HomeAssistant):
"""Discovery URL validation should reject malformed inputs."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"provider": "authentik"}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"discovery_url": "not-a-valid-oidc-url"}
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "discovery_url"
assert result["errors"]["discovery_url"] == "invalid_url_format"
@pytest.mark.asyncio
async def test_generic_provider_skips_groups_config(hass: HomeAssistant):
"""Providers without group support should go straight to user linking."""
with mock_oidc_responses():
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"provider": "generic"}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"discovery_url": MockOIDCServer.get_discovery_url()},
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"action": "continue"}
)
assert result["step_id"] == "client_config"
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"client_id": DEMO_CLIENT_ID, "client_secret": DEMO_CLIENT_SECRET},
)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "user_linking"
@pytest.mark.asyncio
async def test_groups_disabled_skips_roles_and_creates_entry(hass: HomeAssistant):
"""Disabling groups should skip role configuration and omit roles from entry data."""
with mock_oidc_responses():
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"provider": "authentik"}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"discovery_url": MockOIDCServer.get_discovery_url()},
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"action": "continue"}
)
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{"client_id": DEMO_CLIENT_ID, "client_secret": DEMO_CLIENT_SECRET},
)
assert result["step_id"] == "groups_config"
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"enable_groups": False}
)
assert result["step_id"] == "user_linking"
result = await hass.config_entries.flow.async_configure(
result["flow_id"], {"enable_user_linking": True}
)
assert result["type"] == FlowResultType.CREATE_ENTRY
assert "roles" not in result["data"]
assert result["data"][FEATURES][FEATURES_INCLUDE_GROUPS_SCOPE] is False

View File

@@ -1,12 +1,9 @@
"""Tests for the registered webpages"""
import base64
import os
from auth_oidc.config.const import (
DISCOVERY_URL,
CLIENT_ID,
FEATURES,
FEATURES_DISABLE_FRONTEND_INJECTION,
)
from unittest.mock import AsyncMock, MagicMock, patch
from auth_oidc.config.const import DISCOVERY_URL, CLIENT_ID
import pytest
from homeassistant.core import HomeAssistant
@@ -14,62 +11,239 @@ from homeassistant.setup import async_setup_component
from homeassistant.components.http import StaticPathConfig, DOMAIN as HTTP_DOMAIN
from custom_components.auth_oidc import DOMAIN
from custom_components.auth_oidc.endpoints.injected_auth_page import (
OIDCInjectedAuthPage,
frontend_injection,
)
async def setup(hass: HomeAssistant, enable_frontend_changes: bool = None):
WEB_CLIENT_ID = "https://example.com"
MOBILE_CLIENT_ID = "https://home-assistant.io/Android"
def create_redirect_uri(client_id: str) -> str:
"""Build a redirect URI that includes a client_id query parameter."""
return f"http://example.com/auth/authorize?client_id={client_id}"
def encode_redirect_uri(redirect_uri: str) -> str:
"""Encode redirect_uri in the same way as frontend btoa()."""
return base64.b64encode(redirect_uri.encode("utf-8")).decode("utf-8")
async def setup(
hass: HomeAssistant,
):
mock_config = {
DOMAIN: {
CLIENT_ID: "dummy",
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
FEATURES: {
FEATURES_DISABLE_FRONTEND_INJECTION: not enable_frontend_changes
},
}
}
if enable_frontend_changes is None:
del mock_config[DOMAIN][FEATURES][FEATURES_DISABLE_FRONTEND_INJECTION]
result = await async_setup_component(hass, DOMAIN, mock_config)
assert result
@pytest.mark.asyncio
async def test_welcome_page_registration(hass: HomeAssistant, hass_client):
"""Test that welcome page is present if frontend changes are disabled."""
"""Test that welcome page is present."""
await setup(hass, enable_frontend_changes=False)
await setup(hass)
client = await hass_client()
resp = await client.get("/auth/oidc/welcome", allow_redirects=False)
assert resp.status == 200
@pytest.mark.asyncio
async def test_welcome_page_registration_with_changes(hass: HomeAssistant, hass_client):
"""Test that welcome page is redirect if frontend changes are enabled."""
await setup(hass, enable_frontend_changes=True)
client = await hass_client()
resp = await client.get("/auth/oidc/welcome", allow_redirects=False)
assert resp.status == 307
@pytest.mark.asyncio
async def test_redirect_page_registration(hass: HomeAssistant, hass_client):
"""Test that redirect page shows OIDC misconfiguration error if OIDC server is not reachable."""
"""Test that redirect page can be reached."""
await setup(hass)
client = await hass_client()
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
assert resp.status == 200
text = await resp.text()
assert "Integration is misconfigured" in text
assert resp.status == 302
resp2 = await client.post("/auth/oidc/redirect", allow_redirects=False)
assert resp2.status == 200
assert resp2.status == 302
@pytest.mark.asyncio
async def test_welcome_rejects_invalid_encoded_redirect_uri(
hass: HomeAssistant, hass_client
):
"""Welcome should reject malformed base64 redirect_uri values."""
await setup(hass)
client = await hass_client()
resp = await client.get(
"/auth/oidc/welcome?redirect_uri=%25%25%25",
allow_redirects=False,
)
assert resp.status == 400
assert "Invalid redirect_uri, please restart login." in await resp.text()
@pytest.mark.asyncio
async def test_welcome_sets_strict_state_cookie_flags(hass: HomeAssistant, hass_client):
"""Welcome should set secure cookie flags for the OIDC state cookie."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp.status in (200, 302)
assert "auth_oidc_state" in resp.cookies
set_cookie = resp.headers.get("Set-Cookie", "")
assert "Path=/auth/" in set_cookie
assert "SameSite=Strict" in set_cookie
assert "HttpOnly" in set_cookie
assert "Max-Age=300" in set_cookie
@pytest.mark.asyncio
async def test_welcome_mobile_device_code_generation_failure(
hass: HomeAssistant, hass_client
):
"""Welcome should error if device code generation fails for mobile clients."""
await setup(hass)
with patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_generate_device_code",
new=AsyncMock(return_value=None),
):
client = await hass_client()
redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp.status == 500
assert (
"Failed to generate device code, please restart login." in await resp.text()
)
@pytest.mark.asyncio
async def test_welcome_shows_alternative_sign_in_link_when_other_providers_exist(
hass: HomeAssistant, hass_client
):
"""Welcome should render fallback auth link when other providers are present."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp.status == 200
text = await resp.text()
assert 'id="login-button"' in text
assert 'id="alternative-sign-in-link"' in text
assert "skip_oidc_redirect=true" in text
@pytest.mark.asyncio
async def test_welcome_desktop_auto_redirects_without_other_providers(
hass: HomeAssistant, hass_client
):
"""Welcome should auto-redirect desktop clients when no other providers exist."""
# pylint: disable=protected-access
hass.auth._providers = [] # Clear initial providers out
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp.status == 302
assert "/auth/oidc/redirect" in resp.headers["Location"]
@pytest.mark.asyncio
async def test_redirect_without_cookie_goes_to_welcome(
hass: HomeAssistant, hass_client
):
"""Redirect endpoint should bounce to welcome when no state cookie exists."""
await setup(hass)
client = await hass_client()
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
assert resp.status == 302
assert "/auth/oidc/welcome" in resp.headers["Location"]
@pytest.mark.asyncio
async def test_redirect_shows_error_on_oidc_runtime_error(
hass: HomeAssistant, hass_client
):
"""Redirect should show a configuration error when OIDC URL generation raises."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp_welcome.status in (200, 302)
with patch(
"custom_components.auth_oidc.tools.oidc_client.OIDCClient.async_get_authorization_url",
new=AsyncMock(side_effect=RuntimeError("broken discovery")),
):
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
assert resp.status == 500
assert (
"Integration is misconfigured, discovery could not be obtained."
in await resp.text()
)
@pytest.mark.asyncio
async def test_redirect_shows_error_when_auth_url_empty(
hass: HomeAssistant, hass_client
):
"""Redirect should show error page if OIDC returns no authorization URL."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp_welcome.status in (200, 302)
with patch(
"custom_components.auth_oidc.tools.oidc_client.OIDCClient.async_get_authorization_url",
new=AsyncMock(return_value=None),
):
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
assert resp.status == 500
assert (
"Integration is misconfigured, discovery could not be obtained."
in await resp.text()
)
@pytest.mark.asyncio
@@ -80,45 +254,301 @@ async def test_callback_registration(hass: HomeAssistant, hass_client):
client = await hass_client()
resp = await client.get("/auth/oidc/callback", allow_redirects=False)
assert resp.status == 200
assert resp.status == 400
@pytest.mark.asyncio
async def test_finish_registration(hass: HomeAssistant, hass_client):
"""Test that finish page is reachable."""
async def test_callback_rejects_missing_code_or_state(hass: HomeAssistant, hass_client):
"""Callback must reject requests missing either code or state."""
await setup(hass)
client = await hass_client()
resp = await client.get("/auth/oidc/finish", allow_redirects=False)
assert resp.status == 200
text = await resp.text()
# Should miss the code parameter if called without it
assert "Missing code" in text
resp2 = await client.get("/auth/oidc/finish?code=123456", allow_redirects=False)
assert resp2.status == 200
text2 = await resp2.text()
assert "Missing code" not in text2
assert "123456" in text2
@pytest.mark.asyncio
async def test_finish_post(hass: HomeAssistant, hass_client):
"""Test that finish page works with POST."""
await setup(hass)
client = await hass_client()
resp = await client.post("/auth/oidc/finish", data={}, allow_redirects=False)
assert resp.status == 500
resp2 = await client.post(
"/auth/oidc/finish", data={"code": "456888"}, allow_redirects=False
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp2.status == 302
assert resp2.headers["Location"] == "/?storeToken=true"
assert resp2.cookies["auth_oidc_code"].value == "456888"
state = resp_welcome.cookies["auth_oidc_state"].value
resp_missing_code = await client.get(
f"/auth/oidc/callback?state={state}",
allow_redirects=False,
)
assert resp_missing_code.status == 400
assert "Missing code or state parameter." in await resp_missing_code.text()
resp_missing_state = await client.get(
"/auth/oidc/callback?code=testcode",
allow_redirects=False,
)
assert resp_missing_state.status == 400
assert "Missing code or state parameter." in await resp_missing_state.text()
@pytest.mark.asyncio
async def test_callback_rejects_state_mismatch(hass: HomeAssistant, hass_client):
"""Callback must reject state mismatch to protect against CSRF."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
state = resp_welcome.cookies["auth_oidc_state"].value
resp = await client.get(
f"/auth/oidc/callback?code=testcode&state={state}-other",
allow_redirects=False,
)
assert resp.status == 400
assert "State parameter does not match, possible CSRF attack." in await resp.text()
@pytest.mark.asyncio
async def test_callback_rejects_when_user_details_fetch_fails(
hass: HomeAssistant, hass_client
):
"""Callback should error when token exchange/userinfo retrieval fails."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
state = resp_welcome.cookies["auth_oidc_state"].value
with patch(
"custom_components.auth_oidc.tools.oidc_client.OIDCClient.async_complete_token_flow",
new=AsyncMock(return_value=None),
):
resp = await client.get(
f"/auth/oidc/callback?code=testcode&state={state}",
allow_redirects=False,
)
assert resp.status == 500
assert (
"Failed to get user details, see Home Assistant logs for more information."
in await resp.text()
)
@pytest.mark.asyncio
async def test_callback_rejects_invalid_role(hass: HomeAssistant, hass_client):
"""Callback should reject users marked with invalid role."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
state = resp_welcome.cookies["auth_oidc_state"].value
with patch(
"custom_components.auth_oidc.tools.oidc_client.OIDCClient.async_complete_token_flow",
new=AsyncMock(return_value={"sub": "abc", "role": "invalid"}),
):
resp = await client.get(
f"/auth/oidc/callback?code=testcode&state={state}",
allow_redirects=False,
)
assert resp.status == 403
assert (
"User is not in the correct group to access Home Assistant"
in await resp.text()
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
("method", "data"),
[
("get", None),
("post", {}),
("post", {"device_code": "456888"}),
],
)
async def test_finish_requires_state_cookie(
hass: HomeAssistant, hass_client, method: str, data: dict | None
):
"""Finish endpoint should require the OIDC state cookie for both GET and POST."""
await setup(hass)
client = await hass_client()
request = getattr(client, method)
if data is None:
resp = await request("/auth/oidc/finish", allow_redirects=False)
else:
resp = await request("/auth/oidc/finish", data=data, allow_redirects=False)
assert resp.status == 400
assert "Missing state cookie" in await resp.text()
@pytest.mark.asyncio
async def test_finish_post_rejects_invalid_state(hass: HomeAssistant, hass_client):
"""Finish POST should error when the state cookie does not resolve to redirect_uri."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp_welcome.status in (200, 302)
with patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_get_redirect_uri_for_state",
new=AsyncMock(return_value=None),
):
resp = await client.post("/auth/oidc/finish", allow_redirects=False)
assert resp.status == 400
assert "Invalid state, please restart login." in await resp.text()
@pytest.mark.asyncio
async def test_device_sse_requires_state_cookie(hass: HomeAssistant, hass_client):
"""SSE endpoint should reject requests without state cookie."""
await setup(hass)
client = await hass_client()
resp = await client.get("/auth/oidc/device-sse", allow_redirects=False)
assert resp.status == 400
assert "Missing session cookie" in await resp.text()
@pytest.mark.asyncio
async def test_device_sse_emits_expired_for_unknown_state(
hass: HomeAssistant, hass_client
):
"""SSE should emit expired when the state can no longer be resolved."""
await setup(hass)
with patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_get_redirect_uri_for_state",
new=AsyncMock(return_value=None),
):
client = await hass_client()
redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp_welcome.status == 200
resp = await client.get("/auth/oidc/device-sse", allow_redirects=False)
assert resp.status == 200
payload = await resp.text()
assert "event: expired" in payload
@pytest.mark.asyncio
async def test_device_sse_emits_timeout(hass: HomeAssistant, hass_client):
"""SSE should emit timeout if the polling window is exceeded."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp_welcome.status == 200
fake_loop = MagicMock()
fake_loop.time.side_effect = [0, 301]
with (
patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_get_redirect_uri_for_state",
new=AsyncMock(return_value=redirect_uri),
),
patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_is_state_ready",
new=AsyncMock(return_value=False),
),
patch(
"custom_components.auth_oidc.endpoints.device_sse.asyncio.get_running_loop",
return_value=fake_loop,
),
):
resp = await client.get("/auth/oidc/device-sse", allow_redirects=False)
assert resp.status == 200
payload = await resp.text()
assert "event: timeout" in payload
@pytest.mark.asyncio
async def test_device_sse_handles_runtime_error_and_returns_cleanly(
hass: HomeAssistant, hass_client
):
"""SSE should swallow runtime errors from stream loop and finish response."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp_welcome.status == 200
with (
patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_get_redirect_uri_for_state",
new=AsyncMock(return_value=redirect_uri),
),
patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_is_state_ready",
new=AsyncMock(side_effect=RuntimeError("disconnect")),
),
):
resp = await client.get("/auth/oidc/device-sse", allow_redirects=False)
assert resp.status == 200
@pytest.mark.asyncio
async def test_device_sse_ignores_write_eof_connection_reset(
hass: HomeAssistant, hass_client
):
"""SSE should ignore ConnectionResetError while closing the stream."""
await setup(hass)
client = await hass_client()
redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp_welcome.status == 200
with (
patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_get_redirect_uri_for_state",
new=AsyncMock(return_value=None),
),
patch(
"custom_components.auth_oidc.endpoints.device_sse.web.StreamResponse.write_eof",
new=AsyncMock(side_effect=ConnectionResetError),
),
):
resp = await client.get("/auth/oidc/device-sse", allow_redirects=False)
assert resp.status == 200
# Test the frontend injection
@@ -141,7 +571,7 @@ async def test_frontend_injection(hass: HomeAssistant, hass_client):
]
)
await setup(hass, enable_frontend_changes=True)
await setup(hass)
client = await hass_client()
resp = await client.get("/auth/authorize", allow_redirects=False)
@@ -149,4 +579,52 @@ async def test_frontend_injection(hass: HomeAssistant, hass_client):
text = await resp.text()
assert "<script src='/auth/oidc/static/injection.js" in text
assert 'window.sso_name = "OpenID Connect (SSO)";' in text
@pytest.mark.asyncio
async def test_frontend_injection_logs_and_returns_when_route_handler_is_unexpected(
hass: HomeAssistant, caplog
):
"""frontend_injection should log and return if the GET handler shape is unexpected."""
await async_setup_component(hass, HTTP_DOMAIN, {})
class FakeRoute:
method = "GET"
handler = object()
class FakeResource:
canonical = "/auth/authorize"
def __init__(self):
self.prefix = None
def add_prefix(self, prefix):
self.prefix = prefix
def __iter__(self):
return iter([FakeRoute()])
with patch.object(hass.http.app.router, "resources", return_value=[FakeResource()]):
await frontend_injection(hass)
assert "Unexpected route handler type" in caplog.text
assert (
"Failed to find GET route for /auth/authorize, cannot inject OIDC frontend code"
in caplog.text
)
@pytest.mark.asyncio
async def test_injected_auth_page_inject_logs_errors(hass: HomeAssistant, caplog):
"""OIDCInjectedAuthPage.inject should swallow unexpected injection errors."""
await async_setup_component(hass, HTTP_DOMAIN, {})
with patch(
"custom_components.auth_oidc.endpoints.injected_auth_page.frontend_injection",
side_effect=RuntimeError("boom"),
):
await OIDCInjectedAuthPage.inject(hass)
assert "Failed to inject OIDC auth page: boom" in caplog.text

View File

@@ -19,28 +19,25 @@ async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool
@pytest.mark.asyncio
async def test_setup_success_yaml(hass: HomeAssistant):
"""Test successful setup of a YAML configuration."""
await setup(
hass,
@pytest.mark.parametrize(
"config",
[
{
"client_id": "dummy",
"discovery_url": "https://example.com/.well-known/openid-configuration",
},
True,
)
@pytest.mark.asyncio
async def test_setup_success_yaml_with_optional(hass: HomeAssistant):
"""Test successful setup of a YAML configuration with optional parameters."""
await setup(
hass,
{
"client_id": "dummy",
"discovery_url": "https://example.com/.well-known/openid-configuration",
ADDITIONAL_SCOPES: "email phone",
},
],
)
async def test_setup_success_yaml(hass: HomeAssistant, config: dict):
"""YAML setup should succeed for minimal and optional-scope configurations."""
await setup(
hass,
config,
True,
)

View File

@@ -1,11 +1,21 @@
"""Tests for the helpers and validation tools"""
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from aiohttp.test_utils import make_mocked_request
from aiohttp import web
from custom_components.auth_oidc.tools.helpers import get_url, get_view
from custom_components.auth_oidc.tools.helpers import (
STATE_COOKIE_NAME,
error_response,
get_state_id,
get_url,
get_valid_state_id,
get_view,
html_response,
template_response,
)
from custom_components.auth_oidc.tools.validation import (
validate_client_id,
sanitize_client_secret,
@@ -38,6 +48,85 @@ async def test_get_view():
assert data.startswith("<!DOCTYPE html>")
@pytest.mark.asyncio
async def test_get_state_id():
"""State cookie helper should return cookie value when present."""
request = make_mocked_request(
"GET", "/", headers={"Cookie": f"{STATE_COOKIE_NAME}=abc"}
)
assert get_state_id(request) == "abc"
request_without_cookie = make_mocked_request("GET", "/")
assert get_state_id(request_without_cookie) is None
@pytest.mark.asyncio
async def test_get_valid_state_id():
"""Valid-state helper should return only existing and valid cookie states."""
provider = MagicMock()
provider.async_is_state_valid = AsyncMock(return_value=True)
request = make_mocked_request(
"GET", "/", headers={"Cookie": f"{STATE_COOKIE_NAME}=state-1"}
)
state_id = await get_valid_state_id(request, provider)
assert state_id == "state-1"
provider.async_is_state_valid.assert_awaited_once_with("state-1")
@pytest.mark.asyncio
async def test_get_valid_state_id_invalid_or_missing_cookie():
"""Valid-state helper should reject missing and invalid states."""
provider = MagicMock()
provider.async_is_state_valid = AsyncMock(return_value=False)
request = make_mocked_request(
"GET", "/", headers={"Cookie": f"{STATE_COOKIE_NAME}=state-2"}
)
assert await get_valid_state_id(request, provider) is None
provider.async_is_state_valid.assert_awaited_once_with("state-2")
request_without_cookie = make_mocked_request("GET", "/")
provider.async_is_state_valid.reset_mock()
assert await get_valid_state_id(request_without_cookie, provider) is None
provider.async_is_state_valid.assert_not_called()
@pytest.mark.asyncio
async def test_html_response_and_template_helpers():
"""Response helpers should preserve status and render HTML views."""
response = html_response("<p>ok</p>", status=418)
assert isinstance(response, web.Response)
assert response.status == 418
assert response.content_type == "text/html"
assert response.text == "<p>ok</p>"
with patch(
"custom_components.auth_oidc.tools.helpers.get_view",
new=AsyncMock(return_value="<p>rendered</p>"),
) as mocked_get_view:
rendered = await template_response("welcome", {"name": "OIDC"})
assert rendered.status == 200
assert rendered.text == "<p>rendered</p>"
mocked_get_view.assert_awaited_once_with("welcome", {"name": "OIDC"})
@pytest.mark.asyncio
async def test_error_response():
"""Error response helper should render the shared error template with status."""
with patch(
"custom_components.auth_oidc.tools.helpers.get_view",
new=AsyncMock(return_value="<p>error</p>"),
) as mocked_get_view:
rendered = await error_response("boom", status=500)
assert rendered.status == 500
assert rendered.text == "<p>error</p>"
mocked_get_view.assert_awaited_once_with("error", {"error": "boom"})
@pytest.mark.asyncio
async def test_validate_url():
"""Test the validate_url helper."""

View File

@@ -0,0 +1,53 @@
"""Tests for the provider catalog helpers."""
import pytest
from custom_components.auth_oidc.config.const import OIDC_PROVIDERS, REPO_ROOT_URL
from custom_components.auth_oidc.config.provider_catalog import (
get_provider_config,
get_provider_docs_url,
get_provider_name,
)
@pytest.mark.parametrize(
("provider_key", "expected_name", "expected_supports_groups"),
[
("authentik", "Authentik", True),
("generic", "OpenID Connect (SSO)", False),
],
)
def test_get_provider_config_and_name(
provider_key, expected_name, expected_supports_groups
):
"""Known providers should resolve to their configured metadata."""
config = get_provider_config(provider_key)
assert config == OIDC_PROVIDERS[provider_key]
assert get_provider_name(provider_key) == expected_name
assert config["supports_groups"] is expected_supports_groups
@pytest.mark.parametrize("provider_key", [None, "unknown", ""])
def test_provider_fallbacks(provider_key):
"""Unknown providers should fall back to neutral defaults."""
assert get_provider_config(provider_key or "unknown") == {}
assert get_provider_name(provider_key) == "Unknown Provider"
assert (
get_provider_docs_url(provider_key) == f"{REPO_ROOT_URL}/docs/configuration.md"
)
@pytest.mark.parametrize(
("provider_key", "expected_suffix"),
[
("authentik", "/docs/provider-configurations/authentik.md"),
("authelia", "/docs/provider-configurations/authelia.md"),
("pocketid", "/docs/provider-configurations/pocket-id.md"),
("kanidm", "/docs/provider-configurations/kanidm.md"),
("microsoft", "/docs/provider-configurations/microsoft-entra.md"),
],
)
def test_provider_docs_urls(provider_key, expected_suffix):
"""Known providers should point to provider-specific docs."""
assert get_provider_docs_url(provider_key) == f"{REPO_ROOT_URL}{expected_suffix}"

260
tests/test_state_store.py Normal file
View File

@@ -0,0 +1,260 @@
"""Tests for the state store."""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, patch
import pytest
from homeassistant.core import HomeAssistant
from auth_oidc.stores.state_store import MAX_DEVICE_CODE_ATTEMPTS, StateStore
TEST_IP = "127.0.0.1"
@pytest.mark.asyncio
async def test_state_store_generate_and_receive_state(hass: HomeAssistant):
"""Test creating a state, storing user info, and receiving it once."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
store_mock.async_load.return_value = {}
await state_store.async_load()
assert state_store.get_data() == {}
redirect_uri = "https://example.com/callback"
state_id = await state_store.async_create_state_from_url(redirect_uri, TEST_IP)
assert state_id in state_store.get_data()
assert (
await state_store.async_get_redirect_uri_for_state(state_id, TEST_IP)
== redirect_uri
)
user_info = {
"sub": "user1",
"display_name": "Test User",
"username": "testuser",
"role": "system-users",
}
assert (
await state_store.async_add_userinfo_to_state(state_id, user_info) is True
)
assert state_id in state_store.get_data()
assert await state_store.async_is_state_ready(state_id, TEST_IP) is True
assert state_id in state_store.get_data()
result = await state_store.async_receive_userinfo_for_state(state_id, TEST_IP)
assert result == user_info
assert state_id not in state_store.get_data()
@pytest.mark.asyncio
async def test_state_store_generate_code_and_link_state(hass: HomeAssistant):
"""Test generating a device code and linking another state to it."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
store_mock.async_load.return_value = {}
await state_store.async_load()
donor_state = await state_store.async_create_state_from_url(
"https://example.com/donor", TEST_IP
)
target_state = await state_store.async_create_state_from_url(
"https://example.com/target", TEST_IP
)
code = await state_store.async_generate_code_for_state(target_state)
assert code is not None
assert len(code) == 6
assert code.isdigit()
user_info = {
"sub": "user2",
"display_name": "Device User",
"username": "deviceuser",
"role": "system-admin",
}
assert (
await state_store.async_add_userinfo_to_state(donor_state, user_info)
is True
)
assert donor_state in state_store.get_data()
assert (
await state_store.async_link_state_to_code(donor_state, code, TEST_IP)
is True
)
assert donor_state not in state_store.get_data()
assert await state_store.async_is_state_ready(target_state, TEST_IP) is True
assert target_state in state_store.get_data()
assert (
await state_store.async_receive_userinfo_for_state(target_state, TEST_IP)
== user_info
)
assert target_state not in state_store.get_data()
@pytest.mark.asyncio
async def test_state_store_link_state_returns_false_for_wrong_code(hass: HomeAssistant):
"""Test linking fails when the device code does not match any state."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
store_mock.async_load.return_value = {}
await state_store.async_load()
donor_state = await state_store.async_create_state_from_url(
"https://example.com/donor", TEST_IP
)
target_state = await state_store.async_create_state_from_url(
"https://example.com/target", TEST_IP
)
await state_store.async_generate_code_for_state(target_state)
user_info = {
"sub": "user3",
"display_name": "Wrong Code User",
"username": "wrongcode",
"role": "system-users",
}
assert (
await state_store.async_add_userinfo_to_state(donor_state, user_info)
is True
)
assert (
await state_store.async_link_state_to_code(donor_state, "000000", TEST_IP)
is False
)
assert donor_state in state_store.get_data()
assert await state_store.async_is_state_ready(target_state, TEST_IP) is False
@pytest.mark.asyncio
async def test_state_store_throttles_device_code_link_attempts(hass: HomeAssistant):
"""Test that repeated wrong device codes are throttled per state."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
store_mock.async_load.return_value = {}
await state_store.async_load()
donor_state = await state_store.async_create_state_from_url(
"https://example.com/donor", TEST_IP
)
target_state = await state_store.async_create_state_from_url(
"https://example.com/target", TEST_IP
)
code = await state_store.async_generate_code_for_state(target_state)
assert code is not None
user_info = {
"sub": "user-throttle",
"display_name": "Throttle User",
"username": "throttle",
"role": "system-users",
}
assert await state_store.async_add_userinfo_to_state(donor_state, user_info)
for _ in range(MAX_DEVICE_CODE_ATTEMPTS):
assert (
await state_store.async_link_state_to_code(
donor_state, "000000", TEST_IP
)
is False
)
assert (
await state_store.async_link_state_to_code(donor_state, code, TEST_IP)
is False
)
@pytest.mark.asyncio
async def test_state_store_expired_state(hass: HomeAssistant):
"""Test that expired states are treated as invalid."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
store_mock.async_load.return_value = {}
await state_store.async_load()
state_id = await state_store.async_create_state_from_url(
"https://example.com/expired", TEST_IP
)
state_store.get_data()[state_id]["expiration"] = (
datetime.now(timezone.utc) - timedelta(minutes=10)
).isoformat()
assert (
await state_store.async_get_redirect_uri_for_state(state_id, TEST_IP)
is None
)
assert await state_store.async_is_state_ready(state_id, TEST_IP) is False
assert (
await state_store.async_receive_userinfo_for_state(state_id, TEST_IP)
is None
)
@pytest.mark.asyncio
async def test_state_store_data_not_loaded(hass: HomeAssistant):
"""Test that using the store before loading raises RuntimeError."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
with pytest.raises(RuntimeError):
await state_store.async_create_state_from_url(
"https://example.com", TEST_IP
)
with pytest.raises(RuntimeError):
await state_store.async_generate_code_for_state("state")
with pytest.raises(RuntimeError):
await state_store.async_add_userinfo_to_state(
"state",
{
"sub": "user4",
"display_name": "Not Loaded",
"username": "notloaded",
"role": "system-users",
},
)
with pytest.raises(RuntimeError):
await state_store.async_get_redirect_uri_for_state("state", TEST_IP)
with pytest.raises(RuntimeError):
await state_store.async_is_state_ready("state", TEST_IP)
with pytest.raises(RuntimeError):
await state_store.async_link_state_to_code("state", "123456", TEST_IP)
with pytest.raises(RuntimeError):
await state_store.async_receive_userinfo_for_state("state", TEST_IP)
@pytest.mark.asyncio
async def test_state_store_missing_keys(hass: HomeAssistant):
"""Test that missing keys raise correct responses."""
store_mock = AsyncMock()
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
state_store = StateStore(hass)
# async_generate_code_for_state returns None if state_id is not found
store_mock.async_load.return_value = {}
await state_store.async_load()
assert await state_store.async_generate_code_for_state("nonexistent") is None
# async_add_userinfo_to_state returns False if state_id is not found
user_info = {
"sub": "user5",
"display_name": "Missing Keys",
"username": "missingkeys",
"role": "system-users",
}
assert (
await state_store.async_add_userinfo_to_state("nonexistent", user_info)
is False
)