Fix regression of storeToken parameter (#248)
* Try a different method to set ?storeToken * Formatting * Only insert storeToken on web client & fix tests
This commit is contained in:
committed by
GitHub
parent
0ca300c385
commit
07c1e3a4c4
@@ -61,11 +61,9 @@ class OIDCFinishView(HomeAssistantView):
|
|||||||
if "?" in redirect_uri:
|
if "?" in redirect_uri:
|
||||||
separator = "&"
|
separator = "&"
|
||||||
|
|
||||||
# Redirect to this new URL for login
|
# Redirect to this new URL for login, make sure to skip OIDC to prevent loops
|
||||||
new_url = (
|
redirect_uri = f"{redirect_uri}{separator}skip_oidc_redirect=true"
|
||||||
redirect_uri + separator + "storeToken=true&skip_oidc_redirect=true"
|
raise web.HTTPFound(location=redirect_uri)
|
||||||
)
|
|
||||||
raise web.HTTPFound(location=new_url)
|
|
||||||
|
|
||||||
# Check if we can link this device
|
# Check if we can link this device
|
||||||
linked = await self.oidc_provider.async_link_state_to_code(
|
linked = await self.oidc_provider.async_link_state_to_code(
|
||||||
|
|||||||
@@ -113,9 +113,12 @@ class OIDCInjectedAuthPage(HomeAssistantView):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _should_do_oidc_redirect(req: web.Request) -> bool:
|
def _should_do_oidc_redirect(req: web.Request) -> bool:
|
||||||
"""Check if we should redirect to the OIDC flow."""
|
"""Check if we should redirect to the OIDC flow."""
|
||||||
|
# Set when we return from finish
|
||||||
if req.query.get("skip_oidc_redirect") == "true":
|
if req.query.get("skip_oidc_redirect") == "true":
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Set whenever you directly do /?skip_oidc_redirect=true,
|
||||||
|
# for example when you click the "other" button on the welcome screen
|
||||||
redirect_uri = req.query.get("redirect_uri")
|
redirect_uri = req.query.get("redirect_uri")
|
||||||
if not redirect_uri:
|
if not redirect_uri:
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
"""Welcome route to show the user the OIDC login button and give instructions."""
|
"""Welcome route to show the user the OIDC login button and give instructions."""
|
||||||
|
|
||||||
|
from ast import List
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
from urllib.parse import urlparse, parse_qs, unquote
|
from urllib.parse import urlparse, parse_qs, unquote, urlencode
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from ..tools.helpers import error_response, get_url, template_response
|
from ..tools.helpers import error_response, get_url, template_response
|
||||||
@@ -30,13 +31,46 @@ class OIDCWelcomeView(HomeAssistantView):
|
|||||||
self.force_https = force_https
|
self.force_https = force_https
|
||||||
self.has_other_auth_providers = has_other_auth_providers
|
self.has_other_auth_providers = has_other_auth_providers
|
||||||
|
|
||||||
def determine_if_mobile(self, redirect_uri: str) -> bool:
|
async def _process_url(self, redirect_uri: str) -> List[str, bool]:
|
||||||
"""Determine if the client is a mobile client based on the redirect_uri."""
|
"""Processes the redirect URI to determine if we need setTokens and if this is mobile."""
|
||||||
oauth2_url = urlparse(redirect_uri)
|
# decodeURIComponent(btoa(...)) -> unquote first, then base64 decode
|
||||||
client_id = parse_qs(oauth2_url.query).get("client_id")
|
redirect_uri = base64.b64decode(unquote(redirect_uri), validate=True).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
# If the client_id starts with https://home-assistant.io/ we assume it's a mobile client
|
oauth2_url = urlparse(redirect_uri)
|
||||||
return bool(client_id and client_id[0].startswith("https://home-assistant.io/"))
|
oauth2_query = parse_qs(oauth2_url.query)
|
||||||
|
client_id = oauth2_query.get("client_id")[0]
|
||||||
|
original_redirect_uri = oauth2_query.get("redirect_uri")[0]
|
||||||
|
|
||||||
|
# If the client_id starts with https://home-assistant.io/
|
||||||
|
# we assume it's a mobile client
|
||||||
|
# Android = https://home-assistant.io/Android,
|
||||||
|
# iOS = https://home-assistant.io/iOS
|
||||||
|
is_mobile = client_id.startswith("https://home-assistant.io/")
|
||||||
|
|
||||||
|
# Check if we appear to be signing in to the web version,
|
||||||
|
# for which we want to store tokens.
|
||||||
|
# We don't want to set storeTokens on sign-in to Google for instance
|
||||||
|
base_url = get_url("/", self.force_https)
|
||||||
|
is_web_client = original_redirect_uri.startswith(base_url)
|
||||||
|
|
||||||
|
if is_web_client:
|
||||||
|
# Adjust the original_redirect_uri to include the storeTokens parameter
|
||||||
|
separator = "?"
|
||||||
|
if "?" in original_redirect_uri:
|
||||||
|
separator = "&"
|
||||||
|
|
||||||
|
original_redirect_uri = f"{original_redirect_uri}{separator}storeToken=true"
|
||||||
|
oauth2_query.update({"redirect_uri": original_redirect_uri})
|
||||||
|
|
||||||
|
# Create new redirect_uri with the updated query parameters
|
||||||
|
new_oauth2_url = oauth2_url._replace(
|
||||||
|
query=urlencode(oauth2_query, doseq=True)
|
||||||
|
)
|
||||||
|
redirect_uri = new_oauth2_url.geturl()
|
||||||
|
|
||||||
|
return redirect_uri, is_mobile
|
||||||
|
|
||||||
async def get(self, req: web.Request) -> web.Response:
|
async def get(self, req: web.Request) -> web.Response:
|
||||||
"""Receive response."""
|
"""Receive response."""
|
||||||
@@ -44,23 +78,26 @@ class OIDCWelcomeView(HomeAssistantView):
|
|||||||
# Get the query parameter with the redirect_uri
|
# Get the query parameter with the redirect_uri
|
||||||
redirect_uri = req.query.get("redirect_uri")
|
redirect_uri = req.query.get("redirect_uri")
|
||||||
|
|
||||||
# If set, determine if this is a mobile client based on the redirect_uri,
|
# Do some processing on the redirect_uri to correct it
|
||||||
# otherwise assume it's not mobile
|
# and determine if this is a mobile client.
|
||||||
if redirect_uri:
|
if redirect_uri:
|
||||||
try:
|
try:
|
||||||
# decodeURIComponent(btoa(...)) -> unquote first, then base64 decode
|
redirect_uri, is_mobile = await self._process_url(redirect_uri)
|
||||||
redirect_uri = base64.b64decode(
|
except (
|
||||||
unquote(redirect_uri), validate=True
|
binascii.Error,
|
||||||
).decode("utf-8")
|
UnicodeDecodeError,
|
||||||
is_mobile = self.determine_if_mobile(redirect_uri)
|
ValueError,
|
||||||
except (binascii.Error, UnicodeDecodeError, ValueError):
|
KeyError,
|
||||||
|
TypeError,
|
||||||
|
):
|
||||||
return await error_response(
|
return await error_response(
|
||||||
"Invalid redirect_uri, please restart login."
|
"Invalid redirect_uri, please restart login."
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Backwards compatibility with older versions that directly go to /auth/oidc/welcome
|
# Backwards compatibility with older versions that directly go to /auth/oidc/welcome
|
||||||
# If not set, redirect back to the main page and assume that this is a web client
|
# If not set, redirect back to the main page and assume that this is a web client
|
||||||
redirect_uri = get_url("/", self.force_https)
|
redirect_uri = get_url("/?storeToken=true", self.force_https)
|
||||||
is_mobile = False
|
is_mobile = False
|
||||||
|
|
||||||
# Create OIDC state with the redirect_uri so we can use it later in the flow
|
# Create OIDC state with the redirect_uri so we can use it later in the flow
|
||||||
|
|||||||
@@ -103,10 +103,10 @@ async def verify_back_redirect(client, expected_redirect_uri: str):
|
|||||||
"""Verify that POST to finish without body redirects back to the original redirect_uri."""
|
"""Verify that POST to finish without body redirects back to the original redirect_uri."""
|
||||||
resp_finish_post = await client.post("/auth/oidc/finish", allow_redirects=False)
|
resp_finish_post = await client.post("/auth/oidc/finish", allow_redirects=False)
|
||||||
assert resp_finish_post.status == 302
|
assert resp_finish_post.status == 302
|
||||||
assert (
|
|
||||||
resp_finish_post.headers["Location"]
|
location = resp_finish_post.headers["Location"]
|
||||||
== unquote(expected_redirect_uri) + "&storeToken=true&skip_oidc_redirect=true"
|
assert location.startswith(unquote(expected_redirect_uri))
|
||||||
)
|
assert "skip_oidc_redirect=true" in location
|
||||||
|
|
||||||
|
|
||||||
async def listen_for_sse_events(
|
async def listen_for_sse_events(
|
||||||
|
|||||||
@@ -86,7 +86,9 @@ def make_signed_hs256_jwt(secret: str, claims: dict) -> str:
|
|||||||
return jwt.encode({"alg": "HS256"}, claims, jwk_obj)
|
return jwt.encode({"alg": "HS256"}, claims, jwk_obj)
|
||||||
|
|
||||||
|
|
||||||
def build_real_signed_token(algorithm: str, claims: dict, secret: str) -> tuple[str, dict]:
|
def build_real_signed_token(
|
||||||
|
algorithm: str, claims: dict, secret: str
|
||||||
|
) -> tuple[str, dict]:
|
||||||
"""Build a real signed token and matching JWKS payload for a given algorithm."""
|
"""Build a real signed token and matching JWKS payload for a given algorithm."""
|
||||||
if algorithm.startswith("HS"):
|
if algorithm.startswith("HS"):
|
||||||
signing_key = jwk.import_key(
|
signing_key = jwk.import_key(
|
||||||
@@ -96,7 +98,9 @@ def build_real_signed_token(algorithm: str, claims: dict, secret: str) -> tuple[
|
|||||||
"alg": algorithm,
|
"alg": algorithm,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
token = jwt.encode({"alg": algorithm}, claims, signing_key, algorithms=[algorithm])
|
token = jwt.encode(
|
||||||
|
{"alg": algorithm}, claims, signing_key, algorithms=[algorithm]
|
||||||
|
)
|
||||||
return token, {"keys": []}
|
return token, {"keys": []}
|
||||||
|
|
||||||
if algorithm in ("RS256", "RS384", "RS512", "PS256", "PS384", "PS512"):
|
if algorithm in ("RS256", "RS384", "RS512", "PS256", "PS384", "PS512"):
|
||||||
@@ -115,7 +119,11 @@ def build_real_signed_token(algorithm: str, claims: dict, secret: str) -> tuple[
|
|||||||
)
|
)
|
||||||
elif algorithm in ("Ed25519", "Ed448"):
|
elif algorithm in ("Ed25519", "Ed448"):
|
||||||
key = jwk.generate_key(
|
key = jwk.generate_key(
|
||||||
"OKP", algorithm, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True
|
"OKP",
|
||||||
|
algorithm,
|
||||||
|
{"alg": algorithm, "use": "sig"},
|
||||||
|
private=True,
|
||||||
|
auto_kid=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported test algorithm: {algorithm}")
|
raise ValueError(f"Unsupported test algorithm: {algorithm}")
|
||||||
|
|||||||
@@ -2,9 +2,11 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
from urllib.parse import parse_qs, quote, unquote, urlparse
|
from urllib.parse import parse_qs, quote, unquote, urlparse, urlencode
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
from auth_oidc.config.const import DISCOVERY_URL, CLIENT_ID
|
from auth_oidc.config.const import DISCOVERY_URL, CLIENT_ID
|
||||||
|
|
||||||
|
from pytest_homeassistant_custom_component.typing import ClientSessionGenerator
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
@@ -17,14 +19,19 @@ from custom_components.auth_oidc.endpoints.injected_auth_page import (
|
|||||||
frontend_injection,
|
frontend_injection,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
WEB_CLIENT_ID = "https://example.com"
|
|
||||||
MOBILE_CLIENT_ID = "https://home-assistant.io/Android"
|
MOBILE_CLIENT_ID = "https://home-assistant.io/Android"
|
||||||
|
|
||||||
|
|
||||||
def create_redirect_uri(client_id: str) -> str:
|
def create_redirect_uri(client_id: str) -> str:
|
||||||
"""Build a redirect URI that includes a client_id query parameter."""
|
"""Build a redirect URI that includes a client_id query parameter."""
|
||||||
return f"http://example.com/auth/authorize?client_id={client_id}"
|
params = {
|
||||||
|
"response_type": "code",
|
||||||
|
"redirect_uri": client_id,
|
||||||
|
"client_id": client_id,
|
||||||
|
"state": "example",
|
||||||
|
}
|
||||||
|
|
||||||
|
return f"http://example.com/auth/authorize?{urlencode(params)}"
|
||||||
|
|
||||||
|
|
||||||
def encode_redirect_uri(redirect_uri: str) -> str:
|
def encode_redirect_uri(redirect_uri: str) -> str:
|
||||||
@@ -63,7 +70,9 @@ async def setup_mock_authorize_route(hass: HomeAssistant) -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_welcome_page_registration(hass: HomeAssistant, hass_client):
|
async def test_welcome_page_registration(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""Test that welcome page is present."""
|
"""Test that welcome page is present."""
|
||||||
|
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
@@ -74,7 +83,9 @@ async def test_welcome_page_registration(hass: HomeAssistant, hass_client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redirect_page_registration(hass: HomeAssistant, hass_client):
|
async def test_redirect_page_registration(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""Test that redirect page can be reached."""
|
"""Test that redirect page can be reached."""
|
||||||
|
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
@@ -89,7 +100,7 @@ async def test_redirect_page_registration(hass: HomeAssistant, hass_client):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_welcome_rejects_invalid_encoded_redirect_uri(
|
async def test_welcome_rejects_invalid_encoded_redirect_uri(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""Welcome should reject malformed base64 redirect_uri values."""
|
"""Welcome should reject malformed base64 redirect_uri values."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
@@ -104,12 +115,104 @@ async def test_welcome_rejects_invalid_encoded_redirect_uri(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_welcome_sets_secure_state_cookie_flags(hass: HomeAssistant, hass_client):
|
@pytest.mark.parametrize(
|
||||||
|
"redirect_uri",
|
||||||
|
[
|
||||||
|
"http://example.com/auth/authorize?client_id=https://example.com",
|
||||||
|
"http://example.com/auth/authorize?redirect_uri=https://example.com",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_welcome_rejects_redirect_uris_missing_required_query_params(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator, redirect_uri: str
|
||||||
|
):
|
||||||
|
"""Welcome should reject redirect URIs that decode but are incomplete."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == 400
|
||||||
|
assert "Invalid redirect_uri, please restart login." in await resp.text()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("client_id", "should_store_token", "is_mobile"),
|
||||||
|
[
|
||||||
|
("", True, False),
|
||||||
|
(MOBILE_CLIENT_ID, False, True),
|
||||||
|
("https://random.example", False, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_welcome_only_adds_store_token_for_web_clients(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
client_id: str,
|
||||||
|
should_store_token: bool,
|
||||||
|
is_mobile: bool,
|
||||||
|
):
|
||||||
|
"""Welcome should only append storeToken for clients aligned with the base URL."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
captured_redirect_uri = {}
|
||||||
|
|
||||||
|
async def fake_create_state(state_redirect_uri: str, *_args):
|
||||||
|
captured_redirect_uri["value"] = state_redirect_uri
|
||||||
|
return "state-id"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_create_state",
|
||||||
|
new=AsyncMock(side_effect=fake_create_state),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_generate_device_code",
|
||||||
|
new=AsyncMock(return_value="123456"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
if client_id == "":
|
||||||
|
# If not present, set it to the root URL to
|
||||||
|
# emulate the normal website/Lovelace/dashboard
|
||||||
|
client_id = str(client.make_url("/?test=true"))
|
||||||
|
|
||||||
|
redirect_uri = create_redirect_uri(client_id)
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status in (200, 302)
|
||||||
|
assert "value" in captured_redirect_uri
|
||||||
|
|
||||||
|
parsed_state_redirect = urlparse(captured_redirect_uri["value"])
|
||||||
|
state_redirect_query = parse_qs(parsed_state_redirect.query)
|
||||||
|
nested_redirect_uri = unquote(state_redirect_query["redirect_uri"][0])
|
||||||
|
|
||||||
|
if should_store_token:
|
||||||
|
assert "storeToken=true" in nested_redirect_uri
|
||||||
|
else:
|
||||||
|
assert "storeToken=true" not in nested_redirect_uri
|
||||||
|
|
||||||
|
if is_mobile:
|
||||||
|
assert "https://home-assistant.io/" in nested_redirect_uri
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_welcome_sets_secure_state_cookie_flags(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""Welcome should set secure cookie flags for the OIDC state cookie."""
|
"""Welcome should set secure cookie flags for the OIDC state cookie."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
encoded = encode_redirect_uri(redirect_uri)
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
|
||||||
resp = await client.get(
|
resp = await client.get(
|
||||||
@@ -129,7 +232,7 @@ async def test_welcome_sets_secure_state_cookie_flags(hass: HomeAssistant, hass_
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_welcome_mobile_device_code_generation_failure(
|
async def test_welcome_mobile_device_code_generation_failure(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""Welcome should error if device code generation fails for mobile clients."""
|
"""Welcome should error if device code generation fails for mobile clients."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
@@ -154,13 +257,13 @@ async def test_welcome_mobile_device_code_generation_failure(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_welcome_shows_alternative_sign_in_link_when_other_providers_exist(
|
async def test_welcome_shows_alternative_sign_in_link_when_other_providers_exist(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""Welcome should render fallback auth link when other providers are present."""
|
"""Welcome should render fallback auth link when other providers are present."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
encoded = encode_redirect_uri(redirect_uri)
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
resp = await client.get(
|
resp = await client.get(
|
||||||
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
@@ -175,7 +278,7 @@ async def test_welcome_shows_alternative_sign_in_link_when_other_providers_exist
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_welcome_desktop_auto_redirects_without_other_providers(
|
async def test_welcome_desktop_auto_redirects_without_other_providers(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""Welcome should auto-redirect desktop clients when no other providers exist."""
|
"""Welcome should auto-redirect desktop clients when no other providers exist."""
|
||||||
|
|
||||||
@@ -184,7 +287,7 @@ async def test_welcome_desktop_auto_redirects_without_other_providers(
|
|||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
encoded = encode_redirect_uri(redirect_uri)
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
resp = await client.get(
|
resp = await client.get(
|
||||||
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
@@ -196,7 +299,7 @@ async def test_welcome_desktop_auto_redirects_without_other_providers(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redirect_without_cookie_goes_to_welcome(
|
async def test_redirect_without_cookie_goes_to_welcome(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""Redirect endpoint should bounce to welcome when no state cookie exists."""
|
"""Redirect endpoint should bounce to welcome when no state cookie exists."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
@@ -209,13 +312,13 @@ async def test_redirect_without_cookie_goes_to_welcome(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redirect_shows_error_on_oidc_runtime_error(
|
async def test_redirect_shows_error_on_oidc_runtime_error(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""Redirect should show a configuration error when OIDC URL generation raises."""
|
"""Redirect should show a configuration error when OIDC URL generation raises."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
encoded = encode_redirect_uri(redirect_uri)
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
resp_welcome = await client.get(
|
resp_welcome = await client.get(
|
||||||
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
@@ -237,13 +340,13 @@ async def test_redirect_shows_error_on_oidc_runtime_error(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redirect_shows_error_when_auth_url_empty(
|
async def test_redirect_shows_error_when_auth_url_empty(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""Redirect should show error page if OIDC returns no authorization URL."""
|
"""Redirect should show error page if OIDC returns no authorization URL."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
encoded = encode_redirect_uri(redirect_uri)
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
resp_welcome = await client.get(
|
resp_welcome = await client.get(
|
||||||
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
@@ -264,7 +367,9 @@ async def test_redirect_shows_error_when_auth_url_empty(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_callback_registration(hass: HomeAssistant, hass_client):
|
async def test_callback_registration(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""Test that callback page is reachable."""
|
"""Test that callback page is reachable."""
|
||||||
|
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
@@ -275,12 +380,14 @@ async def test_callback_registration(hass: HomeAssistant, hass_client):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_callback_rejects_missing_code_or_state(hass: HomeAssistant, hass_client):
|
async def test_callback_rejects_missing_code_or_state(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""Callback must reject requests missing either code or state."""
|
"""Callback must reject requests missing either code or state."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
encoded = encode_redirect_uri(redirect_uri)
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
resp_welcome = await client.get(
|
resp_welcome = await client.get(
|
||||||
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
@@ -304,12 +411,14 @@ async def test_callback_rejects_missing_code_or_state(hass: HomeAssistant, hass_
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_callback_rejects_state_mismatch(hass: HomeAssistant, hass_client):
|
async def test_callback_rejects_state_mismatch(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""Callback must reject state mismatch to protect against CSRF."""
|
"""Callback must reject state mismatch to protect against CSRF."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
encoded = encode_redirect_uri(redirect_uri)
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
resp_welcome = await client.get(
|
resp_welcome = await client.get(
|
||||||
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
@@ -327,13 +436,13 @@ async def test_callback_rejects_state_mismatch(hass: HomeAssistant, hass_client)
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_callback_rejects_when_user_details_fetch_fails(
|
async def test_callback_rejects_when_user_details_fetch_fails(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""Callback should error when token exchange/userinfo retrieval fails."""
|
"""Callback should error when token exchange/userinfo retrieval fails."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
encoded = encode_redirect_uri(redirect_uri)
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
resp_welcome = await client.get(
|
resp_welcome = await client.get(
|
||||||
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
@@ -357,12 +466,14 @@ async def test_callback_rejects_when_user_details_fetch_fails(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_callback_rejects_invalid_role(hass: HomeAssistant, hass_client):
|
async def test_callback_rejects_invalid_role(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""Callback should reject users marked with invalid role."""
|
"""Callback should reject users marked with invalid role."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
encoded = encode_redirect_uri(redirect_uri)
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
resp_welcome = await client.get(
|
resp_welcome = await client.get(
|
||||||
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
@@ -395,7 +506,10 @@ async def test_callback_rejects_invalid_role(hass: HomeAssistant, hass_client):
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_finish_requires_state_cookie(
|
async def test_finish_requires_state_cookie(
|
||||||
hass: HomeAssistant, hass_client, method: str, data: dict | None
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
method: str,
|
||||||
|
data: dict | None,
|
||||||
):
|
):
|
||||||
"""Finish endpoint should require the OIDC state cookie for both GET and POST."""
|
"""Finish endpoint should require the OIDC state cookie for both GET and POST."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
@@ -412,12 +526,14 @@ async def test_finish_requires_state_cookie(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_finish_post_rejects_invalid_state(hass: HomeAssistant, hass_client):
|
async def test_finish_post_rejects_invalid_state(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""Finish POST should error when the state cookie does not resolve to redirect_uri."""
|
"""Finish POST should error when the state cookie does not resolve to redirect_uri."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
encoded = encode_redirect_uri(redirect_uri)
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
resp_welcome = await client.get(
|
resp_welcome = await client.get(
|
||||||
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
@@ -435,7 +551,9 @@ async def test_finish_post_rejects_invalid_state(hass: HomeAssistant, hass_clien
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_device_sse_requires_state_cookie(hass: HomeAssistant, hass_client):
|
async def test_device_sse_requires_state_cookie(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""SSE endpoint should reject requests without state cookie."""
|
"""SSE endpoint should reject requests without state cookie."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
@@ -447,7 +565,7 @@ async def test_device_sse_requires_state_cookie(hass: HomeAssistant, hass_client
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_device_sse_emits_expired_for_unknown_state(
|
async def test_device_sse_emits_expired_for_unknown_state(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""SSE should emit expired when the state can no longer be resolved."""
|
"""SSE should emit expired when the state can no longer be resolved."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
@@ -472,7 +590,9 @@ async def test_device_sse_emits_expired_for_unknown_state(
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_device_sse_emits_timeout(hass: HomeAssistant, hass_client):
|
async def test_device_sse_emits_timeout(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""SSE should emit timeout if the polling window is exceeded."""
|
"""SSE should emit timeout if the polling window is exceeded."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
@@ -510,7 +630,7 @@ async def test_device_sse_emits_timeout(hass: HomeAssistant, hass_client):
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_device_sse_handles_runtime_error_and_returns_cleanly(
|
async def test_device_sse_handles_runtime_error_and_returns_cleanly(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""SSE should swallow runtime errors from stream loop and finish response."""
|
"""SSE should swallow runtime errors from stream loop and finish response."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
@@ -540,7 +660,7 @@ async def test_device_sse_handles_runtime_error_and_returns_cleanly(
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_device_sse_ignores_write_eof_connection_reset(
|
async def test_device_sse_ignores_write_eof_connection_reset(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""SSE should ignore ConnectionResetError while closing the stream."""
|
"""SSE should ignore ConnectionResetError while closing the stream."""
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
@@ -570,7 +690,9 @@ async def test_device_sse_ignores_write_eof_connection_reset(
|
|||||||
|
|
||||||
# Test the frontend injection
|
# Test the frontend injection
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_frontend_injection(hass: HomeAssistant, hass_client):
|
async def test_frontend_injection(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
"""Test that frontend injection works."""
|
"""Test that frontend injection works."""
|
||||||
|
|
||||||
# Because there is no frontend in the test setup,
|
# Because there is no frontend in the test setup,
|
||||||
@@ -638,7 +760,7 @@ async def test_injected_auth_page_inject_logs_errors(hass: HomeAssistant, caplog
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_injected_auth_page_redirects_to_welcome_when_not_skipped(
|
async def test_injected_auth_page_redirects_to_welcome_when_not_skipped(
|
||||||
hass: HomeAssistant, hass_client
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
):
|
):
|
||||||
"""Injected auth page should redirect into OIDC when skip flags are absent."""
|
"""Injected auth page should redirect into OIDC when skip flags are absent."""
|
||||||
|
|
||||||
@@ -646,7 +768,7 @@ async def test_injected_auth_page_redirects_to_welcome_when_not_skipped(
|
|||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
encoded_redirect_uri = quote(create_redirect_uri(WEB_CLIENT_ID), safe="")
|
encoded_redirect_uri = quote(create_redirect_uri(client.make_url("/")), safe="")
|
||||||
|
|
||||||
resp = await client.get(
|
resp = await client.get(
|
||||||
f"/auth/authorize?redirect_uri={encoded_redirect_uri}",
|
f"/auth/authorize?redirect_uri={encoded_redirect_uri}",
|
||||||
|
|||||||
Reference in New Issue
Block a user