Implement trusted_networks support (#283)

* Implement bypass for trusted_networks

* Trusted Network tests

* Test cleanup

* Improve integration tests

* Defensive programming

* Fix wrong import issue
This commit is contained in:
Christiaan Goossens
2026-05-01 14:03:14 +02:00
committed by GitHub
parent 04abb0fdb3
commit c7370ed266
6 changed files with 386 additions and 54 deletions

View File

@@ -90,24 +90,59 @@ async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry):
return False return False
async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str): async def _register_oidc_provider(hass: HomeAssistant, my_config: dict):
"""Set up the OIDC provider with the given configuration.""" """Register the OIDC provider in Home Assistant's auth system."""
providers = OrderedDict()
# Use private APIs until there is a real auth platform # Use private APIs until there is a real auth platform
# pylint: disable=protected-access # pylint: disable=protected-access
providers = OrderedDict()
provider = OpenIDAuthProvider(hass, hass.auth._store, my_config) provider = OpenIDAuthProvider(hass, hass.auth._store, my_config)
existing_auth_providers = hass.auth._providers.copy()
_LOGGER.debug("Current auth providers: %s", list(existing_auth_providers.keys()))
has_other_auth_providers = len(existing_auth_providers) > 0
has_trusted_networks_provider_first = False
if has_other_auth_providers:
# Pop the first provider from the existing providers to check if it's trusted_networks
first_provider_key, first_provider_obj = next(
iter(existing_auth_providers.items())
)
existing_auth_providers.pop(first_provider_key)
if first_provider_key[0] == "trusted_networks":
_LOGGER.info(
"Trusted Networks provider detected as the first auth provider. "
+ "Keeping registration order intact."
)
providers[first_provider_key] = first_provider_obj
has_trusted_networks_provider_first = True
else:
# Reset back to what we had before
existing_auth_providers = hass.auth._providers.copy()
# Register OIDC at the start of the array
# OIDC needs to be first because it needs to process the login cookie after sign-in
providers[(provider.type, provider.id)] = provider providers[(provider.type, provider.id)] = provider
# Get current provider count # Add back any other providers that were already registered
has_other_auth_providers = len(hass.auth._providers) > 0 providers.update(existing_auth_providers)
providers.update(hass.auth._providers) _LOGGER.debug("Final auth providers: %s", list(providers.values()))
hass.auth._providers = providers hass.auth._providers = providers
# pylint: enable=protected-access # pylint: enable=protected-access
_LOGGER.info("Registered OIDC provider") _LOGGER.info("Registered OIDC provider")
return provider, has_other_auth_providers, has_trusted_networks_provider_first
async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str):
"""Set up the OIDC provider with the given configuration."""
(
provider,
has_other_auth_providers,
has_trusted_networks_provider_first,
) = await _register_oidc_provider(hass, my_config)
# Set the correct scopes # Set the correct scopes
# Always use 'openid' & 'profile' as they are specified in the OIDC spec # Always use 'openid' & 'profile' as they are specified in the OIDC spec
@@ -179,6 +214,8 @@ async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_nam
_LOGGER.info("Registered OIDC views") _LOGGER.info("Registered OIDC views")
# Inject OIDC code into the frontend for /auth/authorize for automatic redirect # Inject OIDC code into the frontend for /auth/authorize for automatic redirect
await OIDCInjectedAuthPage.inject(hass, force_https) await OIDCInjectedAuthPage.inject(
hass, provider, force_https, has_trusted_networks_provider_first
)
return True return True

View File

