diff --git a/custom_components/auth_oidc/__init__.py b/custom_components/auth_oidc/__init__.py index 22e8c05..9d87e0e 100644 --- a/custom_components/auth_oidc/__init__.py +++ b/custom_components/auth_oidc/__init__.py @@ -90,24 +90,59 @@ async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry): return False -async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str): - """Set up the OIDC provider with the given configuration.""" - providers = OrderedDict() - +async def _register_oidc_provider(hass: HomeAssistant, my_config: dict): + """Register the OIDC provider in Home Assistant's auth system.""" # Use private APIs until there is a real auth platform + # pylint: disable=protected-access + providers = OrderedDict() provider = OpenIDAuthProvider(hass, hass.auth._store, my_config) + existing_auth_providers = hass.auth._providers.copy() + _LOGGER.debug("Current auth providers: %s", list(existing_auth_providers.keys())) + has_other_auth_providers = len(existing_auth_providers) > 0 + has_trusted_networks_provider_first = False + + if has_other_auth_providers: + # Pop the first provider from the existing providers to check if it's trusted_networks + first_provider_key, first_provider_obj = next( + iter(existing_auth_providers.items()) + ) + existing_auth_providers.pop(first_provider_key) + + if first_provider_key[0] == "trusted_networks": + _LOGGER.info( + "Trusted Networks provider detected as the first auth provider. " + + "Keeping registration order intact." + ) + providers[first_provider_key] = first_provider_obj + has_trusted_networks_provider_first = True + else: + # Reset back to what we had before + existing_auth_providers = hass.auth._providers.copy() + + # Register OIDC at the start of the array + # OIDC needs to be first because it needs to process the login cookie after sign-in providers[(provider.type, provider.id)] = provider - # Get current provider count - has_other_auth_providers = len(hass.auth._providers) > 0 + # Add back any other providers that were already registered + providers.update(existing_auth_providers) - providers.update(hass.auth._providers) + _LOGGER.debug("Final auth providers: %s", list(providers.values())) hass.auth._providers = providers # pylint: enable=protected-access _LOGGER.info("Registered OIDC provider") + return provider, has_other_auth_providers, has_trusted_networks_provider_first + + +async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str): + """Set up the OIDC provider with the given configuration.""" + ( + provider, + has_other_auth_providers, + has_trusted_networks_provider_first, + ) = await _register_oidc_provider(hass, my_config) # Set the correct scopes # Always use 'openid' & 'profile' as they are specified in the OIDC spec @@ -179,6 +214,8 @@ async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_nam _LOGGER.info("Registered OIDC views") # Inject OIDC code into the frontend for /auth/authorize for automatic redirect - await OIDCInjectedAuthPage.inject(hass, force_https) + await OIDCInjectedAuthPage.inject( + hass, provider, force_https, has_trusted_networks_provider_first + ) return True diff --git a/custom_components/auth_oidc/endpoints/injected_auth_page.py b/custom_components/auth_oidc/endpoints/injected_auth_page.py index c3c2212..d800fb6 100644 --- a/custom_components/auth_oidc/endpoints/injected_auth_page.py +++ b/custom_components/auth_oidc/endpoints/injected_auth_page.py @@ -11,6 +11,7 @@ from homeassistant.components.http import HomeAssistantView, StaticPathConfig from homeassistant.core import HomeAssistant from .welcome import PATH as WELCOME_PATH +from ..provider import OpenIDAuthProvider from ..tools.helpers import get_url PATH = "/auth/authorize" @@ -24,7 +25,12 @@ async def read_file(path: str) -> str: return await f.read() -async def frontend_injection(hass: HomeAssistant, force_https: bool) -> None: +async def frontend_injection( + hass: HomeAssistant, + provider: OpenIDAuthProvider, + force_https: bool, + has_trusted_networks_provider_first: bool, +) -> None: """Inject new frontend code into /auth/authorize.""" router = hass.http.app.router frontend_path = None @@ -81,7 +87,11 @@ async def frontend_injection(hass: HomeAssistant, force_https: bool) -> None: ) # If everything is succesful, register a fake view that just returns the modified HTML - hass.http.register_view(OIDCInjectedAuthPage(frontend_code, force_https)) + hass.http.register_view( + OIDCInjectedAuthPage( + frontend_code, provider, force_https, has_trusted_networks_provider_first + ) + ) _LOGGER.info("Performed OIDC frontend injection") @@ -92,21 +102,36 @@ class OIDCInjectedAuthPage(HomeAssistantView): url = PATH name = "auth:oidc:authorize_page" - def __init__(self, html: str, force_https: bool) -> None: + def __init__( + self, + html: str, + provider: OpenIDAuthProvider, + force_https: bool, + has_trusted_networks_provider_first: bool, + ) -> None: """Initialize the injected auth page.""" self.html = html + self.provider = provider self.force_https = force_https + self.has_trusted_networks_provider_first = has_trusted_networks_provider_first @staticmethod - async def inject(hass: HomeAssistant, force_https: bool) -> None: + async def inject( + hass: HomeAssistant, + provider: OpenIDAuthProvider, + force_https: bool, + has_trusted_networks_provider_first: bool, + ) -> None: """Inject the OIDC auth page into the frontend.""" + try: - await frontend_injection(hass, force_https) + await frontend_injection( + hass, provider, force_https, has_trusted_networks_provider_first + ) except Exception as e: # pylint: disable=broad-except _LOGGER.error("Failed to inject OIDC auth page: %s", e) - @staticmethod - def _should_do_oidc_redirect(req: web.Request) -> bool: + def _should_do_oidc_redirect(self, req: web.Request) -> bool: """Check if we should redirect to the OIDC flow.""" # Set when we return from finish if req.query.get("skip_oidc_redirect") == "true": @@ -118,6 +143,13 @@ class OIDCInjectedAuthPage(HomeAssistantView): if not redirect_uri: return False + # Check if we are on a trusted network if we have trusted networks registered first + if ( + self.has_trusted_networks_provider_first + and self.provider.is_trusted_network_host() + ): + return False + # Handle both encoded and plain redirect_uri values. decoded_redirect_uri = unquote(redirect_uri) return "skip_oidc_redirect=true" not in decoded_redirect_uri diff --git a/custom_components/auth_oidc/provider.py b/custom_components/auth_oidc/provider.py index 8e55e77..1943bb7 100644 --- a/custom_components/auth_oidc/provider.py +++ b/custom_components/auth_oidc/provider.py @@ -6,7 +6,12 @@ import logging from typing import Dict, Optional import asyncio -from homeassistant.auth import EVENT_USER_ADDED +from ipaddress import ( + ip_address, + IPv4Address, + IPv6Address, +) +from homeassistant.auth import EVENT_USER_ADDED, InvalidAuthError as HAInvalidAuthError from homeassistant.auth.providers import ( AUTH_PROVIDERS, AuthProvider, @@ -20,7 +25,6 @@ from homeassistant.auth.providers import ( from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE from homeassistant.core import HomeAssistant, callback from homeassistant.components import http, person -from homeassistant.exceptions import HomeAssistantError from .config.const import ( FEATURES, @@ -31,6 +35,8 @@ from .config.const import ( from .stores.state_store import StateStore from .tools.types import UserDetails +type IPAddress = IPv4Address | IPv6Address + _LOGGER = logging.getLogger(__name__) PROVIDER_TYPE = "auth_oidc" @@ -38,7 +44,7 @@ HASS_PROVIDER_TYPE = "homeassistant" COOKIE_NAME = "auth_oidc_state" -class InvalidAuthError(HomeAssistantError): +class InvalidAuthError(HAInvalidAuthError): """Raised when submitting invalid authentication.""" @@ -114,6 +120,41 @@ class OpenIDAuthProvider(AuthProvider): return None + def is_trusted_network_host(self) -> bool: + """Check if the current request is coming from a trusted network host.""" + ip = self._resolve_ip() + if ip is None: + return False + + # Check if trusted networks auth provider is present + trusted_network_provider = self.hass.auth.get_auth_provider( + "trusted_networks", None + ) + if not trusted_network_provider: + return False + + _LOGGER.debug( + "Trusted networks present and checking if we should OIDC redirect" + ) + + try: + trusted_network_provider.async_validate_access(ip_address(ip)) + _LOGGER.info("IP %s is in a trusted network, skipping OIDC flow", ip) + return True + except HAInvalidAuthError: + # Log the error + _LOGGER.info( + "IP %s is not in a trusted network, proceeding with OIDC flow", ip + ) + return False + # Catch every other error, HA might have changed the API. + # pylint: disable=broad-exception-caught + except Exception as e: + _LOGGER.warning( + "Error while validating trusted network for IP %s: %s", ip, e + ) + return False + async def async_create_state(self, redirect_uri: str, ip: str | None = None) -> str: """Create a new OIDC state and return the state id.""" if self._state_store is None: diff --git a/tests/test_hass_auth_provider.py b/tests/test_hass_auth_provider.py index baa80b5..13296d6 100644 --- a/tests/test_hass_auth_provider.py +++ b/tests/test_hass_auth_provider.py @@ -2,11 +2,13 @@ import base64 import re +from collections import OrderedDict from types import SimpleNamespace from urllib.parse import parse_qs, unquote, urlparse from unittest.mock import patch import pytest +from homeassistant.auth import InvalidAuthError from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType from homeassistant.setup import async_setup_component @@ -24,6 +26,10 @@ from custom_components.auth_oidc.config.const import ( from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses FAKE_REDIR_URL = "http://example.com/auth/authorize?response_type=code&redirect_uri=http%3A%2F%2Fexample.com%3A8123%2F%3Fauth_callback%3D1&client_id=http%3A%2F%2Fexample.com%3A8123%2F&state=example" +DEFAULT_CONFIG = { + CLIENT_ID: "dummy", + DISCOVERY_URL: MockOIDCServer.get_discovery_url(), +} async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool: @@ -40,10 +46,7 @@ async def test_setup_success_auth_provider_registration(hass: HomeAssistant): """Test successful setup""" await setup( hass, - { - CLIENT_ID: "dummy", - DISCOVERY_URL: "https://example.com/.well-known/openid-configuration", - }, + DEFAULT_CONFIG, True, ) @@ -62,10 +65,7 @@ async def test_provider_ip_fallback_fails_closed_without_request_context( """Provider should not invent a shared IP when request context is missing.""" await setup( hass, - { - CLIENT_ID: "dummy", - DISCOVERY_URL: "https://example.com/.well-known/openid-configuration", - }, + DEFAULT_CONFIG, True, ) @@ -83,10 +83,7 @@ async def test_provider_cookie_header_sets_secure_when_requested(hass: HomeAssis """Cookie header should include Secure when HTTPS is in use.""" await setup( hass, - { - CLIENT_ID: "dummy", - DISCOVERY_URL: "https://example.com/.well-known/openid-configuration", - }, + DEFAULT_CONFIG, True, ) @@ -98,6 +95,105 @@ async def test_provider_cookie_header_sets_secure_when_requested(hass: HomeAssis assert "Secure" in cookie_header +@pytest.mark.asyncio +async def test_provider_is_trusted_network_host_true_for_allowed_ip( + hass: HomeAssistant, +): + """Provider should detect trusted network host when trusted provider allows the IP.""" + await setup( + hass, + DEFAULT_CONFIG, + True, + ) + + provider = hass.auth.get_auth_providers(DOMAIN)[0] + + class TrustedNetworksAllowProvider: + def async_validate_access(self, _ip_addr): + return None + + # pylint: disable=protected-access + hass.auth._providers = OrderedDict( + [ + (("trusted_networks", None), TrustedNetworksAllowProvider()), + ((provider.type, provider.id), provider), + ] + ) + # pylint: enable=protected-access + + with patch( + "custom_components.auth_oidc.provider.http.current_request" + ) as current_request: + current_request.get.return_value = SimpleNamespace(remote="127.0.0.1") + assert provider.is_trusted_network_host() is True + + +@pytest.mark.asyncio +async def test_provider_is_trusted_network_host_false_for_disallowed_ip( + hass: HomeAssistant, caplog +): + """Provider should return False when trusted provider denies the current IP.""" + await setup( + hass, + DEFAULT_CONFIG, + True, + ) + + provider = hass.auth.get_auth_providers(DOMAIN)[0] + + class TrustedNetworksDenyProvider: + def async_validate_access(self, _ip_addr): + raise InvalidAuthError("Not in trusted_networks") + + # pylint: disable=protected-access + hass.auth._providers = OrderedDict( + [ + (("trusted_networks", None), TrustedNetworksDenyProvider()), + ((provider.type, provider.id), provider), + ] + ) + # pylint: enable=protected-access + + with patch( + "custom_components.auth_oidc.provider.http.current_request" + ) as current_request: + current_request.get.return_value = SimpleNamespace(remote="127.0.0.1") + assert provider.is_trusted_network_host() is False + assert any( + level >= 0 + and "is not in a trusted network, proceeding with OIDC flow" in message + for _, level, message in caplog.record_tuples + ) + assert not any( + level >= 0 and "Error while validating trusted network for IP" in message + for _, level, message in caplog.record_tuples + ) + + +@pytest.mark.asyncio +async def test_provider_is_trusted_network_host_false_without_trusted_provider( + hass: HomeAssistant, +): + """Provider should return False when trusted_networks auth provider is absent.""" + await setup( + hass, + DEFAULT_CONFIG, + True, + ) + + provider = hass.auth.get_auth_providers(DOMAIN)[0] + + # Without actually getting the IP, should also be false + assert provider.is_trusted_network_host() is False + + # With the IP, should be false + with patch( + "custom_components.auth_oidc.provider.http.current_request" + ) as current_request: + current_request.get.return_value = SimpleNamespace(remote="127.0.0.1") + assert provider.is_trusted_network_host() is False + + async def login_user(hass: HomeAssistant, state_id: str): """Helper to login a user from the stored OIDC state.""" @@ -167,8 +263,7 @@ async def test_full_login(hass: HomeAssistant, hass_client): await setup( hass, { - CLIENT_ID: "dummy", - DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + **DEFAULT_CONFIG, FEATURES: { FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_USER_LINKING: False, @@ -198,8 +293,7 @@ async def test_login_with_linking(hass: HomeAssistant, hass_client): await setup( hass, { - CLIENT_ID: "dummy", - DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + **DEFAULT_CONFIG, FEATURES: { FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_USER_LINKING: True, @@ -233,8 +327,7 @@ async def test_login_with_person_create(hass: HomeAssistant, hass_client): await setup( hass, { - CLIENT_ID: "dummy", - DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + **DEFAULT_CONFIG, FEATURES: { FEATURES_AUTOMATIC_PERSON_CREATION: True, FEATURES_AUTOMATIC_USER_LINKING: False, @@ -267,8 +360,7 @@ async def test_login_without_person_create_does_not_create_person( await setup( hass, { - CLIENT_ID: "dummy", - DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + **DEFAULT_CONFIG, FEATURES: { FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_USER_LINKING: False, @@ -295,8 +387,7 @@ async def test_login_shows_form(hass: HomeAssistant): await setup( hass, { - CLIENT_ID: "dummy", - DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + **DEFAULT_CONFIG, FEATURES: { FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_USER_LINKING: False, @@ -319,8 +410,7 @@ async def test_login_with_invalid_cookie_aborts(hass: HomeAssistant): await setup( hass, { - CLIENT_ID: "dummy", - DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + **DEFAULT_CONFIG, FEATURES: { FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_USER_LINKING: False, @@ -352,8 +442,7 @@ async def test_login_with_no_cookie_aborts(hass: HomeAssistant): await setup( hass, { - CLIENT_ID: "dummy", - DISCOVERY_URL: MockOIDCServer.get_discovery_url(), + **DEFAULT_CONFIG, FEATURES: { FEATURES_AUTOMATIC_PERSON_CREATION: False, FEATURES_AUTOMATIC_USER_LINKING: False, diff --git a/tests/test_hass_oidc_client_integration.py b/tests/test_hass_oidc_client_integration.py index 356bd64..10dafca 100644 --- a/tests/test_hass_oidc_client_integration.py +++ b/tests/test_hass_oidc_client_integration.py @@ -3,14 +3,17 @@ import base64 import asyncio import re +from types import SimpleNamespace from unittest.mock import AsyncMock, patch from urllib.parse import parse_qs, unquote, urlparse, urlencode import pytest from homeassistant.core import HomeAssistant +from homeassistant.data_entry_flow import FlowResultType from homeassistant.setup import async_setup_component from homeassistant.helpers.aiohttp_client import async_get_clientsession from custom_components.auth_oidc import DOMAIN +from custom_components.auth_oidc.provider import COOKIE_NAME from custom_components.auth_oidc.tools.oidc_client import ( OIDCDiscoveryClient, OIDCDiscoveryInvalid, @@ -248,6 +251,47 @@ async def test_full_oidc_flow(hass: HomeAssistant, hass_client): await verify_back_redirect(client, redirect_uri) +@pytest.mark.asyncio +async def test_login_flow_init_completes_with_state_cookie( + hass: HomeAssistant, hass_client +): + """The provider login flow init step should finalize when the auth cookie is present.""" + await setup(hass) + + with mock_oidc_responses(): + client = await hass_client() + redirect_uri = create_redirect_uri(WEB_CLIENT_ID) + + state, _, status = await get_welcome_for_client(client, redirect_uri) + assert status == 200 + + authorization_url = await get_redirect_auth_url(client) + session = async_get_clientsession(hass) + resp_auth = session.get(authorization_url, allow_redirects=False) + json_auth = await resp_auth.json() + + resp = await client.get( + f"/auth/oidc/callback?code={json_auth['code']}&state={state}", + allow_redirects=False, + ) + assert resp.status == 302 + + provider = hass.auth.get_auth_providers(DOMAIN)[0] + flow = await provider.async_login_flow({}) + + fake_request = SimpleNamespace( + cookies={COOKIE_NAME: state}, + remote="127.0.0.1", + ) + with patch( + "custom_components.auth_oidc.provider.http.current_request" + ) as current_request: + current_request.get.return_value = fake_request + result = await flow.async_step_init({}) + + assert result["type"] == FlowResultType.CREATE_ENTRY + + async def discovery_test_through_redirect( hass_client, caplog, scenario: str, match_log_line: str ): diff --git a/tests/test_hass_webserver.py b/tests/test_hass_webserver.py index eb114e8..7af0a69 100644 --- a/tests/test_hass_webserver.py +++ b/tests/test_hass_webserver.py @@ -2,6 +2,7 @@ import base64 import os +from collections import OrderedDict from urllib.parse import parse_qs, quote, unquote, urlparse, urlencode from unittest.mock import AsyncMock, MagicMock, patch from auth_oidc.config.const import ( @@ -25,6 +26,22 @@ from custom_components.auth_oidc.endpoints.injected_auth_page import ( ) MOBILE_CLIENT_ID = "https://home-assistant.io/Android" +WELCOME_PATH = "/auth/oidc/welcome" +INJECTION_SCRIPT_MARKER = "