Move some code around and improve validation (#128)

This commit is contained in:
Christiaan Goossens
2025-10-04 17:34:31 +02:00
committed by GitHub
parent 3b481cd282
commit d1da841e1f
26 changed files with 1334 additions and 1100 deletions

2
.gitignore vendored
View File

@@ -105,6 +105,6 @@ dmypy.json
.pyre/ .pyre/
# End of https://www.gitignore.io/api/python # End of https://www.gitignore.io/api/python
config/ /config/
.venv .venv

View File

@@ -9,8 +9,10 @@ from homeassistant.core import HomeAssistant
# Import and re-export config schema explictly # Import and re-export config schema explictly
# pylint: disable=useless-import-alias # pylint: disable=useless-import-alias
from .config import CONFIG_SCHEMA as CONFIG_SCHEMA
# Get all the constants for the config
from .config import ( from .config import (
CONFIG_SCHEMA as CONFIG_SCHEMA,
DOMAIN, DOMAIN,
DEFAULT_TITLE, DEFAULT_TITLE,
CLIENT_ID, CLIENT_ID,
@@ -27,17 +29,19 @@ from .config import (
FEATURES_INCLUDE_GROUPS_SCOPE, FEATURES_INCLUDE_GROUPS_SCOPE,
FEATURES_DISABLE_FRONTEND_INJECTION, FEATURES_DISABLE_FRONTEND_INJECTION,
FEATURES_FORCE_HTTPS, FEATURES_FORCE_HTTPS,
REQUIRED_SCOPES,
) )
# pylint: enable=useless-import-alias from .config import convert_ui_config_entry_to_internal_format
from .endpoints.welcome import OIDCWelcomeView from .endpoints import (
from .endpoints.redirect import OIDCRedirectView OIDCWelcomeView,
from .endpoints.finish import OIDCFinishView OIDCRedirectView,
from .endpoints.callback import OIDCCallbackView OIDCFinishView,
from .endpoints.injected_auth_page import OIDCInjectedAuthPage OIDCCallbackView,
OIDCInjectedAuthPage,
from .oidc_client import OIDCClient )
from .tools.oidc_client import OIDCClient
from .provider import OpenIDAuthProvider from .provider import OpenIDAuthProvider
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -62,12 +66,12 @@ async def async_setup(hass: HomeAssistant, config):
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
"""Set up OIDC Authentication from a config entry.""" """Set up OIDC Authentication from a config entry (UI config)."""
# Convert config entry data to the format expected by the existing setup # Convert config entry data to the format expected by the existing setup
config_data = entry.data.copy() config_data = entry.data.copy()
# Convert config entry format to internal format # Convert config entry format to internal format
my_config = _convert_config_entry_to_internal_format(config_data) my_config = convert_ui_config_entry_to_internal_format(config_data)
# Get display name from config entry # Get display name from config entry
display_name = config_data.get("display_name", DEFAULT_TITLE) display_name = config_data.get("display_name", DEFAULT_TITLE)
@@ -83,36 +87,6 @@ async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry):
return False return False
def _convert_config_entry_to_internal_format(config_data: dict) -> dict:
"""Convert config entry data to internal configuration format."""
my_config = {}
# Required fields
my_config[CLIENT_ID] = config_data["client_id"]
my_config[DISCOVERY_URL] = config_data["discovery_url"]
# Optional fields
if "client_secret" in config_data:
my_config[CLIENT_SECRET] = config_data["client_secret"]
if "display_name" in config_data:
my_config[DISPLAY_NAME] = config_data["display_name"]
# Features configuration
if "features" in config_data:
my_config[FEATURES] = config_data["features"]
# Claims configuration
if "claims" in config_data:
my_config[CLAIMS] = config_data["claims"]
# Roles configuration
if "roles" in config_data:
my_config[ROLES] = config_data["roles"]
return my_config
async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str): async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str):
"""Set up the OIDC provider with the given configuration.""" """Set up the OIDC provider with the given configuration."""
providers = OrderedDict() providers = OrderedDict()
@@ -131,7 +105,7 @@ async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_nam
# Set the correct scopes # Set the correct scopes
# Always use 'openid' & 'profile' as they are specified in the OIDC spec # Always use 'openid' & 'profile' as they are specified in the OIDC spec
# All servers should support this # All servers should support this
scope = "openid profile" scope = REQUIRED_SCOPES
# Include groups if requested (default is to include 'groups' # Include groups if requested (default is to include 'groups'
# as a scope for Authelia & Authentik) # as a scope for Authelia & Authentik)

View File

@@ -0,0 +1,8 @@
"""Imports manager"""
from .const import * # noqa: F403
from .schema import CONFIG_SCHEMA as CONFIG_SCHEMA
from .ui_flow import (
OIDCConfigFlow as OIDCConfigFlow,
convert_ui_config_entry_to_internal_format as convert_ui_config_entry_to_internal_format,
)

View File

@@ -0,0 +1,92 @@
"""Config constants."""
from typing import Any, Dict
## ===
## General integration constants
## ===
DEFAULT_TITLE = "OpenID Connect (SSO)"
DOMAIN = "auth_oidc"
REPO_ROOT_URL = "https://github.com/christiaangoossens/hass-oidc-auth/tree/v0.7.0-alpha"
## ===
## Config keys
## ===
CLIENT_ID = "client_id"
CLIENT_SECRET = "client_secret"
DISCOVERY_URL = "discovery_url"
DISPLAY_NAME = "display_name"
ID_TOKEN_SIGNING_ALGORITHM = "id_token_signing_alg"
GROUPS_SCOPE = "groups_scope"
ADDITIONAL_SCOPES = "additional_scopes"
FEATURES = "features"
FEATURES_AUTOMATIC_USER_LINKING = "automatic_user_linking"
FEATURES_AUTOMATIC_PERSON_CREATION = "automatic_person_creation"
FEATURES_DISABLE_PKCE = "disable_rfc7636"
FEATURES_INCLUDE_GROUPS_SCOPE = "include_groups_scope"
FEATURES_DISABLE_FRONTEND_INJECTION = "disable_frontend_changes"
FEATURES_FORCE_HTTPS = "force_https"
CLAIMS = "claims"
CLAIMS_DISPLAY_NAME = "display_name"
CLAIMS_USERNAME = "username"
CLAIMS_GROUPS = "groups"
ROLES = "roles"
ROLE_ADMINS = "admin"
ROLE_USERS = "user"
NETWORK = "network"
NETWORK_TLS_VERIFY = "tls_verify"
NETWORK_TLS_CA_PATH = "tls_ca_path"
## ===
## Default configurations for providers
## ===
REQUIRED_SCOPES = "openid profile"
DEFAULT_ID_TOKEN_SIGNING_ALGORITHM = "RS256"
DEFAULT_GROUPS_SCOPE = "groups"
DEFAULT_ADMIN_GROUP = "admins"
OIDC_PROVIDERS: Dict[str, Dict[str, Any]] = {
"authentik": {
"name": "Authentik",
"discovery_url": "",
"default_admin_group": DEFAULT_ADMIN_GROUP,
"supports_groups": True,
"claims": {
"display_name": "name",
"username": "preferred_username",
"groups": "groups",
},
},
"authelia": {
"name": "Authelia",
"discovery_url": "",
"default_admin_group": DEFAULT_ADMIN_GROUP,
"supports_groups": True,
"claims": {
"display_name": "name",
"username": "preferred_username",
"groups": "groups",
},
},
"pocketid": {
"name": "Pocket ID",
"discovery_url": "",
"default_admin_group": DEFAULT_ADMIN_GROUP,
"supports_groups": True,
"claims": {
"display_name": "name",
"username": "preferred_username",
"groups": "groups",
},
},
"generic": {
"name": "OpenID Connect (SSO)",
"discovery_url": "",
"supports_groups": False,
"claims": {"display_name": "name", "username": "preferred_username"},
},
}

View File

@@ -0,0 +1,35 @@
"""Provider catalog and helpers for OIDC providers."""
from __future__ import annotations
from typing import Any, Dict
from .const import OIDC_PROVIDERS, REPO_ROOT_URL
def get_provider_config(key: str) -> Dict[str, Any]:
"""Return provider configuration by key."""
return OIDC_PROVIDERS.get(key, {})
def get_provider_name(key: str | None) -> str:
"""Return provider display name by key."""
if not key:
return "Unknown Provider"
return OIDC_PROVIDERS.get(key, {}).get("name", "Unknown Provider")
def get_provider_docs_url(key: str | None) -> str:
"""Return documentation URL for a provider key."""
base_url = REPO_ROOT_URL + "/docs/provider-configurations"
provider_docs = {
"authentik": f"{base_url}/authentik.md",
"authelia": f"{base_url}/authelia.md",
"pocketid": f"{base_url}/pocket-id.md",
"kanidm": f"{base_url}/kanidm.md",
"microsoft": f"{base_url}/microsoft-entra.md",
}
if key in provider_docs:
return provider_docs[key]
return REPO_ROOT_URL + "/docs/configuration.md"

View File