@@ -11,6 +11,7 @@ from homeassistant.components.http import HomeAssistantView, StaticPathConfig
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .welcome import PATH as WELCOME_PATH from .welcome import PATH as WELCOME_PATH
from ..provider import OpenIDAuthProvider
from ..tools.helpers import get_url from ..tools.helpers import get_url
PATH = "/auth/authorize" PATH = "/auth/authorize"
@@ -24,7 +25,12 @@ async def read_file(path: str) -> str:
return await f.read() return await f.read()
async def frontend_injection(hass: HomeAssistant, force_https: bool) -> None: async def frontend_injection(
hass: HomeAssistant,
provider: OpenIDAuthProvider,
force_https: bool,
has_trusted_networks_provider_first: bool,
) -> None:
"""Inject new frontend code into /auth/authorize.""" """Inject new frontend code into /auth/authorize."""
router = hass.http.app.router router = hass.http.app.router
frontend_path = None frontend_path = None
@@ -81,7 +87,11 @@ async def frontend_injection(hass: HomeAssistant, force_https: bool) -> None:
) )
# If everything is succesful, register a fake view that just returns the modified HTML # If everything is succesful, register a fake view that just returns the modified HTML
hass.http.register_view(OIDCInjectedAuthPage(frontend_code, force_https)) hass.http.register_view(
OIDCInjectedAuthPage(
frontend_code, provider, force_https, has_trusted_networks_provider_first
)
)
_LOGGER.info("Performed OIDC frontend injection") _LOGGER.info("Performed OIDC frontend injection")
@@ -92,21 +102,36 @@ class OIDCInjectedAuthPage(HomeAssistantView):
url = PATH url = PATH
name = "auth:oidc:authorize_page" name = "auth:oidc:authorize_page"
def __init__(self, html: str, force_https: bool) -> None: def __init__(
self,
html: str,
provider: OpenIDAuthProvider,
force_https: bool,
has_trusted_networks_provider_first: bool,
) -> None:
"""Initialize the injected auth page.""" """Initialize the injected auth page."""
self.html = html self.html = html
self.provider = provider
self.force_https = force_https self.force_https = force_https
self.has_trusted_networks_provider_first = has_trusted_networks_provider_first
@staticmethod @staticmethod
async def inject(hass: HomeAssistant, force_https: bool) -> None: async def inject(
hass: HomeAssistant,
provider: OpenIDAuthProvider,
force_https: bool,
has_trusted_networks_provider_first: bool,
) -> None:
"""Inject the OIDC auth page into the frontend.""" """Inject the OIDC auth page into the frontend."""
try: try:
await frontend_injection(hass, force_https) await frontend_injection(
hass, provider, force_https, has_trusted_networks_provider_first
)
except Exception as e: # pylint: disable=broad-except except Exception as e: # pylint: disable=broad-except
_LOGGER.error("Failed to inject OIDC auth page: %s", e) _LOGGER.error("Failed to inject OIDC auth page: %s", e)
@staticmethod def _should_do_oidc_redirect(self, req: web.Request) -> bool:
def _should_do_oidc_redirect(req: web.Request) -> bool:
"""Check if we should redirect to the OIDC flow.""" """Check if we should redirect to the OIDC flow."""
# Set when we return from finish # Set when we return from finish
if req.query.get("skip_oidc_redirect") == "true": if req.query.get("skip_oidc_redirect") == "true":
@@ -118,6 +143,13 @@ class OIDCInjectedAuthPage(HomeAssistantView):
if not redirect_uri: if not redirect_uri:
return False return False
# Check if we are on a trusted network if we have trusted networks registered first
if (
self.has_trusted_networks_provider_first
and self.provider.is_trusted_network_host()
):
return False
# Handle both encoded and plain redirect_uri values. # Handle both encoded and plain redirect_uri values.
decoded_redirect_uri = unquote(redirect_uri) decoded_redirect_uri = unquote(redirect_uri)
return "skip_oidc_redirect=true" not in decoded_redirect_uri return "skip_oidc_redirect=true" not in decoded_redirect_uri

View File

