Implement trusted_networks support (#283)
* Implement bypass for trusted_networks * Trusted Network tests * Test cleanup * Improve integration tests * Defensive programming * Fix wrong import issue
This commit is contained in:
committed by
GitHub
parent
04abb0fdb3
commit
c7370ed266
@@ -6,7 +6,12 @@ import logging
|
||||
|
||||
from typing import Dict, Optional
|
||||
import asyncio
|
||||
from homeassistant.auth import EVENT_USER_ADDED
|
||||
from ipaddress import (
|
||||
ip_address,
|
||||
IPv4Address,
|
||||
IPv6Address,
|
||||
)
|
||||
from homeassistant.auth import EVENT_USER_ADDED, InvalidAuthError as HAInvalidAuthError
|
||||
from homeassistant.auth.providers import (
|
||||
AUTH_PROVIDERS,
|
||||
AuthProvider,
|
||||
@@ -20,7 +25,6 @@ from homeassistant.auth.providers import (
|
||||
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.components import http, person
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .config.const import (
|
||||
FEATURES,
|
||||
@@ -31,6 +35,8 @@ from .config.const import (
|
||||
from .stores.state_store import StateStore
|
||||
from .tools.types import UserDetails
|
||||
|
||||
type IPAddress = IPv4Address | IPv6Address
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PROVIDER_TYPE = "auth_oidc"
|
||||
@@ -38,7 +44,7 @@ HASS_PROVIDER_TYPE = "homeassistant"
|
||||
COOKIE_NAME = "auth_oidc_state"
|
||||
|
||||
|
||||
class InvalidAuthError(HomeAssistantError):
|
||||
class InvalidAuthError(HAInvalidAuthError):
|
||||
"""Raised when submitting invalid authentication."""
|
||||
|
||||
|
||||
@@ -114,6 +120,41 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
|
||||
return None
|
||||
|
||||
def is_trusted_network_host(self) -> bool:
|
||||
"""Check if the current request is coming from a trusted network host."""
|
||||
ip = self._resolve_ip()
|
||||
if ip is None:
|
||||
return False
|
||||
|
||||
# Check if trusted networks auth provider is present
|
||||
trusted_network_provider = self.hass.auth.get_auth_provider(
|
||||
"trusted_networks", None
|
||||
)
|
||||
if not trusted_network_provider:
|
||||
return False
|
||||
|
||||
_LOGGER.debug(
|
||||
"Trusted networks present and checking if we should OIDC redirect"
|
||||
)
|
||||
|
||||
try:
|
||||
trusted_network_provider.async_validate_access(ip_address(ip))
|
||||
_LOGGER.info("IP %s is in a trusted network, skipping OIDC flow", ip)
|
||||
return True
|
||||
except HAInvalidAuthError:
|
||||
# Log the error
|
||||
_LOGGER.info(
|
||||
"IP %s is not in a trusted network, proceeding with OIDC flow", ip
|
||||
)
|
||||
return False
|
||||
# Catch every other error, HA might have changed the API.
|
||||
# pylint: disable=broad-exception-caught
|
||||
except Exception as e:
|
||||
_LOGGER.warning(
|
||||
"Error while validating trusted network for IP %s: %s", ip, e
|
||||
)
|
||||
return False
|
||||
|
||||
async def async_create_state(self, redirect_uri: str, ip: str | None = None) -> str:
|
||||
"""Create a new OIDC state and return the state id."""
|
||||
if self._state_store is None:
|
||||
|
||||
Reference in New Issue
Block a user