"""OIDC Integration for Home Assistant.""" import logging import re from typing import OrderedDict from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.components.http import StaticPathConfig # Import and re-export config schema explictly # pylint: disable=useless-import-alias from .config import CONFIG_SCHEMA as CONFIG_SCHEMA # Get all the constants for the config from .config import ( DOMAIN, DEFAULT_TITLE, CLIENT_ID, CLIENT_SECRET, DISCOVERY_URL, DISPLAY_NAME, ID_TOKEN_SIGNING_ALGORITHM, GROUPS_SCOPE, ADDITIONAL_SCOPES, FEATURES, CLAIMS, ROLES, NETWORK, FEATURES_INCLUDE_GROUPS_SCOPE, FEATURES_DEFAULT_REDIRECT, FEATURES_FORCE_HTTPS, REQUIRED_SCOPES, ) from .config import convert_ui_config_entry_to_internal_format from .endpoints import ( OIDCWelcomeView, OIDCRedirectView, OIDCFinishView, OIDCCallbackView, OIDCInjectedAuthPage, OIDCDeviceSSE, ) from .tools.oidc_client import OIDCClient from .tools.types import OIDCWelcomeOptions from .provider import OpenIDAuthProvider _LOGGER = logging.getLogger(__name__) async def async_setup(hass: HomeAssistant, config): """Add the OIDC Auth Provider to the providers in Home Assistant (YAML config).""" if DOMAIN not in config: return True my_config = config[DOMAIN] # Store YAML config for later access by config flow if DOMAIN not in hass.data: hass.data[DOMAIN] = {} hass.data[DOMAIN]["yaml_config"] = my_config await _setup_oidc_provider( hass, my_config, config[DOMAIN].get(DISPLAY_NAME, DEFAULT_TITLE) ) return True async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): """Set up OIDC Authentication from a config entry (UI config).""" # Convert config entry data to the format expected by the existing setup config_data = entry.data.copy() # Convert config entry format to internal format my_config = convert_ui_config_entry_to_internal_format(config_data) # Get display name from config entry display_name = config_data.get("display_name", DEFAULT_TITLE) await _setup_oidc_provider(hass, my_config, display_name) return True async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry): """Unload a config entry.""" # OIDC auth providers cannot be easily unloaded as they are integrated # into Home Assistant's auth system. A restart is required. return False async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str): """Set up the OIDC provider with the given configuration.""" providers = OrderedDict() # Use private APIs until there is a real auth platform # pylint: disable=protected-access provider = OpenIDAuthProvider(hass, hass.auth._store, my_config) providers[(provider.type, provider.id)] = provider # Get current provider count has_other_auth_providers = len(hass.auth._providers) > 0 providers.update(hass.auth._providers) hass.auth._providers = providers # pylint: enable=protected-access _LOGGER.info("Registered OIDC provider") # Set the correct scopes # Always use 'openid' & 'profile' as they are specified in the OIDC spec # All servers should support this scope = REQUIRED_SCOPES # Include groups if requested (default is to include 'groups' # as a scope for Authelia & Authentik) features_config = my_config.get(FEATURES, {}) include_groups_scope = features_config.get(FEATURES_INCLUDE_GROUPS_SCOPE, True) groups_scope = my_config.get(GROUPS_SCOPE, "groups") if include_groups_scope: scope += " " + groups_scope # Add additional scopes if configured additional_scopes = my_config.get(ADDITIONAL_SCOPES, []) if additional_scopes: # Ensure we have a space before adding additional scopes if scope: scope += " " scope += " ".join(additional_scopes) # Create the OIDC client oidc_client = OIDCClient( hass=hass, discovery_url=my_config.get(DISCOVERY_URL), client_id=my_config.get(CLIENT_ID), scope=scope, client_secret=my_config.get(CLIENT_SECRET), id_token_signing_alg=my_config.get(ID_TOKEN_SIGNING_ALGORITHM), features=my_config.get(FEATURES, {}), claims=my_config.get(CLAIMS, {}), roles=my_config.get(ROLES, {}), network=my_config.get(NETWORK, {}), ) # Register the views name = display_name name = re.sub(r"[^A-Za-z0-9 _\-\(\)]", "", name) force_https = features_config.get(FEATURES_FORCE_HTTPS, False) default_redirect = features_config.get(FEATURES_DEFAULT_REDIRECT, False) await hass.http.async_register_static_paths( [ StaticPathConfig( "/auth/oidc/static/style.css", hass.config.path("custom_components/auth_oidc/static/style.css"), cache_headers=True, ), ] ) hass.http.register_view( OIDCWelcomeView( provider, OIDCWelcomeOptions( name=name, force_https=force_https, has_other_auth_providers=has_other_auth_providers, prefers_skipping=default_redirect, ), ) ) hass.http.register_view(OIDCDeviceSSE(provider)) hass.http.register_view(OIDCRedirectView(oidc_client, provider, force_https)) hass.http.register_view(OIDCCallbackView(oidc_client, provider, force_https)) hass.http.register_view(OIDCFinishView(provider)) _LOGGER.info("Registered OIDC views") # Inject OIDC code into the frontend for /auth/authorize for automatic redirect await OIDCInjectedAuthPage.inject(hass, force_https) return True