@@ -6,7 +6,12 @@ import logging
from typing import Dict, Optional from typing import Dict, Optional
import asyncio import asyncio
from homeassistant.auth import EVENT_USER_ADDED from ipaddress import (
ip_address,
IPv4Address,
IPv6Address,
)
from homeassistant.auth import EVENT_USER_ADDED, InvalidAuthError as HAInvalidAuthError
from homeassistant.auth.providers import ( from homeassistant.auth.providers import (
AUTH_PROVIDERS, AUTH_PROVIDERS,
AuthProvider, AuthProvider,
@@ -20,7 +25,6 @@ from homeassistant.auth.providers import (
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.components import http, person from homeassistant.components import http, person
from homeassistant.exceptions import HomeAssistantError
from .config.const import ( from .config.const import (
FEATURES, FEATURES,
@@ -31,6 +35,8 @@ from .config.const import (
from .stores.state_store import StateStore from .stores.state_store import StateStore
from .tools.types import UserDetails from .tools.types import UserDetails
type IPAddress = IPv4Address | IPv6Address
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
PROVIDER_TYPE = "auth_oidc" PROVIDER_TYPE = "auth_oidc"
@@ -38,7 +44,7 @@ HASS_PROVIDER_TYPE = "homeassistant"
COOKIE_NAME = "auth_oidc_state" COOKIE_NAME = "auth_oidc_state"
class InvalidAuthError(HomeAssistantError): class InvalidAuthError(HAInvalidAuthError):
"""Raised when submitting invalid authentication.""" """Raised when submitting invalid authentication."""
@@ -114,6 +120,41 @@ class OpenIDAuthProvider(AuthProvider):
return None return None
def is_trusted_network_host(self) -> bool:
"""Check if the current request is coming from a trusted network host."""
ip = self._resolve_ip()
if ip is None:
return False
# Check if trusted networks auth provider is present
trusted_network_provider = self.hass.auth.get_auth_provider(
"trusted_networks", None
)
if not trusted_network_provider:
return False
_LOGGER.debug(
"Trusted networks present and checking if we should OIDC redirect"
)
try:
trusted_network_provider.async_validate_access(ip_address(ip))
_LOGGER.info("IP %s is in a trusted network, skipping OIDC flow", ip)
return True
except HAInvalidAuthError:
# Log the error
_LOGGER.info(
"IP %s is not in a trusted network, proceeding with OIDC flow", ip
)
return False
# Catch every other error, HA might have changed the API.
# pylint: disable=broad-exception-caught
except Exception as e:
_LOGGER.warning(
"Error while validating trusted network for IP %s: %s", ip, e
)
return False
async def async_create_state(self, redirect_uri: str, ip: str | None = None) -> str: async def async_create_state(self, redirect_uri: str, ip: str | None = None) -> str:
"""Create a new OIDC state and return the state id.""" """Create a new OIDC state and return the state id."""
if self._state_store is None: if self._state_store is None:

View File

@@ -2,11 +2,13 @@
import base64 import base64
import re import re
from collections import OrderedDict
from types import SimpleNamespace from types import SimpleNamespace
from urllib.parse import parse_qs, unquote, urlparse from urllib.parse import parse_qs, unquote, urlparse
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from homeassistant.auth import InvalidAuthError
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
@@ -24,6 +26,10 @@ from custom_components.auth_oidc.config.const import (
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
FAKE_REDIR_URL = "http://example.com/auth/authorize?response_type=code&redirect_uri=http%3A%2F%2Fexample.com%3A8123%2F%3Fauth_callback%3D1&client_id=http%3A%2F%2Fexample.com%3A8123%2F&state=example" FAKE_REDIR_URL = "http://example.com/auth/authorize?response_type=code&redirect_uri=http%3A%2F%2Fexample.com%3A8123%2F%3Fauth_callback%3D1&client_id=http%3A%2F%2Fexample.com%3A8123%2F&state=example"
DEFAULT_CONFIG = {
CLIENT_ID: "dummy",
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
}
async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool: async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool:
@@ -40,10 +46,7 @@ async def test_setup_success_auth_provider_registration(hass: HomeAssistant):
"""Test successful setup""" """Test successful setup"""
await setup( await setup(
hass, hass,
{ DEFAULT_CONFIG,
CLIENT_ID: "dummy",
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
},
True, True,
) )
@@ -62,10 +65,7 @@ async def test_provider_ip_fallback_fails_closed_without_request_context(
"""Provider should not invent a shared IP when request context is missing.""" """Provider should not invent a shared IP when request context is missing."""
await setup( await setup(
hass, hass,
{ DEFAULT_CONFIG,
CLIENT_ID: "dummy",
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
},
True, True,
) )
@@ -83,10 +83,7 @@ async def test_provider_cookie_header_sets_secure_when_requested(hass: HomeAssis
"""Cookie header should include Secure when HTTPS is in use.""" """Cookie header should include Secure when HTTPS is in use."""
await setup( await setup(
hass, hass,
{ DEFAULT_CONFIG,
CLIENT_ID: "dummy",
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
},
True, True,
) )
@@ -98,6 +95,105 @@ async def test_provider_cookie_header_sets_secure_when_requested(hass: HomeAssis
assert "Secure" in cookie_header assert "Secure" in cookie_header
@pytest.mark.asyncio
async def test_provider_is_trusted_network_host_true_for_allowed_ip(
hass: HomeAssistant,
):
"""Provider should detect trusted network host when trusted provider allows the IP."""
await setup(
hass,
DEFAULT_CONFIG,
True,
)
provider = hass.auth.get_auth_providers(DOMAIN)[0]
class TrustedNetworksAllowProvider:
def async_validate_access(self, _ip_addr):
return None
# pylint: disable=protected-access
hass.auth._providers = OrderedDict(
[
(("trusted_networks", None), TrustedNetworksAllowProvider()),
((provider.type, provider.id), provider),
]
)
# pylint: enable=protected-access
with patch(
"custom_components.auth_oidc.provider.http.current_request"
) as current_request:
current_request.get.return_value = SimpleNamespace(remote="127.0.0.1")
assert provider.is_trusted_network_host() is True
@pytest.mark.asyncio
async def test_provider_is_trusted_network_host_false_for_disallowed_ip(
hass: HomeAssistant, caplog
):
"""Provider should return False when trusted provider denies the current IP."""
await setup(
hass,
DEFAULT_CONFIG,
True,
)
provider = hass.auth.get_auth_providers(DOMAIN)[0]
class TrustedNetworksDenyProvider:
def async_validate_access(self, _ip_addr):
raise InvalidAuthError("Not in trusted_networks")
# pylint: disable=protected-access
hass.auth._providers = OrderedDict(
[
(("trusted_networks", None), TrustedNetworksDenyProvider()),
((provider.type, provider.id), provider),
]
)
# pylint: enable=protected-access
with patch(
"custom_components.auth_oidc.provider.http.current_request"
) as current_request:
current_request.get.return_value = SimpleNamespace(remote="127.0.0.1")
assert provider.is_trusted_network_host() is False
assert any(
level >= 0
and "is not in a trusted network, proceeding with OIDC flow" in message
for _, level, message in caplog.record_tuples
)
assert not any(
level >= 0 and "Error while validating trusted network for IP" in message
for _, level, message in caplog.record_tuples
)
@pytest.mark.asyncio
async def test_provider_is_trusted_network_host_false_without_trusted_provider(
hass: HomeAssistant,
):
"""Provider should return False when trusted_networks auth provider is absent."""
await setup(
hass,
DEFAULT_CONFIG,
True,
)
provider = hass.auth.get_auth_providers(DOMAIN)[0]
# Without actually getting the IP, should also be false
assert provider.is_trusted_network_host() is False
# With the IP, should be false
with patch(
"custom_components.auth_oidc.provider.http.current_request"
) as current_request:
current_request.get.return_value = SimpleNamespace(remote="127.0.0.1")
assert provider.is_trusted_network_host() is False
async def login_user(hass: HomeAssistant, state_id: str): async def login_user(hass: HomeAssistant, state_id: str):
"""Helper to login a user from the stored OIDC state.""" """Helper to login a user from the stored OIDC state."""
@@ -167,8 +263,7 @@ async def test_full_login(hass: HomeAssistant, hass_client):
await setup( await setup(
hass, hass,
{ {
CLIENT_ID: "dummy", **DEFAULT_CONFIG,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
FEATURES: { FEATURES: {
FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_PERSON_CREATION: False,
FEATURES_AUTOMATIC_USER_LINKING: False, FEATURES_AUTOMATIC_USER_LINKING: False,
@@ -198,8 +293,7 @@ async def test_login_with_linking(hass: HomeAssistant, hass_client):
await setup( await setup(
hass, hass,
{ {
CLIENT_ID: "dummy", **DEFAULT_CONFIG,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
FEATURES: { FEATURES: {
FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_PERSON_CREATION: False,
FEATURES_AUTOMATIC_USER_LINKING: True, FEATURES_AUTOMATIC_USER_LINKING: True,
@@ -233,8 +327,7 @@ async def test_login_with_person_create(hass: HomeAssistant, hass_client):
await setup( await setup(
hass, hass,
{ {
CLIENT_ID: "dummy", **DEFAULT_CONFIG,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
FEATURES: { FEATURES: {
FEATURES_AUTOMATIC_PERSON_CREATION: True, FEATURES_AUTOMATIC_PERSON_CREATION: True,
FEATURES_AUTOMATIC_USER_LINKING: False, FEATURES_AUTOMATIC_USER_LINKING: False,
@@ -267,8 +360,7 @@ async def test_login_without_person_create_does_not_create_person(
await setup( await setup(
hass, hass,
{ {
CLIENT_ID: "dummy", **DEFAULT_CONFIG,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
FEATURES: { FEATURES: {
FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_PERSON_CREATION: False,
FEATURES_AUTOMATIC_USER_LINKING: False, FEATURES_AUTOMATIC_USER_LINKING: False,
@@ -295,8 +387,7 @@ async def test_login_shows_form(hass: HomeAssistant):
await setup( await setup(
hass, hass,
{ {
CLIENT_ID: "dummy", **DEFAULT_CONFIG,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
FEATURES: { FEATURES: {
FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_PERSON_CREATION: False,
FEATURES_AUTOMATIC_USER_LINKING: False, FEATURES_AUTOMATIC_USER_LINKING: False,
@@ -319,8 +410,7 @@ async def test_login_with_invalid_cookie_aborts(hass: HomeAssistant):
await setup( await setup(
hass, hass,
{ {
CLIENT_ID: "dummy", **DEFAULT_CONFIG,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
FEATURES: { FEATURES: {
FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_PERSON_CREATION: False,
FEATURES_AUTOMATIC_USER_LINKING: False, FEATURES_AUTOMATIC_USER_LINKING: False,
@@ -352,8 +442,7 @@ async def test_login_with_no_cookie_aborts(hass: HomeAssistant):
await setup( await setup(
hass, hass,
{ {
CLIENT_ID: "dummy", **DEFAULT_CONFIG,
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
FEATURES: { FEATURES: {
FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_PERSON_CREATION: False,
FEATURES_AUTOMATIC_USER_LINKING: False, FEATURES_AUTOMATIC_USER_LINKING: False,

View File

@@ -3,14 +3,17 @@
import base64 import base64
import asyncio import asyncio
import re import re
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from urllib.parse import parse_qs, unquote, urlparse, urlencode from urllib.parse import parse_qs, unquote, urlparse, urlencode
import pytest import pytest
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from custom_components.auth_oidc import DOMAIN from custom_components.auth_oidc import DOMAIN
from custom_components.auth_oidc.provider import COOKIE_NAME
from custom_components.auth_oidc.tools.oidc_client import ( from custom_components.auth_oidc.tools.oidc_client import (
OIDCDiscoveryClient, OIDCDiscoveryClient,
OIDCDiscoveryInvalid, OIDCDiscoveryInvalid,
@@ -248,6 +251,47 @@ async def test_full_oidc_flow(hass: HomeAssistant, hass_client):
await verify_back_redirect(client, redirect_uri) await verify_back_redirect(client, redirect_uri)
@pytest.mark.asyncio
async def test_login_flow_init_completes_with_state_cookie(
hass: HomeAssistant, hass_client
):
"""The provider login flow init step should finalize when the auth cookie is present."""
await setup(hass)
with mock_oidc_responses():
client = await hass_client()
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
state, _, status = await get_welcome_for_client(client, redirect_uri)
assert status == 200
authorization_url = await get_redirect_auth_url(client)
session = async_get_clientsession(hass)
resp_auth = session.get(authorization_url, allow_redirects=False)
json_auth = await resp_auth.json()
resp = await client.get(
f"/auth/oidc/callback?code={json_auth['code']}&state={state}",
allow_redirects=False,
)
assert resp.status == 302
provider = hass.auth.get_auth_providers(DOMAIN)[0]
flow = await provider.async_login_flow({})
fake_request = SimpleNamespace(
cookies={COOKIE_NAME: state},
remote="127.0.0.1",
)
with patch(
"custom_components.auth_oidc.provider.http.current_request"
) as current_request:
current_request.get.return_value = fake_request
result = await flow.async_step_init({})
assert result["type"] == FlowResultType.CREATE_ENTRY
async def discovery_test_through_redirect( async def discovery_test_through_redirect(
hass_client, caplog, scenario: str, match_log_line: str hass_client, caplog, scenario: str, match_log_line: str
): ):

View File

@@ -2,6 +2,7 @@
import base64 import base64
import os import os
from collections import OrderedDict
from urllib.parse import parse_qs, quote, unquote, urlparse, urlencode from urllib.parse import parse_qs, quote, unquote, urlparse, urlencode
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from auth_oidc.config.const import ( from auth_oidc.config.const import (
@@ -25,6 +26,22 @@ from custom_components.auth_oidc.endpoints.injected_auth_page import (
) )
MOBILE_CLIENT_ID = "https://home-assistant.io/Android" MOBILE_CLIENT_ID = "https://home-assistant.io/Android"
WELCOME_PATH = "/auth/oidc/welcome"
INJECTION_SCRIPT_MARKER = "<script src='/auth/oidc/static/injection.js"
def assert_redirects_to_welcome(resp) -> None:
"""Assert a response redirects to the OIDC welcome endpoint."""
assert resp.status == 302
location = resp.headers["Location"]
parsed_location = urlparse(location)
assert parsed_location.path == WELCOME_PATH
async def assert_normal_login_screen(resp) -> None:
"""Assert we stayed on the auth page and render the injected normal login HTML."""
assert resp.status == 200
assert INJECTION_SCRIPT_MARKER in await resp.text()
def create_redirect_uri(client_id: str) -> str: def create_redirect_uri(client_id: str) -> str:
@@ -310,7 +327,9 @@ async def test_welcome_desktop_auto_redirects_without_other_providers(
"""Welcome should auto-redirect desktop clients when no other providers exist.""" """Welcome should auto-redirect desktop clients when no other providers exist."""
# pylint: disable=protected-access # pylint: disable=protected-access
hass.auth._providers = [] # Clear initial providers out hass.auth._providers = {} # Clear initial providers out
# pylint: enable=protected-access
await setup(hass) await setup(hass)
client = await hass_client() client = await hass_client()
@@ -334,7 +353,7 @@ async def test_redirect_without_cookie_goes_to_welcome(
client = await hass_client() client = await hass_client()
resp = await client.get("/auth/oidc/redirect", allow_redirects=False) resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
assert resp.status == 302 assert resp.status == 302
assert "/auth/oidc/welcome" in resp.headers["Location"] assert_redirects_to_welcome(resp)
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -730,10 +749,7 @@ async def test_frontend_injection(
client = await hass_client() client = await hass_client()
resp = await client.get("/auth/authorize", allow_redirects=False) resp = await client.get("/auth/authorize", allow_redirects=False)
assert resp.status == 200 # 200 because there is no redirect_uri await assert_normal_login_screen(resp)
text = await resp.text()
assert "<script src='/auth/oidc/static/injection.js" in text
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -760,8 +776,12 @@ async def test_frontend_injection_logs_and_returns_when_route_handler_is_unexpec
def __iter__(self): def __iter__(self):
return iter([FakeRoute()]) return iter([FakeRoute()])
provider = MagicMock()
with patch.object(hass.http.app.router, "resources", return_value=[FakeResource()]): with patch.object(hass.http.app.router, "resources", return_value=[FakeResource()]):
await frontend_injection(hass, force_https=False) await frontend_injection(
hass, provider, force_https=False, has_trusted_networks_provider_first=False
)
assert "Unexpected route handler type" in caplog.text assert "Unexpected route handler type" in caplog.text
assert ( assert (
@@ -780,7 +800,10 @@ async def test_injected_auth_page_inject_logs_errors(hass: HomeAssistant, caplog
"custom_components.auth_oidc.endpoints.injected_auth_page.frontend_injection", "custom_components.auth_oidc.endpoints.injected_auth_page.frontend_injection",
side_effect=RuntimeError("boom"), side_effect=RuntimeError("boom"),
): ):
await OIDCInjectedAuthPage.inject(hass, force_https=False) provider = MagicMock()
await OIDCInjectedAuthPage.inject(
hass, provider, force_https=False, has_trusted_networks_provider_first=False
)
assert "Failed to inject OIDC auth page: boom" in caplog.text assert "Failed to inject OIDC auth page: boom" in caplog.text
@@ -836,5 +859,71 @@ async def test_injected_auth_page_returns_original_html_when_skipped(
client = await hass_client() client = await hass_client()
response = await client.get(request_target, allow_redirects=False) response = await client.get(request_target, allow_redirects=False)
assert response.status == 200 await assert_normal_login_screen(response)
assert "<script src='/auth/oidc/static/injection.js" in await response.text()
@pytest.mark.asyncio
async def test_injected_auth_page_trusted_networks_bypass_skips_oidc_redirect(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""Trusted network hosts should bypass OIDC redirect when trusted_networks is first."""
class TrustedNetworksAllowProvider:
def async_validate_access(self, _ip_addr):
return None
# pylint: disable=protected-access
hass.auth._providers = OrderedDict(
[(("trusted_networks", None), TrustedNetworksAllowProvider())]
)
# pylint: enable=protected-access
await setup_mock_authorize_route(hass)
await setup(hass)
client = await hass_client()
encoded_redirect_uri = quote(create_redirect_uri(client.make_url("/")), safe="")
resp = await client.get(
f"/auth/authorize?redirect_uri={encoded_redirect_uri}",
allow_redirects=False,
)
await assert_normal_login_screen(resp)
@pytest.mark.asyncio
async def test_injected_auth_page_ignores_trusted_networks_when_not_first(
hass: HomeAssistant, hass_client: ClientSessionGenerator
):
"""OIDC redirect should continue when trusted_networks is not the first provider."""
class DummyProvider:
pass
class TrustedNetworksAllowProvider:
def async_validate_access(self, _ip_addr):
return None
# Keep trusted_networks present but not first, so bypass should not apply.
# pylint: disable=protected-access
hass.auth._providers = OrderedDict(
[
(("homeassistant", None), DummyProvider()),
(("trusted_networks", None), TrustedNetworksAllowProvider()),
]
)
# pylint: enable=protected-access
await setup_mock_authorize_route(hass)
await setup(hass)
client = await hass_client()
encoded_redirect_uri = quote(create_redirect_uri(client.make_url("/")), safe="")
resp = await client.get(
f"/auth/authorize?redirect_uri={encoded_redirect_uri}",
allow_redirects=False,
)
assert_redirects_to_welcome(resp)