Implement trusted_networks support (#283)

* Implement bypass for trusted_networks

* Trusted Network tests

* Test cleanup

* Improve integration tests

* Defensive programming

* Fix wrong import issue
This commit is contained in:
Christiaan Goossens
2026-05-01 14:03:14 +02:00
committed by GitHub
parent 04abb0fdb3
commit c7370ed266
6 changed files with 386 additions and 54 deletions

View File

@@ -6,7 +6,12 @@ import logging
from typing import Dict, Optional
import asyncio
from homeassistant.auth import EVENT_USER_ADDED
from ipaddress import (
ip_address,
IPv4Address,
IPv6Address,
)
from homeassistant.auth import EVENT_USER_ADDED, InvalidAuthError as HAInvalidAuthError
from homeassistant.auth.providers import (
AUTH_PROVIDERS,
AuthProvider,
@@ -20,7 +25,6 @@ from homeassistant.auth.providers import (
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
from homeassistant.core import HomeAssistant, callback
from homeassistant.components import http, person
from homeassistant.exceptions import HomeAssistantError
from .config.const import (
FEATURES,
@@ -31,6 +35,8 @@ from .config.const import (
from .stores.state_store import StateStore
from .tools.types import UserDetails
type IPAddress = IPv4Address | IPv6Address
_LOGGER = logging.getLogger(__name__)
PROVIDER_TYPE = "auth_oidc"
@@ -38,7 +44,7 @@ HASS_PROVIDER_TYPE = "homeassistant"
COOKIE_NAME = "auth_oidc_state"
class InvalidAuthError(HomeAssistantError):
class InvalidAuthError(HAInvalidAuthError):
"""Raised when submitting invalid authentication."""
@@ -114,6 +120,41 @@ class OpenIDAuthProvider(AuthProvider):
return None
def is_trusted_network_host(self) -> bool:
"""Check if the current request is coming from a trusted network host."""
ip = self._resolve_ip()
if ip is None:
return False
# Check if trusted networks auth provider is present
trusted_network_provider = self.hass.auth.get_auth_provider(
"trusted_networks", None
)
if not trusted_network_provider:
return False
_LOGGER.debug(
"Trusted networks present and checking if we should OIDC redirect"
)
try:
trusted_network_provider.async_validate_access(ip_address(ip))
_LOGGER.info("IP %s is in a trusted network, skipping OIDC flow", ip)
return True
except HAInvalidAuthError:
# Log the error
_LOGGER.info(
"IP %s is not in a trusted network, proceeding with OIDC flow", ip
)
return False
# Catch every other error, HA might have changed the API.
# pylint: disable=broad-exception-caught
except Exception as e:
_LOGGER.warning(
"Error while validating trusted network for IP %s: %s", ip, e
)
return False
async def async_create_state(self, redirect_uri: str, ip: str | None = None) -> str:
"""Create a new OIDC state and return the state id."""
if self._state_store is None: