227 lines
7.5 KiB
Python
227 lines
7.5 KiB
Python
"""OIDC Integration for Home Assistant."""
|
|
|
|
import logging
|
|
import re
|
|
from typing import OrderedDict
|
|
|
|
from homeassistant.config_entries import ConfigEntry
|
|
from homeassistant.core import HomeAssistant
|
|
from homeassistant.components.http import StaticPathConfig
|
|
|
|
# Import and re-export config schema explictly
|
|
# pylint: disable=useless-import-alias
|
|
from .config import CONFIG_SCHEMA as CONFIG_SCHEMA
|
|
|
|
# Get all the constants for the config
|
|
from .config import (
|
|
DOMAIN,
|
|
DEFAULT_TITLE,
|
|
CLIENT_ID,
|
|
CLIENT_SECRET,
|
|
DISCOVERY_URL,
|
|
DISPLAY_NAME,
|
|
ID_TOKEN_SIGNING_ALGORITHM,
|
|
GROUPS_SCOPE,
|
|
ADDITIONAL_SCOPES,
|
|
FEATURES,
|
|
CLAIMS,
|
|
ROLES,
|
|
NETWORK,
|
|
FEATURES_INCLUDE_GROUPS_SCOPE,
|
|
FEATURES_DEFAULT_REDIRECT,
|
|
FEATURES_FORCE_HTTPS,
|
|
REQUIRED_SCOPES,
|
|
)
|
|
|
|
from .config import convert_ui_config_entry_to_internal_format
|
|
|
|
from .endpoints import (
|
|
OIDCWelcomeView,
|
|
OIDCRedirectView,
|
|
OIDCFinishView,
|
|
OIDCCallbackView,
|
|
OIDCInjectedAuthPage,
|
|
OIDCDeviceSSE,
|
|
)
|
|
from .tools.oidc_client import OIDCClient
|
|
from .tools.types import OIDCWelcomeOptions
|
|
from .provider import OpenIDAuthProvider
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
async def async_setup(hass: HomeAssistant, config):
|
|
"""Add the OIDC Auth Provider to the providers in Home Assistant (YAML config)."""
|
|
if DOMAIN not in config:
|
|
return True
|
|
|
|
my_config = config[DOMAIN]
|
|
|
|
# Store YAML config for later access by config flow
|
|
if DOMAIN not in hass.data:
|
|
hass.data[DOMAIN] = {}
|
|
hass.data[DOMAIN]["yaml_config"] = my_config
|
|
|
|
await _setup_oidc_provider(
|
|
hass, my_config, config[DOMAIN].get(DISPLAY_NAME, DEFAULT_TITLE)
|
|
)
|
|
return True
|
|
|
|
|
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
|
|
"""Set up OIDC Authentication from a config entry (UI config)."""
|
|
# Convert config entry data to the format expected by the existing setup
|
|
config_data = entry.data.copy()
|
|
|
|
# Convert config entry format to internal format
|
|
my_config = convert_ui_config_entry_to_internal_format(config_data)
|
|
|
|
# Get display name from config entry
|
|
display_name = config_data.get("display_name", DEFAULT_TITLE)
|
|
|
|
await _setup_oidc_provider(hass, my_config, display_name)
|
|
return True
|
|
|
|
|
|
async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry):
|
|
"""Unload a config entry."""
|
|
# OIDC auth providers cannot be easily unloaded as they are integrated
|
|
# into Home Assistant's auth system. A restart is required.
|
|
return False
|
|
|
|
|
|
async def _register_oidc_provider(hass: HomeAssistant, my_config: dict):
|
|
"""Register the OIDC provider in Home Assistant's auth system."""
|
|
# Use private APIs until there is a real auth platform
|
|
|
|
# pylint: disable=protected-access
|
|
providers = OrderedDict()
|
|
provider = OpenIDAuthProvider(hass, hass.auth._store, my_config)
|
|
|
|
existing_auth_providers = hass.auth._providers.copy()
|
|
_LOGGER.debug("Current auth providers: %s", list(existing_auth_providers.keys()))
|
|
auth_provider_count = len(existing_auth_providers)
|
|
has_trusted_networks_provider_first = False
|
|
|
|
if auth_provider_count > 0:
|
|
# Pop the first provider from the existing providers to check if it's trusted_networks
|
|
first_provider_key, first_provider_obj = next(
|
|
iter(existing_auth_providers.items())
|
|
)
|
|
existing_auth_providers.pop(first_provider_key)
|
|
|
|
if first_provider_key[0] == "trusted_networks":
|
|
_LOGGER.info(
|
|
"Trusted Networks provider detected as the first auth provider. "
|
|
+ "Keeping registration order intact."
|
|
)
|
|
providers[first_provider_key] = first_provider_obj
|
|
has_trusted_networks_provider_first = True
|
|
else:
|
|
# Reset back to what we had before
|
|
existing_auth_providers = hass.auth._providers.copy()
|
|
|
|
# Register OIDC at the start of the array
|
|
# OIDC needs to be first because it needs to process the login cookie after sign-in
|
|
providers[(provider.type, provider.id)] = provider
|
|
|
|
# Add back any other providers that were already registered
|
|
providers.update(existing_auth_providers)
|
|
|
|
_LOGGER.debug("Final auth providers: %s", list(providers.values()))
|
|
hass.auth._providers = providers
|
|
# pylint: enable=protected-access
|
|
|
|
_LOGGER.info("Registered OIDC provider")
|
|
return provider, auth_provider_count, has_trusted_networks_provider_first
|
|
|
|
|
|
# pylint: disable=too-many-locals
|
|
async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str):
|
|
"""Set up the OIDC provider with the given configuration."""
|
|
(
|
|
provider,
|
|
auth_provider_count,
|
|
has_trusted_networks_provider_first,
|
|
) = await _register_oidc_provider(hass, my_config)
|
|
|
|
# Set the correct scopes
|
|
# Always use 'openid' & 'profile' as they are specified in the OIDC spec
|
|
# All servers should support this
|
|
scope = REQUIRED_SCOPES
|
|
|
|
# Include groups if requested (default is to include 'groups'
|
|
# as a scope for Authelia & Authentik)
|
|
features_config = my_config.get(FEATURES, {})
|
|
include_groups_scope = features_config.get(FEATURES_INCLUDE_GROUPS_SCOPE, True)
|
|
groups_scope = my_config.get(GROUPS_SCOPE, "groups")
|
|
if include_groups_scope:
|
|
scope += " " + groups_scope
|
|
# Add additional scopes if configured
|
|
additional_scopes = my_config.get(ADDITIONAL_SCOPES, [])
|
|
if additional_scopes:
|
|
# Ensure we have a space before adding additional scopes
|
|
if scope:
|
|
scope += " "
|
|
scope += " ".join(additional_scopes)
|
|
|
|
# Create the OIDC client
|
|
oidc_client = OIDCClient(
|
|
hass=hass,
|
|
discovery_url=my_config.get(DISCOVERY_URL),
|
|
client_id=my_config.get(CLIENT_ID),
|
|
scope=scope,
|
|
client_secret=my_config.get(CLIENT_SECRET),
|
|
id_token_signing_alg=my_config.get(ID_TOKEN_SIGNING_ALGORITHM),
|
|
features=my_config.get(FEATURES, {}),
|
|
claims=my_config.get(CLAIMS, {}),
|
|
roles=my_config.get(ROLES, {}),
|
|
network=my_config.get(NETWORK, {}),
|
|
)
|
|
|
|
# Register the views
|
|
name = display_name
|
|
name = re.sub(r"[^A-Za-z0-9 _\-\(\)]", "", name)
|
|
|
|
force_https = features_config.get(FEATURES_FORCE_HTTPS, False)
|
|
default_redirect = features_config.get(FEATURES_DEFAULT_REDIRECT, False)
|
|
|
|
await hass.http.async_register_static_paths(
|
|
[
|
|
StaticPathConfig(
|
|
"/auth/oidc/static/style.css",
|
|
hass.config.path("custom_components/auth_oidc/static/style.css"),
|
|
cache_headers=True,
|
|
),
|
|
]
|
|
)
|
|
|
|
has_only_trusted_networks = (
|
|
auth_provider_count == 1 and has_trusted_networks_provider_first
|
|
)
|
|
|
|
hass.http.register_view(
|
|
OIDCWelcomeView(
|
|
provider,
|
|
OIDCWelcomeOptions(
|
|
name=name,
|
|
force_https=force_https,
|
|
has_other_auth_providers=auth_provider_count > 0,
|
|
prefers_skipping=default_redirect or has_only_trusted_networks,
|
|
),
|
|
)
|
|
)
|
|
hass.http.register_view(OIDCDeviceSSE(provider))
|
|
hass.http.register_view(OIDCRedirectView(oidc_client, provider, force_https))
|
|
hass.http.register_view(OIDCCallbackView(oidc_client, provider, force_https))
|
|
hass.http.register_view(OIDCFinishView(provider))
|
|
|
|
_LOGGER.info("Registered OIDC views")
|
|
|
|
# Inject OIDC code into the frontend for /auth/authorize for automatic redirect
|
|
await OIDCInjectedAuthPage.inject(
|
|
hass, provider, force_https, has_trusted_networks_provider_first
|
|
)
|
|
|
|
return True
|