Fix regression of storeToken parameter (#248)

* Try a different method to set ?storeToken

* Formatting

* Only insert storeToken on web client & fix tests
This commit is contained in:
Christiaan Goossens
2026-04-15 12:07:19 +02:00
committed by GitHub
parent 0ca300c385
commit 07c1e3a4c4
6 changed files with 235 additions and 67 deletions

View File

@@ -61,11 +61,9 @@ class OIDCFinishView(HomeAssistantView):
if "?" in redirect_uri: if "?" in redirect_uri:
separator = "&" separator = "&"
# Redirect to this new URL for login # Redirect to this new URL for login, make sure to skip OIDC to prevent loops
new_url = ( redirect_uri = f"{redirect_uri}{separator}skip_oidc_redirect=true"
redirect_uri + separator + "storeToken=true&skip_oidc_redirect=true" raise web.HTTPFound(location=redirect_uri)
)
raise web.HTTPFound(location=new_url)
# Check if we can link this device # Check if we can link this device
linked = await self.oidc_provider.async_link_state_to_code( linked = await self.oidc_provider.async_link_state_to_code(

View File

@@ -113,9 +113,12 @@ class OIDCInjectedAuthPage(HomeAssistantView):
@staticmethod @staticmethod
def _should_do_oidc_redirect(req: web.Request) -> bool: def _should_do_oidc_redirect(req: web.Request) -> bool:
"""Check if we should redirect to the OIDC flow.""" """Check if we should redirect to the OIDC flow."""
# Set when we return from finish
if req.query.get("skip_oidc_redirect") == "true": if req.query.get("skip_oidc_redirect") == "true":
return False return False
# Set whenever you directly do /?skip_oidc_redirect=true,
# for example when you click the "other" button on the welcome screen
redirect_uri = req.query.get("redirect_uri") redirect_uri = req.query.get("redirect_uri")
if not redirect_uri: if not redirect_uri:
return False return False

View File

@@ -1,8 +1,9 @@
"""Welcome route to show the user the OIDC login button and give instructions.""" """Welcome route to show the user the OIDC login button and give instructions."""
from ast import List
import base64 import base64
import binascii import binascii
from urllib.parse import urlparse, parse_qs, unquote from urllib.parse import urlparse, parse_qs, unquote, urlencode
from aiohttp import web from aiohttp import web
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from ..tools.helpers import error_response, get_url, template_response from ..tools.helpers import error_response, get_url, template_response
@@ -30,13 +31,46 @@ class OIDCWelcomeView(HomeAssistantView):
self.force_https = force_https self.force_https = force_https
self.has_other_auth_providers = has_other_auth_providers self.has_other_auth_providers = has_other_auth_providers
def determine_if_mobile(self, redirect_uri: str) -> bool: async def _process_url(self, redirect_uri: str) -> List[str, bool]:
"""Determine if the client is a mobile client based on the redirect_uri.""" """Processes the redirect URI to determine if we need setTokens and if this is mobile."""
oauth2_url = urlparse(redirect_uri) # decodeURIComponent(btoa(...)) -> unquote first, then base64 decode
client_id = parse_qs(oauth2_url.query).get("client_id") redirect_uri = base64.b64decode(unquote(redirect_uri), validate=True).decode(
"utf-8"
)
# If the client_id starts with https://home-assistant.io/ we assume it's a mobile client oauth2_url = urlparse(redirect_uri)
return bool(client_id and client_id[0].startswith("https://home-assistant.io/")) oauth2_query = parse_qs(oauth2_url.query)
client_id = oauth2_query.get("client_id")[0]
original_redirect_uri = oauth2_query.get("redirect_uri")[0]
# If the client_id starts with https://home-assistant.io/
# we assume it's a mobile client
# Android = https://home-assistant.io/Android,
# iOS = https://home-assistant.io/iOS
is_mobile = client_id.startswith("https://home-assistant.io/")
# Check if we appear to be signing in to the web version,
# for which we want to store tokens.
# We don't want to set storeTokens on sign-in to Google for instance
base_url = get_url("/", self.force_https)
is_web_client = original_redirect_uri.startswith(base_url)
if is_web_client:
# Adjust the original_redirect_uri to include the storeTokens parameter
separator = "?"
if "?" in original_redirect_uri:
separator = "&"
original_redirect_uri = f"{original_redirect_uri}{separator}storeToken=true"
oauth2_query.update({"redirect_uri": original_redirect_uri})
# Create new redirect_uri with the updated query parameters
new_oauth2_url = oauth2_url._replace(
query=urlencode(oauth2_query, doseq=True)
)
redirect_uri = new_oauth2_url.geturl()
return redirect_uri, is_mobile
async def get(self, req: web.Request) -> web.Response: async def get(self, req: web.Request) -> web.Response:
"""Receive response.""" """Receive response."""
@@ -44,23 +78,26 @@ class OIDCWelcomeView(HomeAssistantView):
# Get the query parameter with the redirect_uri # Get the query parameter with the redirect_uri
redirect_uri = req.query.get("redirect_uri") redirect_uri = req.query.get("redirect_uri")
# If set, determine if this is a mobile client based on the redirect_uri, # Do some processing on the redirect_uri to correct it
# otherwise assume it's not mobile # and determine if this is a mobile client.
if redirect_uri: if redirect_uri:
try: try:
# decodeURIComponent(btoa(...)) -> unquote first, then base64 decode redirect_uri, is_mobile = await self._process_url(redirect_uri)
redirect_uri = base64.b64decode( except (
unquote(redirect_uri), validate=True binascii.Error,
).decode("utf-8") UnicodeDecodeError,
is_mobile = self.determine_if_mobile(redirect_uri) ValueError,
except (binascii.Error, UnicodeDecodeError, ValueError): KeyError,
TypeError,
):
return await error_response( return await error_response(
"Invalid redirect_uri, please restart login." "Invalid redirect_uri, please restart login."
) )
else: else:
# Backwards compatibility with older versions that directly go to /auth/oidc/welcome # Backwards compatibility with older versions that directly go to /auth/oidc/welcome
# If not set, redirect back to the main page and assume that this is a web client # If not set, redirect back to the main page and assume that this is a web client
redirect_uri = get_url("/", self.force_https) redirect_uri = get_url("/?storeToken=true", self.force_https)
is_mobile = False is_mobile = False
# Create OIDC state with the redirect_uri so we can use it later in the flow # Create OIDC state with the redirect_uri so we can use it later in the flow

View File

@@ -103,10 +103,10 @@ async def verify_back_redirect(client, expected_redirect_uri: str):
"""Verify that POST to finish without body redirects back to the original redirect_uri.""" """Verify that POST to finish without body redirects back to the original redirect_uri."""
resp_finish_post = await client.post("/auth/oidc/finish", allow_redirects=False) resp_finish_post = await client.post("/auth/oidc/finish", allow_redirects=False)
assert resp_finish_post.status == 302 assert resp_finish_post.status == 302
assert (
resp_finish_post.headers["Location"] location = resp_finish_post.headers["Location"]
== unquote(expected_redirect_uri) + "&storeToken=true&skip_oidc_redirect=true" assert location.startswith(unquote(expected_redirect_uri))
) assert "skip_oidc_redirect=true" in location
async def listen_for_sse_events( async def listen_for_sse_events(

View File

@@ -86,7 +86,9 @@ def make_signed_hs256_jwt(secret: str, claims: dict) -> str:
return jwt.encode({"alg": "HS256"}, claims, jwk_obj) return jwt.encode({"alg": "HS256"}, claims, jwk_obj)
def build_real_signed_token(algorithm: str, claims: dict, secret: str) -> tuple[str, dict]: def build_real_signed_token(
algorithm: str, claims: dict, secret: str
) -> tuple[str, dict]:
"""Build a real signed token and matching JWKS payload for a given algorithm.""" """Build a real signed token and matching JWKS payload for a given algorithm."""
if algorithm.startswith("HS"): if algorithm.startswith("HS"):
signing_key = jwk.import_key( signing_key = jwk.import_key(
@@ -96,7 +98,9 @@ def build_real_signed_token(algorithm: str, claims: dict, secret: str) -> tuple[
"alg": algorithm, "alg": algorithm,
} }
) )
token = jwt.encode({"alg": algorithm}, claims, signing_key, algorithms=[algorithm]) token = jwt.encode(
{"alg": algorithm}, claims, signing_key, algorithms=[algorithm]
)
return token, {"keys": []} return token, {"keys": []}
if algorithm in ("RS256", "RS384", "RS512", "PS256", "PS384", "PS512"): if algorithm in ("RS256", "RS384", "RS512", "PS256", "PS384", "PS512"):
@@ -115,7 +119,11 @@ def build_real_signed_token(algorithm: str, claims: dict, secret: str) -> tuple[
) )
elif algorithm in ("Ed25519", "Ed448"): elif algorithm in ("Ed25519", "Ed448"):
key = jwk.generate_key( key = jwk.generate_key(
"OKP", algorithm, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True "OKP",
algorithm,
{"alg": algorithm, "use": "sig"},
private=True,
auto_kid=True,
) )
else: else:
raise ValueError(f"Unsupported test algorithm: {algorithm}") raise ValueError(f"Unsupported test algorithm: {algorithm}")

View File

@@ -2,9 +2,11 @@
import base64 import base64
import os import os
from urllib.parse import parse_qs, quote, unquote, urlparse from urllib.parse import parse_qs, quote, unquote, urlparse, urlencode
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from auth_oidc.config.const import DISCOVERY_URL, CLIENT_ID from auth_oidc.config.const import DISCOVERY_URL, CLIENT_ID
from pytest_homeassistant_custom_component.typing import ClientSessionGenerator
import pytest import pytest
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@@ -17,14 +19,19 @@ from custom_components.auth_oidc.endpoints.injected_auth_page import (
frontend_injection, frontend_injection,
) )
WEB_CLIENT_ID = "https://example.com"
MOBILE_CLIENT_ID = "https://home-assistant.io/Android" MOBILE_CLIENT_ID = "https://home-assistant.io/Android"
def create_redirect_uri(client_id: str) -> str: def create_redirect_uri(client_id: str) -> str:
"""Build a redirect URI that includes a client_id query parameter.""" """Build a redirect URI that includes a client_id query parameter."""
return f"http://example.com/auth/authorize?client_id={client_id}" params = {
"response_type": "code",
"redirect_uri": client_id,
"client_id": client_id,
"state": "example",
}
return f"http://example.com/auth/authorize?{urlencode(params)}"
def encode_redirect_uri(redirect_uri: str) -> str: def encode_redirect_uri(redirect_uri: str) -> str:
@@ -63,7 +70,9 @@ async def setup_mock_authorize_route(hass: HomeAssistant) -> None:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_welcome_page_registration(hass: HomeAssistant, hass_client): async def test_welcome_page_registration(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""Test that welcome page is present.""" """Test that welcome page is present."""
await setup(hass) await setup(hass)
@@ -74,7 +83,9 @@ async def test_welcome_page_registration(hass: HomeAssistant, hass_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redirect_page_registration(hass: HomeAssistant, hass_client): async def test_redirect_page_registration(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""Test that redirect page can be reached.""" """Test that redirect page can be reached."""
await setup(hass) await setup(hass)
@@ -89,7 +100,7 @@ async def test_redirect_page_registration(hass: HomeAssistant, hass_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_welcome_rejects_invalid_encoded_redirect_uri( async def test_welcome_rejects_invalid_encoded_redirect_uri(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""Welcome should reject malformed base64 redirect_uri values.""" """Welcome should reject malformed base64 redirect_uri values."""
await setup(hass) await setup(hass)
@@ -104,12 +115,104 @@ async def test_welcome_rejects_invalid_encoded_redirect_uri(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_welcome_sets_secure_state_cookie_flags(hass: HomeAssistant, hass_client): @pytest.mark.parametrize(
"redirect_uri",
[
"http://example.com/auth/authorize?client_id=https://example.com",
"http://example.com/auth/authorize?redirect_uri=https://example.com",
],
)
async def test_welcome_rejects_redirect_uris_missing_required_query_params(
hass: HomeAssistant, hass_client: ClientSessionGenerator, redirect_uri: str
):
"""Welcome should reject redirect URIs that decode but are incomplete."""
await setup(hass)
client = await hass_client()
encoded = encode_redirect_uri(redirect_uri)
resp = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp.status == 400
assert "Invalid redirect_uri, please restart login." in await resp.text()
@pytest.mark.asyncio
@pytest.mark.parametrize(
("client_id", "should_store_token", "is_mobile"),
[
("", True, False),
(MOBILE_CLIENT_ID, False, True),
("https://random.example", False, False),
],
)
async def test_welcome_only_adds_store_token_for_web_clients(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
client_id: str,
should_store_token: bool,
is_mobile: bool,
):
"""Welcome should only append storeToken for clients aligned with the base URL."""
await setup(hass)
captured_redirect_uri = {}
async def fake_create_state(state_redirect_uri: str, *_args):
captured_redirect_uri["value"] = state_redirect_uri
return "state-id"
with (
patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_create_state",
new=AsyncMock(side_effect=fake_create_state),
),
patch(
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_generate_device_code",
new=AsyncMock(return_value="123456"),
),
):
client = await hass_client()
if client_id == "":
# If not present, set it to the root URL to
# emulate the normal website/Lovelace/dashboard
client_id = str(client.make_url("/?test=true"))
redirect_uri = create_redirect_uri(client_id)
encoded = encode_redirect_uri(redirect_uri)
resp = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}",
allow_redirects=False,
)
assert resp.status in (200, 302)
assert "value" in captured_redirect_uri
parsed_state_redirect = urlparse(captured_redirect_uri["value"])
state_redirect_query = parse_qs(parsed_state_redirect.query)
nested_redirect_uri = unquote(state_redirect_query["redirect_uri"][0])
if should_store_token:
assert "storeToken=true" in nested_redirect_uri
else:
assert "storeToken=true" not in nested_redirect_uri
if is_mobile:
assert "https://home-assistant.io/" in nested_redirect_uri
@pytest.mark.asyncio
async def test_welcome_sets_secure_state_cookie_flags(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""Welcome should set secure cookie flags for the OIDC state cookie.""" """Welcome should set secure cookie flags for the OIDC state cookie."""
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID) redirect_uri = create_redirect_uri(client.make_url("/"))
encoded = encode_redirect_uri(redirect_uri) encoded = encode_redirect_uri(redirect_uri)
resp = await client.get( resp = await client.get(
@@ -129,7 +232,7 @@ async def test_welcome_sets_secure_state_cookie_flags(hass: HomeAssistant, hass_
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_welcome_mobile_device_code_generation_failure( async def test_welcome_mobile_device_code_generation_failure(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""Welcome should error if device code generation fails for mobile clients.""" """Welcome should error if device code generation fails for mobile clients."""
await setup(hass) await setup(hass)
@@ -154,13 +257,13 @@ async def test_welcome_mobile_device_code_generation_failure(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_welcome_shows_alternative_sign_in_link_when_other_providers_exist( async def test_welcome_shows_alternative_sign_in_link_when_other_providers_exist(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""Welcome should render fallback auth link when other providers are present.""" """Welcome should render fallback auth link when other providers are present."""
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID) redirect_uri = create_redirect_uri(client.make_url("/"))
encoded = encode_redirect_uri(redirect_uri) encoded = encode_redirect_uri(redirect_uri)
resp = await client.get( resp = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}", f"/auth/oidc/welcome?redirect_uri={encoded}",
@@ -175,7 +278,7 @@ async def test_welcome_shows_alternative_sign_in_link_when_other_providers_exist
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_welcome_desktop_auto_redirects_without_other_providers( async def test_welcome_desktop_auto_redirects_without_other_providers(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""Welcome should auto-redirect desktop clients when no other providers exist.""" """Welcome should auto-redirect desktop clients when no other providers exist."""
@@ -184,7 +287,7 @@ async def test_welcome_desktop_auto_redirects_without_other_providers(
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID) redirect_uri = create_redirect_uri(client.make_url("/"))
encoded = encode_redirect_uri(redirect_uri) encoded = encode_redirect_uri(redirect_uri)
resp = await client.get( resp = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}", f"/auth/oidc/welcome?redirect_uri={encoded}",
@@ -196,7 +299,7 @@ async def test_welcome_desktop_auto_redirects_without_other_providers(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redirect_without_cookie_goes_to_welcome( async def test_redirect_without_cookie_goes_to_welcome(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""Redirect endpoint should bounce to welcome when no state cookie exists.""" """Redirect endpoint should bounce to welcome when no state cookie exists."""
await setup(hass) await setup(hass)
@@ -209,13 +312,13 @@ async def test_redirect_without_cookie_goes_to_welcome(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redirect_shows_error_on_oidc_runtime_error( async def test_redirect_shows_error_on_oidc_runtime_error(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""Redirect should show a configuration error when OIDC URL generation raises.""" """Redirect should show a configuration error when OIDC URL generation raises."""
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID) redirect_uri = create_redirect_uri(client.make_url("/"))
encoded = encode_redirect_uri(redirect_uri) encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get( resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}", f"/auth/oidc/welcome?redirect_uri={encoded}",
@@ -237,13 +340,13 @@ async def test_redirect_shows_error_on_oidc_runtime_error(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_redirect_shows_error_when_auth_url_empty( async def test_redirect_shows_error_when_auth_url_empty(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""Redirect should show error page if OIDC returns no authorization URL.""" """Redirect should show error page if OIDC returns no authorization URL."""
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID) redirect_uri = create_redirect_uri(client.make_url("/"))
encoded = encode_redirect_uri(redirect_uri) encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get( resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}", f"/auth/oidc/welcome?redirect_uri={encoded}",
@@ -264,7 +367,9 @@ async def test_redirect_shows_error_when_auth_url_empty(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_callback_registration(hass: HomeAssistant, hass_client): async def test_callback_registration(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""Test that callback page is reachable.""" """Test that callback page is reachable."""
await setup(hass) await setup(hass)
@@ -275,12 +380,14 @@ async def test_callback_registration(hass: HomeAssistant, hass_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_callback_rejects_missing_code_or_state(hass: HomeAssistant, hass_client): async def test_callback_rejects_missing_code_or_state(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""Callback must reject requests missing either code or state.""" """Callback must reject requests missing either code or state."""
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID) redirect_uri = create_redirect_uri(client.make_url("/"))
encoded = encode_redirect_uri(redirect_uri) encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get( resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}", f"/auth/oidc/welcome?redirect_uri={encoded}",
@@ -304,12 +411,14 @@ async def test_callback_rejects_missing_code_or_state(hass: HomeAssistant, hass_
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_callback_rejects_state_mismatch(hass: HomeAssistant, hass_client): async def test_callback_rejects_state_mismatch(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""Callback must reject state mismatch to protect against CSRF.""" """Callback must reject state mismatch to protect against CSRF."""
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID) redirect_uri = create_redirect_uri(client.make_url("/"))
encoded = encode_redirect_uri(redirect_uri) encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get( resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}", f"/auth/oidc/welcome?redirect_uri={encoded}",
@@ -327,13 +436,13 @@ async def test_callback_rejects_state_mismatch(hass: HomeAssistant, hass_client)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_callback_rejects_when_user_details_fetch_fails( async def test_callback_rejects_when_user_details_fetch_fails(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""Callback should error when token exchange/userinfo retrieval fails.""" """Callback should error when token exchange/userinfo retrieval fails."""
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID) redirect_uri = create_redirect_uri(client.make_url("/"))
encoded = encode_redirect_uri(redirect_uri) encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get( resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}", f"/auth/oidc/welcome?redirect_uri={encoded}",
@@ -357,12 +466,14 @@ async def test_callback_rejects_when_user_details_fetch_fails(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_callback_rejects_invalid_role(hass: HomeAssistant, hass_client): async def test_callback_rejects_invalid_role(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""Callback should reject users marked with invalid role.""" """Callback should reject users marked with invalid role."""
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID) redirect_uri = create_redirect_uri(client.make_url("/"))
encoded = encode_redirect_uri(redirect_uri) encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get( resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}", f"/auth/oidc/welcome?redirect_uri={encoded}",
@@ -395,7 +506,10 @@ async def test_callback_rejects_invalid_role(hass: HomeAssistant, hass_client):
], ],
) )
async def test_finish_requires_state_cookie( async def test_finish_requires_state_cookie(
hass: HomeAssistant, hass_client, method: str, data: dict | None hass: HomeAssistant,
hass_client: ClientSessionGenerator,
method: str,
data: dict | None,
): ):
"""Finish endpoint should require the OIDC state cookie for both GET and POST.""" """Finish endpoint should require the OIDC state cookie for both GET and POST."""
await setup(hass) await setup(hass)
@@ -412,12 +526,14 @@ async def test_finish_requires_state_cookie(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_finish_post_rejects_invalid_state(hass: HomeAssistant, hass_client): async def test_finish_post_rejects_invalid_state(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""Finish POST should error when the state cookie does not resolve to redirect_uri.""" """Finish POST should error when the state cookie does not resolve to redirect_uri."""
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID) redirect_uri = create_redirect_uri(client.make_url("/"))
encoded = encode_redirect_uri(redirect_uri) encoded = encode_redirect_uri(redirect_uri)
resp_welcome = await client.get( resp_welcome = await client.get(
f"/auth/oidc/welcome?redirect_uri={encoded}", f"/auth/oidc/welcome?redirect_uri={encoded}",
@@ -435,7 +551,9 @@ async def test_finish_post_rejects_invalid_state(hass: HomeAssistant, hass_clien
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_device_sse_requires_state_cookie(hass: HomeAssistant, hass_client): async def test_device_sse_requires_state_cookie(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""SSE endpoint should reject requests without state cookie.""" """SSE endpoint should reject requests without state cookie."""
await setup(hass) await setup(hass)
@@ -447,7 +565,7 @@ async def test_device_sse_requires_state_cookie(hass: HomeAssistant, hass_client
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_device_sse_emits_expired_for_unknown_state( async def test_device_sse_emits_expired_for_unknown_state(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""SSE should emit expired when the state can no longer be resolved.""" """SSE should emit expired when the state can no longer be resolved."""
await setup(hass) await setup(hass)
@@ -472,7 +590,9 @@ async def test_device_sse_emits_expired_for_unknown_state(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_device_sse_emits_timeout(hass: HomeAssistant, hass_client): async def test_device_sse_emits_timeout(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""SSE should emit timeout if the polling window is exceeded.""" """SSE should emit timeout if the polling window is exceeded."""
await setup(hass) await setup(hass)
@@ -510,7 +630,7 @@ async def test_device_sse_emits_timeout(hass: HomeAssistant, hass_client):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_device_sse_handles_runtime_error_and_returns_cleanly( async def test_device_sse_handles_runtime_error_and_returns_cleanly(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""SSE should swallow runtime errors from stream loop and finish response.""" """SSE should swallow runtime errors from stream loop and finish response."""
await setup(hass) await setup(hass)
@@ -540,7 +660,7 @@ async def test_device_sse_handles_runtime_error_and_returns_cleanly(
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_device_sse_ignores_write_eof_connection_reset( async def test_device_sse_ignores_write_eof_connection_reset(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""SSE should ignore ConnectionResetError while closing the stream.""" """SSE should ignore ConnectionResetError while closing the stream."""
await setup(hass) await setup(hass)
@@ -570,7 +690,9 @@ async def test_device_sse_ignores_write_eof_connection_reset(
# Test the frontend injection # Test the frontend injection
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_frontend_injection(hass: HomeAssistant, hass_client): async def test_frontend_injection(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""Test that frontend injection works.""" """Test that frontend injection works."""
# Because there is no frontend in the test setup, # Because there is no frontend in the test setup,
@@ -638,7 +760,7 @@ async def test_injected_auth_page_inject_logs_errors(hass: HomeAssistant, caplog
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_injected_auth_page_redirects_to_welcome_when_not_skipped( async def test_injected_auth_page_redirects_to_welcome_when_not_skipped(
hass: HomeAssistant, hass_client hass: HomeAssistant, hass_client: ClientSessionGenerator
): ):
"""Injected auth page should redirect into OIDC when skip flags are absent.""" """Injected auth page should redirect into OIDC when skip flags are absent."""
@@ -646,7 +768,7 @@ async def test_injected_auth_page_redirects_to_welcome_when_not_skipped(
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
encoded_redirect_uri = quote(create_redirect_uri(WEB_CLIENT_ID), safe="") encoded_redirect_uri = quote(create_redirect_uri(client.make_url("/")), safe="")
resp = await client.get( resp = await client.get(
f"/auth/authorize?redirect_uri={encoded_redirect_uri}", f"/auth/authorize?redirect_uri={encoded_redirect_uri}",