Files
Christiaan Goossens d3c359064d Do not reveal existance of trusted networks provider (#302)
* Skip welcome page if the only other provider is trusted networks

* Add test
2026-05-01 14:27:23 +02:00

227 lines
7.5 KiB
Python

"""OIDC Integration for Home Assistant."""
import logging
import re
from typing import OrderedDict
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.components.http import StaticPathConfig
# Import and re-export config schema explictly
# pylint: disable=useless-import-alias
from .config import CONFIG_SCHEMA as CONFIG_SCHEMA
# Get all the constants for the config
from .config import (
DOMAIN,
DEFAULT_TITLE,
CLIENT_ID,
CLIENT_SECRET,
DISCOVERY_URL,
DISPLAY_NAME,
ID_TOKEN_SIGNING_ALGORITHM,
GROUPS_SCOPE,
ADDITIONAL_SCOPES,
FEATURES,
CLAIMS,
ROLES,
NETWORK,
FEATURES_INCLUDE_GROUPS_SCOPE,
FEATURES_DEFAULT_REDIRECT,
FEATURES_FORCE_HTTPS,
REQUIRED_SCOPES,
)
from .config import convert_ui_config_entry_to_internal_format
from .endpoints import (
OIDCWelcomeView,
OIDCRedirectView,
OIDCFinishView,
OIDCCallbackView,
OIDCInjectedAuthPage,
OIDCDeviceSSE,
)
from .tools.oidc_client import OIDCClient
from .tools.types import OIDCWelcomeOptions
from .provider import OpenIDAuthProvider
_LOGGER = logging.getLogger(__name__)
async def async_setup(hass: HomeAssistant, config):
"""Add the OIDC Auth Provider to the providers in Home Assistant (YAML config)."""
if DOMAIN not in config:
return True
my_config = config[DOMAIN]
# Store YAML config for later access by config flow
if DOMAIN not in hass.data:
hass.data[DOMAIN] = {}
hass.data[DOMAIN]["yaml_config"] = my_config
await _setup_oidc_provider(
hass, my_config, config[DOMAIN].get(DISPLAY_NAME, DEFAULT_TITLE)
)
return True
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Set up OIDC Authentication from a config entry (UI config)."""
# Convert config entry data to the format expected by the existing setup
config_data = entry.data.copy()
# Convert config entry format to internal format
my_config = convert_ui_config_entry_to_internal_format(config_data)
# Get display name from config entry
display_name = config_data.get("display_name", DEFAULT_TITLE)
await _setup_oidc_provider(hass, my_config, display_name)
return True
async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry):
"""Unload a config entry."""
# OIDC auth providers cannot be easily unloaded as they are integrated
# into Home Assistant's auth system. A restart is required.
return False
async def _register_oidc_provider(hass: HomeAssistant, my_config: dict):
"""Register the OIDC provider in Home Assistant's auth system."""
# Use private APIs until there is a real auth platform
# pylint: disable=protected-access
providers = OrderedDict()
provider = OpenIDAuthProvider(hass, hass.auth._store, my_config)
existing_auth_providers = hass.auth._providers.copy()
_LOGGER.debug("Current auth providers: %s", list(existing_auth_providers.keys()))
auth_provider_count = len(existing_auth_providers)
has_trusted_networks_provider_first = False
if auth_provider_count > 0:
# Pop the first provider from the existing providers to check if it's trusted_networks
first_provider_key, first_provider_obj = next(
iter(existing_auth_providers.items())
)
existing_auth_providers.pop(first_provider_key)
if first_provider_key[0] == "trusted_networks":
_LOGGER.info(
"Trusted Networks provider detected as the first auth provider. "
+ "Keeping registration order intact."
)
providers[first_provider_key] = first_provider_obj
has_trusted_networks_provider_first = True
else:
# Reset back to what we had before
existing_auth_providers = hass.auth._providers.copy()
# Register OIDC at the start of the array
# OIDC needs to be first because it needs to process the login cookie after sign-in
providers[(provider.type, provider.id)] = provider
# Add back any other providers that were already registered
providers.update(existing_auth_providers)
_LOGGER.debug("Final auth providers: %s", list(providers.values()))
hass.auth._providers = providers
# pylint: enable=protected-access
_LOGGER.info("Registered OIDC provider")
return provider, auth_provider_count, has_trusted_networks_provider_first
# pylint: disable=too-many-locals
async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str):
"""Set up the OIDC provider with the given configuration."""
(
provider,
auth_provider_count,
has_trusted_networks_provider_first,
) = await _register_oidc_provider(hass, my_config)
# Set the correct scopes
# Always use 'openid' & 'profile' as they are specified in the OIDC spec
# All servers should support this
scope = REQUIRED_SCOPES
# Include groups if requested (default is to include 'groups'
# as a scope for Authelia & Authentik)
features_config = my_config.get(FEATURES, {})
include_groups_scope = features_config.get(FEATURES_INCLUDE_GROUPS_SCOPE, True)
groups_scope = my_config.get(GROUPS_SCOPE, "groups")
if include_groups_scope:
scope += " " + groups_scope
# Add additional scopes if configured
additional_scopes = my_config.get(ADDITIONAL_SCOPES, [])
if additional_scopes:
# Ensure we have a space before adding additional scopes
if scope:
scope += " "
scope += " ".join(additional_scopes)
# Create the OIDC client
oidc_client = OIDCClient(
hass=hass,
discovery_url=my_config.get(DISCOVERY_URL),
client_id=my_config.get(CLIENT_ID),
scope=scope,
client_secret=my_config.get(CLIENT_SECRET),
id_token_signing_alg=my_config.get(ID_TOKEN_SIGNING_ALGORITHM),
features=my_config.get(FEATURES, {}),
claims=my_config.get(CLAIMS, {}),
roles=my_config.get(ROLES, {}),
network=my_config.get(NETWORK, {}),
)
# Register the views
name = display_name
name = re.sub(r"[^A-Za-z0-9 _\-\(\)]", "", name)
force_https = features_config.get(FEATURES_FORCE_HTTPS, False)
default_redirect = features_config.get(FEATURES_DEFAULT_REDIRECT, False)
await hass.http.async_register_static_paths(
[
StaticPathConfig(
"/auth/oidc/static/style.css",
hass.config.path("custom_components/auth_oidc/static/style.css"),
cache_headers=True,
),
]
)
has_only_trusted_networks = (
auth_provider_count == 1 and has_trusted_networks_provider_first
)
hass.http.register_view(
OIDCWelcomeView(
provider,
OIDCWelcomeOptions(
name=name,
force_https=force_https,
has_other_auth_providers=auth_provider_count > 0,
prefers_skipping=default_redirect or has_only_trusted_networks,
),
)
)
hass.http.register_view(OIDCDeviceSSE(provider))
hass.http.register_view(OIDCRedirectView(oidc_client, provider, force_https))
hass.http.register_view(OIDCCallbackView(oidc_client, provider, force_https))
hass.http.register_view(OIDCFinishView(provider))
_LOGGER.info("Registered OIDC views")
# Inject OIDC code into the frontend for /auth/authorize for automatic redirect
await OIDCInjectedAuthPage.inject(
hass, provider, force_https, has_trusted_networks_provider_first
)
return True