Files
hass-oidc-auth/custom_components/auth_oidc/config/ui_flow.py
2025-10-04 17:34:31 +02:00

843 lines
29 KiB
Python

"""Config flow for OIDC Authentication integration."""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from typing import Any
import aiohttp
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from .const import (
DOMAIN,
DEFAULT_ADMIN_GROUP,
CLIENT_ID,
CLIENT_SECRET,
DISCOVERY_URL,
DISPLAY_NAME,
FEATURES,
CLAIMS,
ROLES,
DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
)
from ..tools.oidc_client import (
OIDCDiscoveryClient,
OIDCDiscoveryInvalid,
OIDCJWKSInvalid,
)
from .provider_catalog import (
OIDC_PROVIDERS,
get_provider_name,
get_provider_docs_url,
)
from ..tools.validation import (
validate_discovery_url,
sanitize_client_secret,
validate_client_id,
)
_LOGGER = logging.getLogger(__name__)
# Configuration field names
CONF_PROVIDER = "provider"
CONF_CLIENT_ID = "client_id"
CONF_CLIENT_SECRET = "client_secret"
CONF_DISCOVERY_URL = "discovery_url"
CONF_ENABLE_GROUPS = "enable_groups"
CONF_ADMIN_GROUP = "admin_group"
CONF_USER_GROUP = "user_group"
CONF_ENABLE_USER_LINKING = "enable_user_linking"
# Cache settings
DISCOVERY_CACHE_TTL = 300 # 5 minutes
MAX_CACHE_SIZE = 10
@dataclass
class FlowState:
"""State tracking for the configuration flow."""
provider: str | None = None
discovery_url: str | None = None
@dataclass
class ClientConfig:
"""Client configuration settings."""
client_id: str | None = None
client_secret: str | None = None
@dataclass
class FeatureConfig:
"""Feature configuration settings."""
enable_groups: bool = False
admin_group: str = DEFAULT_ADMIN_GROUP
user_group: str | None = None
enable_user_linking: bool = False
class OIDCConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for OIDC Authentication."""
VERSION = 1
def is_matching(self, other_flow):
"""Check if this flow is the same as another flow."""
self_state = getattr(self, "_flow_state", None)
other_state = getattr(other_flow, "_flow_state", None)
if not self_state or not other_state:
return False
self_discovery_url = self_state.discovery_url
other_discovery_url = other_state.discovery_url
return (
self_discovery_url
and other_discovery_url
and self_discovery_url.rstrip("/").lower()
== other_discovery_url.rstrip("/").lower()
)
def __init__(self):
"""Initialize the config flow."""
self._flow_state = FlowState()
self._client_config = ClientConfig()
self._feature_config = FeatureConfig()
self._discovery_cache = {}
self._cache_timestamps = {}
@property
def current_provider_config(self) -> dict[str, Any]:
"""Get the configuration for the currently selected provider."""
if not self._flow_state.provider:
return {}
return OIDC_PROVIDERS.get(self._flow_state.provider, {})
@property
def current_provider_name(self) -> str:
"""Get the name of the currently selected provider."""
return get_provider_name(self._flow_state.provider)
def _cleanup_discovery_cache(self) -> None:
"""Remove expired and excess cache entries."""
current_time = time.time()
# Remove expired entries
expired_keys = [
key
for key, timestamp in self._cache_timestamps.items()
if current_time - timestamp > DISCOVERY_CACHE_TTL
]
for key in expired_keys:
self._discovery_cache.pop(key, None)
self._cache_timestamps.pop(key, None)
# Remove oldest entries if cache is too large
if len(self._discovery_cache) > MAX_CACHE_SIZE:
sorted_items = sorted(self._cache_timestamps.items(), key=lambda x: x[1])
excess_count = len(self._discovery_cache) - MAX_CACHE_SIZE
for key, _ in sorted_items[:excess_count]:
self._discovery_cache.pop(key, None)
self._cache_timestamps.pop(key, None)
def _is_cache_valid(self, cache_key: str) -> bool:
"""Check if a cache entry is still valid."""
if cache_key not in self._cache_timestamps:
return False
age = time.time() - self._cache_timestamps[cache_key]
return age <= DISCOVERY_CACHE_TTL
# =================
# Step 1: Provider selection
# =================
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the initial step - provider selection."""
# Check if OIDC is already configured (only one instance allowed)
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
# Check if YAML configuration exists
if self.hass.data.get(DOMAIN, {}).get("yaml_config"):
return self.async_abort(reason="yaml_configured")
errors = {}
if user_input is not None:
self._flow_state.provider = user_input[CONF_PROVIDER]
# If provider has a predefined discovery URL, prefill it but still
# show the discovery URL step so the user can customize it.
predefined = self.current_provider_config.get("discovery_url")
if predefined:
self._flow_state.discovery_url = predefined
# Always request discovery URL next (prefilled when available)
return await self.async_step_discovery_url()
data_schema = vol.Schema(
{
vol.Required(CONF_PROVIDER): vol.In(
{key: provider["name"] for key, provider in OIDC_PROVIDERS.items()}
)
}
)
return self.async_show_form(
step_id="user",
data_schema=data_schema,
errors=errors,
description_placeholders={},
)
# =================
# Step 2: Discovery URL
# =================
async def async_step_discovery_url(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle discovery URL input for providers requiring URL configuration."""
errors = {}
if user_input is not None:
discovery_url = user_input[CONF_DISCOVERY_URL].rstrip("/")
# Validate discovery URL format
if not validate_discovery_url(discovery_url):
errors["discovery_url"] = "invalid_url_format"
else:
self._flow_state.discovery_url = discovery_url
return await self.async_step_validate_connection()
provider_name = self.current_provider_name
provider_key = self._flow_state.provider
# Pre-populate with existing discovery URL if available
default_url = (
self._flow_state.discovery_url
if self._flow_state.discovery_url
else vol.UNDEFINED
)
data_schema = vol.Schema(
{vol.Required(CONF_DISCOVERY_URL, default=default_url): str}
)
return self.async_show_form(
step_id="discovery_url",
data_schema=data_schema,
errors=errors,
description_placeholders={
"provider_name": provider_name,
"documentation_url": get_provider_docs_url(provider_key),
},
)
# =================
# Step 3: Discovery Validation
# =================
async def _handle_validation_actions(
self, user_input: dict[str, Any]
) -> FlowResult | None:
"""Handle user actions from the validation form so they can fix errors."""
action = user_input.get("action")
# Handle special actions first
if action == "retry":
return None # Continue with validation
if action == "continue":
return await self.async_step_client_config()
# Handle redirect actions
action_handlers = {
"fix_discovery": self.async_step_discovery_url,
"change_provider": self.async_step_user,
}
handler = action_handlers.get(action)
return await handler() if handler else None
async def _perform_oidc_validation(self) -> tuple[dict, dict]:
"""Perform the actual OIDC validation and return discovery doc and errors."""
errors = {}
discovery_doc = {}
try:
http_session = aiohttp.ClientSession()
discovery_client = OIDCDiscoveryClient(
discovery_url=self._flow_state.discovery_url,
http_session=http_session,
verification_context={
# Cannot be changed from the UI config currently
"id_token_signing_alg": DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
},
)
# Clean up expired cache entries first
self._cleanup_discovery_cache()
# Check if discovery document is already cached and valid
cache_key = self._flow_state.discovery_url
if cache_key in self._discovery_cache and self._is_cache_valid(cache_key):
discovery_doc = self._discovery_cache[cache_key]
# Still validate JWKS if available since this might be a retry
if "jwks_uri" in discovery_doc:
await discovery_client.fetch_jwks(discovery_doc["jwks_uri"])
else:
# Perform discovery and JWKS validation
discovery_doc = await discovery_client.fetch_discovery_document()
# Cache the discovery document with timestamp
self._discovery_cache[cache_key] = discovery_doc
self._cache_timestamps[cache_key] = time.time()
# Validate JWKS if available
if "jwks_uri" in discovery_doc:
await discovery_client.fetch_jwks(discovery_doc["jwks_uri"])
except OIDCDiscoveryInvalid as e:
errors["base"] = "discovery_invalid"
errors["detail_string"] = e.get_detail_string()
except OIDCJWKSInvalid:
errors["base"] = "jwks_invalid"
except aiohttp.ClientError:
errors["base"] = "cannot_connect"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected error during validation")
errors["base"] = "unknown"
await http_session.close()
return discovery_doc, errors
def _get_action_options(self, has_errors: bool) -> dict[str, str]:
"""Get action options based on validation state."""
if has_errors:
return {
"retry": "Retry Validation",
"fix_discovery": "Change Discovery URL",
"change_provider": "Change Provider",
}
return {
"continue": "Continue Setup",
"fix_discovery": "Change Discovery URL",
"change_provider": "Change Provider",
}
def _build_discovery_success_details(self, discovery_doc: dict) -> str:
"""Build success details from discovery document."""
return (
f"✅ Connected and verified succesfully!\n"
f"_Discovered valid OIDC issuer: {discovery_doc['issuer']}_\n\n"
)
def _build_error_details(self, errors: dict[str, str]) -> str:
"""Build error details from validation errors."""
base = errors.get("base", "")
detail_string = errors.get("detail_string", "")
error_messages = {
"discovery_invalid": (
"❌ **Discovery document could not be validated.**\n"
"Please verify the discovery URL is correct and accessible.\n\n"
f"_({detail_string})_"
),
"jwks_invalid": (
"❌ **JWKS validation failed**\n"
"The JSON Web Key Set could not be retrieved or validated."
),
"cannot_connect": (
"❌ **Connection failed**\n"
"Unable to connect to the OIDC provider. Check your network and URL."
),
}
return error_messages.get(base, "")
async def _build_validation_form(
self, errors: dict[str, str], discovery_doc: dict | None = None
) -> FlowResult:
"""Build the validation form with errors and action options."""
action_options = self._get_action_options(bool(errors))
data_schema = vol.Schema({vol.Required("action"): vol.In(action_options)})
# Build description with discovery details
description_placeholders = {
"discovery_url": self._flow_state.discovery_url,
"provider_name": self.current_provider_name,
"discovery_details": "",
"documentation_url": get_provider_docs_url(self._flow_state.provider),
}
# Add appropriate details based on validation state
if discovery_doc and not errors:
description_placeholders["discovery_details"] = (
self._build_discovery_success_details(discovery_doc)
)
elif errors:
description_placeholders["discovery_details"] = self._build_error_details(
errors
)
return self.async_show_form(
step_id="validate_connection",
data_schema=data_schema,
errors=errors,
description_placeholders=description_placeholders,
)
async def async_step_validate_connection(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Validate the OIDC configuration by testing discovery and JWKS."""
# Handle user actions from validation form
if user_input is not None:
action_result = await self._handle_validation_actions(user_input)
if action_result is not None:
return action_result
# Perform validation (either initial attempt or retry)
discovery_doc, errors = await self._perform_oidc_validation()
# Always show validation form with results (success or error)
return await self._build_validation_form(errors, discovery_doc)
# =================
# Step 4: Configure client details (client_id & client_secret)
# =================
async def _proceed_to_next_step_after_client_config(self) -> FlowResult:
"""Proceed to next step after client config."""
if self.current_provider_config.get("supports_groups", True):
return await self.async_step_groups_config()
return await self.async_step_user_linking()
async def async_step_client_config(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle client ID and client type selection."""
errors = {}
if user_input is not None:
client_id = user_input[CONF_CLIENT_ID]
# Validate client ID
if not validate_client_id(client_id):
errors["client_id"] = "invalid_client_id"
if not errors:
self._client_config.client_id = client_id.strip()
# Optional client secret determines confidential/public
provided_secret = sanitize_client_secret(
user_input.get(CONF_CLIENT_SECRET, "")
)
self._client_config.client_secret = provided_secret or None
if not errors:
return await self._proceed_to_next_step_after_client_config()
provider_name = self.current_provider_name
# Pre-populate with existing values if available
default_client_id = (
self._client_config.client_id
if self._client_config.client_id
else vol.UNDEFINED
)
default_client_secret = (
self._client_config.client_secret
if self._client_config.client_secret
else vol.UNDEFINED
)
data_schema = vol.Schema(
{
vol.Required(CONF_CLIENT_ID, default=default_client_id): str,
vol.Optional(CONF_CLIENT_SECRET, default=default_client_secret): str,
}
)
return self.async_show_form(
step_id="client_config",
data_schema=data_schema,
errors=errors,
description_placeholders={
"provider_name": provider_name,
"discovery_url": self._flow_state.discovery_url,
"documentation_url": get_provider_docs_url(self._flow_state.provider),
},
)
# =================
# Step 5: Configure groups settings
# =================
async def async_step_groups_config(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Configure groups and roles."""
errors = {}
if user_input is not None:
self._feature_config.enable_groups = user_input.get(
CONF_ENABLE_GROUPS, False
)
if self._feature_config.enable_groups:
self._feature_config.admin_group = user_input.get(
CONF_ADMIN_GROUP, "admins"
)
self._feature_config.user_group = user_input.get(CONF_USER_GROUP)
return await self.async_step_user_linking()
default_admin_group = self.current_provider_config.get(
"default_admin_group", "admins"
)
data_schema_dict = {vol.Optional(CONF_ENABLE_GROUPS, default=True): bool}
# Add group configuration fields if groups are enabled
if user_input is None or user_input.get(CONF_ENABLE_GROUPS, True):
data_schema_dict.update(
{
vol.Optional(CONF_ADMIN_GROUP, default=default_admin_group): str,
vol.Optional(CONF_USER_GROUP): str,
}
)
data_schema = vol.Schema(data_schema_dict)
return self.async_show_form(
step_id="groups_config",
data_schema=data_schema,
errors=errors,
description_placeholders={"provider_name": self.current_provider_name},
)
# =================
# Step 6: Configure user linking
# =================
async def async_step_user_linking(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Configure user linking options."""
errors = {}
if user_input is not None:
self._feature_config.enable_user_linking = user_input.get(
CONF_ENABLE_USER_LINKING, False
)
return await self.async_step_finalize()
data_schema = vol.Schema(
{vol.Optional(CONF_ENABLE_USER_LINKING, default=False): bool}
)
return self.async_show_form(
step_id="user_linking",
data_schema=data_schema,
errors=errors,
description_placeholders={},
)
# =================
# Step 7: Finalize and create entry
# =================
async def async_step_finalize(self) -> FlowResult:
"""Finalize the configuration and create the config entry."""
await self.async_set_unique_id(DOMAIN)
self._abort_if_unique_id_configured()
# Build the configuration
config_data = {
"provider": self._flow_state.provider,
"client_id": self._client_config.client_id,
"discovery_url": self._flow_state.discovery_url,
"display_name": f"{self.current_provider_name}",
}
# Add optional fields
if self._client_config.client_secret:
config_data["client_secret"] = self._client_config.client_secret
# Configure features
features = {
"automatic_user_linking": self._feature_config.enable_user_linking,
"automatic_person_creation": True,
"include_groups_scope": self._feature_config.enable_groups,
}
config_data["features"] = features
# Configure claims using provider defaults
claims = self.current_provider_config["claims"].copy()
config_data["claims"] = claims
# Configure roles if groups are enabled
if self._feature_config.enable_groups:
roles = {}
if self._feature_config.admin_group:
roles["admin"] = self._feature_config.admin_group
if self._feature_config.user_group:
roles["user"] = self._feature_config.user_group
config_data["roles"] = roles
title = f"{self.current_provider_name}"
return self.async_create_entry(title=title, data=config_data)
# =================
# Allow reconfiguration of client ID and secret
# =================
async def _validate_reconfigure_input(
self, entry, user_input: dict[str, Any]
) -> tuple[dict[str, str], dict[str, Any] | None]:
"""Validate reconfigure input and return errors and data updates."""
errors = {}
# Validate client ID
client_id = user_input[CONF_CLIENT_ID].strip()
if not validate_client_id(client_id):
errors["client_id"] = "invalid_client_id"
return errors, None
# Determine confidentiality by presence of client secret
client_secret = user_input.get(CONF_CLIENT_SECRET, "").strip()
# If secret is empty, keep the existing one (if any)
if not client_secret:
client_secret = entry.data.get("client_secret")
# Build updated data
data_updates = {"client_id": client_id}
if client_secret:
data_updates["client_secret"] = client_secret
elif "client_secret" in entry.data and not client_secret:
# Remove client secret if switching from confidential to public
data_updates = {**entry.data, **data_updates}
data_updates.pop("client_secret", None)
return errors, data_updates
def _build_reconfigure_schema(
self, current_data: dict[str, Any], _user_input: dict[str, Any] | None
) -> vol.Schema:
"""Build the reconfigure form schema."""
schema_dict = {
vol.Required(
CONF_CLIENT_ID, default=current_data.get("client_id", vol.UNDEFINED)
): str,
}
# Always allow updating or clearing the client secret
schema_dict[vol.Optional(CONF_CLIENT_SECRET)] = str
return vol.Schema(schema_dict)
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle reconfiguration of OIDC client credentials."""
errors = {}
entry = self._get_reconfigure_entry()
if entry is None:
return self.async_abort(reason="unknown")
if user_input is not None:
try:
errors, data_updates = await self._validate_reconfigure_input(
entry, user_input
)
if not errors:
# Update the config entry
await self.async_set_unique_id(entry.unique_id)
self._abort_if_unique_id_mismatch()
return self.async_update_reload_and_abort(
entry, data_updates=data_updates
)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected error during reconfiguration")
errors["base"] = "unknown"
# Show form
current_data = entry.data
data_schema = self._build_reconfigure_schema(current_data, user_input)
return self.async_show_form(
step_id="reconfigure",
data_schema=data_schema,
errors=errors,
description_placeholders={
"provider_name": get_provider_name(current_data.get("provider")),
"discovery_url": current_data.get("discovery_url", ""),
},
)
def _get_reconfigure_entry(self):
"""Return the config entry being reconfigured if available.
Prefer the entry referenced by the flow context's entry_id. Fall back to the
first existing entry for this domain when only a single instance is allowed.
"""
# Try from flow context (preferred)
entry_id = None
context = getattr(self, "context", None)
if context and hasattr(context, "get"):
entry_id = context.get("entry_id")
if entry_id:
entry = self.hass.config_entries.async_get_entry(entry_id)
if entry and entry.domain == DOMAIN:
return entry
# Fallback: this integration allows a single instance; use the first
current = self._async_current_entries()
if current:
return current[0]
return None
@staticmethod
@callback
def async_get_options_flow(config_entry):
"""Get the options flow for this handler."""
return OIDCOptionsFlowHandler()
class OIDCOptionsFlowHandler(config_entries.OptionsFlow):
"""Handle options flow for OIDC Authentication."""
async def async_step_init(self, user_input=None):
"""Handle options flow."""
if user_input is not None:
# Process the updated configuration
updated_features = {
"automatic_user_linking": user_input.get("enable_user_linking", False),
"include_groups_scope": user_input.get("enable_groups", False),
}
updated_roles = {}
if user_input.get("enable_groups", False):
if user_input.get("admin_group"):
updated_roles["admin"] = user_input["admin_group"]
if user_input.get("user_group"):
updated_roles["user"] = user_input["user_group"]
# Update the config entry data
new_data = self.config_entry.data.copy()
new_data["features"] = {**new_data.get("features", {}), **updated_features}
if updated_roles:
new_data["roles"] = updated_roles
elif "roles" in new_data:
# Remove roles if groups are disabled
if not user_input.get("enable_groups", False):
del new_data["roles"]
# Update the config entry
self.hass.config_entries.async_update_entry(
self.config_entry, data=new_data
)
return self.async_create_entry(title="", data={})
current_config = self.config_entry.data
current_features = current_config.get("features", {})
current_roles = current_config.get("roles", {})
# Determine if this provider supports groups
provider = current_config.get("provider", "authentik")
provider_supports_groups = OIDC_PROVIDERS.get(provider, {}).get(
"supports_groups", True
)
# Build schema based on provider capabilities
schema_dict = {
vol.Optional(
"enable_user_linking",
default=current_features.get("automatic_user_linking", False),
): bool
}
# Add groups options if provider supports them
if provider_supports_groups:
enable_groups_default = current_features.get("include_groups_scope", False)
schema_dict[
vol.Optional("enable_groups", default=enable_groups_default)
] = bool
# Add group name fields if groups are currently enabled or being enabled
if enable_groups_default or (
user_input and user_input.get("enable_groups", False)
):
schema_dict.update(
{
vol.Optional(
"admin_group",
default=current_roles.get("admin", DEFAULT_ADMIN_GROUP),
): str,
vol.Optional(
"user_group", default=current_roles.get("user", "")
): str,
}
)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema_dict),
description_placeholders={
"provider_name": get_provider_name(provider),
},
)
def convert_ui_config_entry_to_internal_format(config_data: dict) -> dict:
"""Convert config entry data to internal configuration format."""
my_config = {}
# Required fields
my_config[CLIENT_ID] = config_data["client_id"]
my_config[DISCOVERY_URL] = config_data["discovery_url"]
# Optional fields
if "client_secret" in config_data:
my_config[CLIENT_SECRET] = config_data["client_secret"]
if "display_name" in config_data:
my_config[DISPLAY_NAME] = config_data["display_name"]
# Features configuration
if "features" in config_data:
my_config[FEATURES] = config_data["features"]
# Claims configuration
if "claims" in config_data:
my_config[CLAIMS] = config_data["claims"]
# Roles configuration
if "roles" in config_data:
my_config[ROLES] = config_data["roles"]
return my_config