@@ -1,36 +1,35 @@
"""Config schema and constants.""" """Config schema"""
import voluptuous as vol import voluptuous as vol
from .const import (
CLIENT_ID,
CLIENT_SECRET,
DISCOVERY_URL,
DISPLAY_NAME,
ID_TOKEN_SIGNING_ALGORITHM,
GROUPS_SCOPE,
ADDITIONAL_SCOPES,
FEATURES,
FEATURES_AUTOMATIC_USER_LINKING,
FEATURES_AUTOMATIC_PERSON_CREATION,
FEATURES_DISABLE_PKCE,
FEATURES_INCLUDE_GROUPS_SCOPE,
FEATURES_DISABLE_FRONTEND_INJECTION,
FEATURES_FORCE_HTTPS,
CLAIMS,
CLAIMS_DISPLAY_NAME,
CLAIMS_USERNAME,
CLAIMS_GROUPS,
ROLES,
ROLE_ADMINS,
ROLE_USERS,
NETWORK,
NETWORK_TLS_VERIFY,
NETWORK_TLS_CA_PATH,
DOMAIN,
DEFAULT_GROUPS_SCOPE,
)
CLIENT_ID = "client_id"
CLIENT_SECRET = "client_secret"
DISCOVERY_URL = "discovery_url"
DISPLAY_NAME = "display_name"
ID_TOKEN_SIGNING_ALGORITHM = "id_token_signing_alg"
GROUPS_SCOPE = "groups_scope"
ADDITIONAL_SCOPES = "additional_scopes"
FEATURES = "features"
FEATURES_AUTOMATIC_USER_LINKING = "automatic_user_linking"
FEATURES_AUTOMATIC_PERSON_CREATION = "automatic_person_creation"
FEATURES_DISABLE_PKCE = "disable_rfc7636"
FEATURES_INCLUDE_GROUPS_SCOPE = "include_groups_scope"
FEATURES_DISABLE_FRONTEND_INJECTION = "disable_frontend_changes"
FEATURES_FORCE_HTTPS = "force_https"
CLAIMS = "claims"
CLAIMS_DISPLAY_NAME = "display_name"
CLAIMS_USERNAME = "username"
CLAIMS_GROUPS = "groups"
ROLES = "roles"
ROLE_ADMINS = "admin"
ROLE_USERS = "user"
NETWORK = "network"
NETWORK_TLS_VERIFY = "tls_verify"
NETWORK_TLS_CA_PATH = "tls_ca_path"
DEFAULT_TITLE = "OpenID Connect (SSO)"
DOMAIN = "auth_oidc"
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
DOMAIN: vol.Schema( DOMAIN: vol.Schema(
@@ -48,7 +47,9 @@ CONFIG_SCHEMA = vol.Schema(
vol.Optional(ID_TOKEN_SIGNING_ALGORITHM): vol.Coerce(str), vol.Optional(ID_TOKEN_SIGNING_ALGORITHM): vol.Coerce(str),
# String value to allow changing the groups scope # String value to allow changing the groups scope
# Defaults to 'groups' which is used by Authelia and Authentik # Defaults to 'groups' which is used by Authelia and Authentik
vol.Optional(GROUPS_SCOPE, default="groups"): vol.Coerce(str), vol.Optional(GROUPS_SCOPE, default=DEFAULT_GROUPS_SCOPE): vol.Coerce(
str
),
# Additional scopes to request from the OIDC provider # Additional scopes to request from the OIDC provider
# Optional, this field is unnecessary if you only use the openid and profile scopes. # Optional, this field is unnecessary if you only use the openid and profile scopes.
vol.Optional(ADDITIONAL_SCOPES, default=[]): vol.Coerce(list[str]), vol.Optional(ADDITIONAL_SCOPES, default=[]): vol.Coerce(list[str]),

View File

@@ -0,0 +1,842 @@
"""Config flow for OIDC Authentication integration."""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from typing import Any
import aiohttp
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from .const import (
DOMAIN,
DEFAULT_ADMIN_GROUP,
CLIENT_ID,
CLIENT_SECRET,
DISCOVERY_URL,
DISPLAY_NAME,
FEATURES,
CLAIMS,
ROLES,
DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
)
from ..tools.oidc_client import (
OIDCDiscoveryClient,
OIDCDiscoveryInvalid,
OIDCJWKSInvalid,
)
from .provider_catalog import (
OIDC_PROVIDERS,
get_provider_name,
get_provider_docs_url,
)
from ..tools.validation import (
validate_discovery_url,
sanitize_client_secret,
validate_client_id,
)
_LOGGER = logging.getLogger(__name__)
# Configuration field names
CONF_PROVIDER = "provider"
CONF_CLIENT_ID = "client_id"
CONF_CLIENT_SECRET = "client_secret"
CONF_DISCOVERY_URL = "discovery_url"
CONF_ENABLE_GROUPS = "enable_groups"
CONF_ADMIN_GROUP = "admin_group"
CONF_USER_GROUP = "user_group"
CONF_ENABLE_USER_LINKING = "enable_user_linking"
# Cache settings
DISCOVERY_CACHE_TTL = 300 # 5 minutes
MAX_CACHE_SIZE = 10
@dataclass
class FlowState:
"""State tracking for the configuration flow."""
provider: str | None = None
discovery_url: str | None = None
@dataclass
class ClientConfig:
"""Client configuration settings."""
client_id: str | None = None
client_secret: str | None = None
@dataclass
class FeatureConfig:
"""Feature configuration settings."""
enable_groups: bool = False
admin_group: str = DEFAULT_ADMIN_GROUP
user_group: str | None = None
enable_user_linking: bool = False
class OIDCConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for OIDC Authentication."""
VERSION = 1
def is_matching(self, other_flow):
"""Check if this flow is the same as another flow."""
self_state = getattr(self, "_flow_state", None)
other_state = getattr(other_flow, "_flow_state", None)
if not self_state or not other_state:
return False
self_discovery_url = self_state.discovery_url
other_discovery_url = other_state.discovery_url
return (
self_discovery_url
and other_discovery_url
and self_discovery_url.rstrip("/").lower()
== other_discovery_url.rstrip("/").lower()
)
def __init__(self):
"""Initialize the config flow."""
self._flow_state = FlowState()
self._client_config = ClientConfig()
self._feature_config = FeatureConfig()
self._discovery_cache = {}
self._cache_timestamps = {}
@property
def current_provider_config(self) -> dict[str, Any]:
"""Get the configuration for the currently selected provider."""
if not self._flow_state.provider:
return {}
return OIDC_PROVIDERS.get(self._flow_state.provider, {})
@property
def current_provider_name(self) -> str:
"""Get the name of the currently selected provider."""
return get_provider_name(self._flow_state.provider)
def _cleanup_discovery_cache(self) -> None:
"""Remove expired and excess cache entries."""
current_time = time.time()
# Remove expired entries
expired_keys = [
key
for key, timestamp in self._cache_timestamps.items()
if current_time - timestamp > DISCOVERY_CACHE_TTL
]
for key in expired_keys:
self._discovery_cache.pop(key, None)
self._cache_timestamps.pop(key, None)
# Remove oldest entries if cache is too large
if len(self._discovery_cache) > MAX_CACHE_SIZE:
sorted_items = sorted(self._cache_timestamps.items(), key=lambda x: x[1])
excess_count = len(self._discovery_cache) - MAX_CACHE_SIZE
for key, _ in sorted_items[:excess_count]:
self._discovery_cache.pop(key, None)
self._cache_timestamps.pop(key, None)
def _is_cache_valid(self, cache_key: str) -> bool:
"""Check if a cache entry is still valid."""
if cache_key not in self._cache_timestamps:
return False
age = time.time() - self._cache_timestamps[cache_key]
return age <= DISCOVERY_CACHE_TTL
# =================
# Step 1: Provider selection
# =================
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the initial step - provider selection."""
# Check if OIDC is already configured (only one instance allowed)
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
# Check if YAML configuration exists
if self.hass.data.get(DOMAIN, {}).get("yaml_config"):
return self.async_abort(reason="yaml_configured")
errors = {}
if user_input is not None:
self._flow_state.provider = user_input[CONF_PROVIDER]
# If provider has a predefined discovery URL, prefill it but still
# show the discovery URL step so the user can customize it.
predefined = self.current_provider_config.get("discovery_url")
if predefined:
self._flow_state.discovery_url = predefined
# Always request discovery URL next (prefilled when available)
return await self.async_step_discovery_url()
data_schema = vol.Schema(
{
vol.Required(CONF_PROVIDER): vol.In(
{key: provider["name"] for key, provider in OIDC_PROVIDERS.items()}
)
}
)
return self.async_show_form(
step_id="user",
data_schema=data_schema,
errors=errors,
description_placeholders={},
)
# =================
# Step 2: Discovery URL
# =================
async def async_step_discovery_url(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle discovery URL input for providers requiring URL configuration."""
errors = {}
if user_input is not None:
discovery_url = user_input[CONF_DISCOVERY_URL].rstrip("/")
# Validate discovery URL format
if not validate_discovery_url(discovery_url):
errors["discovery_url"] = "invalid_url_format"
else:
self._flow_state.discovery_url = discovery_url
return await self.async_step_validate_connection()
provider_name = self.current_provider_name
provider_key = self._flow_state.provider
# Pre-populate with existing discovery URL if available
default_url = (
self._flow_state.discovery_url
if self._flow_state.discovery_url
else vol.UNDEFINED
)
data_schema = vol.Schema(
{vol.Required(CONF_DISCOVERY_URL, default=default_url): str}
)
return self.async_show_form(
step_id="discovery_url",
data_schema=data_schema,
errors=errors,
description_placeholders={
"provider_name": provider_name,
"documentation_url": get_provider_docs_url(provider_key),
},
)
# =================
# Step 3: Discovery Validation
# =================
async def _handle_validation_actions(
self, user_input: dict[str, Any]
) -> FlowResult | None:
"""Handle user actions from the validation form so they can fix errors."""
action = user_input.get("action")
# Handle special actions first
if action == "retry":
return None # Continue with validation
if action == "continue":
return await self.async_step_client_config()
# Handle redirect actions
action_handlers = {
"fix_discovery": self.async_step_discovery_url,
"change_provider": self.async_step_user,
}
handler = action_handlers.get(action)
return await handler() if handler else None
async def _perform_oidc_validation(self) -> tuple[dict, dict]:
"""Perform the actual OIDC validation and return discovery doc and errors."""
errors = {}
discovery_doc = {}
try:
http_session = aiohttp.ClientSession()
discovery_client = OIDCDiscoveryClient(
discovery_url=self._flow_state.discovery_url,
http_session=http_session,
verification_context={
# Cannot be changed from the UI config currently
"id_token_signing_alg": DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
},
)
# Clean up expired cache entries first
self._cleanup_discovery_cache()
# Check if discovery document is already cached and valid
cache_key = self._flow_state.discovery_url
if cache_key in self._discovery_cache and self._is_cache_valid(cache_key):
discovery_doc = self._discovery_cache[cache_key]
# Still validate JWKS if available since this might be a retry
if "jwks_uri" in discovery_doc:
await discovery_client.fetch_jwks(discovery_doc["jwks_uri"])
else:
# Perform discovery and JWKS validation
discovery_doc = await discovery_client.fetch_discovery_document()
# Cache the discovery document with timestamp
self._discovery_cache[cache_key] = discovery_doc
self._cache_timestamps[cache_key] = time.time()
# Validate JWKS if available
if "jwks_uri" in discovery_doc:
await discovery_client.fetch_jwks(discovery_doc["jwks_uri"])
except OIDCDiscoveryInvalid as e:
errors["base"] = "discovery_invalid"
errors["detail_string"] = e.get_detail_string()
except OIDCJWKSInvalid:
errors["base"] = "jwks_invalid"
except aiohttp.ClientError:
errors["base"] = "cannot_connect"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected error during validation")
errors["base"] = "unknown"
await http_session.close()
return discovery_doc, errors
def _get_action_options(self, has_errors: bool) -> dict[str, str]:
"""Get action options based on validation state."""
if has_errors:
return {
"retry": "Retry Validation",
"fix_discovery": "Change Discovery URL",
"change_provider": "Change Provider",
}
return {
"continue": "Continue Setup",
"fix_discovery": "Change Discovery URL",
"change_provider": "Change Provider",
}
def _build_discovery_success_details(self, discovery_doc: dict) -> str:
"""Build success details from discovery document."""
return (
f"✅ Connected and verified succesfully!\n"
f"_Discovered valid OIDC issuer: {discovery_doc['issuer']}_\n\n"
)
def _build_error_details(self, errors: dict[str, str]) -> str:
"""Build error details from validation errors."""
base = errors.get("base", "")
detail_string = errors.get("detail_string", "")
error_messages = {
"discovery_invalid": (
"❌ **Discovery document could not be validated.**\n"
"Please verify the discovery URL is correct and accessible.\n\n"
f"_({detail_string})_"
),
"jwks_invalid": (
"❌ **JWKS validation failed**\n"
"The JSON Web Key Set could not be retrieved or validated."
),
"cannot_connect": (
"❌ **Connection failed**\n"
"Unable to connect to the OIDC provider. Check your network and URL."
),
}
return error_messages.get(base, "")
async def _build_validation_form(
self, errors: dict[str, str], discovery_doc: dict | None = None
) -> FlowResult:
"""Build the validation form with errors and action options."""
action_options = self._get_action_options(bool(errors))
data_schema = vol.Schema({vol.Required("action"): vol.In(action_options)})
# Build description with discovery details
description_placeholders = {
"discovery_url": self._flow_state.discovery_url,
"provider_name": self.current_provider_name,
"discovery_details": "",
"documentation_url": get_provider_docs_url(self._flow_state.provider),
}
# Add appropriate details based on validation state
if discovery_doc and not errors:
description_placeholders["discovery_details"] = (
self._build_discovery_success_details(discovery_doc)
)
elif errors:
description_placeholders["discovery_details"] = self._build_error_details(
errors
)
return self.async_show_form(
step_id="validate_connection",
data_schema=data_schema,
errors=errors,
description_placeholders=description_placeholders,
)
async def async_step_validate_connection(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Validate the OIDC configuration by testing discovery and JWKS."""
# Handle user actions from validation form
if user_input is not None:
action_result = await self._handle_validation_actions(user_input)
if action_result is not None:
return action_result
# Perform validation (either initial attempt or retry)
discovery_doc, errors = await self._perform_oidc_validation()
# Always show validation form with results (success or error)
return await self._build_validation_form(errors, discovery_doc)
# =================
# Step 4: Configure client details (client_id & client_secret)
# =================
async def _proceed_to_next_step_after_client_config(self) -> FlowResult:
"""Proceed to next step after client config."""
if self.current_provider_config.get("supports_groups", True):
return await self.async_step_groups_config()
return await self.async_step_user_linking()
async def async_step_client_config(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle client ID and client type selection."""
errors = {}
if user_input is not None:
client_id = user_input[CONF_CLIENT_ID]
# Validate client ID
if not validate_client_id(client_id):
errors["client_id"] = "invalid_client_id"
if not errors:
self._client_config.client_id = client_id.strip()
# Optional client secret determines confidential/public
provided_secret = sanitize_client_secret(
user_input.get(CONF_CLIENT_SECRET, "")
)
self._client_config.client_secret = provided_secret or None
if not errors:
return await self._proceed_to_next_step_after_client_config()
provider_name = self.current_provider_name
# Pre-populate with existing values if available
default_client_id = (
self._client_config.client_id
if self._client_config.client_id
else vol.UNDEFINED
)
default_client_secret = (
self._client_config.client_secret
if self._client_config.client_secret
else vol.UNDEFINED
)
data_schema = vol.Schema(
{
vol.Required(CONF_CLIENT_ID, default=default_client_id): str,
vol.Optional(CONF_CLIENT_SECRET, default=default_client_secret): str,
}
)
return self.async_show_form(
step_id="client_config",
data_schema=data_schema,
errors=errors,
description_placeholders={
"provider_name": provider_name,
"discovery_url": self._flow_state.discovery_url,
"documentation_url": get_provider_docs_url(self._flow_state.provider),
},
)
# =================
# Step 5: Configure groups settings
# =================
async def async_step_groups_config(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Configure groups and roles."""
errors = {}
if user_input is not None:
self._feature_config.enable_groups = user_input.get(
CONF_ENABLE_GROUPS, False
)
if self._feature_config.enable_groups:
self._feature_config.admin_group = user_input.get(
CONF_ADMIN_GROUP, "admins"
)
self._feature_config.user_group = user_input.get(CONF_USER_GROUP)
return await self.async_step_user_linking()
default_admin_group = self.current_provider_config.get(
"default_admin_group", "admins"
)
data_schema_dict = {vol.Optional(CONF_ENABLE_GROUPS, default=True): bool}
# Add group configuration fields if groups are enabled
if user_input is None or user_input.get(CONF_ENABLE_GROUPS, True):
data_schema_dict.update(
{
vol.Optional(CONF_ADMIN_GROUP, default=default_admin_group): str,
vol.Optional(CONF_USER_GROUP): str,
}
)
data_schema = vol.Schema(data_schema_dict)
return self.async_show_form(
step_id="groups_config",
data_schema=data_schema,
errors=errors,
description_placeholders={"provider_name": self.current_provider_name},
)
# =================
# Step 6: Configure user linking
# =================
async def async_step_user_linking(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Configure user linking options."""
errors = {}
if user_input is not None:
self._feature_config.enable_user_linking = user_input.get(
CONF_ENABLE_USER_LINKING, False
)
return await self.async_step_finalize()
data_schema = vol.Schema(
{vol.Optional(CONF_ENABLE_USER_LINKING, default=False): bool}
)
return self.async_show_form(
step_id="user_linking",
data_schema=data_schema,
errors=errors,
description_placeholders={},
)
# =================
# Step 7: Finalize and create entry
# =================
async def async_step_finalize(self) -> FlowResult:
"""Finalize the configuration and create the config entry."""
await self.async_set_unique_id(DOMAIN)
self._abort_if_unique_id_configured()
# Build the configuration
config_data = {
"provider": self._flow_state.provider,
"client_id": self._client_config.client_id,
"discovery_url": self._flow_state.discovery_url,
"display_name": f"{self.current_provider_name}",
}
# Add optional fields
if self._client_config.client_secret:
config_data["client_secret"] = self._client_config.client_secret
# Configure features
features = {
"automatic_user_linking": self._feature_config.enable_user_linking,
"automatic_person_creation": True,
"include_groups_scope": self._feature_config.enable_groups,
}
config_data["features"] = features
# Configure claims using provider defaults
claims = self.current_provider_config["claims"].copy()
config_data["claims"] = claims
# Configure roles if groups are enabled
if self._feature_config.enable_groups:
roles = {}
if self._feature_config.admin_group:
roles["admin"] = self._feature_config.admin_group
if self._feature_config.user_group:
roles["user"] = self._feature_config.user_group
config_data["roles"] = roles
title = f"{self.current_provider_name}"
return self.async_create_entry(title=title, data=config_data)
# =================
# Allow reconfiguration of client ID and secret
# =================
async def _validate_reconfigure_input(
self, entry, user_input: dict[str, Any]
) -> tuple[dict[str, str], dict[str, Any] | None]:
"""Validate reconfigure input and return errors and data updates."""
errors = {}
# Validate client ID
client_id = user_input[CONF_CLIENT_ID].strip()
if not validate_client_id(client_id):
errors["client_id"] = "invalid_client_id"
return errors, None
# Determine confidentiality by presence of client secret
client_secret = user_input.get(CONF_CLIENT_SECRET, "").strip()
# If secret is empty, keep the existing one (if any)
if not client_secret:
client_secret = entry.data.get("client_secret")
# Build updated data
data_updates = {"client_id": client_id}
if client_secret:
data_updates["client_secret"] = client_secret
elif "client_secret" in entry.data and not client_secret:
# Remove client secret if switching from confidential to public
data_updates = {**entry.data, **data_updates}
data_updates.pop("client_secret", None)
return errors, data_updates
def _build_reconfigure_schema(
self, current_data: dict[str, Any], _user_input: dict[str, Any] | None
) -> vol.Schema:
"""Build the reconfigure form schema."""
schema_dict = {
vol.Required(
CONF_CLIENT_ID, default=current_data.get("client_id", vol.UNDEFINED)
): str,
}
# Always allow updating or clearing the client secret
schema_dict[vol.Optional(CONF_CLIENT_SECRET)] = str
return vol.Schema(schema_dict)
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle reconfiguration of OIDC client credentials."""
errors = {}
entry = self._get_reconfigure_entry()
if entry is None:
return self.async_abort(reason="unknown")
if user_input is not None:
try:
errors, data_updates = await self._validate_reconfigure_input(
entry, user_input
)
if not errors:
# Update the config entry
await self.async_set_unique_id(entry.unique_id)
self._abort_if_unique_id_mismatch()
return self.async_update_reload_and_abort(
entry, data_updates=data_updates
)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected error during reconfiguration")
errors["base"] = "unknown"
# Show form
current_data = entry.data
data_schema = self._build_reconfigure_schema(current_data, user_input)
return self.async_show_form(
step_id="reconfigure",
data_schema=data_schema,
errors=errors,
description_placeholders={
"provider_name": get_provider_name(current_data.get("provider")),
"discovery_url": current_data.get("discovery_url", ""),
},
)
def _get_reconfigure_entry(self):
"""Return the config entry being reconfigured if available.
Prefer the entry referenced by the flow context's entry_id. Fall back to the
first existing entry for this domain when only a single instance is allowed.
"""
# Try from flow context (preferred)
entry_id = None
context = getattr(self, "context", None)
if context and hasattr(context, "get"):
entry_id = context.get("entry_id")
if entry_id:
entry = self.hass.config_entries.async_get_entry(entry_id)
if entry and entry.domain == DOMAIN:
return entry
# Fallback: this integration allows a single instance; use the first
current = self._async_current_entries()
if current:
return current[0]
return None
@staticmethod
@callback
def async_get_options_flow(config_entry):
"""Get the options flow for this handler."""
return OIDCOptionsFlowHandler()
class OIDCOptionsFlowHandler(config_entries.OptionsFlow):
"""Handle options flow for OIDC Authentication."""
async def async_step_init(self, user_input=None):
"""Handle options flow."""
if user_input is not None:
# Process the updated configuration
updated_features = {
"automatic_user_linking": user_input.get("enable_user_linking", False),
"include_groups_scope": user_input.get("enable_groups", False),
}
updated_roles = {}
if user_input.get("enable_groups", False):
if user_input.get("admin_group"):
updated_roles["admin"] = user_input["admin_group"]
if user_input.get("user_group"):
updated_roles["user"] = user_input["user_group"]
# Update the config entry data
new_data = self.config_entry.data.copy()
new_data["features"] = {**new_data.get("features", {}), **updated_features}
if updated_roles:
new_data["roles"] = updated_roles
elif "roles" in new_data:
# Remove roles if groups are disabled
if not user_input.get("enable_groups", False):
del new_data["roles"]
# Update the config entry
self.hass.config_entries.async_update_entry(
self.config_entry, data=new_data
)
return self.async_create_entry(title="", data={})
current_config = self.config_entry.data
current_features = current_config.get("features", {})
current_roles = current_config.get("roles", {})
# Determine if this provider supports groups
provider = current_config.get("provider", "authentik")
provider_supports_groups = OIDC_PROVIDERS.get(provider, {}).get(
"supports_groups", True
)
# Build schema based on provider capabilities
schema_dict = {
vol.Optional(
"enable_user_linking",
default=current_features.get("automatic_user_linking", False),
): bool
}
# Add groups options if provider supports them
if provider_supports_groups:
enable_groups_default = current_features.get("include_groups_scope", False)
schema_dict[
vol.Optional("enable_groups", default=enable_groups_default)
] = bool
# Add group name fields if groups are currently enabled or being enabled
if enable_groups_default or (
user_input and user_input.get("enable_groups", False)
):
schema_dict.update(
{
vol.Optional(
"admin_group",
default=current_roles.get("admin", DEFAULT_ADMIN_GROUP),
): str,
vol.Optional(
"user_group", default=current_roles.get("user", "")
): str,
}
)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema_dict),
description_placeholders={
"provider_name": get_provider_name(provider),
},
)
def convert_ui_config_entry_to_internal_format(config_data: dict) -> dict:
"""Convert config entry data to internal configuration format."""
my_config = {}
# Required fields
my_config[CLIENT_ID] = config_data["client_id"]
my_config[DISCOVERY_URL] = config_data["discovery_url"]
# Optional fields
if "client_secret" in config_data:
my_config[CLIENT_SECRET] = config_data["client_secret"]
if "display_name" in config_data:
my_config[DISPLAY_NAME] = config_data["display_name"]
# Features configuration
if "features" in config_data:
my_config[FEATURES] = config_data["features"]
# Claims configuration
if "claims" in config_data:
my_config[CLAIMS] = config_data["claims"]
# Roles configuration
if "roles" in config_data:
my_config[ROLES] = config_data["roles"]
return my_config

View File

@@ -1,806 +1,5 @@
"""Config flow for OIDC Authentication integration.""" """UI config flow re-export"""
from __future__ import annotations # pylint: disable=useless-import-alias
# pylint: disable=unused-import
import logging from .config.ui_flow import OIDCConfigFlow as OIDCConfigFlow
import time
from dataclasses import dataclass
from typing import Any
import aiohttp
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from .config import DOMAIN
from .oidc_client import OIDCClient, OIDCDiscoveryInvalid, OIDCJWKSInvalid
from .provider_catalog import (
OIDC_PROVIDERS,
get_provider_name,
get_provider_docs_url,
)
from .validation import (
validate_discovery_url,
sanitize_client_secret,
validate_client_id,
)
_LOGGER = logging.getLogger(__name__)
# Configuration field names
CONF_PROVIDER = "provider"
CONF_CLIENT_ID = "client_id"
CONF_CLIENT_SECRET = "client_secret"
CONF_DISCOVERY_URL = "discovery_url"
CONF_ENABLE_GROUPS = "enable_groups"
CONF_ADMIN_GROUP = "admin_group"
CONF_USER_GROUP = "user_group"
CONF_ENABLE_USER_LINKING = "enable_user_linking"
DEFAULT_ADMIN_GROUP = "admins"
# Cache settings
DISCOVERY_CACHE_TTL = 300 # 5 minutes
MAX_CACHE_SIZE = 10
@dataclass
class FlowState:
"""State tracking for the configuration flow."""
provider: str | None = None
discovery_url: str | None = None
@dataclass
class ClientConfig:
"""Client configuration settings."""
client_id: str | None = None
client_secret: str | None = None
@dataclass
class FeatureConfig:
"""Feature configuration settings."""
enable_groups: bool = False
admin_group: str = DEFAULT_ADMIN_GROUP
user_group: str | None = None
enable_user_linking: bool = False
class OIDCConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for OIDC Authentication."""
VERSION = 1
def is_matching(self, other_flow):
"""Check if this flow is the same as another flow."""
self_state = getattr(self, "_flow_state", None)
other_state = getattr(other_flow, "_flow_state", None)
if not self_state or not other_state:
return False
self_discovery_url = self_state.discovery_url
other_discovery_url = other_state.discovery_url
return (
self_discovery_url
and other_discovery_url
and self_discovery_url.rstrip("/").lower()
== other_discovery_url.rstrip("/").lower()
)
def __init__(self):
"""Initialize the config flow."""
self._flow_state = FlowState()
self._client_config = ClientConfig()
self._feature_config = FeatureConfig()
self._discovery_cache = {}
self._cache_timestamps = {}
@property
def current_provider_config(self) -> dict[str, Any]:
"""Get the configuration for the currently selected provider."""
if not self._flow_state.provider:
return {}
return OIDC_PROVIDERS.get(self._flow_state.provider, {})
@property
def current_provider_name(self) -> str:
"""Get the name of the currently selected provider."""
return get_provider_name(self._flow_state.provider)
def _cleanup_discovery_cache(self) -> None:
"""Remove expired and excess cache entries."""
current_time = time.time()
# Remove expired entries
expired_keys = [
key
for key, timestamp in self._cache_timestamps.items()
if current_time - timestamp > DISCOVERY_CACHE_TTL
]
for key in expired_keys:
self._discovery_cache.pop(key, None)
self._cache_timestamps.pop(key, None)
# Remove oldest entries if cache is too large
if len(self._discovery_cache) > MAX_CACHE_SIZE:
sorted_items = sorted(self._cache_timestamps.items(), key=lambda x: x[1])
excess_count = len(self._discovery_cache) - MAX_CACHE_SIZE
for key, _ in sorted_items[:excess_count]:
self._discovery_cache.pop(key, None)
self._cache_timestamps.pop(key, None)
def _is_cache_valid(self, cache_key: str) -> bool:
"""Check if a cache entry is still valid."""
if cache_key not in self._cache_timestamps:
return False
age = time.time() - self._cache_timestamps[cache_key]
return age <= DISCOVERY_CACHE_TTL
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle the initial step - provider selection."""
# Check if OIDC is already configured (only one instance allowed)
if self._async_current_entries():
return self.async_abort(reason="single_instance_allowed")
# Check if YAML configuration exists
if self.hass.data.get(DOMAIN, {}).get("yaml_config"):
return self.async_abort(reason="single_instance_allowed")
errors = {}
if user_input is not None:
self._flow_state.provider = user_input[CONF_PROVIDER]
# If provider has a predefined discovery URL, prefill it but still
# show the discovery URL step so the user can customize it.
predefined = self.current_provider_config.get("discovery_url")
if predefined:
self._flow_state.discovery_url = predefined
# Always request discovery URL next (prefilled when available)
return await self.async_step_discovery_url()
data_schema = vol.Schema(
{
vol.Required(CONF_PROVIDER): vol.In(
{key: provider["name"] for key, provider in OIDC_PROVIDERS.items()}
)
}
)
return self.async_show_form(
step_id="user",
data_schema=data_schema,
errors=errors,
description_placeholders={},
)
async def async_step_discovery_url(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle discovery URL input for providers requiring URL configuration."""
errors = {}
if user_input is not None:
discovery_url = user_input[CONF_DISCOVERY_URL].rstrip("/")
# Validate discovery URL format
if not validate_discovery_url(discovery_url):
errors["discovery_url"] = "invalid_url_format"
else:
self._flow_state.discovery_url = discovery_url
return await self.async_step_client_config()
provider_name = self.current_provider_name
provider_key = self._flow_state.provider
# Pre-populate with existing discovery URL if available
default_url = (
self._flow_state.discovery_url
if self._flow_state.discovery_url
else vol.UNDEFINED
)
data_schema = vol.Schema(
{vol.Required(CONF_DISCOVERY_URL, default=default_url): str}
)
return self.async_show_form(
step_id="discovery_url",
data_schema=data_schema,
errors=errors,
description_placeholders={
"provider_name": provider_name,
"documentation_url": get_provider_docs_url(provider_key),
},
)
async def async_step_client_config(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle client ID and client type selection."""
errors = {}
if user_input is not None:
client_id = user_input[CONF_CLIENT_ID]
# Validate client ID
if not validate_client_id(client_id):
errors["client_id"] = "invalid_client_id"
if not errors:
self._client_config.client_id = client_id.strip()
# Optional client secret determines confidential/public
provided_secret = sanitize_client_secret(
user_input.get(CONF_CLIENT_SECRET, "")
)
self._client_config.client_secret = provided_secret or None
if not errors:
# Proceed to validation directly from here
return await self.async_step_validate_connection()
provider_name = self.current_provider_name
# Pre-populate with existing values if available
default_client_id = (
self._client_config.client_id
if self._client_config.client_id
else vol.UNDEFINED
)
default_client_secret = (
self._client_config.client_secret
if self._client_config.client_secret
else vol.UNDEFINED
)
data_schema = vol.Schema(
{
vol.Required(CONF_CLIENT_ID, default=default_client_id): str,
vol.Optional(CONF_CLIENT_SECRET, default=default_client_secret): str,
}
)
return self.async_show_form(
step_id="client_config",
data_schema=data_schema,
errors=errors,
description_placeholders={
"provider_name": provider_name,
"discovery_url": self._flow_state.discovery_url,
"documentation_url": get_provider_docs_url(self._flow_state.provider),
},
)
async def _handle_validation_actions(
self, user_input: dict[str, Any]
) -> FlowResult | None:
"""Handle user actions from the validation form so they can fix errors."""
action = user_input.get("action")
# Handle special actions first
if action == "retry":
return None # Continue with validation
if action == "continue":
return await self._proceed_to_next_step()
# Handle redirect actions
action_handlers = {
"fix_discovery": self.async_step_discovery_url,
"fix_client": self.async_step_client_config,
"change_provider": self.async_step_user,
}
handler = action_handlers.get(action)
return await handler() if handler else None
async def _proceed_to_next_step(self) -> FlowResult:
"""Proceed to next step after successful validation."""
if self.current_provider_config.get("supports_groups", True):
return await self.async_step_groups_config()
return await self.async_step_user_linking()
async def _perform_oidc_validation(self) -> tuple[dict, dict]:
"""Perform the actual OIDC validation and return discovery doc and errors."""
errors = {}
discovery_doc = {}
try:
# Create a test OIDC client to validate configuration
test_client = OIDCClient(
hass=self.hass,
discovery_url=self._flow_state.discovery_url,
client_id=self._client_config.client_id,
scope="openid profile",
client_secret=self._client_config.client_secret or None,
features={},
claims={},
roles={},
network={},
)
# Clean up expired cache entries first
self._cleanup_discovery_cache()
# Check if discovery document is already cached and valid
cache_key = self._flow_state.discovery_url
if cache_key in self._discovery_cache and self._is_cache_valid(cache_key):
discovery_doc = self._discovery_cache[cache_key]
# Still validate JWKS if available since this might be a retry
if "jwks_uri" in discovery_doc:
await test_client.validate_jwks(discovery_doc["jwks_uri"])
else:
# Perform discovery and JWKS validation
discovery_doc = await test_client.validate_discovery()
# Cache the discovery document with timestamp
self._discovery_cache[cache_key] = discovery_doc
self._cache_timestamps[cache_key] = time.time()
# Validate JWKS if available
if "jwks_uri" in discovery_doc:
await test_client.validate_jwks(discovery_doc["jwks_uri"])
except OIDCDiscoveryInvalid:
errors["base"] = "discovery_invalid"
except OIDCJWKSInvalid:
errors["base"] = "jwks_invalid"
except aiohttp.ClientError:
errors["base"] = "cannot_connect"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected error during validation")
errors["base"] = "unknown"
return discovery_doc, errors
def _get_action_options(self, has_errors: bool) -> dict[str, str]:
"""Get action options based on validation state."""
if has_errors:
return {
"retry": "Retry Validation",
"fix_client": "Fix Client Settings",
"fix_discovery": "Fix Discovery URL",
"change_provider": "Change Provider",
}
return {
"continue": "Continue Setup",
"fix_client": "Modify Client Settings",
"fix_discovery": "Modify Discovery URL",
"change_provider": "Change Provider",
}
def _build_discovery_success_details(self, discovery_doc: dict) -> str:
"""Build success details from discovery document."""
discovery_info = []
endpoints = [
("issuer", "✅ Connection Successful", "**Issuer:** {value}"),
("authorization_endpoint", "✅ Authorization endpoint found", None),
("token_endpoint", "✅ Token endpoint found", None),
("jwks_uri", "✅ JWKS endpoint found", None),
("userinfo_endpoint", "✅ User info endpoint found", None),
]
for key, message, formatted in endpoints:
if key in discovery_doc:
discovery_info.append(message)
if formatted and key == "issuer":
discovery_info.append(formatted.format(value=discovery_doc[key]))
return "\n".join(discovery_info)
def _build_error_details(self, errors: dict[str, str]) -> str:
"""Build error details from validation errors."""
error_messages = {
"discovery_invalid": (
"❌ **Discovery document could not be retrieved**\n"
"Please verify the discovery URL is correct and accessible."
),
"jwks_invalid": (
"❌ **JWKS validation failed**\n"
"The JSON Web Key Set could not be retrieved or validated."
),
"cannot_connect": (
"❌ **Connection failed**\n"
"Unable to connect to the OIDC provider. Check your network and URL."
),
}
return error_messages.get(errors.get("base", ""), "")
async def _build_validation_form(
self, errors: dict[str, str], discovery_doc: dict | None = None
) -> FlowResult:
"""Build the validation form with errors and action options."""
action_options = self._get_action_options(bool(errors))
data_schema = vol.Schema({vol.Required("action"): vol.In(action_options)})
# Build description with discovery details
description_placeholders = {
"discovery_url": self._flow_state.discovery_url,
"client_id": self._client_config.client_id,
"provider_name": self.current_provider_name,
"discovery_details": "",
"documentation_url": get_provider_docs_url(self._flow_state.provider),
}
# Add appropriate details based on validation state
if discovery_doc and not errors:
description_placeholders["discovery_details"] = (
self._build_discovery_success_details(discovery_doc)
)
elif errors:
description_placeholders["discovery_details"] = self._build_error_details(
errors
)
return self.async_show_form(
step_id="validate_connection",
data_schema=data_schema,
errors=errors,
description_placeholders=description_placeholders,
)
async def async_step_validate_connection(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Validate the OIDC configuration by testing discovery and JWKS."""
# Handle user actions from validation form
if user_input is not None:
action_result = await self._handle_validation_actions(user_input)
if action_result is not None:
return action_result
# Perform validation (either initial attempt or retry)
discovery_doc, errors = await self._perform_oidc_validation()
# Always show validation form with results (success or error)
return await self._build_validation_form(errors, discovery_doc)
async def async_step_groups_config(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Configure groups and roles."""
errors = {}
if user_input is not None:
self._feature_config.enable_groups = user_input.get(
CONF_ENABLE_GROUPS, False
)
if self._feature_config.enable_groups:
self._feature_config.admin_group = user_input.get(
CONF_ADMIN_GROUP, "admins"
)
self._feature_config.user_group = user_input.get(CONF_USER_GROUP)
return await self.async_step_user_linking()
default_admin_group = self.current_provider_config.get(
"default_admin_group", "admins"
)
data_schema_dict = {vol.Optional(CONF_ENABLE_GROUPS, default=True): bool}
# Add group configuration fields if groups are enabled
if user_input is None or user_input.get(CONF_ENABLE_GROUPS, True):
data_schema_dict.update(
{
vol.Optional(CONF_ADMIN_GROUP, default=default_admin_group): str,
vol.Optional(CONF_USER_GROUP): str,
}
)
data_schema = vol.Schema(data_schema_dict)
return self.async_show_form(
step_id="groups_config",
data_schema=data_schema,
errors=errors,
description_placeholders={"provider_name": self.current_provider_name},
)
async def async_step_user_linking(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Configure user linking options."""
errors = {}
if user_input is not None:
self._feature_config.enable_user_linking = user_input.get(
CONF_ENABLE_USER_LINKING, False
)
return await self.async_step_finalize()
data_schema = vol.Schema(
{vol.Optional(CONF_ENABLE_USER_LINKING, default=False): bool}
)
return self.async_show_form(
step_id="user_linking",
data_schema=data_schema,
errors=errors,
description_placeholders={},
)
async def async_step_finalize(self) -> FlowResult:
"""Finalize the configuration and create the config entry."""
await self.async_set_unique_id(DOMAIN)
self._abort_if_unique_id_configured()
# Build the configuration
config_data = {
"provider": self._flow_state.provider,
"client_id": self._client_config.client_id,
"discovery_url": self._flow_state.discovery_url,
"display_name": f"{self.current_provider_name} (OIDC)",
}
# Add optional fields
if self._client_config.client_secret:
config_data["client_secret"] = self._client_config.client_secret
# Configure features
features = {
"automatic_user_linking": self._feature_config.enable_user_linking,
"automatic_person_creation": True,
"include_groups_scope": self._feature_config.enable_groups,
}
config_data["features"] = features
# Configure claims using provider defaults
claims = self.current_provider_config["claims"].copy()
config_data["claims"] = claims
# Configure roles if groups are enabled
if self._feature_config.enable_groups:
roles = {}
if self._feature_config.admin_group:
roles["admin"] = self._feature_config.admin_group
if self._feature_config.user_group:
roles["user"] = self._feature_config.user_group
config_data["roles"] = roles
title = f"{self.current_provider_name} OIDC"
return self.async_create_entry(title=title, data=config_data)
async def _validate_reconfigure_input(
self, entry, user_input: dict[str, Any]
) -> tuple[dict[str, str], dict[str, Any] | None]:
"""Validate reconfigure input and return errors and data updates."""
errors = {}
# Validate client ID
client_id = user_input[CONF_CLIENT_ID].strip()
if not validate_client_id(client_id):
errors["client_id"] = "invalid_client_id"
return errors, None
# Determine confidentiality by presence of client secret
client_secret = user_input.get(CONF_CLIENT_SECRET, "").strip()
# If secret is empty, keep the existing one (if any)
if not client_secret:
client_secret = entry.data.get("client_secret")
# Test the new configuration
test_client = OIDCClient(
hass=self.hass,
discovery_url=entry.data["discovery_url"],
client_id=client_id,
scope="openid profile",
client_secret=client_secret or None,
features={},
claims={},
roles={},
network={},
)
# Validate the new credentials
discovery_doc = await test_client.validate_discovery()
if "jwks_uri" in discovery_doc:
await test_client.validate_jwks(discovery_doc["jwks_uri"])
# Build updated data
data_updates = {"client_id": client_id}
if client_secret:
data_updates["client_secret"] = client_secret
elif "client_secret" in entry.data and not client_secret:
# Remove client secret if switching from confidential to public
data_updates = {**entry.data, **data_updates}
data_updates.pop("client_secret", None)
return errors, data_updates
def _build_reconfigure_schema(
self, current_data: dict[str, Any], _user_input: dict[str, Any] | None
) -> vol.Schema:
"""Build the reconfigure form schema."""
schema_dict = {
vol.Required(
CONF_CLIENT_ID, default=current_data.get("client_id", vol.UNDEFINED)
): str,
}
# Always allow updating or clearing the client secret
schema_dict[vol.Optional(CONF_CLIENT_SECRET)] = str
return vol.Schema(schema_dict)
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Handle reconfiguration of OIDC client credentials."""
errors = {}
entry = self._get_reconfigure_entry()
if entry is None:
return self.async_abort(reason="unknown")
if user_input is not None:
try:
errors, data_updates = await self._validate_reconfigure_input(
entry, user_input
)
if not errors:
# Update the config entry
await self.async_set_unique_id(entry.unique_id)
self._abort_if_unique_id_mismatch()
return self.async_update_reload_and_abort(
entry, data_updates=data_updates
)
except OIDCDiscoveryInvalid:
errors["base"] = "discovery_invalid"
except OIDCJWKSInvalid:
errors["base"] = "jwks_invalid"
except aiohttp.ClientError:
errors["base"] = "cannot_connect"
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected error during reconfiguration")
errors["base"] = "unknown"
# Show form
current_data = entry.data
data_schema = self._build_reconfigure_schema(current_data, user_input)
return self.async_show_form(
step_id="reconfigure",
data_schema=data_schema,
errors=errors,
description_placeholders={
"provider_name": get_provider_name(current_data.get("provider")),
"discovery_url": current_data.get("discovery_url", ""),
},
)
def _get_reconfigure_entry(self):
"""Return the config entry being reconfigured if available.
Prefer the entry referenced by the flow context's entry_id. Fall back to the
first existing entry for this domain when only a single instance is allowed.
"""
# Try from flow context (preferred)
entry_id = None
context = getattr(self, "context", None)
if context and hasattr(context, "get"):
entry_id = context.get("entry_id")
if entry_id:
entry = self.hass.config_entries.async_get_entry(entry_id)
if entry and entry.domain == DOMAIN:
return entry
# Fallback: this integration allows a single instance; use the first
current = self._async_current_entries()
if current:
return current[0]
return None
@staticmethod
@callback
def async_get_options_flow(config_entry):
"""Get the options flow for this handler."""
return OIDCOptionsFlowHandler(config_entry)
class OIDCOptionsFlowHandler(config_entries.OptionsFlow):
"""Handle options flow for OIDC Authentication."""
def __init__(self, config_entry):
"""Initialize options flow."""
self.config_entry = config_entry
async def async_step_init(self, user_input=None):
"""Handle options flow."""
if user_input is not None:
# Process the updated configuration
updated_features = {
"automatic_user_linking": user_input.get("enable_user_linking", False),
"include_groups_scope": user_input.get("enable_groups", False),
}
updated_roles = {}
if user_input.get("enable_groups", False):
if user_input.get("admin_group"):
updated_roles["admin"] = user_input["admin_group"]
if user_input.get("user_group"):
updated_roles["user"] = user_input["user_group"]
# Update the config entry data
new_data = self.config_entry.data.copy()
new_data["features"] = {**new_data.get("features", {}), **updated_features}
if updated_roles:
new_data["roles"] = updated_roles
elif "roles" in new_data:
# Remove roles if groups are disabled
if not user_input.get("enable_groups", False):
del new_data["roles"]
# Update the config entry
self.hass.config_entries.async_update_entry(
self.config_entry, data=new_data
)
return self.async_create_entry(title="", data={})
current_config = self.config_entry.data
current_features = current_config.get("features", {})
current_roles = current_config.get("roles", {})
# Determine if this provider supports groups
provider = current_config.get("provider", "authentik")
provider_supports_groups = OIDC_PROVIDERS.get(provider, {}).get(
"supports_groups", True
)
# Build schema based on provider capabilities
schema_dict = {
vol.Optional(
"enable_user_linking",
default=current_features.get("automatic_user_linking", False),
): bool
}
# Add groups options if provider supports them
if provider_supports_groups:
enable_groups_default = current_features.get("include_groups_scope", False)
schema_dict[
vol.Optional("enable_groups", default=enable_groups_default)
] = bool
# Add group name fields if groups are currently enabled or being enabled
if enable_groups_default or (
user_input and user_input.get("enable_groups", False)
):
schema_dict.update(
{
vol.Optional(
"admin_group",
default=current_roles.get("admin", DEFAULT_ADMIN_GROUP),
): str,
vol.Optional(
"user_group", default=current_roles.get("user", "")
): str,
}
)
return self.async_show_form(
step_id="init",
data_schema=vol.Schema(schema_dict),
description_placeholders={
"provider_name": get_provider_name(provider),
},
)

View File

@@ -0,0 +1,7 @@
"""Imports manager"""
from .callback import OIDCCallbackView as OIDCCallbackView
from .finish import OIDCFinishView as OIDCFinishView
from .injected_auth_page import OIDCInjectedAuthPage as OIDCInjectedAuthPage
from .redirect import OIDCRedirectView as OIDCRedirectView
from .welcome import OIDCWelcomeView as OIDCWelcomeView

View File

@@ -2,9 +2,9 @@
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from aiohttp import web from aiohttp import web
from ..oidc_client import OIDCClient from ..tools.oidc_client import OIDCClient
from ..provider import OpenIDAuthProvider from ..provider import OpenIDAuthProvider
from ..helpers import get_url, get_view from ..tools.helpers import get_url, get_view
PATH = "/auth/oidc/callback" PATH = "/auth/oidc/callback"

View File

@@ -2,7 +2,7 @@
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from aiohttp import web from aiohttp import web
from ..helpers import get_view from ..tools.helpers import get_view
PATH = "/auth/oidc/finish" PATH = "/auth/oidc/finish"

View File

@@ -4,8 +4,8 @@ can either be linked to directly or accessed through the welcome page."""
from aiohttp import web from aiohttp import web
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from ..oidc_client import OIDCClient from ..tools.oidc_client import OIDCClient
from ..helpers import get_url, get_view from ..tools.helpers import get_url, get_view
PATH = "/auth/oidc/redirect" PATH = "/auth/oidc/redirect"

View File

@@ -2,7 +2,7 @@
from aiohttp import web from aiohttp import web
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from ..helpers import get_url, get_view from ..tools.helpers import get_url, get_view
PATH = "/auth/oidc/welcome" PATH = "/auth/oidc/welcome"

View File

@@ -1,6 +1,6 @@
{ {
"domain": "auth_oidc", "domain": "auth_oidc",
"name": "OIDC Authentication", "name": "OpenID Connect/SSO Authentication",
"codeowners": [ "codeowners": [
"@christiaangoossens" "@christiaangoossens"
], ],

View File

@@ -24,14 +24,14 @@ from homeassistant.components import http, person
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
import voluptuous as vol import voluptuous as vol
from .config import ( from .config.const import (
FEATURES, FEATURES,
FEATURES_AUTOMATIC_USER_LINKING, FEATURES_AUTOMATIC_USER_LINKING,
FEATURES_AUTOMATIC_PERSON_CREATION, FEATURES_AUTOMATIC_PERSON_CREATION,
DEFAULT_TITLE, DEFAULT_TITLE,
) )
from .stores.code_store import CodeStore from .stores.code_store import CodeStore
from .types import UserDetails from .tools.types import UserDetails
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)

View File

@@ -1,104 +0,0 @@
"""Provider catalog and helpers for OIDC providers."""
from __future__ import annotations
from typing import Any, Dict
DEFAULT_ADMIN_GROUP = "admins"
OIDC_PROVIDERS: Dict[str, Dict[str, Any]] = {
"authentik": {
"name": "Authentik",
"discovery_url": "",
"default_admin_group": DEFAULT_ADMIN_GROUP,
"supports_groups": True,
"claims": {
"display_name": "name",
"username": "preferred_username",
"groups": "groups",
},
},
"authelia": {
"name": "Authelia",
"discovery_url": "",
"default_admin_group": DEFAULT_ADMIN_GROUP,
"supports_groups": True,
"claims": {
"display_name": "name",
"username": "preferred_username",
"groups": "groups",
},
},
"pocketid": {
"name": "Pocket ID",
"discovery_url": "",
"default_admin_group": DEFAULT_ADMIN_GROUP,
"supports_groups": True,
"claims": {
"display_name": "name",
"username": "preferred_username",
"groups": "groups",
},
},
"kanidm": {
"name": "Kanidm",
"discovery_url": "",
"default_admin_group": DEFAULT_ADMIN_GROUP,
"supports_groups": True,
"claims": {
"display_name": "name",
"username": "preferred_username",
"groups": "groups",
},
},
"microsoft": {
"name": "Microsoft Entra ID",
"discovery_url": (
"https://login.microsoftonline.com/common/v2.0/"
".well-known/openid_configuration"
),
"default_admin_group": DEFAULT_ADMIN_GROUP,
"supports_groups": True,
"claims": {
"display_name": "name",
"username": "preferred_username",
"groups": "groups",
},
},
}
def get_provider_config(key: str) -> Dict[str, Any]:
"""Return provider configuration by key."""
return OIDC_PROVIDERS.get(key, {})
def get_provider_name(key: str | None) -> str:
"""Return provider display name by key."""
if not key:
return "Unknown Provider"
return OIDC_PROVIDERS.get(key, {}).get("name", "Unknown Provider")
def get_provider_docs_url(key: str | None) -> str:
"""Return documentation URL for a provider key."""
base_url = (
"https://github.com/christiaangoossens/hass-oidc-auth/blob/main"
"/docs/provider-configurations"
)
provider_docs = {
"authentik": f"{base_url}/authentik.md",
"authelia": f"{base_url}/authelia.md",
"pocketid": f"{base_url}/pocket-id.md",
"kanidm": f"{base_url}/kanidm.md",
"microsoft": f"{base_url}/microsoft-entra.md",
}
if key in provider_docs:
return provider_docs[key]
return (
"https://github.com/christiaangoossens/hass-oidc-auth"
"/blob/main/docs/configuration.md"
)

View File

@@ -8,7 +8,7 @@ from typing import cast, Optional
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from ..types import UserDetails from ..tools.types import UserDetails
STORAGE_VERSION = 1 STORAGE_VERSION = 1
STORAGE_KEY = "auth_provider.auth_oidc.codes" STORAGE_KEY = "auth_provider.auth_oidc.codes"

View File

@@ -1,7 +1,7 @@
"""Helper functions for the integration.""" """Helper functions for the integration."""
from homeassistant.components import http from homeassistant.components import http
from .views.loader import AsyncTemplateRenderer from ..views.loader import AsyncTemplateRenderer
def get_url(path: str, force_https: bool) -> str: def get_url(path: str, force_https: bool) -> str:

View File

@@ -13,7 +13,7 @@ from jose import jwt, jwk
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .types import UserDetails from .types import UserDetails
from .config import ( from ..config.const import (
FEATURES_DISABLE_PKCE, FEATURES_DISABLE_PKCE,
CLAIMS_DISPLAY_NAME, CLAIMS_DISPLAY_NAME,
CLAIMS_USERNAME, CLAIMS_USERNAME,
@@ -22,7 +22,9 @@ from .config import (
ROLE_USERS, ROLE_USERS,
NETWORK_TLS_VERIFY, NETWORK_TLS_VERIFY,
NETWORK_TLS_CA_PATH, NETWORK_TLS_CA_PATH,
DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
) )
from .validation import validate_url
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -34,6 +36,32 @@ class OIDCClientException(Exception):
class OIDCDiscoveryInvalid(OIDCClientException): class OIDCDiscoveryInvalid(OIDCClientException):
"Raised when the discovery document is not found, invalid or otherwise malformed." "Raised when the discovery document is not found, invalid or otherwise malformed."
type: Optional[str]
details: Optional[dict]
def __init__(self, *args, **kwargs):
if args:
self.message = args[0]
else:
self.message = "OIDC Discovery document is invalid"
self.type = kwargs.pop("type", None)
self.details = kwargs.pop("details", None)
super().__init__(self.message)
def get_detail_string(self) -> str:
"""Returns a detailed string for logging purposes."""
string = []
if self.type:
string.append(f"type: {self.type}")
if self.details:
for key, value in self.details.items():
string.append(f"{key}: {value}")
return ", ".join(string)
class OIDCTokenResponseInvalid(OIDCClientException): class OIDCTokenResponseInvalid(OIDCClientException):
"Raised when the token request returns invalid." "Raised when the token request returns invalid."
@@ -68,6 +96,199 @@ class HTTPClientError(aiohttp.ClientResponseError):
return f"{self.status} ({self.message}) with response body: {self.body}" return f"{self.status} ({self.message}) with response body: {self.body}"
async def http_raise_for_status(response: aiohttp.ClientResponse) -> None:
"""Raises an exception if the response is not OK."""
if not response.ok:
# reason should always be not None for a started response
assert response.reason is not None
body = await response.text()
raise HTTPClientError(
response.request_info,
response.history,
status=response.status,
message=response.reason,
headers=response.headers,
body=body,
)
class OIDCDiscoveryClient:
"""OIDC Discovery Client implementation for Python"""
def __init__(
self,
discovery_url: str,
http_session: aiohttp.ClientSession,
verification_context: dict,
):
self.discovery_url = discovery_url
self.http_session = http_session
self.verification_context = verification_context
async def _fetch_discovery_document(self):
"""Fetches discovery document from the given URL."""
try:
async with self.http_session.get(self.discovery_url) as response:
await http_raise_for_status(response)
return await response.json()
except HTTPClientError as e:
if e.status == 404:
_LOGGER.warning(
"Error: Discovery document not found at %s", self.discovery_url
)
else:
_LOGGER.warning("Error fetching discovery: %s", e)
raise OIDCDiscoveryInvalid(type="fetch_error") from e
async def _fetch_jwks(self, jwks_uri):
"""Fetches JWKS from the given URL."""
try:
async with self.http_session.get(jwks_uri) as response:
await http_raise_for_status(response)
return await response.json()
except HTTPClientError as e:
_LOGGER.warning("Error fetching JWKS: %s", e)
raise OIDCJWKSInvalid from e
# pylint: disable=too-many-branches
async def _validate_discovery_document(self, document):
"""Validates the discovery document."""
# Verify that required endpoints are present
required_endpoints = [
"issuer",
"authorization_endpoint",
"token_endpoint",
"jwks_uri",
]
for endpoint in required_endpoints:
if endpoint not in document:
_LOGGER.warning(
"Error: Discovery document %s is missing required endpoint: %s",
self.discovery_url,
endpoint,
)
raise OIDCDiscoveryInvalid(
type="missing_endpoint", details={"endpoint": endpoint}
)
if validate_url(document[endpoint]) is False:
_LOGGER.warning(
"Error: Discovery document %s has invalid URL in endpoint: %s (%s)",
self.discovery_url,
endpoint,
document[endpoint],
)
raise OIDCDiscoveryInvalid(
type="invalid_endpoint",
details={"endpoint": endpoint, "url": document[endpoint]},
)
# Verify optional response_modes_supported
if "response_modes_supported" in document:
if "query" not in document["response_modes_supported"]:
_LOGGER.warning(
"Error: Discovery document %s does not support required 'query' "
"response mode, only supports: %s",
self.discovery_url,
document["response_modes_supported"],
)
raise OIDCDiscoveryInvalid(
type="does_not_support_response_mode",
modes=document["response_modes_supported"],
)
# If grant_types_supported is set, should support 'authorization_code'
if "grant_types_supported" in document:
if "authorization_code" not in document["grant_types_supported"]:
_LOGGER.warning(
"Error: Discovery document %s does not support required "
"'authorization_code' grant type, only supports: %s",
self.discovery_url,
document["grant_types_supported"],
)
raise OIDCDiscoveryInvalid(
type="does_not_support_grant_type",
details={
"required": "authorization_code",
"supported": document["grant_types_supported"],
},
)
# If response_types_supported is set, should support 'code'
if "response_types_supported" in document:
if "code" not in document["response_types_supported"]:
_LOGGER.warning(
"Error: Discovery document %s does not support required "
"'code' response type, only supports: %s",
self.discovery_url,
document["response_types_supported"],
)
raise OIDCDiscoveryInvalid(
type="does_not_support_response_type",
details={
"required": "code",
"supported": document["response_types_supported"],
},
)
# If code_challenge_methods_supported is present, check that it contains S256
if "code_challenge_methods_supported" in document:
if "S256" not in document["code_challenge_methods_supported"]:
_LOGGER.warning(
"Error: Discovery document %s does not support required "
"'S256' code challenge method, only supports: %s",
self.discovery_url,
document["code_challenge_methods_supported"],
)
raise OIDCDiscoveryInvalid(
type="does_not_support_required_code_challenge_method",
details={
"required": "S256",
"supported": document["code_challenge_methods_supported"],
},
)
# Verify the id_token_signing_alg_values_supported field is present and filled
signing_values = document.get("id_token_signing_alg_values_supported", None)
if signing_values is None:
_LOGGER.warning(
"Error: Discovery document %s does not have "
"'id_token_signing_alg_values_supported' field",
self.discovery_url,
)
raise OIDCDiscoveryInvalid(type="missing_id_token_signing_alg_values")
# Verify that the requested id_token_signing_alg is supported
requested_alg = self.verification_context.get("id_token_signing_alg", None)
if requested_alg is not None and requested_alg not in signing_values:
_LOGGER.warning(
"Error: Discovery document %s does not support requested "
"id_token_signing_alg '%s', only supports: %s",
self.discovery_url,
requested_alg,
signing_values,
)
raise OIDCDiscoveryInvalid(
type="does_not_support_id_token_signing_alg",
details={"requested": requested_alg, "supported": signing_values},
)
async def fetch_discovery_document(self):
"""Fetches discovery document."""
document = await self._fetch_discovery_document()
await self._validate_discovery_document(document)
return document
async def fetch_jwks(self, jwks_uri: str | None):
"""Fetches JWKS."""
if jwks_uri is None:
discovery_document = await self._fetch_discovery_document()
jwks_uri = discovery_document["jwks_uri"]
return await self._fetch_jwks(jwks_uri)
# pylint: disable=too-many-instance-attributes # pylint: disable=too-many-instance-attributes
class OIDCClient: class OIDCClient:
"""OIDC Client implementation for Python, including PKCE.""" """OIDC Client implementation for Python, including PKCE."""
@@ -78,6 +299,9 @@ class OIDCClient:
# HTTP session to be used # HTTP session to be used
http_session: aiohttp.ClientSession = None http_session: aiohttp.ClientSession = None
# OIDC Discovery tool to be used
discovery_class: OIDCDiscoveryClient = None
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
@@ -98,7 +322,7 @@ class OIDCClient:
# Default id_token_signing_alg to RS256 if not specified # Default id_token_signing_alg to RS256 if not specified
self.id_token_signing_alg = kwargs.get("id_token_signing_alg") self.id_token_signing_alg = kwargs.get("id_token_signing_alg")
if self.id_token_signing_alg is None: if self.id_token_signing_alg is None:
self.id_token_signing_alg = "RS256" self.id_token_signing_alg = DEFAULT_ID_TOKEN_SIGNING_ALGORITHM
features = kwargs.get("features") features = kwargs.get("features")
claims = kwargs.get("claims") claims = kwargs.get("claims")
@@ -122,23 +346,6 @@ class OIDCClient:
_LOGGER.debug("Closing HTTP session") _LOGGER.debug("Closing HTTP session")
self.http_session.close() self.http_session.close()
async def http_raise_for_status(self, response: aiohttp.ClientResponse) -> None:
"""Raises an exception if the response is not OK."""
if not response.ok:
# reason should always be not None for a started response
assert response.reason is not None
body = await response.text()
raise HTTPClientError(
response.request_info,
response.history,
status=response.status,
message=response.reason,
headers=response.headers,
body=body,
)
def _base64url_encode(self, value: str) -> str: def _base64url_encode(self, value: str) -> str:
"""Uses base64url encoding on a given string""" """Uses base64url encoding on a given string"""
return base64.urlsafe_b64encode(value).rstrip(b"=").decode("utf-8") return base64.urlsafe_b64encode(value).rstrip(b"=").decode("utf-8")
@@ -173,42 +380,13 @@ class OIDCClient:
) )
return self.http_session return self.http_session
async def _fetch_discovery_document(self):
"""Fetches discovery document from the given URL."""
try:
session = await self._get_http_session()
async with session.get(self.discovery_url) as response:
await self.http_raise_for_status(response)
return await response.json()
except HTTPClientError as e:
if e.status == 404:
_LOGGER.warning(
"Error: Discovery document not found at %s", self.discovery_url
)
else:
_LOGGER.warning("Error fetching discovery: %s", e)
raise OIDCDiscoveryInvalid from e
async def _get_jwks(self, jwks_uri):
"""Fetches JWKS from the given URL."""
try:
session = await self._get_http_session()
async with session.get(jwks_uri) as response:
await self.http_raise_for_status(response)
return await response.json()
except HTTPClientError as e:
_LOGGER.warning("Error fetching JWKS: %s", e)
raise OIDCJWKSInvalid from e
async def _make_token_request(self, token_endpoint, query_params): async def _make_token_request(self, token_endpoint, query_params):
"""Performs the token POST call""" """Performs the token POST call"""
try: try:
session = await self._get_http_session() session = await self._get_http_session()
async with session.post(token_endpoint, data=query_params) as response: async with session.post(token_endpoint, data=query_params) as response:
await self.http_raise_for_status(response) await http_raise_for_status(response)
return await response.json() return await response.json()
except HTTPClientError as e: except HTTPClientError as e:
if e.status == 400: if e.status == 400:
@@ -231,12 +409,34 @@ class OIDCClient:
headers = {"Authorization": "Bearer " + access_token} headers = {"Authorization": "Bearer " + access_token}
async with session.get(userinfo_uri, headers=headers) as response: async with session.get(userinfo_uri, headers=headers) as response:
await self.http_raise_for_status(response) await http_raise_for_status(response)
return await response.json() return await response.json()
except HTTPClientError as e: except HTTPClientError as e:
_LOGGER.warning("Error fetching userinfo: %s", e) _LOGGER.warning("Error fetching userinfo: %s", e)
raise OIDCUserinfoInvalid from e raise OIDCUserinfoInvalid from e
async def _fetch_discovery_document(self):
"""Fetches discovery document."""
if self.discovery_document is not None:
return self.discovery_document
if self.discovery_class is None:
session = await self._get_http_session()
self.discovery_class = OIDCDiscoveryClient(
discovery_url=self.discovery_url,
http_session=session,
verification_context={
"id_token_signing_alg": self.id_token_signing_alg,
},
)
self.discovery_document = await self.discovery_class.fetch_discovery_document()
return self.discovery_document
async def _fetch_jwks(self, jwks_uri: str):
"""Fetches JWKS."""
return await self.discovery_class.fetch_jwks(jwks_uri)
async def _parse_id_token( async def _parse_id_token(
self, id_token: str, access_token: str | None self, id_token: str, access_token: str | None
) -> Optional[dict]: ) -> Optional[dict]:
@@ -245,7 +445,7 @@ class OIDCClient:
self.discovery_document = await self._fetch_discovery_document() self.discovery_document = await self._fetch_discovery_document()
jwks_uri = self.discovery_document["jwks_uri"] jwks_uri = self.discovery_document["jwks_uri"]
jwks_data = await self._get_jwks(jwks_uri) jwks_data = await self._fetch_jwks(jwks_uri)
try: try:
# Obtain the id_token header # Obtain the id_token header
@@ -369,10 +569,8 @@ class OIDCClient:
async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]: async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]:
"""Generates the authorization URL for the OIDC flow.""" """Generates the authorization URL for the OIDC flow."""
try: try:
if self.discovery_document is None: discovery_document = await self._fetch_discovery_document()
self.discovery_document = await self._fetch_discovery_document() auth_endpoint = discovery_document["authorization_endpoint"]
auth_endpoint = self.discovery_document["authorization_endpoint"]
# Generate random nonce & state # Generate random nonce & state
nonce = self._generate_random_url_string() nonce = self._generate_random_url_string()
@@ -417,8 +615,9 @@ class OIDCClient:
# Fetch userinfo if there is an userinfo_endpoint available # Fetch userinfo if there is an userinfo_endpoint available
# and use the data to supply the missing values in id_token # and use the data to supply the missing values in id_token
if "userinfo_endpoint" in self.discovery_document: discovery_document = await self._fetch_discovery_document()
userinfo_endpoint = self.discovery_document["userinfo_endpoint"] if "userinfo_endpoint" in discovery_document:
userinfo_endpoint = discovery_document["userinfo_endpoint"]
userinfo = await self._get_userinfo(userinfo_endpoint, access_token) userinfo = await self._get_userinfo(userinfo_endpoint, access_token)
# Replace missing claims in the id_token with their userinfo version # Replace missing claims in the id_token with their userinfo version
@@ -451,9 +650,7 @@ class OIDCClient:
# Only unique per issuer, so we combine it with the issuer and hash it. # Only unique per issuer, so we combine it with the issuer and hash it.
# This might allow multiple OIDC providers to be used with this integration. # This might allow multiple OIDC providers to be used with this integration.
"sub": hashlib.sha256( "sub": hashlib.sha256(
f"{self.discovery_document['issuer']}.{id_token.get('sub')}".encode( f"{discovery_document['issuer']}.{id_token.get('sub')}".encode("utf-8")
"utf-8"
)
).hexdigest(), ).hexdigest(),
# Display name, configurable # Display name, configurable
"display_name": id_token.get(self.display_name_claim), "display_name": id_token.get(self.display_name_claim),
@@ -474,10 +671,8 @@ class OIDCClient:
flow = self.flows[state] flow = self.flows[state]
if self.discovery_document is None: discovery_document = await self._fetch_discovery_document()
self.discovery_document = await self._fetch_discovery_document() token_endpoint = discovery_document["token_endpoint"]
token_endpoint = self.discovery_document["token_endpoint"]
# Construct the params # Construct the params
query_params = { query_params = {
@@ -532,21 +727,3 @@ class OIDCClient:
except OIDCClientException as e: except OIDCClientException as e:
_LOGGER.warning("Failed to complete token flow, returning None. (%s)", e) _LOGGER.warning("Failed to complete token flow, returning None. (%s)", e)
return None return None
async def validate_discovery(self):
"""Validate that the discovery document can be fetched and is valid.
Public method for configuration validation.
Returns the discovery document if valid.
Raises OIDCDiscoveryInvalid if invalid.
"""
return await self._fetch_discovery_document()
async def validate_jwks(self, jwks_uri: str):
"""Validate that the JWKS can be fetched from the given URI.
Public method for configuration validation.
Returns the JWKS if valid.
Raises OIDCJWKSInvalid if invalid.
"""
return await self._get_jwks(jwks_uri)

View File

@@ -5,11 +5,24 @@ from __future__ import annotations
from urllib.parse import urlparse from urllib.parse import urlparse
def validate_url(url: str) -> bool:
"""Validate that a URL is properly formatted."""
try:
parsed = urlparse(url.strip())
return bool(parsed.scheme in ("http", "https") and parsed.netloc)
except (ValueError, TypeError):
return False
def validate_discovery_url(url: str) -> bool: def validate_discovery_url(url: str) -> bool:
"""Validate that a URL is properly formatted for OIDC discovery.""" """Validate that a URL is properly formatted for OIDC discovery."""
try: try:
parsed = urlparse(url.strip()) parsed = urlparse(url.strip())
return bool(parsed.scheme in ("http", "https") and parsed.netloc) return bool(
parsed.scheme in ("http", "https")
and parsed.netloc
and parsed.path.endswith("/.well-known/openid-configuration")
)
except (ValueError, TypeError): except (ValueError, TypeError):
return False return False

View File

@@ -3,18 +3,25 @@
"step": { "step": {
"user": { "user": {
"title": "Choose OIDC Provider", "title": "Choose OIDC Provider",
"description": "Select your OpenID Connect identity provider to get started with the setup.", "description": "Select your OpenID Connect identity provider to get started with the setup.\n\nIf you want to use a provider that isn't listed, try the Generic OpenID Connect provider or use the advanced YAML configuration instead.",
"data": { "data": {
"provider": "Provider" "provider": "Provider"
} }
}, },
"discovery_url": { "discovery_url": {
"title": "Provider Configuration", "title": "Provider Configuration",
"description": "Enter the discovery URL for {provider_name}. This is typically found in your provider's documentation and usually ends with '/.well-known/openid-configuration'.\n\nNeed detailed setup instructions? See the [provider guide]({documentation_url}).", "description": "Enter the discovery URL for {provider_name}. This is typically found in your provider's admin interface and ends with '/.well-known/openid-configuration'.\n\nNeed detailed setup instructions? See the [provider guide]({documentation_url}).",
"data": { "data": {
"discovery_url": "Discovery URL" "discovery_url": "Discovery URL"
} }
}, },
"validate_connection": {
"title": "Connection Validation",
"description": "Testing connection to your {provider_name} OIDC provider...\n\n**Discovery URL:** {discovery_url}\n\n{discovery_details}\n\n**What to do next:**\n- **Continue Setup:** Proceed with the configuration (when validation succeeds)\n- **Retry Validation:** Test the connection again with current settings\n- **Modify Discovery URL:** Go back to change the discovery URL\n- **Change Provider:** Start over with a different provider\n\n**Need Help?** Check the [setup documentation]({documentation_url}) for detailed configuration instructions.",
"data": {
"action": "Choose an action"
}
},
"client_config": { "client_config": {
"title": "Client Configuration", "title": "Client Configuration",
"description": "Configure your OIDC client. You can find these details in your {provider_name} application settings.\n\n**Discovery URL:** {discovery_url}\n\n**Setup Instructions:**\n1. Register a new application in your OIDC provider\n2. Set the application type to 'Public Client' (recommended for most users)\n3. Add redirect URLs for Home Assistant\n4. Copy the Client ID below\n\n**Note:** If your provider requires a client secret, check 'Use Confidential Client' and provide your client secret below.\n\n**Need detailed setup instructions?** Check the [setup guide]({documentation_url}) for step-by-step instructions.", "description": "Configure your OIDC client. You can find these details in your {provider_name} application settings.\n\n**Discovery URL:** {discovery_url}\n\n**Setup Instructions:**\n1. Register a new application in your OIDC provider\n2. Set the application type to 'Public Client' (recommended for most users)\n3. Add redirect URLs for Home Assistant\n4. Copy the Client ID below\n\n**Note:** If your provider requires a client secret, check 'Use Confidential Client' and provide your client secret below.\n\n**Need detailed setup instructions?** Check the [setup guide]({documentation_url}) for step-by-step instructions.",
@@ -23,20 +30,6 @@
"client_secret": "Client Secret (optional; required by some providers)" "client_secret": "Client Secret (optional; required by some providers)"
} }
}, },
"client_secret": {
"title": "Client Secret Configuration",
"description": "Since you selected 'Confidential Client', please provide your client secret.\n\n**Provider:** {provider_name}\n**Client ID:** {client_id}\n**Discovery URL:** {discovery_url}\n\n**Security Note:** The client secret will be stored securely in Home Assistant's configuration. Never share your client secret with others.",
"data": {
"client_secret": "Client Secret"
}
},
"validate_connection": {
"title": "Connection Validation",
"description": "Testing connection to your {provider_name} OIDC provider...\n\n**Discovery URL:** {discovery_url}\n**Client ID:** {client_id}\n\n{discovery_details}\n\n**What to do next:**\n- **Continue Setup:** Proceed with the configuration (when validation succeeds)\n- **Retry Validation:** Test the connection again with current settings\n- **Modify Client Settings:** Go back to change Client ID or secret\n- **Modify Discovery URL:** Go back to change the discovery URL\n- **Change Provider:** Start over with a different provider\n\n**Need Help?** Check the [setup documentation]({documentation_url}) for detailed configuration instructions.",
"data": {
"action": "Choose an action"
}
},
"groups_config": { "groups_config": {
"title": "Groups & Role Configuration", "title": "Groups & Role Configuration",
"description": "Configure how user groups from {provider_name} should be mapped to Home Assistant roles.\n\n**Groups Support:** Groups allow you to automatically assign admin or user roles based on group membership in your identity provider.\n\n**Admin Group:** Users in this group will have administrator access\n**User Group:** Users in this group will have standard user access (leave empty to allow all authenticated users)", "description": "Configure how user groups from {provider_name} should be mapped to Home Assistant roles.\n\n**Groups Support:** Groups allow you to automatically assign admin or user roles based on group membership in your identity provider.\n\n**Admin Group:** Users in this group will have administrator access\n**User Group:** Users in this group will have standard user access (leave empty to allow all authenticated users)",
@@ -71,12 +64,8 @@
"cannot_connect": "Failed to connect to the OIDC provider. Please check your network connection and discovery URL.", "cannot_connect": "Failed to connect to the OIDC provider. Please check your network connection and discovery URL.",
"discovery_invalid": "The discovery document could not be retrieved or is invalid. Please verify the discovery URL is correct.", "discovery_invalid": "The discovery document could not be retrieved or is invalid. Please verify the discovery URL is correct.",
"jwks_invalid": "Failed to retrieve or validate the JWKS (JSON Web Key Set). Please check your provider configuration.", "jwks_invalid": "Failed to retrieve or validate the JWKS (JSON Web Key Set). Please check your provider configuration.",
"invalid_client_credentials": "The client ID or client secret appears to be invalid. Please check your OIDC application settings and ensure the credentials are correct.", "invalid_url_format": "The discovery URL must be a valid HTTP or HTTPS URL and should end with '/.well-known/openid-configuration'",
"client_secret_required": "Client secret is required when using confidential client mode.",
"invalid_url_format": "The discovery URL must be a valid HTTP or HTTPS URL.",
"invalid_client_id": "Client ID cannot be empty and must contain valid characters.", "invalid_client_id": "Client ID cannot be empty and must contain valid characters.",
"no_url_available": "Unable to determine Home Assistant URL for OAuth redirect. Please check your network configuration.",
"auth_url_failed": "Failed to generate authorization URL for OAuth test.",
"unknown": "An unexpected error occurred. Please check the logs for more details." "unknown": "An unexpected error occurred. Please check the logs for more details."
}, },
"abort": { "abort": {
@@ -84,7 +73,8 @@
"cannot_connect": "Unable to connect to the OIDC provider.", "cannot_connect": "Unable to connect to the OIDC provider.",
"invalid_discovery": "Invalid discovery document received from the provider.", "invalid_discovery": "Invalid discovery document received from the provider.",
"reconfigure_successful": "OIDC Authentication has been successfully reconfigured with the updated client credentials.", "reconfigure_successful": "OIDC Authentication has been successfully reconfigured with the updated client credentials.",
"single_instance_allowed": "OIDC Authentication only supports a single configuration. You already have OIDC configured (either through YAML or the UI). To modify your existing configuration, go to Settings > Devices & Services > OIDC Authentication and click 'Configure'. To replace your configuration, first remove the existing one." "single_instance_allowed": "OIDC Authentication only supports a single configuration. You already have OIDC configured in the UI. To modify your existing configuration, go to Settings > Devices & Services > OIDC Authentication and click 'Configure'. To replace your configuration, first remove the existing one.",
"yaml_configured": "You are currently using YAML configuration for this integration. To switch to UI configuration, please remove the YAML configuration first. Note that some advanced options configured via YAML may not be available in the UI."
} }
}, },
"options": { "options": {

View File

@@ -1,4 +1,4 @@
#! /bin/bash #! /bin/bash
uv run ruff check uv run ruff check
uv run ruff format --check uv run ruff format --check
uv run pylint custom_components uv run pylint custom_components --allow-reexport-from-package true