Implement trusted_networks support (#283)
* Implement bypass for trusted_networks * Trusted Network tests * Test cleanup * Improve integration tests * Defensive programming * Fix wrong import issue
This commit is contained in:
committed by
GitHub
parent
04abb0fdb3
commit
c7370ed266
@@ -90,24 +90,59 @@ async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str):
|
async def _register_oidc_provider(hass: HomeAssistant, my_config: dict):
|
||||||
"""Set up the OIDC provider with the given configuration."""
|
"""Register the OIDC provider in Home Assistant's auth system."""
|
||||||
providers = OrderedDict()
|
|
||||||
|
|
||||||
# Use private APIs until there is a real auth platform
|
# Use private APIs until there is a real auth platform
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
|
providers = OrderedDict()
|
||||||
provider = OpenIDAuthProvider(hass, hass.auth._store, my_config)
|
provider = OpenIDAuthProvider(hass, hass.auth._store, my_config)
|
||||||
|
|
||||||
|
existing_auth_providers = hass.auth._providers.copy()
|
||||||
|
_LOGGER.debug("Current auth providers: %s", list(existing_auth_providers.keys()))
|
||||||
|
has_other_auth_providers = len(existing_auth_providers) > 0
|
||||||
|
has_trusted_networks_provider_first = False
|
||||||
|
|
||||||
|
if has_other_auth_providers:
|
||||||
|
# Pop the first provider from the existing providers to check if it's trusted_networks
|
||||||
|
first_provider_key, first_provider_obj = next(
|
||||||
|
iter(existing_auth_providers.items())
|
||||||
|
)
|
||||||
|
existing_auth_providers.pop(first_provider_key)
|
||||||
|
|
||||||
|
if first_provider_key[0] == "trusted_networks":
|
||||||
|
_LOGGER.info(
|
||||||
|
"Trusted Networks provider detected as the first auth provider. "
|
||||||
|
+ "Keeping registration order intact."
|
||||||
|
)
|
||||||
|
providers[first_provider_key] = first_provider_obj
|
||||||
|
has_trusted_networks_provider_first = True
|
||||||
|
else:
|
||||||
|
# Reset back to what we had before
|
||||||
|
existing_auth_providers = hass.auth._providers.copy()
|
||||||
|
|
||||||
|
# Register OIDC at the start of the array
|
||||||
|
# OIDC needs to be first because it needs to process the login cookie after sign-in
|
||||||
providers[(provider.type, provider.id)] = provider
|
providers[(provider.type, provider.id)] = provider
|
||||||
|
|
||||||
# Get current provider count
|
# Add back any other providers that were already registered
|
||||||
has_other_auth_providers = len(hass.auth._providers) > 0
|
providers.update(existing_auth_providers)
|
||||||
|
|
||||||
providers.update(hass.auth._providers)
|
_LOGGER.debug("Final auth providers: %s", list(providers.values()))
|
||||||
hass.auth._providers = providers
|
hass.auth._providers = providers
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
_LOGGER.info("Registered OIDC provider")
|
_LOGGER.info("Registered OIDC provider")
|
||||||
|
return provider, has_other_auth_providers, has_trusted_networks_provider_first
|
||||||
|
|
||||||
|
|
||||||
|
async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str):
|
||||||
|
"""Set up the OIDC provider with the given configuration."""
|
||||||
|
(
|
||||||
|
provider,
|
||||||
|
has_other_auth_providers,
|
||||||
|
has_trusted_networks_provider_first,
|
||||||
|
) = await _register_oidc_provider(hass, my_config)
|
||||||
|
|
||||||
# Set the correct scopes
|
# Set the correct scopes
|
||||||
# Always use 'openid' & 'profile' as they are specified in the OIDC spec
|
# Always use 'openid' & 'profile' as they are specified in the OIDC spec
|
||||||
@@ -179,6 +214,8 @@ async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_nam
|
|||||||
_LOGGER.info("Registered OIDC views")
|
_LOGGER.info("Registered OIDC views")
|
||||||
|
|
||||||
# Inject OIDC code into the frontend for /auth/authorize for automatic redirect
|
# Inject OIDC code into the frontend for /auth/authorize for automatic redirect
|
||||||
await OIDCInjectedAuthPage.inject(hass, force_https)
|
await OIDCInjectedAuthPage.inject(
|
||||||
|
hass, provider, force_https, has_trusted_networks_provider_first
|
||||||
|
)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from homeassistant.components.http import HomeAssistantView, StaticPathConfig
|
|||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from .welcome import PATH as WELCOME_PATH
|
from .welcome import PATH as WELCOME_PATH
|
||||||
|
from ..provider import OpenIDAuthProvider
|
||||||
from ..tools.helpers import get_url
|
from ..tools.helpers import get_url
|
||||||
|
|
||||||
PATH = "/auth/authorize"
|
PATH = "/auth/authorize"
|
||||||
@@ -24,7 +25,12 @@ async def read_file(path: str) -> str:
|
|||||||
return await f.read()
|
return await f.read()
|
||||||
|
|
||||||
|
|
||||||
async def frontend_injection(hass: HomeAssistant, force_https: bool) -> None:
|
async def frontend_injection(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
provider: OpenIDAuthProvider,
|
||||||
|
force_https: bool,
|
||||||
|
has_trusted_networks_provider_first: bool,
|
||||||
|
) -> None:
|
||||||
"""Inject new frontend code into /auth/authorize."""
|
"""Inject new frontend code into /auth/authorize."""
|
||||||
router = hass.http.app.router
|
router = hass.http.app.router
|
||||||
frontend_path = None
|
frontend_path = None
|
||||||
@@ -81,7 +87,11 @@ async def frontend_injection(hass: HomeAssistant, force_https: bool) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# If everything is succesful, register a fake view that just returns the modified HTML
|
# If everything is succesful, register a fake view that just returns the modified HTML
|
||||||
hass.http.register_view(OIDCInjectedAuthPage(frontend_code, force_https))
|
hass.http.register_view(
|
||||||
|
OIDCInjectedAuthPage(
|
||||||
|
frontend_code, provider, force_https, has_trusted_networks_provider_first
|
||||||
|
)
|
||||||
|
)
|
||||||
_LOGGER.info("Performed OIDC frontend injection")
|
_LOGGER.info("Performed OIDC frontend injection")
|
||||||
|
|
||||||
|
|
||||||
@@ -92,21 +102,36 @@ class OIDCInjectedAuthPage(HomeAssistantView):
|
|||||||
url = PATH
|
url = PATH
|
||||||
name = "auth:oidc:authorize_page"
|
name = "auth:oidc:authorize_page"
|
||||||
|
|
||||||
def __init__(self, html: str, force_https: bool) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
html: str,
|
||||||
|
provider: OpenIDAuthProvider,
|
||||||
|
force_https: bool,
|
||||||
|
has_trusted_networks_provider_first: bool,
|
||||||
|
) -> None:
|
||||||
"""Initialize the injected auth page."""
|
"""Initialize the injected auth page."""
|
||||||
self.html = html
|
self.html = html
|
||||||
|
self.provider = provider
|
||||||
self.force_https = force_https
|
self.force_https = force_https
|
||||||
|
self.has_trusted_networks_provider_first = has_trusted_networks_provider_first
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def inject(hass: HomeAssistant, force_https: bool) -> None:
|
async def inject(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
provider: OpenIDAuthProvider,
|
||||||
|
force_https: bool,
|
||||||
|
has_trusted_networks_provider_first: bool,
|
||||||
|
) -> None:
|
||||||
"""Inject the OIDC auth page into the frontend."""
|
"""Inject the OIDC auth page into the frontend."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await frontend_injection(hass, force_https)
|
await frontend_injection(
|
||||||
|
hass, provider, force_https, has_trusted_networks_provider_first
|
||||||
|
)
|
||||||
except Exception as e: # pylint: disable=broad-except
|
except Exception as e: # pylint: disable=broad-except
|
||||||
_LOGGER.error("Failed to inject OIDC auth page: %s", e)
|
_LOGGER.error("Failed to inject OIDC auth page: %s", e)
|
||||||
|
|
||||||
@staticmethod
|
def _should_do_oidc_redirect(self, req: web.Request) -> bool:
|
||||||
def _should_do_oidc_redirect(req: web.Request) -> bool:
|
|
||||||
"""Check if we should redirect to the OIDC flow."""
|
"""Check if we should redirect to the OIDC flow."""
|
||||||
# Set when we return from finish
|
# Set when we return from finish
|
||||||
if req.query.get("skip_oidc_redirect") == "true":
|
if req.query.get("skip_oidc_redirect") == "true":
|
||||||
@@ -118,6 +143,13 @@ class OIDCInjectedAuthPage(HomeAssistantView):
|
|||||||
if not redirect_uri:
|
if not redirect_uri:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
# Check if we are on a trusted network if we have trusted networks registered first
|
||||||
|
if (
|
||||||
|
self.has_trusted_networks_provider_first
|
||||||
|
and self.provider.is_trusted_network_host()
|
||||||
|
):
|
||||||
|
return False
|
||||||
|
|
||||||
# Handle both encoded and plain redirect_uri values.
|
# Handle both encoded and plain redirect_uri values.
|
||||||
decoded_redirect_uri = unquote(redirect_uri)
|
decoded_redirect_uri = unquote(redirect_uri)
|
||||||
return "skip_oidc_redirect=true" not in decoded_redirect_uri
|
return "skip_oidc_redirect=true" not in decoded_redirect_uri
|
||||||
|
|||||||
@@ -6,7 +6,12 @@ import logging
|
|||||||
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
from homeassistant.auth import EVENT_USER_ADDED
|
from ipaddress import (
|
||||||
|
ip_address,
|
||||||
|
IPv4Address,
|
||||||
|
IPv6Address,
|
||||||
|
)
|
||||||
|
from homeassistant.auth import EVENT_USER_ADDED, InvalidAuthError as HAInvalidAuthError
|
||||||
from homeassistant.auth.providers import (
|
from homeassistant.auth.providers import (
|
||||||
AUTH_PROVIDERS,
|
AUTH_PROVIDERS,
|
||||||
AuthProvider,
|
AuthProvider,
|
||||||
@@ -20,7 +25,6 @@ from homeassistant.auth.providers import (
|
|||||||
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
|
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.components import http, person
|
from homeassistant.components import http, person
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
|
||||||
|
|
||||||
from .config.const import (
|
from .config.const import (
|
||||||
FEATURES,
|
FEATURES,
|
||||||
@@ -31,6 +35,8 @@ from .config.const import (
|
|||||||
from .stores.state_store import StateStore
|
from .stores.state_store import StateStore
|
||||||
from .tools.types import UserDetails
|
from .tools.types import UserDetails
|
||||||
|
|
||||||
|
type IPAddress = IPv4Address | IPv6Address
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
PROVIDER_TYPE = "auth_oidc"
|
PROVIDER_TYPE = "auth_oidc"
|
||||||
@@ -38,7 +44,7 @@ HASS_PROVIDER_TYPE = "homeassistant"
|
|||||||
COOKIE_NAME = "auth_oidc_state"
|
COOKIE_NAME = "auth_oidc_state"
|
||||||
|
|
||||||
|
|
||||||
class InvalidAuthError(HomeAssistantError):
|
class InvalidAuthError(HAInvalidAuthError):
|
||||||
"""Raised when submitting invalid authentication."""
|
"""Raised when submitting invalid authentication."""
|
||||||
|
|
||||||
|
|
||||||
@@ -114,6 +120,41 @@ class OpenIDAuthProvider(AuthProvider):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def is_trusted_network_host(self) -> bool:
|
||||||
|
"""Check if the current request is coming from a trusted network host."""
|
||||||
|
ip = self._resolve_ip()
|
||||||
|
if ip is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if trusted networks auth provider is present
|
||||||
|
trusted_network_provider = self.hass.auth.get_auth_provider(
|
||||||
|
"trusted_networks", None
|
||||||
|
)
|
||||||
|
if not trusted_network_provider:
|
||||||
|
return False
|
||||||
|
|
||||||
|
_LOGGER.debug(
|
||||||
|
"Trusted networks present and checking if we should OIDC redirect"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
trusted_network_provider.async_validate_access(ip_address(ip))
|
||||||
|
_LOGGER.info("IP %s is in a trusted network, skipping OIDC flow", ip)
|
||||||
|
return True
|
||||||
|
except HAInvalidAuthError:
|
||||||
|
# Log the error
|
||||||
|
_LOGGER.info(
|
||||||
|
"IP %s is not in a trusted network, proceeding with OIDC flow", ip
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
# Catch every other error, HA might have changed the API.
|
||||||
|
# pylint: disable=broad-exception-caught
|
||||||
|
except Exception as e:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Error while validating trusted network for IP %s: %s", ip, e
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
async def async_create_state(self, redirect_uri: str, ip: str | None = None) -> str:
|
async def async_create_state(self, redirect_uri: str, ip: str | None = None) -> str:
|
||||||
"""Create a new OIDC state and return the state id."""
|
"""Create a new OIDC state and return the state id."""
|
||||||
if self._state_store is None:
|
if self._state_store is None:
|
||||||
|
|||||||
@@ -2,11 +2,13 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import re
|
import re
|
||||||
|
from collections import OrderedDict
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from urllib.parse import parse_qs, unquote, urlparse
|
from urllib.parse import parse_qs, unquote, urlparse
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.auth import InvalidAuthError
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.data_entry_flow import FlowResultType
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
@@ -24,6 +26,10 @@ from custom_components.auth_oidc.config.const import (
|
|||||||
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
|
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
|
||||||
|
|
||||||
FAKE_REDIR_URL = "http://example.com/auth/authorize?response_type=code&redirect_uri=http%3A%2F%2Fexample.com%3A8123%2F%3Fauth_callback%3D1&client_id=http%3A%2F%2Fexample.com%3A8123%2F&state=example"
|
FAKE_REDIR_URL = "http://example.com/auth/authorize?response_type=code&redirect_uri=http%3A%2F%2Fexample.com%3A8123%2F%3Fauth_callback%3D1&client_id=http%3A%2F%2Fexample.com%3A8123%2F&state=example"
|
||||||
|
DEFAULT_CONFIG = {
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool:
|
async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool:
|
||||||
@@ -40,10 +46,7 @@ async def test_setup_success_auth_provider_registration(hass: HomeAssistant):
|
|||||||
"""Test successful setup"""
|
"""Test successful setup"""
|
||||||
await setup(
|
await setup(
|
||||||
hass,
|
hass,
|
||||||
{
|
DEFAULT_CONFIG,
|
||||||
CLIENT_ID: "dummy",
|
|
||||||
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
|
|
||||||
},
|
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -62,10 +65,7 @@ async def test_provider_ip_fallback_fails_closed_without_request_context(
|
|||||||
"""Provider should not invent a shared IP when request context is missing."""
|
"""Provider should not invent a shared IP when request context is missing."""
|
||||||
await setup(
|
await setup(
|
||||||
hass,
|
hass,
|
||||||
{
|
DEFAULT_CONFIG,
|
||||||
CLIENT_ID: "dummy",
|
|
||||||
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
|
|
||||||
},
|
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -83,10 +83,7 @@ async def test_provider_cookie_header_sets_secure_when_requested(hass: HomeAssis
|
|||||||
"""Cookie header should include Secure when HTTPS is in use."""
|
"""Cookie header should include Secure when HTTPS is in use."""
|
||||||
await setup(
|
await setup(
|
||||||
hass,
|
hass,
|
||||||
{
|
DEFAULT_CONFIG,
|
||||||
CLIENT_ID: "dummy",
|
|
||||||
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
|
|
||||||
},
|
|
||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -98,6 +95,105 @@ async def test_provider_cookie_header_sets_secure_when_requested(hass: HomeAssis
|
|||||||
assert "Secure" in cookie_header
|
assert "Secure" in cookie_header
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_is_trusted_network_host_true_for_allowed_ip(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
"""Provider should detect trusted network host when trusted provider allows the IP."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
DEFAULT_CONFIG,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||||
|
|
||||||
|
class TrustedNetworksAllowProvider:
|
||||||
|
def async_validate_access(self, _ip_addr):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
hass.auth._providers = OrderedDict(
|
||||||
|
[
|
||||||
|
(("trusted_networks", None), TrustedNetworksAllowProvider()),
|
||||||
|
((provider.type, provider.id), provider),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.provider.http.current_request"
|
||||||
|
) as current_request:
|
||||||
|
current_request.get.return_value = SimpleNamespace(remote="127.0.0.1")
|
||||||
|
assert provider.is_trusted_network_host() is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_is_trusted_network_host_false_for_disallowed_ip(
|
||||||
|
hass: HomeAssistant, caplog
|
||||||
|
):
|
||||||
|
"""Provider should return False when trusted provider denies the current IP."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
DEFAULT_CONFIG,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||||
|
|
||||||
|
class TrustedNetworksDenyProvider:
|
||||||
|
def async_validate_access(self, _ip_addr):
|
||||||
|
raise InvalidAuthError("Not in trusted_networks")
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
hass.auth._providers = OrderedDict(
|
||||||
|
[
|
||||||
|
(("trusted_networks", None), TrustedNetworksDenyProvider()),
|
||||||
|
((provider.type, provider.id), provider),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.provider.http.current_request"
|
||||||
|
) as current_request:
|
||||||
|
current_request.get.return_value = SimpleNamespace(remote="127.0.0.1")
|
||||||
|
assert provider.is_trusted_network_host() is False
|
||||||
|
assert any(
|
||||||
|
level >= 0
|
||||||
|
and "is not in a trusted network, proceeding with OIDC flow" in message
|
||||||
|
for _, level, message in caplog.record_tuples
|
||||||
|
)
|
||||||
|
assert not any(
|
||||||
|
level >= 0 and "Error while validating trusted network for IP" in message
|
||||||
|
for _, level, message in caplog.record_tuples
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_is_trusted_network_host_false_without_trusted_provider(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
"""Provider should return False when trusted_networks auth provider is absent."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
DEFAULT_CONFIG,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||||
|
|
||||||
|
# Without actually getting the IP, should also be false
|
||||||
|
assert provider.is_trusted_network_host() is False
|
||||||
|
|
||||||
|
# With the IP, should be false
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.provider.http.current_request"
|
||||||
|
) as current_request:
|
||||||
|
current_request.get.return_value = SimpleNamespace(remote="127.0.0.1")
|
||||||
|
assert provider.is_trusted_network_host() is False
|
||||||
|
|
||||||
|
|
||||||
async def login_user(hass: HomeAssistant, state_id: str):
|
async def login_user(hass: HomeAssistant, state_id: str):
|
||||||
"""Helper to login a user from the stored OIDC state."""
|
"""Helper to login a user from the stored OIDC state."""
|
||||||
|
|
||||||
@@ -167,8 +263,7 @@ async def test_full_login(hass: HomeAssistant, hass_client):
|
|||||||
await setup(
|
await setup(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
CLIENT_ID: "dummy",
|
**DEFAULT_CONFIG,
|
||||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
|
||||||
FEATURES: {
|
FEATURES: {
|
||||||
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
@@ -198,8 +293,7 @@ async def test_login_with_linking(hass: HomeAssistant, hass_client):
|
|||||||
await setup(
|
await setup(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
CLIENT_ID: "dummy",
|
**DEFAULT_CONFIG,
|
||||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
|
||||||
FEATURES: {
|
FEATURES: {
|
||||||
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
FEATURES_AUTOMATIC_USER_LINKING: True,
|
FEATURES_AUTOMATIC_USER_LINKING: True,
|
||||||
@@ -233,8 +327,7 @@ async def test_login_with_person_create(hass: HomeAssistant, hass_client):
|
|||||||
await setup(
|
await setup(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
CLIENT_ID: "dummy",
|
**DEFAULT_CONFIG,
|
||||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
|
||||||
FEATURES: {
|
FEATURES: {
|
||||||
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
@@ -267,8 +360,7 @@ async def test_login_without_person_create_does_not_create_person(
|
|||||||
await setup(
|
await setup(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
CLIENT_ID: "dummy",
|
**DEFAULT_CONFIG,
|
||||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
|
||||||
FEATURES: {
|
FEATURES: {
|
||||||
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
@@ -295,8 +387,7 @@ async def test_login_shows_form(hass: HomeAssistant):
|
|||||||
await setup(
|
await setup(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
CLIENT_ID: "dummy",
|
**DEFAULT_CONFIG,
|
||||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
|
||||||
FEATURES: {
|
FEATURES: {
|
||||||
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
@@ -319,8 +410,7 @@ async def test_login_with_invalid_cookie_aborts(hass: HomeAssistant):
|
|||||||
await setup(
|
await setup(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
CLIENT_ID: "dummy",
|
**DEFAULT_CONFIG,
|
||||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
|
||||||
FEATURES: {
|
FEATURES: {
|
||||||
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
@@ -352,8 +442,7 @@ async def test_login_with_no_cookie_aborts(hass: HomeAssistant):
|
|||||||
await setup(
|
await setup(
|
||||||
hass,
|
hass,
|
||||||
{
|
{
|
||||||
CLIENT_ID: "dummy",
|
**DEFAULT_CONFIG,
|
||||||
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
|
||||||
FEATURES: {
|
FEATURES: {
|
||||||
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
FEATURES_AUTOMATIC_USER_LINKING: False,
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
|||||||
@@ -3,14 +3,17 @@
|
|||||||
import base64
|
import base64
|
||||||
import asyncio
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
from urllib.parse import parse_qs, unquote, urlparse, urlencode
|
from urllib.parse import parse_qs, unquote, urlparse, urlencode
|
||||||
import pytest
|
import pytest
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
|
|
||||||
from custom_components.auth_oidc import DOMAIN
|
from custom_components.auth_oidc import DOMAIN
|
||||||
|
from custom_components.auth_oidc.provider import COOKIE_NAME
|
||||||
from custom_components.auth_oidc.tools.oidc_client import (
|
from custom_components.auth_oidc.tools.oidc_client import (
|
||||||
OIDCDiscoveryClient,
|
OIDCDiscoveryClient,
|
||||||
OIDCDiscoveryInvalid,
|
OIDCDiscoveryInvalid,
|
||||||
@@ -248,6 +251,47 @@ async def test_full_oidc_flow(hass: HomeAssistant, hass_client):
|
|||||||
await verify_back_redirect(client, redirect_uri)
|
await verify_back_redirect(client, redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_flow_init_completes_with_state_cookie(
|
||||||
|
hass: HomeAssistant, hass_client
|
||||||
|
):
|
||||||
|
"""The provider login flow init step should finalize when the auth cookie is present."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
with mock_oidc_responses():
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
||||||
|
|
||||||
|
state, _, status = await get_welcome_for_client(client, redirect_uri)
|
||||||
|
assert status == 200
|
||||||
|
|
||||||
|
authorization_url = await get_redirect_auth_url(client)
|
||||||
|
session = async_get_clientsession(hass)
|
||||||
|
resp_auth = session.get(authorization_url, allow_redirects=False)
|
||||||
|
json_auth = await resp_auth.json()
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/callback?code={json_auth['code']}&state={state}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 302
|
||||||
|
|
||||||
|
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||||
|
flow = await provider.async_login_flow({})
|
||||||
|
|
||||||
|
fake_request = SimpleNamespace(
|
||||||
|
cookies={COOKIE_NAME: state},
|
||||||
|
remote="127.0.0.1",
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.provider.http.current_request"
|
||||||
|
) as current_request:
|
||||||
|
current_request.get.return_value = fake_request
|
||||||
|
result = await flow.async_step_init({})
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||||
|
|
||||||
|
|
||||||
async def discovery_test_through_redirect(
|
async def discovery_test_through_redirect(
|
||||||
hass_client, caplog, scenario: str, match_log_line: str
|
hass_client, caplog, scenario: str, match_log_line: str
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import base64
|
import base64
|
||||||
import os
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
from urllib.parse import parse_qs, quote, unquote, urlparse, urlencode
|
from urllib.parse import parse_qs, quote, unquote, urlparse, urlencode
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
from auth_oidc.config.const import (
|
from auth_oidc.config.const import (
|
||||||
@@ -25,6 +26,22 @@ from custom_components.auth_oidc.endpoints.injected_auth_page import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
MOBILE_CLIENT_ID = "https://home-assistant.io/Android"
|
MOBILE_CLIENT_ID = "https://home-assistant.io/Android"
|
||||||
|
WELCOME_PATH = "/auth/oidc/welcome"
|
||||||
|
INJECTION_SCRIPT_MARKER = "<script src='/auth/oidc/static/injection.js"
|
||||||
|
|
||||||
|
|
||||||
|
def assert_redirects_to_welcome(resp) -> None:
|
||||||
|
"""Assert a response redirects to the OIDC welcome endpoint."""
|
||||||
|
assert resp.status == 302
|
||||||
|
location = resp.headers["Location"]
|
||||||
|
parsed_location = urlparse(location)
|
||||||
|
assert parsed_location.path == WELCOME_PATH
|
||||||
|
|
||||||
|
|
||||||
|
async def assert_normal_login_screen(resp) -> None:
|
||||||
|
"""Assert we stayed on the auth page and render the injected normal login HTML."""
|
||||||
|
assert resp.status == 200
|
||||||
|
assert INJECTION_SCRIPT_MARKER in await resp.text()
|
||||||
|
|
||||||
|
|
||||||
def create_redirect_uri(client_id: str) -> str:
|
def create_redirect_uri(client_id: str) -> str:
|
||||||
@@ -310,7 +327,9 @@ async def test_welcome_desktop_auto_redirects_without_other_providers(
|
|||||||
"""Welcome should auto-redirect desktop clients when no other providers exist."""
|
"""Welcome should auto-redirect desktop clients when no other providers exist."""
|
||||||
|
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
hass.auth._providers = [] # Clear initial providers out
|
hass.auth._providers = {} # Clear initial providers out
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
await setup(hass)
|
await setup(hass)
|
||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
@@ -334,7 +353,7 @@ async def test_redirect_without_cookie_goes_to_welcome(
|
|||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||||
assert resp.status == 302
|
assert resp.status == 302
|
||||||
assert "/auth/oidc/welcome" in resp.headers["Location"]
|
assert_redirects_to_welcome(resp)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -730,10 +749,7 @@ async def test_frontend_injection(
|
|||||||
|
|
||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
resp = await client.get("/auth/authorize", allow_redirects=False)
|
resp = await client.get("/auth/authorize", allow_redirects=False)
|
||||||
assert resp.status == 200 # 200 because there is no redirect_uri
|
await assert_normal_login_screen(resp)
|
||||||
text = await resp.text()
|
|
||||||
|
|
||||||
assert "<script src='/auth/oidc/static/injection.js" in text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -760,8 +776,12 @@ async def test_frontend_injection_logs_and_returns_when_route_handler_is_unexpec
|
|||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter([FakeRoute()])
|
return iter([FakeRoute()])
|
||||||
|
|
||||||
|
provider = MagicMock()
|
||||||
|
|
||||||
with patch.object(hass.http.app.router, "resources", return_value=[FakeResource()]):
|
with patch.object(hass.http.app.router, "resources", return_value=[FakeResource()]):
|
||||||
await frontend_injection(hass, force_https=False)
|
await frontend_injection(
|
||||||
|
hass, provider, force_https=False, has_trusted_networks_provider_first=False
|
||||||
|
)
|
||||||
|
|
||||||
assert "Unexpected route handler type" in caplog.text
|
assert "Unexpected route handler type" in caplog.text
|
||||||
assert (
|
assert (
|
||||||
@@ -780,7 +800,10 @@ async def test_injected_auth_page_inject_logs_errors(hass: HomeAssistant, caplog
|
|||||||
"custom_components.auth_oidc.endpoints.injected_auth_page.frontend_injection",
|
"custom_components.auth_oidc.endpoints.injected_auth_page.frontend_injection",
|
||||||
side_effect=RuntimeError("boom"),
|
side_effect=RuntimeError("boom"),
|
||||||
):
|
):
|
||||||
await OIDCInjectedAuthPage.inject(hass, force_https=False)
|
provider = MagicMock()
|
||||||
|
await OIDCInjectedAuthPage.inject(
|
||||||
|
hass, provider, force_https=False, has_trusted_networks_provider_first=False
|
||||||
|
)
|
||||||
|
|
||||||
assert "Failed to inject OIDC auth page: boom" in caplog.text
|
assert "Failed to inject OIDC auth page: boom" in caplog.text
|
||||||
|
|
||||||
@@ -836,5 +859,71 @@ async def test_injected_auth_page_returns_original_html_when_skipped(
|
|||||||
client = await hass_client()
|
client = await hass_client()
|
||||||
response = await client.get(request_target, allow_redirects=False)
|
response = await client.get(request_target, allow_redirects=False)
|
||||||
|
|
||||||
assert response.status == 200
|
await assert_normal_login_screen(response)
|
||||||
assert "<script src='/auth/oidc/static/injection.js" in await response.text()
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_injected_auth_page_trusted_networks_bypass_skips_oidc_redirect(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Trusted network hosts should bypass OIDC redirect when trusted_networks is first."""
|
||||||
|
|
||||||
|
class TrustedNetworksAllowProvider:
|
||||||
|
def async_validate_access(self, _ip_addr):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
hass.auth._providers = OrderedDict(
|
||||||
|
[(("trusted_networks", None), TrustedNetworksAllowProvider())]
|
||||||
|
)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
await setup_mock_authorize_route(hass)
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
encoded_redirect_uri = quote(create_redirect_uri(client.make_url("/")), safe="")
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/authorize?redirect_uri={encoded_redirect_uri}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
await assert_normal_login_screen(resp)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_injected_auth_page_ignores_trusted_networks_when_not_first(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""OIDC redirect should continue when trusted_networks is not the first provider."""
|
||||||
|
|
||||||
|
class DummyProvider:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class TrustedNetworksAllowProvider:
|
||||||
|
def async_validate_access(self, _ip_addr):
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Keep trusted_networks present but not first, so bypass should not apply.
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
hass.auth._providers = OrderedDict(
|
||||||
|
[
|
||||||
|
(("homeassistant", None), DummyProvider()),
|
||||||
|
(("trusted_networks", None), TrustedNetworksAllowProvider()),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
await setup_mock_authorize_route(hass)
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
encoded_redirect_uri = quote(create_redirect_uri(client.make_url("/")), safe="")
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/authorize?redirect_uri={encoded_redirect_uri}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_redirects_to_welcome(resp)
|
||||||
|
|||||||
Reference in New Issue
Block a user