Move some code around and improve validation (#128)
This commit is contained in:
committed by
GitHub
parent
3b481cd282
commit
d1da841e1f
@@ -9,8 +9,10 @@ from homeassistant.core import HomeAssistant
|
||||
|
||||
# Import and re-export config schema explictly
|
||||
# pylint: disable=useless-import-alias
|
||||
from .config import CONFIG_SCHEMA as CONFIG_SCHEMA
|
||||
|
||||
# Get all the constants for the config
|
||||
from .config import (
|
||||
CONFIG_SCHEMA as CONFIG_SCHEMA,
|
||||
DOMAIN,
|
||||
DEFAULT_TITLE,
|
||||
CLIENT_ID,
|
||||
@@ -27,17 +29,19 @@ from .config import (
|
||||
FEATURES_INCLUDE_GROUPS_SCOPE,
|
||||
FEATURES_DISABLE_FRONTEND_INJECTION,
|
||||
FEATURES_FORCE_HTTPS,
|
||||
REQUIRED_SCOPES,
|
||||
)
|
||||
|
||||
# pylint: enable=useless-import-alias
|
||||
from .config import convert_ui_config_entry_to_internal_format
|
||||
|
||||
from .endpoints.welcome import OIDCWelcomeView
|
||||
from .endpoints.redirect import OIDCRedirectView
|
||||
from .endpoints.finish import OIDCFinishView
|
||||
from .endpoints.callback import OIDCCallbackView
|
||||
from .endpoints.injected_auth_page import OIDCInjectedAuthPage
|
||||
|
||||
from .oidc_client import OIDCClient
|
||||
from .endpoints import (
|
||||
OIDCWelcomeView,
|
||||
OIDCRedirectView,
|
||||
OIDCFinishView,
|
||||
OIDCCallbackView,
|
||||
OIDCInjectedAuthPage,
|
||||
)
|
||||
from .tools.oidc_client import OIDCClient
|
||||
from .provider import OpenIDAuthProvider
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@@ -62,12 +66,12 @@ async def async_setup(hass: HomeAssistant, config):
|
||||
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
|
||||
"""Set up OIDC Authentication from a config entry."""
|
||||
"""Set up OIDC Authentication from a config entry (UI config)."""
|
||||
# Convert config entry data to the format expected by the existing setup
|
||||
config_data = entry.data.copy()
|
||||
|
||||
# Convert config entry format to internal format
|
||||
my_config = _convert_config_entry_to_internal_format(config_data)
|
||||
my_config = convert_ui_config_entry_to_internal_format(config_data)
|
||||
|
||||
# Get display name from config entry
|
||||
display_name = config_data.get("display_name", DEFAULT_TITLE)
|
||||
@@ -83,36 +87,6 @@ async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry):
|
||||
return False
|
||||
|
||||
|
||||
def _convert_config_entry_to_internal_format(config_data: dict) -> dict:
|
||||
"""Convert config entry data to internal configuration format."""
|
||||
my_config = {}
|
||||
|
||||
# Required fields
|
||||
my_config[CLIENT_ID] = config_data["client_id"]
|
||||
my_config[DISCOVERY_URL] = config_data["discovery_url"]
|
||||
|
||||
# Optional fields
|
||||
if "client_secret" in config_data:
|
||||
my_config[CLIENT_SECRET] = config_data["client_secret"]
|
||||
|
||||
if "display_name" in config_data:
|
||||
my_config[DISPLAY_NAME] = config_data["display_name"]
|
||||
|
||||
# Features configuration
|
||||
if "features" in config_data:
|
||||
my_config[FEATURES] = config_data["features"]
|
||||
|
||||
# Claims configuration
|
||||
if "claims" in config_data:
|
||||
my_config[CLAIMS] = config_data["claims"]
|
||||
|
||||
# Roles configuration
|
||||
if "roles" in config_data:
|
||||
my_config[ROLES] = config_data["roles"]
|
||||
|
||||
return my_config
|
||||
|
||||
|
||||
async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str):
|
||||
"""Set up the OIDC provider with the given configuration."""
|
||||
providers = OrderedDict()
|
||||
@@ -131,7 +105,7 @@ async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_nam
|
||||
# Set the correct scopes
|
||||
# Always use 'openid' & 'profile' as they are specified in the OIDC spec
|
||||
# All servers should support this
|
||||
scope = "openid profile"
|
||||
scope = REQUIRED_SCOPES
|
||||
|
||||
# Include groups if requested (default is to include 'groups'
|
||||
# as a scope for Authelia & Authentik)
|
||||
|
||||
8
custom_components/auth_oidc/config/__init__.py
Normal file
8
custom_components/auth_oidc/config/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""Imports manager"""
|
||||
|
||||
from .const import * # noqa: F403
|
||||
from .schema import CONFIG_SCHEMA as CONFIG_SCHEMA
|
||||
from .ui_flow import (
|
||||
OIDCConfigFlow as OIDCConfigFlow,
|
||||
convert_ui_config_entry_to_internal_format as convert_ui_config_entry_to_internal_format,
|
||||
)
|
||||
92
custom_components/auth_oidc/config/const.py
Normal file
92
custom_components/auth_oidc/config/const.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Config constants."""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
## ===
|
||||
## General integration constants
|
||||
## ===
|
||||
|
||||
DEFAULT_TITLE = "OpenID Connect (SSO)"
|
||||
DOMAIN = "auth_oidc"
|
||||
REPO_ROOT_URL = "https://github.com/christiaangoossens/hass-oidc-auth/tree/v0.7.0-alpha"
|
||||
|
||||
## ===
|
||||
## Config keys
|
||||
## ===
|
||||
|
||||
CLIENT_ID = "client_id"
|
||||
CLIENT_SECRET = "client_secret"
|
||||
DISCOVERY_URL = "discovery_url"
|
||||
DISPLAY_NAME = "display_name"
|
||||
ID_TOKEN_SIGNING_ALGORITHM = "id_token_signing_alg"
|
||||
GROUPS_SCOPE = "groups_scope"
|
||||
ADDITIONAL_SCOPES = "additional_scopes"
|
||||
FEATURES = "features"
|
||||
FEATURES_AUTOMATIC_USER_LINKING = "automatic_user_linking"
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION = "automatic_person_creation"
|
||||
FEATURES_DISABLE_PKCE = "disable_rfc7636"
|
||||
FEATURES_INCLUDE_GROUPS_SCOPE = "include_groups_scope"
|
||||
FEATURES_DISABLE_FRONTEND_INJECTION = "disable_frontend_changes"
|
||||
FEATURES_FORCE_HTTPS = "force_https"
|
||||
CLAIMS = "claims"
|
||||
CLAIMS_DISPLAY_NAME = "display_name"
|
||||
CLAIMS_USERNAME = "username"
|
||||
CLAIMS_GROUPS = "groups"
|
||||
ROLES = "roles"
|
||||
ROLE_ADMINS = "admin"
|
||||
ROLE_USERS = "user"
|
||||
NETWORK = "network"
|
||||
NETWORK_TLS_VERIFY = "tls_verify"
|
||||
NETWORK_TLS_CA_PATH = "tls_ca_path"
|
||||
|
||||
## ===
|
||||
## Default configurations for providers
|
||||
## ===
|
||||
|
||||
REQUIRED_SCOPES = "openid profile"
|
||||
DEFAULT_ID_TOKEN_SIGNING_ALGORITHM = "RS256"
|
||||
|
||||
DEFAULT_GROUPS_SCOPE = "groups"
|
||||
DEFAULT_ADMIN_GROUP = "admins"
|
||||
|
||||
OIDC_PROVIDERS: Dict[str, Dict[str, Any]] = {
|
||||
"authentik": {
|
||||
"name": "Authentik",
|
||||
"discovery_url": "",
|
||||
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||
"supports_groups": True,
|
||||
"claims": {
|
||||
"display_name": "name",
|
||||
"username": "preferred_username",
|
||||
"groups": "groups",
|
||||
},
|
||||
},
|
||||
"authelia": {
|
||||
"name": "Authelia",
|
||||
"discovery_url": "",
|
||||
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||
"supports_groups": True,
|
||||
"claims": {
|
||||
"display_name": "name",
|
||||
"username": "preferred_username",
|
||||
"groups": "groups",
|
||||
},
|
||||
},
|
||||
"pocketid": {
|
||||
"name": "Pocket ID",
|
||||
"discovery_url": "",
|
||||
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||
"supports_groups": True,
|
||||
"claims": {
|
||||
"display_name": "name",
|
||||
"username": "preferred_username",
|
||||
"groups": "groups",
|
||||
},
|
||||
},
|
||||
"generic": {
|
||||
"name": "OpenID Connect (SSO)",
|
||||
"discovery_url": "",
|
||||
"supports_groups": False,
|
||||
"claims": {"display_name": "name", "username": "preferred_username"},
|
||||
},
|
||||
}
|
||||
35
custom_components/auth_oidc/config/provider_catalog.py
Normal file
35
custom_components/auth_oidc/config/provider_catalog.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Provider catalog and helpers for OIDC providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Any, Dict
|
||||
|
||||
from .const import OIDC_PROVIDERS, REPO_ROOT_URL
|
||||
|
||||
|
||||
def get_provider_config(key: str) -> Dict[str, Any]:
|
||||
"""Return provider configuration by key."""
|
||||
return OIDC_PROVIDERS.get(key, {})
|
||||
|
||||
|
||||
def get_provider_name(key: str | None) -> str:
|
||||
"""Return provider display name by key."""
|
||||
if not key:
|
||||
return "Unknown Provider"
|
||||
return OIDC_PROVIDERS.get(key, {}).get("name", "Unknown Provider")
|
||||
|
||||
|
||||
def get_provider_docs_url(key: str | None) -> str:
|
||||
"""Return documentation URL for a provider key."""
|
||||
base_url = REPO_ROOT_URL + "/docs/provider-configurations"
|
||||
|
||||
provider_docs = {
|
||||
"authentik": f"{base_url}/authentik.md",
|
||||
"authelia": f"{base_url}/authelia.md",
|
||||
"pocketid": f"{base_url}/pocket-id.md",
|
||||
"kanidm": f"{base_url}/kanidm.md",
|
||||
"microsoft": f"{base_url}/microsoft-entra.md",
|
||||
}
|
||||
|
||||
if key in provider_docs:
|
||||
return provider_docs[key]
|
||||
return REPO_ROOT_URL + "/docs/configuration.md"
|
||||
@@ -1,36 +1,35 @@
|
||||
"""Config schema and constants."""
|
||||
"""Config schema"""
|
||||
|
||||
import voluptuous as vol
|
||||
from .const import (
|
||||
CLIENT_ID,
|
||||
CLIENT_SECRET,
|
||||
DISCOVERY_URL,
|
||||
DISPLAY_NAME,
|
||||
ID_TOKEN_SIGNING_ALGORITHM,
|
||||
GROUPS_SCOPE,
|
||||
ADDITIONAL_SCOPES,
|
||||
FEATURES,
|
||||
FEATURES_AUTOMATIC_USER_LINKING,
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION,
|
||||
FEATURES_DISABLE_PKCE,
|
||||
FEATURES_INCLUDE_GROUPS_SCOPE,
|
||||
FEATURES_DISABLE_FRONTEND_INJECTION,
|
||||
FEATURES_FORCE_HTTPS,
|
||||
CLAIMS,
|
||||
CLAIMS_DISPLAY_NAME,
|
||||
CLAIMS_USERNAME,
|
||||
CLAIMS_GROUPS,
|
||||
ROLES,
|
||||
ROLE_ADMINS,
|
||||
ROLE_USERS,
|
||||
NETWORK,
|
||||
NETWORK_TLS_VERIFY,
|
||||
NETWORK_TLS_CA_PATH,
|
||||
DOMAIN,
|
||||
DEFAULT_GROUPS_SCOPE,
|
||||
)
|
||||
|
||||
CLIENT_ID = "client_id"
|
||||
CLIENT_SECRET = "client_secret"
|
||||
DISCOVERY_URL = "discovery_url"
|
||||
DISPLAY_NAME = "display_name"
|
||||
ID_TOKEN_SIGNING_ALGORITHM = "id_token_signing_alg"
|
||||
GROUPS_SCOPE = "groups_scope"
|
||||
ADDITIONAL_SCOPES = "additional_scopes"
|
||||
FEATURES = "features"
|
||||
FEATURES_AUTOMATIC_USER_LINKING = "automatic_user_linking"
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION = "automatic_person_creation"
|
||||
FEATURES_DISABLE_PKCE = "disable_rfc7636"
|
||||
FEATURES_INCLUDE_GROUPS_SCOPE = "include_groups_scope"
|
||||
FEATURES_DISABLE_FRONTEND_INJECTION = "disable_frontend_changes"
|
||||
FEATURES_FORCE_HTTPS = "force_https"
|
||||
CLAIMS = "claims"
|
||||
CLAIMS_DISPLAY_NAME = "display_name"
|
||||
CLAIMS_USERNAME = "username"
|
||||
CLAIMS_GROUPS = "groups"
|
||||
ROLES = "roles"
|
||||
ROLE_ADMINS = "admin"
|
||||
ROLE_USERS = "user"
|
||||
|
||||
NETWORK = "network"
|
||||
NETWORK_TLS_VERIFY = "tls_verify"
|
||||
NETWORK_TLS_CA_PATH = "tls_ca_path"
|
||||
|
||||
DEFAULT_TITLE = "OpenID Connect (SSO)"
|
||||
|
||||
DOMAIN = "auth_oidc"
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
@@ -48,7 +47,9 @@ CONFIG_SCHEMA = vol.Schema(
|
||||
vol.Optional(ID_TOKEN_SIGNING_ALGORITHM): vol.Coerce(str),
|
||||
# String value to allow changing the groups scope
|
||||
# Defaults to 'groups' which is used by Authelia and Authentik
|
||||
vol.Optional(GROUPS_SCOPE, default="groups"): vol.Coerce(str),
|
||||
vol.Optional(GROUPS_SCOPE, default=DEFAULT_GROUPS_SCOPE): vol.Coerce(
|
||||
str
|
||||
),
|
||||
# Additional scopes to request from the OIDC provider
|
||||
# Optional, this field is unnecessary if you only use the openid and profile scopes.
|
||||
vol.Optional(ADDITIONAL_SCOPES, default=[]): vol.Coerce(list[str]),
|
||||
842
custom_components/auth_oidc/config/ui_flow.py
Normal file
842
custom_components/auth_oidc/config/ui_flow.py
Normal file
@@ -0,0 +1,842 @@
|
||||
"""Config flow for OIDC Authentication integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
import aiohttp
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
|
||||
from .const import (
|
||||
DOMAIN,
|
||||
DEFAULT_ADMIN_GROUP,
|
||||
CLIENT_ID,
|
||||
CLIENT_SECRET,
|
||||
DISCOVERY_URL,
|
||||
DISPLAY_NAME,
|
||||
FEATURES,
|
||||
CLAIMS,
|
||||
ROLES,
|
||||
DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
|
||||
)
|
||||
|
||||
from ..tools.oidc_client import (
|
||||
OIDCDiscoveryClient,
|
||||
OIDCDiscoveryInvalid,
|
||||
OIDCJWKSInvalid,
|
||||
)
|
||||
|
||||
from .provider_catalog import (
|
||||
OIDC_PROVIDERS,
|
||||
get_provider_name,
|
||||
get_provider_docs_url,
|
||||
)
|
||||
|
||||
from ..tools.validation import (
|
||||
validate_discovery_url,
|
||||
sanitize_client_secret,
|
||||
validate_client_id,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Configuration field names
|
||||
CONF_PROVIDER = "provider"
|
||||
CONF_CLIENT_ID = "client_id"
|
||||
CONF_CLIENT_SECRET = "client_secret"
|
||||
CONF_DISCOVERY_URL = "discovery_url"
|
||||
CONF_ENABLE_GROUPS = "enable_groups"
|
||||
CONF_ADMIN_GROUP = "admin_group"
|
||||
CONF_USER_GROUP = "user_group"
|
||||
CONF_ENABLE_USER_LINKING = "enable_user_linking"
|
||||
|
||||
# Cache settings
|
||||
DISCOVERY_CACHE_TTL = 300 # 5 minutes
|
||||
MAX_CACHE_SIZE = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowState:
|
||||
"""State tracking for the configuration flow."""
|
||||
|
||||
provider: str | None = None
|
||||
discovery_url: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
"""Client configuration settings."""
|
||||
|
||||
client_id: str | None = None
|
||||
client_secret: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureConfig:
|
||||
"""Feature configuration settings."""
|
||||
|
||||
enable_groups: bool = False
|
||||
admin_group: str = DEFAULT_ADMIN_GROUP
|
||||
user_group: str | None = None
|
||||
enable_user_linking: bool = False
|
||||
|
||||
|
||||
class OIDCConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for OIDC Authentication."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
def is_matching(self, other_flow):
|
||||
"""Check if this flow is the same as another flow."""
|
||||
self_state = getattr(self, "_flow_state", None)
|
||||
other_state = getattr(other_flow, "_flow_state", None)
|
||||
|
||||
if not self_state or not other_state:
|
||||
return False
|
||||
|
||||
self_discovery_url = self_state.discovery_url
|
||||
other_discovery_url = other_state.discovery_url
|
||||
|
||||
return (
|
||||
self_discovery_url
|
||||
and other_discovery_url
|
||||
and self_discovery_url.rstrip("/").lower()
|
||||
== other_discovery_url.rstrip("/").lower()
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the config flow."""
|
||||
self._flow_state = FlowState()
|
||||
self._client_config = ClientConfig()
|
||||
self._feature_config = FeatureConfig()
|
||||
self._discovery_cache = {}
|
||||
self._cache_timestamps = {}
|
||||
|
||||
@property
|
||||
def current_provider_config(self) -> dict[str, Any]:
|
||||
"""Get the configuration for the currently selected provider."""
|
||||
if not self._flow_state.provider:
|
||||
return {}
|
||||
return OIDC_PROVIDERS.get(self._flow_state.provider, {})
|
||||
|
||||
@property
|
||||
def current_provider_name(self) -> str:
|
||||
"""Get the name of the currently selected provider."""
|
||||
return get_provider_name(self._flow_state.provider)
|
||||
|
||||
def _cleanup_discovery_cache(self) -> None:
|
||||
"""Remove expired and excess cache entries."""
|
||||
current_time = time.time()
|
||||
|
||||
# Remove expired entries
|
||||
expired_keys = [
|
||||
key
|
||||
for key, timestamp in self._cache_timestamps.items()
|
||||
if current_time - timestamp > DISCOVERY_CACHE_TTL
|
||||
]
|
||||
for key in expired_keys:
|
||||
self._discovery_cache.pop(key, None)
|
||||
self._cache_timestamps.pop(key, None)
|
||||
|
||||
# Remove oldest entries if cache is too large
|
||||
if len(self._discovery_cache) > MAX_CACHE_SIZE:
|
||||
sorted_items = sorted(self._cache_timestamps.items(), key=lambda x: x[1])
|
||||
excess_count = len(self._discovery_cache) - MAX_CACHE_SIZE
|
||||
for key, _ in sorted_items[:excess_count]:
|
||||
self._discovery_cache.pop(key, None)
|
||||
self._cache_timestamps.pop(key, None)
|
||||
|
||||
def _is_cache_valid(self, cache_key: str) -> bool:
|
||||
"""Check if a cache entry is still valid."""
|
||||
if cache_key not in self._cache_timestamps:
|
||||
return False
|
||||
|
||||
age = time.time() - self._cache_timestamps[cache_key]
|
||||
return age <= DISCOVERY_CACHE_TTL
|
||||
|
||||
# =================
|
||||
# Step 1: Provider selection
|
||||
# =================
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle the initial step - provider selection."""
|
||||
# Check if OIDC is already configured (only one instance allowed)
|
||||
if self._async_current_entries():
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
|
||||
# Check if YAML configuration exists
|
||||
if self.hass.data.get(DOMAIN, {}).get("yaml_config"):
|
||||
return self.async_abort(reason="yaml_configured")
|
||||
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
self._flow_state.provider = user_input[CONF_PROVIDER]
|
||||
|
||||
# If provider has a predefined discovery URL, prefill it but still
|
||||
# show the discovery URL step so the user can customize it.
|
||||
predefined = self.current_provider_config.get("discovery_url")
|
||||
if predefined:
|
||||
self._flow_state.discovery_url = predefined
|
||||
|
||||
# Always request discovery URL next (prefilled when available)
|
||||
return await self.async_step_discovery_url()
|
||||
|
||||
data_schema = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_PROVIDER): vol.In(
|
||||
{key: provider["name"] for key, provider in OIDC_PROVIDERS.items()}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={},
|
||||
)
|
||||
|
||||
# =================
|
||||
# Step 2: Discovery URL
|
||||
# =================
|
||||
|
||||
async def async_step_discovery_url(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle discovery URL input for providers requiring URL configuration."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
discovery_url = user_input[CONF_DISCOVERY_URL].rstrip("/")
|
||||
|
||||
# Validate discovery URL format
|
||||
if not validate_discovery_url(discovery_url):
|
||||
errors["discovery_url"] = "invalid_url_format"
|
||||
else:
|
||||
self._flow_state.discovery_url = discovery_url
|
||||
return await self.async_step_validate_connection()
|
||||
|
||||
provider_name = self.current_provider_name
|
||||
provider_key = self._flow_state.provider
|
||||
|
||||
# Pre-populate with existing discovery URL if available
|
||||
default_url = (
|
||||
self._flow_state.discovery_url
|
||||
if self._flow_state.discovery_url
|
||||
else vol.UNDEFINED
|
||||
)
|
||||
|
||||
data_schema = vol.Schema(
|
||||
{vol.Required(CONF_DISCOVERY_URL, default=default_url): str}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="discovery_url",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={
|
||||
"provider_name": provider_name,
|
||||
"documentation_url": get_provider_docs_url(provider_key),
|
||||
},
|
||||
)
|
||||
|
||||
# =================
|
||||
# Step 3: Discovery Validation
|
||||
# =================
|
||||
|
||||
async def _handle_validation_actions(
|
||||
self, user_input: dict[str, Any]
|
||||
) -> FlowResult | None:
|
||||
"""Handle user actions from the validation form so they can fix errors."""
|
||||
action = user_input.get("action")
|
||||
|
||||
# Handle special actions first
|
||||
if action == "retry":
|
||||
return None # Continue with validation
|
||||
if action == "continue":
|
||||
return await self.async_step_client_config()
|
||||
|
||||
# Handle redirect actions
|
||||
action_handlers = {
|
||||
"fix_discovery": self.async_step_discovery_url,
|
||||
"change_provider": self.async_step_user,
|
||||
}
|
||||
|
||||
handler = action_handlers.get(action)
|
||||
return await handler() if handler else None
|
||||
|
||||
async def _perform_oidc_validation(self) -> tuple[dict, dict]:
|
||||
"""Perform the actual OIDC validation and return discovery doc and errors."""
|
||||
errors = {}
|
||||
discovery_doc = {}
|
||||
|
||||
try:
|
||||
http_session = aiohttp.ClientSession()
|
||||
discovery_client = OIDCDiscoveryClient(
|
||||
discovery_url=self._flow_state.discovery_url,
|
||||
http_session=http_session,
|
||||
verification_context={
|
||||
# Cannot be changed from the UI config currently
|
||||
"id_token_signing_alg": DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
|
||||
},
|
||||
)
|
||||
|
||||
# Clean up expired cache entries first
|
||||
self._cleanup_discovery_cache()
|
||||
|
||||
# Check if discovery document is already cached and valid
|
||||
cache_key = self._flow_state.discovery_url
|
||||
if cache_key in self._discovery_cache and self._is_cache_valid(cache_key):
|
||||
discovery_doc = self._discovery_cache[cache_key]
|
||||
|
||||
# Still validate JWKS if available since this might be a retry
|
||||
if "jwks_uri" in discovery_doc:
|
||||
await discovery_client.fetch_jwks(discovery_doc["jwks_uri"])
|
||||
else:
|
||||
# Perform discovery and JWKS validation
|
||||
discovery_doc = await discovery_client.fetch_discovery_document()
|
||||
|
||||
# Cache the discovery document with timestamp
|
||||
self._discovery_cache[cache_key] = discovery_doc
|
||||
self._cache_timestamps[cache_key] = time.time()
|
||||
|
||||
# Validate JWKS if available
|
||||
if "jwks_uri" in discovery_doc:
|
||||
await discovery_client.fetch_jwks(discovery_doc["jwks_uri"])
|
||||
|
||||
except OIDCDiscoveryInvalid as e:
|
||||
errors["base"] = "discovery_invalid"
|
||||
errors["detail_string"] = e.get_detail_string()
|
||||
except OIDCJWKSInvalid:
|
||||
errors["base"] = "jwks_invalid"
|
||||
except aiohttp.ClientError:
|
||||
errors["base"] = "cannot_connect"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected error during validation")
|
||||
errors["base"] = "unknown"
|
||||
|
||||
await http_session.close()
|
||||
return discovery_doc, errors
|
||||
|
||||
def _get_action_options(self, has_errors: bool) -> dict[str, str]:
|
||||
"""Get action options based on validation state."""
|
||||
if has_errors:
|
||||
return {
|
||||
"retry": "Retry Validation",
|
||||
"fix_discovery": "Change Discovery URL",
|
||||
"change_provider": "Change Provider",
|
||||
}
|
||||
return {
|
||||
"continue": "Continue Setup",
|
||||
"fix_discovery": "Change Discovery URL",
|
||||
"change_provider": "Change Provider",
|
||||
}
|
||||
|
||||
def _build_discovery_success_details(self, discovery_doc: dict) -> str:
|
||||
"""Build success details from discovery document."""
|
||||
return (
|
||||
f"✅ Connected and verified succesfully!\n"
|
||||
f"_Discovered valid OIDC issuer: {discovery_doc['issuer']}_\n\n"
|
||||
)
|
||||
|
||||
def _build_error_details(self, errors: dict[str, str]) -> str:
|
||||
"""Build error details from validation errors."""
|
||||
|
||||
base = errors.get("base", "")
|
||||
detail_string = errors.get("detail_string", "")
|
||||
|
||||
error_messages = {
|
||||
"discovery_invalid": (
|
||||
"❌ **Discovery document could not be validated.**\n"
|
||||
"Please verify the discovery URL is correct and accessible.\n\n"
|
||||
f"_({detail_string})_"
|
||||
),
|
||||
"jwks_invalid": (
|
||||
"❌ **JWKS validation failed**\n"
|
||||
"The JSON Web Key Set could not be retrieved or validated."
|
||||
),
|
||||
"cannot_connect": (
|
||||
"❌ **Connection failed**\n"
|
||||
"Unable to connect to the OIDC provider. Check your network and URL."
|
||||
),
|
||||
}
|
||||
return error_messages.get(base, "")
|
||||
|
||||
async def _build_validation_form(
|
||||
self, errors: dict[str, str], discovery_doc: dict | None = None
|
||||
) -> FlowResult:
|
||||
"""Build the validation form with errors and action options."""
|
||||
action_options = self._get_action_options(bool(errors))
|
||||
data_schema = vol.Schema({vol.Required("action"): vol.In(action_options)})
|
||||
|
||||
# Build description with discovery details
|
||||
description_placeholders = {
|
||||
"discovery_url": self._flow_state.discovery_url,
|
||||
"provider_name": self.current_provider_name,
|
||||
"discovery_details": "",
|
||||
"documentation_url": get_provider_docs_url(self._flow_state.provider),
|
||||
}
|
||||
|
||||
# Add appropriate details based on validation state
|
||||
if discovery_doc and not errors:
|
||||
description_placeholders["discovery_details"] = (
|
||||
self._build_discovery_success_details(discovery_doc)
|
||||
)
|
||||
elif errors:
|
||||
description_placeholders["discovery_details"] = self._build_error_details(
|
||||
errors
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="validate_connection",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders=description_placeholders,
|
||||
)
|
||||
|
||||
async def async_step_validate_connection(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Validate the OIDC configuration by testing discovery and JWKS."""
|
||||
# Handle user actions from validation form
|
||||
if user_input is not None:
|
||||
action_result = await self._handle_validation_actions(user_input)
|
||||
if action_result is not None:
|
||||
return action_result
|
||||
|
||||
# Perform validation (either initial attempt or retry)
|
||||
discovery_doc, errors = await self._perform_oidc_validation()
|
||||
|
||||
# Always show validation form with results (success or error)
|
||||
return await self._build_validation_form(errors, discovery_doc)
|
||||
|
||||
# =================
|
||||
# Step 4: Configure client details (client_id & client_secret)
|
||||
# =================
|
||||
|
||||
async def _proceed_to_next_step_after_client_config(self) -> FlowResult:
|
||||
"""Proceed to next step after client config."""
|
||||
if self.current_provider_config.get("supports_groups", True):
|
||||
return await self.async_step_groups_config()
|
||||
return await self.async_step_user_linking()
|
||||
|
||||
async def async_step_client_config(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle client ID and client type selection."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
client_id = user_input[CONF_CLIENT_ID]
|
||||
|
||||
# Validate client ID
|
||||
if not validate_client_id(client_id):
|
||||
errors["client_id"] = "invalid_client_id"
|
||||
if not errors:
|
||||
self._client_config.client_id = client_id.strip()
|
||||
# Optional client secret determines confidential/public
|
||||
provided_secret = sanitize_client_secret(
|
||||
user_input.get(CONF_CLIENT_SECRET, "")
|
||||
)
|
||||
self._client_config.client_secret = provided_secret or None
|
||||
|
||||
if not errors:
|
||||
return await self._proceed_to_next_step_after_client_config()
|
||||
|
||||
provider_name = self.current_provider_name
|
||||
|
||||
# Pre-populate with existing values if available
|
||||
default_client_id = (
|
||||
self._client_config.client_id
|
||||
if self._client_config.client_id
|
||||
else vol.UNDEFINED
|
||||
)
|
||||
default_client_secret = (
|
||||
self._client_config.client_secret
|
||||
if self._client_config.client_secret
|
||||
else vol.UNDEFINED
|
||||
)
|
||||
|
||||
data_schema = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_CLIENT_ID, default=default_client_id): str,
|
||||
vol.Optional(CONF_CLIENT_SECRET, default=default_client_secret): str,
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="client_config",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={
|
||||
"provider_name": provider_name,
|
||||
"discovery_url": self._flow_state.discovery_url,
|
||||
"documentation_url": get_provider_docs_url(self._flow_state.provider),
|
||||
},
|
||||
)
|
||||
|
||||
# =================
|
||||
# Step 5: Configure groups settings
|
||||
# =================
|
||||
|
||||
async def async_step_groups_config(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Configure groups and roles."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
self._feature_config.enable_groups = user_input.get(
|
||||
CONF_ENABLE_GROUPS, False
|
||||
)
|
||||
if self._feature_config.enable_groups:
|
||||
self._feature_config.admin_group = user_input.get(
|
||||
CONF_ADMIN_GROUP, "admins"
|
||||
)
|
||||
self._feature_config.user_group = user_input.get(CONF_USER_GROUP)
|
||||
|
||||
return await self.async_step_user_linking()
|
||||
|
||||
default_admin_group = self.current_provider_config.get(
|
||||
"default_admin_group", "admins"
|
||||
)
|
||||
|
||||
data_schema_dict = {vol.Optional(CONF_ENABLE_GROUPS, default=True): bool}
|
||||
|
||||
# Add group configuration fields if groups are enabled
|
||||
if user_input is None or user_input.get(CONF_ENABLE_GROUPS, True):
|
||||
data_schema_dict.update(
|
||||
{
|
||||
vol.Optional(CONF_ADMIN_GROUP, default=default_admin_group): str,
|
||||
vol.Optional(CONF_USER_GROUP): str,
|
||||
}
|
||||
)
|
||||
|
||||
data_schema = vol.Schema(data_schema_dict)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="groups_config",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={"provider_name": self.current_provider_name},
|
||||
)
|
||||
|
||||
# =================
|
||||
# Step 6: Configure user linking
|
||||
# =================
|
||||
|
||||
async def async_step_user_linking(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Configure user linking options."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
self._feature_config.enable_user_linking = user_input.get(
|
||||
CONF_ENABLE_USER_LINKING, False
|
||||
)
|
||||
return await self.async_step_finalize()
|
||||
|
||||
data_schema = vol.Schema(
|
||||
{vol.Optional(CONF_ENABLE_USER_LINKING, default=False): bool}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user_linking",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={},
|
||||
)
|
||||
|
||||
# =================
|
||||
# Step 7: Finalize and create entry
|
||||
# =================
|
||||
|
||||
async def async_step_finalize(self) -> FlowResult:
|
||||
"""Finalize the configuration and create the config entry."""
|
||||
await self.async_set_unique_id(DOMAIN)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
# Build the configuration
|
||||
config_data = {
|
||||
"provider": self._flow_state.provider,
|
||||
"client_id": self._client_config.client_id,
|
||||
"discovery_url": self._flow_state.discovery_url,
|
||||
"display_name": f"{self.current_provider_name}",
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if self._client_config.client_secret:
|
||||
config_data["client_secret"] = self._client_config.client_secret
|
||||
|
||||
# Configure features
|
||||
features = {
|
||||
"automatic_user_linking": self._feature_config.enable_user_linking,
|
||||
"automatic_person_creation": True,
|
||||
"include_groups_scope": self._feature_config.enable_groups,
|
||||
}
|
||||
config_data["features"] = features
|
||||
|
||||
# Configure claims using provider defaults
|
||||
claims = self.current_provider_config["claims"].copy()
|
||||
config_data["claims"] = claims
|
||||
|
||||
# Configure roles if groups are enabled
|
||||
if self._feature_config.enable_groups:
|
||||
roles = {}
|
||||
if self._feature_config.admin_group:
|
||||
roles["admin"] = self._feature_config.admin_group
|
||||
if self._feature_config.user_group:
|
||||
roles["user"] = self._feature_config.user_group
|
||||
config_data["roles"] = roles
|
||||
|
||||
title = f"{self.current_provider_name}"
|
||||
|
||||
return self.async_create_entry(title=title, data=config_data)
|
||||
|
||||
# =================
|
||||
# Allow reconfiguration of client ID and secret
|
||||
# =================
|
||||
|
||||
async def _validate_reconfigure_input(
|
||||
self, entry, user_input: dict[str, Any]
|
||||
) -> tuple[dict[str, str], dict[str, Any] | None]:
|
||||
"""Validate reconfigure input and return errors and data updates."""
|
||||
errors = {}
|
||||
|
||||
# Validate client ID
|
||||
client_id = user_input[CONF_CLIENT_ID].strip()
|
||||
if not validate_client_id(client_id):
|
||||
errors["client_id"] = "invalid_client_id"
|
||||
return errors, None
|
||||
|
||||
# Determine confidentiality by presence of client secret
|
||||
client_secret = user_input.get(CONF_CLIENT_SECRET, "").strip()
|
||||
# If secret is empty, keep the existing one (if any)
|
||||
if not client_secret:
|
||||
client_secret = entry.data.get("client_secret")
|
||||
|
||||
# Build updated data
|
||||
data_updates = {"client_id": client_id}
|
||||
|
||||
if client_secret:
|
||||
data_updates["client_secret"] = client_secret
|
||||
elif "client_secret" in entry.data and not client_secret:
|
||||
# Remove client secret if switching from confidential to public
|
||||
data_updates = {**entry.data, **data_updates}
|
||||
data_updates.pop("client_secret", None)
|
||||
|
||||
return errors, data_updates
|
||||
|
||||
def _build_reconfigure_schema(
|
||||
self, current_data: dict[str, Any], _user_input: dict[str, Any] | None
|
||||
) -> vol.Schema:
|
||||
"""Build the reconfigure form schema."""
|
||||
schema_dict = {
|
||||
vol.Required(
|
||||
CONF_CLIENT_ID, default=current_data.get("client_id", vol.UNDEFINED)
|
||||
): str,
|
||||
}
|
||||
|
||||
# Always allow updating or clearing the client secret
|
||||
schema_dict[vol.Optional(CONF_CLIENT_SECRET)] = str
|
||||
|
||||
return vol.Schema(schema_dict)
|
||||
|
||||
async def async_step_reconfigure(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle reconfiguration of OIDC client credentials."""
|
||||
errors = {}
|
||||
entry = self._get_reconfigure_entry()
|
||||
if entry is None:
|
||||
return self.async_abort(reason="unknown")
|
||||
|
||||
if user_input is not None:
|
||||
try:
|
||||
errors, data_updates = await self._validate_reconfigure_input(
|
||||
entry, user_input
|
||||
)
|
||||
|
||||
if not errors:
|
||||
# Update the config entry
|
||||
await self.async_set_unique_id(entry.unique_id)
|
||||
self._abort_if_unique_id_mismatch()
|
||||
|
||||
return self.async_update_reload_and_abort(
|
||||
entry, data_updates=data_updates
|
||||
)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected error during reconfiguration")
|
||||
errors["base"] = "unknown"
|
||||
|
||||
# Show form
|
||||
current_data = entry.data
|
||||
data_schema = self._build_reconfigure_schema(current_data, user_input)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="reconfigure",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={
|
||||
"provider_name": get_provider_name(current_data.get("provider")),
|
||||
"discovery_url": current_data.get("discovery_url", ""),
|
||||
},
|
||||
)
|
||||
|
||||
def _get_reconfigure_entry(self):
|
||||
"""Return the config entry being reconfigured if available.
|
||||
|
||||
Prefer the entry referenced by the flow context's entry_id. Fall back to the
|
||||
first existing entry for this domain when only a single instance is allowed.
|
||||
"""
|
||||
# Try from flow context (preferred)
|
||||
entry_id = None
|
||||
context = getattr(self, "context", None)
|
||||
if context and hasattr(context, "get"):
|
||||
entry_id = context.get("entry_id")
|
||||
|
||||
if entry_id:
|
||||
entry = self.hass.config_entries.async_get_entry(entry_id)
|
||||
if entry and entry.domain == DOMAIN:
|
||||
return entry
|
||||
|
||||
# Fallback: this integration allows a single instance; use the first
|
||||
current = self._async_current_entries()
|
||||
if current:
|
||||
return current[0]
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(config_entry):
|
||||
"""Get the options flow for this handler."""
|
||||
return OIDCOptionsFlowHandler()
|
||||
|
||||
|
||||
class OIDCOptionsFlowHandler(config_entries.OptionsFlow):
|
||||
"""Handle options flow for OIDC Authentication."""
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
"""Handle options flow."""
|
||||
if user_input is not None:
|
||||
# Process the updated configuration
|
||||
updated_features = {
|
||||
"automatic_user_linking": user_input.get("enable_user_linking", False),
|
||||
"include_groups_scope": user_input.get("enable_groups", False),
|
||||
}
|
||||
|
||||
updated_roles = {}
|
||||
if user_input.get("enable_groups", False):
|
||||
if user_input.get("admin_group"):
|
||||
updated_roles["admin"] = user_input["admin_group"]
|
||||
if user_input.get("user_group"):
|
||||
updated_roles["user"] = user_input["user_group"]
|
||||
|
||||
# Update the config entry data
|
||||
new_data = self.config_entry.data.copy()
|
||||
new_data["features"] = {**new_data.get("features", {}), **updated_features}
|
||||
if updated_roles:
|
||||
new_data["roles"] = updated_roles
|
||||
elif "roles" in new_data:
|
||||
# Remove roles if groups are disabled
|
||||
if not user_input.get("enable_groups", False):
|
||||
del new_data["roles"]
|
||||
|
||||
# Update the config entry
|
||||
self.hass.config_entries.async_update_entry(
|
||||
self.config_entry, data=new_data
|
||||
)
|
||||
|
||||
return self.async_create_entry(title="", data={})
|
||||
|
||||
current_config = self.config_entry.data
|
||||
current_features = current_config.get("features", {})
|
||||
current_roles = current_config.get("roles", {})
|
||||
|
||||
# Determine if this provider supports groups
|
||||
provider = current_config.get("provider", "authentik")
|
||||
provider_supports_groups = OIDC_PROVIDERS.get(provider, {}).get(
|
||||
"supports_groups", True
|
||||
)
|
||||
|
||||
# Build schema based on provider capabilities
|
||||
schema_dict = {
|
||||
vol.Optional(
|
||||
"enable_user_linking",
|
||||
default=current_features.get("automatic_user_linking", False),
|
||||
): bool
|
||||
}
|
||||
|
||||
# Add groups options if provider supports them
|
||||
if provider_supports_groups:
|
||||
enable_groups_default = current_features.get("include_groups_scope", False)
|
||||
schema_dict[
|
||||
vol.Optional("enable_groups", default=enable_groups_default)
|
||||
] = bool
|
||||
|
||||
# Add group name fields if groups are currently enabled or being enabled
|
||||
if enable_groups_default or (
|
||||
user_input and user_input.get("enable_groups", False)
|
||||
):
|
||||
schema_dict.update(
|
||||
{
|
||||
vol.Optional(
|
||||
"admin_group",
|
||||
default=current_roles.get("admin", DEFAULT_ADMIN_GROUP),
|
||||
): str,
|
||||
vol.Optional(
|
||||
"user_group", default=current_roles.get("user", "")
|
||||
): str,
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(schema_dict),
|
||||
description_placeholders={
|
||||
"provider_name": get_provider_name(provider),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def convert_ui_config_entry_to_internal_format(config_data: dict) -> dict:
|
||||
"""Convert config entry data to internal configuration format."""
|
||||
my_config = {}
|
||||
|
||||
# Required fields
|
||||
my_config[CLIENT_ID] = config_data["client_id"]
|
||||
my_config[DISCOVERY_URL] = config_data["discovery_url"]
|
||||
|
||||
# Optional fields
|
||||
if "client_secret" in config_data:
|
||||
my_config[CLIENT_SECRET] = config_data["client_secret"]
|
||||
|
||||
if "display_name" in config_data:
|
||||
my_config[DISPLAY_NAME] = config_data["display_name"]
|
||||
|
||||
# Features configuration
|
||||
if "features" in config_data:
|
||||
my_config[FEATURES] = config_data["features"]
|
||||
|
||||
# Claims configuration
|
||||
if "claims" in config_data:
|
||||
my_config[CLAIMS] = config_data["claims"]
|
||||
|
||||
# Roles configuration
|
||||
if "roles" in config_data:
|
||||
my_config[ROLES] = config_data["roles"]
|
||||
|
||||
return my_config
|
||||
@@ -1,806 +1,5 @@
|
||||
"""Config flow for OIDC Authentication integration."""
|
||||
"""UI config flow re-export"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
import aiohttp
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.data_entry_flow import FlowResult
|
||||
|
||||
from .config import DOMAIN
|
||||
from .oidc_client import OIDCClient, OIDCDiscoveryInvalid, OIDCJWKSInvalid
|
||||
from .provider_catalog import (
|
||||
OIDC_PROVIDERS,
|
||||
get_provider_name,
|
||||
get_provider_docs_url,
|
||||
)
|
||||
from .validation import (
|
||||
validate_discovery_url,
|
||||
sanitize_client_secret,
|
||||
validate_client_id,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Configuration field names
|
||||
CONF_PROVIDER = "provider"
|
||||
CONF_CLIENT_ID = "client_id"
|
||||
CONF_CLIENT_SECRET = "client_secret"
|
||||
CONF_DISCOVERY_URL = "discovery_url"
|
||||
CONF_ENABLE_GROUPS = "enable_groups"
|
||||
CONF_ADMIN_GROUP = "admin_group"
|
||||
CONF_USER_GROUP = "user_group"
|
||||
CONF_ENABLE_USER_LINKING = "enable_user_linking"
|
||||
|
||||
DEFAULT_ADMIN_GROUP = "admins"
|
||||
|
||||
# Cache settings
|
||||
DISCOVERY_CACHE_TTL = 300 # 5 minutes
|
||||
MAX_CACHE_SIZE = 10
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlowState:
|
||||
"""State tracking for the configuration flow."""
|
||||
|
||||
provider: str | None = None
|
||||
discovery_url: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientConfig:
|
||||
"""Client configuration settings."""
|
||||
|
||||
client_id: str | None = None
|
||||
client_secret: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class FeatureConfig:
|
||||
"""Feature configuration settings."""
|
||||
|
||||
enable_groups: bool = False
|
||||
admin_group: str = DEFAULT_ADMIN_GROUP
|
||||
user_group: str | None = None
|
||||
enable_user_linking: bool = False
|
||||
|
||||
|
||||
class OIDCConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||
"""Handle a config flow for OIDC Authentication."""
|
||||
|
||||
VERSION = 1
|
||||
|
||||
def is_matching(self, other_flow):
|
||||
"""Check if this flow is the same as another flow."""
|
||||
self_state = getattr(self, "_flow_state", None)
|
||||
other_state = getattr(other_flow, "_flow_state", None)
|
||||
|
||||
if not self_state or not other_state:
|
||||
return False
|
||||
|
||||
self_discovery_url = self_state.discovery_url
|
||||
other_discovery_url = other_state.discovery_url
|
||||
|
||||
return (
|
||||
self_discovery_url
|
||||
and other_discovery_url
|
||||
and self_discovery_url.rstrip("/").lower()
|
||||
== other_discovery_url.rstrip("/").lower()
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the config flow."""
|
||||
self._flow_state = FlowState()
|
||||
self._client_config = ClientConfig()
|
||||
self._feature_config = FeatureConfig()
|
||||
self._discovery_cache = {}
|
||||
self._cache_timestamps = {}
|
||||
|
||||
@property
|
||||
def current_provider_config(self) -> dict[str, Any]:
|
||||
"""Get the configuration for the currently selected provider."""
|
||||
if not self._flow_state.provider:
|
||||
return {}
|
||||
return OIDC_PROVIDERS.get(self._flow_state.provider, {})
|
||||
|
||||
@property
|
||||
def current_provider_name(self) -> str:
|
||||
"""Get the name of the currently selected provider."""
|
||||
return get_provider_name(self._flow_state.provider)
|
||||
|
||||
def _cleanup_discovery_cache(self) -> None:
|
||||
"""Remove expired and excess cache entries."""
|
||||
current_time = time.time()
|
||||
|
||||
# Remove expired entries
|
||||
expired_keys = [
|
||||
key
|
||||
for key, timestamp in self._cache_timestamps.items()
|
||||
if current_time - timestamp > DISCOVERY_CACHE_TTL
|
||||
]
|
||||
for key in expired_keys:
|
||||
self._discovery_cache.pop(key, None)
|
||||
self._cache_timestamps.pop(key, None)
|
||||
|
||||
# Remove oldest entries if cache is too large
|
||||
if len(self._discovery_cache) > MAX_CACHE_SIZE:
|
||||
sorted_items = sorted(self._cache_timestamps.items(), key=lambda x: x[1])
|
||||
excess_count = len(self._discovery_cache) - MAX_CACHE_SIZE
|
||||
for key, _ in sorted_items[:excess_count]:
|
||||
self._discovery_cache.pop(key, None)
|
||||
self._cache_timestamps.pop(key, None)
|
||||
|
||||
def _is_cache_valid(self, cache_key: str) -> bool:
|
||||
"""Check if a cache entry is still valid."""
|
||||
if cache_key not in self._cache_timestamps:
|
||||
return False
|
||||
|
||||
age = time.time() - self._cache_timestamps[cache_key]
|
||||
return age <= DISCOVERY_CACHE_TTL
|
||||
|
||||
async def async_step_user(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle the initial step - provider selection."""
|
||||
# Check if OIDC is already configured (only one instance allowed)
|
||||
if self._async_current_entries():
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
|
||||
# Check if YAML configuration exists
|
||||
if self.hass.data.get(DOMAIN, {}).get("yaml_config"):
|
||||
return self.async_abort(reason="single_instance_allowed")
|
||||
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
self._flow_state.provider = user_input[CONF_PROVIDER]
|
||||
|
||||
# If provider has a predefined discovery URL, prefill it but still
|
||||
# show the discovery URL step so the user can customize it.
|
||||
predefined = self.current_provider_config.get("discovery_url")
|
||||
if predefined:
|
||||
self._flow_state.discovery_url = predefined
|
||||
|
||||
# Always request discovery URL next (prefilled when available)
|
||||
return await self.async_step_discovery_url()
|
||||
|
||||
data_schema = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_PROVIDER): vol.In(
|
||||
{key: provider["name"] for key, provider in OIDC_PROVIDERS.items()}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={},
|
||||
)
|
||||
|
||||
async def async_step_discovery_url(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle discovery URL input for providers requiring URL configuration."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
discovery_url = user_input[CONF_DISCOVERY_URL].rstrip("/")
|
||||
|
||||
# Validate discovery URL format
|
||||
if not validate_discovery_url(discovery_url):
|
||||
errors["discovery_url"] = "invalid_url_format"
|
||||
else:
|
||||
self._flow_state.discovery_url = discovery_url
|
||||
return await self.async_step_client_config()
|
||||
|
||||
provider_name = self.current_provider_name
|
||||
provider_key = self._flow_state.provider
|
||||
|
||||
# Pre-populate with existing discovery URL if available
|
||||
default_url = (
|
||||
self._flow_state.discovery_url
|
||||
if self._flow_state.discovery_url
|
||||
else vol.UNDEFINED
|
||||
)
|
||||
|
||||
data_schema = vol.Schema(
|
||||
{vol.Required(CONF_DISCOVERY_URL, default=default_url): str}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="discovery_url",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={
|
||||
"provider_name": provider_name,
|
||||
"documentation_url": get_provider_docs_url(provider_key),
|
||||
},
|
||||
)
|
||||
|
||||
async def async_step_client_config(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle client ID and client type selection."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
client_id = user_input[CONF_CLIENT_ID]
|
||||
|
||||
# Validate client ID
|
||||
if not validate_client_id(client_id):
|
||||
errors["client_id"] = "invalid_client_id"
|
||||
if not errors:
|
||||
self._client_config.client_id = client_id.strip()
|
||||
# Optional client secret determines confidential/public
|
||||
provided_secret = sanitize_client_secret(
|
||||
user_input.get(CONF_CLIENT_SECRET, "")
|
||||
)
|
||||
self._client_config.client_secret = provided_secret or None
|
||||
|
||||
if not errors:
|
||||
# Proceed to validation directly from here
|
||||
return await self.async_step_validate_connection()
|
||||
|
||||
provider_name = self.current_provider_name
|
||||
|
||||
# Pre-populate with existing values if available
|
||||
default_client_id = (
|
||||
self._client_config.client_id
|
||||
if self._client_config.client_id
|
||||
else vol.UNDEFINED
|
||||
)
|
||||
default_client_secret = (
|
||||
self._client_config.client_secret
|
||||
if self._client_config.client_secret
|
||||
else vol.UNDEFINED
|
||||
)
|
||||
|
||||
data_schema = vol.Schema(
|
||||
{
|
||||
vol.Required(CONF_CLIENT_ID, default=default_client_id): str,
|
||||
vol.Optional(CONF_CLIENT_SECRET, default=default_client_secret): str,
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="client_config",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={
|
||||
"provider_name": provider_name,
|
||||
"discovery_url": self._flow_state.discovery_url,
|
||||
"documentation_url": get_provider_docs_url(self._flow_state.provider),
|
||||
},
|
||||
)
|
||||
|
||||
async def _handle_validation_actions(
|
||||
self, user_input: dict[str, Any]
|
||||
) -> FlowResult | None:
|
||||
"""Handle user actions from the validation form so they can fix errors."""
|
||||
action = user_input.get("action")
|
||||
|
||||
# Handle special actions first
|
||||
if action == "retry":
|
||||
return None # Continue with validation
|
||||
if action == "continue":
|
||||
return await self._proceed_to_next_step()
|
||||
|
||||
# Handle redirect actions
|
||||
action_handlers = {
|
||||
"fix_discovery": self.async_step_discovery_url,
|
||||
"fix_client": self.async_step_client_config,
|
||||
"change_provider": self.async_step_user,
|
||||
}
|
||||
|
||||
handler = action_handlers.get(action)
|
||||
return await handler() if handler else None
|
||||
|
||||
async def _proceed_to_next_step(self) -> FlowResult:
|
||||
"""Proceed to next step after successful validation."""
|
||||
if self.current_provider_config.get("supports_groups", True):
|
||||
return await self.async_step_groups_config()
|
||||
return await self.async_step_user_linking()
|
||||
|
||||
async def _perform_oidc_validation(self) -> tuple[dict, dict]:
|
||||
"""Perform the actual OIDC validation and return discovery doc and errors."""
|
||||
errors = {}
|
||||
discovery_doc = {}
|
||||
|
||||
try:
|
||||
# Create a test OIDC client to validate configuration
|
||||
test_client = OIDCClient(
|
||||
hass=self.hass,
|
||||
discovery_url=self._flow_state.discovery_url,
|
||||
client_id=self._client_config.client_id,
|
||||
scope="openid profile",
|
||||
client_secret=self._client_config.client_secret or None,
|
||||
features={},
|
||||
claims={},
|
||||
roles={},
|
||||
network={},
|
||||
)
|
||||
|
||||
# Clean up expired cache entries first
|
||||
self._cleanup_discovery_cache()
|
||||
|
||||
# Check if discovery document is already cached and valid
|
||||
cache_key = self._flow_state.discovery_url
|
||||
if cache_key in self._discovery_cache and self._is_cache_valid(cache_key):
|
||||
discovery_doc = self._discovery_cache[cache_key]
|
||||
# Still validate JWKS if available since this might be a retry
|
||||
if "jwks_uri" in discovery_doc:
|
||||
await test_client.validate_jwks(discovery_doc["jwks_uri"])
|
||||
else:
|
||||
# Perform discovery and JWKS validation
|
||||
discovery_doc = await test_client.validate_discovery()
|
||||
|
||||
# Cache the discovery document with timestamp
|
||||
self._discovery_cache[cache_key] = discovery_doc
|
||||
self._cache_timestamps[cache_key] = time.time()
|
||||
|
||||
# Validate JWKS if available
|
||||
if "jwks_uri" in discovery_doc:
|
||||
await test_client.validate_jwks(discovery_doc["jwks_uri"])
|
||||
|
||||
except OIDCDiscoveryInvalid:
|
||||
errors["base"] = "discovery_invalid"
|
||||
except OIDCJWKSInvalid:
|
||||
errors["base"] = "jwks_invalid"
|
||||
except aiohttp.ClientError:
|
||||
errors["base"] = "cannot_connect"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected error during validation")
|
||||
errors["base"] = "unknown"
|
||||
|
||||
return discovery_doc, errors
|
||||
|
||||
def _get_action_options(self, has_errors: bool) -> dict[str, str]:
|
||||
"""Get action options based on validation state."""
|
||||
if has_errors:
|
||||
return {
|
||||
"retry": "Retry Validation",
|
||||
"fix_client": "Fix Client Settings",
|
||||
"fix_discovery": "Fix Discovery URL",
|
||||
"change_provider": "Change Provider",
|
||||
}
|
||||
return {
|
||||
"continue": "Continue Setup",
|
||||
"fix_client": "Modify Client Settings",
|
||||
"fix_discovery": "Modify Discovery URL",
|
||||
"change_provider": "Change Provider",
|
||||
}
|
||||
|
||||
def _build_discovery_success_details(self, discovery_doc: dict) -> str:
|
||||
"""Build success details from discovery document."""
|
||||
discovery_info = []
|
||||
|
||||
endpoints = [
|
||||
("issuer", "✅ Connection Successful", "**Issuer:** {value}"),
|
||||
("authorization_endpoint", "✅ Authorization endpoint found", None),
|
||||
("token_endpoint", "✅ Token endpoint found", None),
|
||||
("jwks_uri", "✅ JWKS endpoint found", None),
|
||||
("userinfo_endpoint", "✅ User info endpoint found", None),
|
||||
]
|
||||
|
||||
for key, message, formatted in endpoints:
|
||||
if key in discovery_doc:
|
||||
discovery_info.append(message)
|
||||
if formatted and key == "issuer":
|
||||
discovery_info.append(formatted.format(value=discovery_doc[key]))
|
||||
|
||||
return "\n".join(discovery_info)
|
||||
|
||||
def _build_error_details(self, errors: dict[str, str]) -> str:
|
||||
"""Build error details from validation errors."""
|
||||
error_messages = {
|
||||
"discovery_invalid": (
|
||||
"❌ **Discovery document could not be retrieved**\n"
|
||||
"Please verify the discovery URL is correct and accessible."
|
||||
),
|
||||
"jwks_invalid": (
|
||||
"❌ **JWKS validation failed**\n"
|
||||
"The JSON Web Key Set could not be retrieved or validated."
|
||||
),
|
||||
"cannot_connect": (
|
||||
"❌ **Connection failed**\n"
|
||||
"Unable to connect to the OIDC provider. Check your network and URL."
|
||||
),
|
||||
}
|
||||
return error_messages.get(errors.get("base", ""), "")
|
||||
|
||||
async def _build_validation_form(
|
||||
self, errors: dict[str, str], discovery_doc: dict | None = None
|
||||
) -> FlowResult:
|
||||
"""Build the validation form with errors and action options."""
|
||||
action_options = self._get_action_options(bool(errors))
|
||||
data_schema = vol.Schema({vol.Required("action"): vol.In(action_options)})
|
||||
|
||||
# Build description with discovery details
|
||||
description_placeholders = {
|
||||
"discovery_url": self._flow_state.discovery_url,
|
||||
"client_id": self._client_config.client_id,
|
||||
"provider_name": self.current_provider_name,
|
||||
"discovery_details": "",
|
||||
"documentation_url": get_provider_docs_url(self._flow_state.provider),
|
||||
}
|
||||
|
||||
# Add appropriate details based on validation state
|
||||
if discovery_doc and not errors:
|
||||
description_placeholders["discovery_details"] = (
|
||||
self._build_discovery_success_details(discovery_doc)
|
||||
)
|
||||
elif errors:
|
||||
description_placeholders["discovery_details"] = self._build_error_details(
|
||||
errors
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="validate_connection",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders=description_placeholders,
|
||||
)
|
||||
|
||||
async def async_step_validate_connection(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Validate the OIDC configuration by testing discovery and JWKS."""
|
||||
# Handle user actions from validation form
|
||||
if user_input is not None:
|
||||
action_result = await self._handle_validation_actions(user_input)
|
||||
if action_result is not None:
|
||||
return action_result
|
||||
|
||||
# Perform validation (either initial attempt or retry)
|
||||
discovery_doc, errors = await self._perform_oidc_validation()
|
||||
|
||||
# Always show validation form with results (success or error)
|
||||
return await self._build_validation_form(errors, discovery_doc)
|
||||
|
||||
async def async_step_groups_config(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Configure groups and roles."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
self._feature_config.enable_groups = user_input.get(
|
||||
CONF_ENABLE_GROUPS, False
|
||||
)
|
||||
if self._feature_config.enable_groups:
|
||||
self._feature_config.admin_group = user_input.get(
|
||||
CONF_ADMIN_GROUP, "admins"
|
||||
)
|
||||
self._feature_config.user_group = user_input.get(CONF_USER_GROUP)
|
||||
|
||||
return await self.async_step_user_linking()
|
||||
|
||||
default_admin_group = self.current_provider_config.get(
|
||||
"default_admin_group", "admins"
|
||||
)
|
||||
|
||||
data_schema_dict = {vol.Optional(CONF_ENABLE_GROUPS, default=True): bool}
|
||||
|
||||
# Add group configuration fields if groups are enabled
|
||||
if user_input is None or user_input.get(CONF_ENABLE_GROUPS, True):
|
||||
data_schema_dict.update(
|
||||
{
|
||||
vol.Optional(CONF_ADMIN_GROUP, default=default_admin_group): str,
|
||||
vol.Optional(CONF_USER_GROUP): str,
|
||||
}
|
||||
)
|
||||
|
||||
data_schema = vol.Schema(data_schema_dict)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="groups_config",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={"provider_name": self.current_provider_name},
|
||||
)
|
||||
|
||||
async def async_step_user_linking(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Configure user linking options."""
|
||||
errors = {}
|
||||
|
||||
if user_input is not None:
|
||||
self._feature_config.enable_user_linking = user_input.get(
|
||||
CONF_ENABLE_USER_LINKING, False
|
||||
)
|
||||
return await self.async_step_finalize()
|
||||
|
||||
data_schema = vol.Schema(
|
||||
{vol.Optional(CONF_ENABLE_USER_LINKING, default=False): bool}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="user_linking",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={},
|
||||
)
|
||||
|
||||
async def async_step_finalize(self) -> FlowResult:
|
||||
"""Finalize the configuration and create the config entry."""
|
||||
await self.async_set_unique_id(DOMAIN)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
# Build the configuration
|
||||
config_data = {
|
||||
"provider": self._flow_state.provider,
|
||||
"client_id": self._client_config.client_id,
|
||||
"discovery_url": self._flow_state.discovery_url,
|
||||
"display_name": f"{self.current_provider_name} (OIDC)",
|
||||
}
|
||||
|
||||
# Add optional fields
|
||||
if self._client_config.client_secret:
|
||||
config_data["client_secret"] = self._client_config.client_secret
|
||||
|
||||
# Configure features
|
||||
features = {
|
||||
"automatic_user_linking": self._feature_config.enable_user_linking,
|
||||
"automatic_person_creation": True,
|
||||
"include_groups_scope": self._feature_config.enable_groups,
|
||||
}
|
||||
config_data["features"] = features
|
||||
|
||||
# Configure claims using provider defaults
|
||||
claims = self.current_provider_config["claims"].copy()
|
||||
config_data["claims"] = claims
|
||||
|
||||
# Configure roles if groups are enabled
|
||||
if self._feature_config.enable_groups:
|
||||
roles = {}
|
||||
if self._feature_config.admin_group:
|
||||
roles["admin"] = self._feature_config.admin_group
|
||||
if self._feature_config.user_group:
|
||||
roles["user"] = self._feature_config.user_group
|
||||
config_data["roles"] = roles
|
||||
|
||||
title = f"{self.current_provider_name} OIDC"
|
||||
|
||||
return self.async_create_entry(title=title, data=config_data)
|
||||
|
||||
async def _validate_reconfigure_input(
|
||||
self, entry, user_input: dict[str, Any]
|
||||
) -> tuple[dict[str, str], dict[str, Any] | None]:
|
||||
"""Validate reconfigure input and return errors and data updates."""
|
||||
errors = {}
|
||||
|
||||
# Validate client ID
|
||||
client_id = user_input[CONF_CLIENT_ID].strip()
|
||||
if not validate_client_id(client_id):
|
||||
errors["client_id"] = "invalid_client_id"
|
||||
return errors, None
|
||||
|
||||
# Determine confidentiality by presence of client secret
|
||||
client_secret = user_input.get(CONF_CLIENT_SECRET, "").strip()
|
||||
# If secret is empty, keep the existing one (if any)
|
||||
if not client_secret:
|
||||
client_secret = entry.data.get("client_secret")
|
||||
|
||||
# Test the new configuration
|
||||
test_client = OIDCClient(
|
||||
hass=self.hass,
|
||||
discovery_url=entry.data["discovery_url"],
|
||||
client_id=client_id,
|
||||
scope="openid profile",
|
||||
client_secret=client_secret or None,
|
||||
features={},
|
||||
claims={},
|
||||
roles={},
|
||||
network={},
|
||||
)
|
||||
|
||||
# Validate the new credentials
|
||||
discovery_doc = await test_client.validate_discovery()
|
||||
if "jwks_uri" in discovery_doc:
|
||||
await test_client.validate_jwks(discovery_doc["jwks_uri"])
|
||||
|
||||
# Build updated data
|
||||
data_updates = {"client_id": client_id}
|
||||
|
||||
if client_secret:
|
||||
data_updates["client_secret"] = client_secret
|
||||
elif "client_secret" in entry.data and not client_secret:
|
||||
# Remove client secret if switching from confidential to public
|
||||
data_updates = {**entry.data, **data_updates}
|
||||
data_updates.pop("client_secret", None)
|
||||
|
||||
return errors, data_updates
|
||||
|
||||
def _build_reconfigure_schema(
|
||||
self, current_data: dict[str, Any], _user_input: dict[str, Any] | None
|
||||
) -> vol.Schema:
|
||||
"""Build the reconfigure form schema."""
|
||||
schema_dict = {
|
||||
vol.Required(
|
||||
CONF_CLIENT_ID, default=current_data.get("client_id", vol.UNDEFINED)
|
||||
): str,
|
||||
}
|
||||
|
||||
# Always allow updating or clearing the client secret
|
||||
schema_dict[vol.Optional(CONF_CLIENT_SECRET)] = str
|
||||
|
||||
return vol.Schema(schema_dict)
|
||||
|
||||
async def async_step_reconfigure(
|
||||
self, user_input: dict[str, Any] | None = None
|
||||
) -> FlowResult:
|
||||
"""Handle reconfiguration of OIDC client credentials."""
|
||||
errors = {}
|
||||
entry = self._get_reconfigure_entry()
|
||||
if entry is None:
|
||||
return self.async_abort(reason="unknown")
|
||||
|
||||
if user_input is not None:
|
||||
try:
|
||||
errors, data_updates = await self._validate_reconfigure_input(
|
||||
entry, user_input
|
||||
)
|
||||
|
||||
if not errors:
|
||||
# Update the config entry
|
||||
await self.async_set_unique_id(entry.unique_id)
|
||||
self._abort_if_unique_id_mismatch()
|
||||
|
||||
return self.async_update_reload_and_abort(
|
||||
entry, data_updates=data_updates
|
||||
)
|
||||
|
||||
except OIDCDiscoveryInvalid:
|
||||
errors["base"] = "discovery_invalid"
|
||||
except OIDCJWKSInvalid:
|
||||
errors["base"] = "jwks_invalid"
|
||||
except aiohttp.ClientError:
|
||||
errors["base"] = "cannot_connect"
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected error during reconfiguration")
|
||||
errors["base"] = "unknown"
|
||||
|
||||
# Show form
|
||||
current_data = entry.data
|
||||
data_schema = self._build_reconfigure_schema(current_data, user_input)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="reconfigure",
|
||||
data_schema=data_schema,
|
||||
errors=errors,
|
||||
description_placeholders={
|
||||
"provider_name": get_provider_name(current_data.get("provider")),
|
||||
"discovery_url": current_data.get("discovery_url", ""),
|
||||
},
|
||||
)
|
||||
|
||||
def _get_reconfigure_entry(self):
|
||||
"""Return the config entry being reconfigured if available.
|
||||
|
||||
Prefer the entry referenced by the flow context's entry_id. Fall back to the
|
||||
first existing entry for this domain when only a single instance is allowed.
|
||||
"""
|
||||
# Try from flow context (preferred)
|
||||
entry_id = None
|
||||
context = getattr(self, "context", None)
|
||||
if context and hasattr(context, "get"):
|
||||
entry_id = context.get("entry_id")
|
||||
|
||||
if entry_id:
|
||||
entry = self.hass.config_entries.async_get_entry(entry_id)
|
||||
if entry and entry.domain == DOMAIN:
|
||||
return entry
|
||||
|
||||
# Fallback: this integration allows a single instance; use the first
|
||||
current = self._async_current_entries()
|
||||
if current:
|
||||
return current[0]
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@callback
|
||||
def async_get_options_flow(config_entry):
|
||||
"""Get the options flow for this handler."""
|
||||
return OIDCOptionsFlowHandler(config_entry)
|
||||
|
||||
|
||||
class OIDCOptionsFlowHandler(config_entries.OptionsFlow):
|
||||
"""Handle options flow for OIDC Authentication."""
|
||||
|
||||
def __init__(self, config_entry):
|
||||
"""Initialize options flow."""
|
||||
self.config_entry = config_entry
|
||||
|
||||
async def async_step_init(self, user_input=None):
|
||||
"""Handle options flow."""
|
||||
if user_input is not None:
|
||||
# Process the updated configuration
|
||||
updated_features = {
|
||||
"automatic_user_linking": user_input.get("enable_user_linking", False),
|
||||
"include_groups_scope": user_input.get("enable_groups", False),
|
||||
}
|
||||
|
||||
updated_roles = {}
|
||||
if user_input.get("enable_groups", False):
|
||||
if user_input.get("admin_group"):
|
||||
updated_roles["admin"] = user_input["admin_group"]
|
||||
if user_input.get("user_group"):
|
||||
updated_roles["user"] = user_input["user_group"]
|
||||
|
||||
# Update the config entry data
|
||||
new_data = self.config_entry.data.copy()
|
||||
new_data["features"] = {**new_data.get("features", {}), **updated_features}
|
||||
if updated_roles:
|
||||
new_data["roles"] = updated_roles
|
||||
elif "roles" in new_data:
|
||||
# Remove roles if groups are disabled
|
||||
if not user_input.get("enable_groups", False):
|
||||
del new_data["roles"]
|
||||
|
||||
# Update the config entry
|
||||
self.hass.config_entries.async_update_entry(
|
||||
self.config_entry, data=new_data
|
||||
)
|
||||
|
||||
return self.async_create_entry(title="", data={})
|
||||
|
||||
current_config = self.config_entry.data
|
||||
current_features = current_config.get("features", {})
|
||||
current_roles = current_config.get("roles", {})
|
||||
|
||||
# Determine if this provider supports groups
|
||||
provider = current_config.get("provider", "authentik")
|
||||
provider_supports_groups = OIDC_PROVIDERS.get(provider, {}).get(
|
||||
"supports_groups", True
|
||||
)
|
||||
|
||||
# Build schema based on provider capabilities
|
||||
schema_dict = {
|
||||
vol.Optional(
|
||||
"enable_user_linking",
|
||||
default=current_features.get("automatic_user_linking", False),
|
||||
): bool
|
||||
}
|
||||
|
||||
# Add groups options if provider supports them
|
||||
if provider_supports_groups:
|
||||
enable_groups_default = current_features.get("include_groups_scope", False)
|
||||
schema_dict[
|
||||
vol.Optional("enable_groups", default=enable_groups_default)
|
||||
] = bool
|
||||
|
||||
# Add group name fields if groups are currently enabled or being enabled
|
||||
if enable_groups_default or (
|
||||
user_input and user_input.get("enable_groups", False)
|
||||
):
|
||||
schema_dict.update(
|
||||
{
|
||||
vol.Optional(
|
||||
"admin_group",
|
||||
default=current_roles.get("admin", DEFAULT_ADMIN_GROUP),
|
||||
): str,
|
||||
vol.Optional(
|
||||
"user_group", default=current_roles.get("user", "")
|
||||
): str,
|
||||
}
|
||||
)
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="init",
|
||||
data_schema=vol.Schema(schema_dict),
|
||||
description_placeholders={
|
||||
"provider_name": get_provider_name(provider),
|
||||
},
|
||||
)
|
||||
# pylint: disable=useless-import-alias
|
||||
# pylint: disable=unused-import
|
||||
from .config.ui_flow import OIDCConfigFlow as OIDCConfigFlow
|
||||
|
||||
7
custom_components/auth_oidc/endpoints/__init__.py
Normal file
7
custom_components/auth_oidc/endpoints/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""Imports manager"""
|
||||
|
||||
from .callback import OIDCCallbackView as OIDCCallbackView
|
||||
from .finish import OIDCFinishView as OIDCFinishView
|
||||
from .injected_auth_page import OIDCInjectedAuthPage as OIDCInjectedAuthPage
|
||||
from .redirect import OIDCRedirectView as OIDCRedirectView
|
||||
from .welcome import OIDCWelcomeView as OIDCWelcomeView
|
||||
@@ -2,9 +2,9 @@
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from aiohttp import web
|
||||
from ..oidc_client import OIDCClient
|
||||
from ..tools.oidc_client import OIDCClient
|
||||
from ..provider import OpenIDAuthProvider
|
||||
from ..helpers import get_url, get_view
|
||||
from ..tools.helpers import get_url, get_view
|
||||
|
||||
PATH = "/auth/oidc/callback"
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from aiohttp import web
|
||||
from ..helpers import get_view
|
||||
from ..tools.helpers import get_view
|
||||
|
||||
PATH = "/auth/oidc/finish"
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ can either be linked to directly or accessed through the welcome page."""
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
|
||||
from ..oidc_client import OIDCClient
|
||||
from ..helpers import get_url, get_view
|
||||
from ..tools.oidc_client import OIDCClient
|
||||
from ..tools.helpers import get_url, get_view
|
||||
|
||||
PATH = "/auth/oidc/redirect"
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from ..helpers import get_url, get_view
|
||||
from ..tools.helpers import get_url, get_view
|
||||
|
||||
PATH = "/auth/oidc/welcome"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"domain": "auth_oidc",
|
||||
"name": "OIDC Authentication",
|
||||
"name": "OpenID Connect/SSO Authentication",
|
||||
"codeowners": [
|
||||
"@christiaangoossens"
|
||||
],
|
||||
|
||||
@@ -24,14 +24,14 @@ from homeassistant.components import http, person
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import voluptuous as vol
|
||||
|
||||
from .config import (
|
||||
from .config.const import (
|
||||
FEATURES,
|
||||
FEATURES_AUTOMATIC_USER_LINKING,
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION,
|
||||
DEFAULT_TITLE,
|
||||
)
|
||||
from .stores.code_store import CodeStore
|
||||
from .types import UserDetails
|
||||
from .tools.types import UserDetails
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
"""Provider catalog and helpers for OIDC providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
DEFAULT_ADMIN_GROUP = "admins"
|
||||
|
||||
|
||||
OIDC_PROVIDERS: Dict[str, Dict[str, Any]] = {
|
||||
"authentik": {
|
||||
"name": "Authentik",
|
||||
"discovery_url": "",
|
||||
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||
"supports_groups": True,
|
||||
"claims": {
|
||||
"display_name": "name",
|
||||
"username": "preferred_username",
|
||||
"groups": "groups",
|
||||
},
|
||||
},
|
||||
"authelia": {
|
||||
"name": "Authelia",
|
||||
"discovery_url": "",
|
||||
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||
"supports_groups": True,
|
||||
"claims": {
|
||||
"display_name": "name",
|
||||
"username": "preferred_username",
|
||||
"groups": "groups",
|
||||
},
|
||||
},
|
||||
"pocketid": {
|
||||
"name": "Pocket ID",
|
||||
"discovery_url": "",
|
||||
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||
"supports_groups": True,
|
||||
"claims": {
|
||||
"display_name": "name",
|
||||
"username": "preferred_username",
|
||||
"groups": "groups",
|
||||
},
|
||||
},
|
||||
"kanidm": {
|
||||
"name": "Kanidm",
|
||||
"discovery_url": "",
|
||||
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||
"supports_groups": True,
|
||||
"claims": {
|
||||
"display_name": "name",
|
||||
"username": "preferred_username",
|
||||
"groups": "groups",
|
||||
},
|
||||
},
|
||||
"microsoft": {
|
||||
"name": "Microsoft Entra ID",
|
||||
"discovery_url": (
|
||||
"https://login.microsoftonline.com/common/v2.0/"
|
||||
".well-known/openid_configuration"
|
||||
),
|
||||
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||
"supports_groups": True,
|
||||
"claims": {
|
||||
"display_name": "name",
|
||||
"username": "preferred_username",
|
||||
"groups": "groups",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_provider_config(key: str) -> Dict[str, Any]:
|
||||
"""Return provider configuration by key."""
|
||||
return OIDC_PROVIDERS.get(key, {})
|
||||
|
||||
|
||||
def get_provider_name(key: str | None) -> str:
|
||||
"""Return provider display name by key."""
|
||||
if not key:
|
||||
return "Unknown Provider"
|
||||
return OIDC_PROVIDERS.get(key, {}).get("name", "Unknown Provider")
|
||||
|
||||
|
||||
def get_provider_docs_url(key: str | None) -> str:
|
||||
"""Return documentation URL for a provider key."""
|
||||
base_url = (
|
||||
"https://github.com/christiaangoossens/hass-oidc-auth/blob/main"
|
||||
"/docs/provider-configurations"
|
||||
)
|
||||
|
||||
provider_docs = {
|
||||
"authentik": f"{base_url}/authentik.md",
|
||||
"authelia": f"{base_url}/authelia.md",
|
||||
"pocketid": f"{base_url}/pocket-id.md",
|
||||
"kanidm": f"{base_url}/kanidm.md",
|
||||
"microsoft": f"{base_url}/microsoft-entra.md",
|
||||
}
|
||||
|
||||
if key in provider_docs:
|
||||
return provider_docs[key]
|
||||
return (
|
||||
"https://github.com/christiaangoossens/hass-oidc-auth"
|
||||
"/blob/main/docs/configuration.md"
|
||||
)
|
||||
0
custom_components/auth_oidc/stores/__init__.py
Normal file
0
custom_components/auth_oidc/stores/__init__.py
Normal file
@@ -8,7 +8,7 @@ from typing import cast, Optional
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from ..types import UserDetails
|
||||
from ..tools.types import UserDetails
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "auth_provider.auth_oidc.codes"
|
||||
|
||||
0
custom_components/auth_oidc/tools/__init__.py
Normal file
0
custom_components/auth_oidc/tools/__init__.py
Normal file
@@ -1,7 +1,7 @@
|
||||
"""Helper functions for the integration."""
|
||||
|
||||
from homeassistant.components import http
|
||||
from .views.loader import AsyncTemplateRenderer
|
||||
from ..views.loader import AsyncTemplateRenderer
|
||||
|
||||
|
||||
def get_url(path: str, force_https: bool) -> str:
|
||||
@@ -13,7 +13,7 @@ from jose import jwt, jwk
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .types import UserDetails
|
||||
from .config import (
|
||||
from ..config.const import (
|
||||
FEATURES_DISABLE_PKCE,
|
||||
CLAIMS_DISPLAY_NAME,
|
||||
CLAIMS_USERNAME,
|
||||
@@ -22,7 +22,9 @@ from .config import (
|
||||
ROLE_USERS,
|
||||
NETWORK_TLS_VERIFY,
|
||||
NETWORK_TLS_CA_PATH,
|
||||
DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
|
||||
)
|
||||
from .validation import validate_url
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,6 +36,32 @@ class OIDCClientException(Exception):
|
||||
class OIDCDiscoveryInvalid(OIDCClientException):
|
||||
"Raised when the discovery document is not found, invalid or otherwise malformed."
|
||||
|
||||
type: Optional[str]
|
||||
details: Optional[dict]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if args:
|
||||
self.message = args[0]
|
||||
else:
|
||||
self.message = "OIDC Discovery document is invalid"
|
||||
|
||||
self.type = kwargs.pop("type", None)
|
||||
self.details = kwargs.pop("details", None)
|
||||
super().__init__(self.message)
|
||||
|
||||
def get_detail_string(self) -> str:
|
||||
"""Returns a detailed string for logging purposes."""
|
||||
string = []
|
||||
|
||||
if self.type:
|
||||
string.append(f"type: {self.type}")
|
||||
|
||||
if self.details:
|
||||
for key, value in self.details.items():
|
||||
string.append(f"{key}: {value}")
|
||||
|
||||
return ", ".join(string)
|
||||
|
||||
|
||||
class OIDCTokenResponseInvalid(OIDCClientException):
|
||||
"Raised when the token request returns invalid."
|
||||
@@ -68,6 +96,199 @@ class HTTPClientError(aiohttp.ClientResponseError):
|
||||
return f"{self.status} ({self.message}) with response body: {self.body}"
|
||||
|
||||
|
||||
async def http_raise_for_status(response: aiohttp.ClientResponse) -> None:
|
||||
"""Raises an exception if the response is not OK."""
|
||||
if not response.ok:
|
||||
# reason should always be not None for a started response
|
||||
assert response.reason is not None
|
||||
body = await response.text()
|
||||
|
||||
raise HTTPClientError(
|
||||
response.request_info,
|
||||
response.history,
|
||||
status=response.status,
|
||||
message=response.reason,
|
||||
headers=response.headers,
|
||||
body=body,
|
||||
)
|
||||
|
||||
|
||||
class OIDCDiscoveryClient:
|
||||
"""OIDC Discovery Client implementation for Python"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
discovery_url: str,
|
||||
http_session: aiohttp.ClientSession,
|
||||
verification_context: dict,
|
||||
):
|
||||
self.discovery_url = discovery_url
|
||||
self.http_session = http_session
|
||||
self.verification_context = verification_context
|
||||
|
||||
async def _fetch_discovery_document(self):
|
||||
"""Fetches discovery document from the given URL."""
|
||||
try:
|
||||
async with self.http_session.get(self.discovery_url) as response:
|
||||
await http_raise_for_status(response)
|
||||
return await response.json()
|
||||
except HTTPClientError as e:
|
||||
if e.status == 404:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document not found at %s", self.discovery_url
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning("Error fetching discovery: %s", e)
|
||||
raise OIDCDiscoveryInvalid(type="fetch_error") from e
|
||||
|
||||
async def _fetch_jwks(self, jwks_uri):
|
||||
"""Fetches JWKS from the given URL."""
|
||||
try:
|
||||
async with self.http_session.get(jwks_uri) as response:
|
||||
await http_raise_for_status(response)
|
||||
return await response.json()
|
||||
except HTTPClientError as e:
|
||||
_LOGGER.warning("Error fetching JWKS: %s", e)
|
||||
raise OIDCJWKSInvalid from e
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
async def _validate_discovery_document(self, document):
|
||||
"""Validates the discovery document."""
|
||||
|
||||
# Verify that required endpoints are present
|
||||
required_endpoints = [
|
||||
"issuer",
|
||||
"authorization_endpoint",
|
||||
"token_endpoint",
|
||||
"jwks_uri",
|
||||
]
|
||||
|
||||
for endpoint in required_endpoints:
|
||||
if endpoint not in document:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document %s is missing required endpoint: %s",
|
||||
self.discovery_url,
|
||||
endpoint,
|
||||
)
|
||||
raise OIDCDiscoveryInvalid(
|
||||
type="missing_endpoint", details={"endpoint": endpoint}
|
||||
)
|
||||
if validate_url(document[endpoint]) is False:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document %s has invalid URL in endpoint: %s (%s)",
|
||||
self.discovery_url,
|
||||
endpoint,
|
||||
document[endpoint],
|
||||
)
|
||||
raise OIDCDiscoveryInvalid(
|
||||
type="invalid_endpoint",
|
||||
details={"endpoint": endpoint, "url": document[endpoint]},
|
||||
)
|
||||
|
||||
# Verify optional response_modes_supported
|
||||
if "response_modes_supported" in document:
|
||||
if "query" not in document["response_modes_supported"]:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document %s does not support required 'query' "
|
||||
"response mode, only supports: %s",
|
||||
self.discovery_url,
|
||||
document["response_modes_supported"],
|
||||
)
|
||||
raise OIDCDiscoveryInvalid(
|
||||
type="does_not_support_response_mode",
|
||||
modes=document["response_modes_supported"],
|
||||
)
|
||||
|
||||
# If grant_types_supported is set, should support 'authorization_code'
|
||||
if "grant_types_supported" in document:
|
||||
if "authorization_code" not in document["grant_types_supported"]:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document %s does not support required "
|
||||
"'authorization_code' grant type, only supports: %s",
|
||||
self.discovery_url,
|
||||
document["grant_types_supported"],
|
||||
)
|
||||
raise OIDCDiscoveryInvalid(
|
||||
type="does_not_support_grant_type",
|
||||
details={
|
||||
"required": "authorization_code",
|
||||
"supported": document["grant_types_supported"],
|
||||
},
|
||||
)
|
||||
|
||||
# If response_types_supported is set, should support 'code'
|
||||
if "response_types_supported" in document:
|
||||
if "code" not in document["response_types_supported"]:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document %s does not support required "
|
||||
"'code' response type, only supports: %s",
|
||||
self.discovery_url,
|
||||
document["response_types_supported"],
|
||||
)
|
||||
raise OIDCDiscoveryInvalid(
|
||||
type="does_not_support_response_type",
|
||||
details={
|
||||
"required": "code",
|
||||
"supported": document["response_types_supported"],
|
||||
},
|
||||
)
|
||||
|
||||
# If code_challenge_methods_supported is present, check that it contains S256
|
||||
if "code_challenge_methods_supported" in document:
|
||||
if "S256" not in document["code_challenge_methods_supported"]:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document %s does not support required "
|
||||
"'S256' code challenge method, only supports: %s",
|
||||
self.discovery_url,
|
||||
document["code_challenge_methods_supported"],
|
||||
)
|
||||
raise OIDCDiscoveryInvalid(
|
||||
type="does_not_support_required_code_challenge_method",
|
||||
details={
|
||||
"required": "S256",
|
||||
"supported": document["code_challenge_methods_supported"],
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the id_token_signing_alg_values_supported field is present and filled
|
||||
signing_values = document.get("id_token_signing_alg_values_supported", None)
|
||||
if signing_values is None:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document %s does not have "
|
||||
"'id_token_signing_alg_values_supported' field",
|
||||
self.discovery_url,
|
||||
)
|
||||
raise OIDCDiscoveryInvalid(type="missing_id_token_signing_alg_values")
|
||||
|
||||
# Verify that the requested id_token_signing_alg is supported
|
||||
requested_alg = self.verification_context.get("id_token_signing_alg", None)
|
||||
if requested_alg is not None and requested_alg not in signing_values:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document %s does not support requested "
|
||||
"id_token_signing_alg '%s', only supports: %s",
|
||||
self.discovery_url,
|
||||
requested_alg,
|
||||
signing_values,
|
||||
)
|
||||
raise OIDCDiscoveryInvalid(
|
||||
type="does_not_support_id_token_signing_alg",
|
||||
details={"requested": requested_alg, "supported": signing_values},
|
||||
)
|
||||
|
||||
async def fetch_discovery_document(self):
|
||||
"""Fetches discovery document."""
|
||||
document = await self._fetch_discovery_document()
|
||||
await self._validate_discovery_document(document)
|
||||
return document
|
||||
|
||||
async def fetch_jwks(self, jwks_uri: str | None):
|
||||
"""Fetches JWKS."""
|
||||
if jwks_uri is None:
|
||||
discovery_document = await self._fetch_discovery_document()
|
||||
jwks_uri = discovery_document["jwks_uri"]
|
||||
return await self._fetch_jwks(jwks_uri)
|
||||
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
class OIDCClient:
|
||||
"""OIDC Client implementation for Python, including PKCE."""
|
||||
@@ -78,6 +299,9 @@ class OIDCClient:
|
||||
# HTTP session to be used
|
||||
http_session: aiohttp.ClientSession = None
|
||||
|
||||
# OIDC Discovery tool to be used
|
||||
discovery_class: OIDCDiscoveryClient = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
@@ -98,7 +322,7 @@ class OIDCClient:
|
||||
# Default id_token_signing_alg to RS256 if not specified
|
||||
self.id_token_signing_alg = kwargs.get("id_token_signing_alg")
|
||||
if self.id_token_signing_alg is None:
|
||||
self.id_token_signing_alg = "RS256"
|
||||
self.id_token_signing_alg = DEFAULT_ID_TOKEN_SIGNING_ALGORITHM
|
||||
|
||||
features = kwargs.get("features")
|
||||
claims = kwargs.get("claims")
|
||||
@@ -122,23 +346,6 @@ class OIDCClient:
|
||||
_LOGGER.debug("Closing HTTP session")
|
||||
self.http_session.close()
|
||||
|
||||
async def http_raise_for_status(self, response: aiohttp.ClientResponse) -> None:
|
||||
"""Raises an exception if the response is not OK."""
|
||||
if not response.ok:
|
||||
# reason should always be not None for a started response
|
||||
assert response.reason is not None
|
||||
|
||||
body = await response.text()
|
||||
|
||||
raise HTTPClientError(
|
||||
response.request_info,
|
||||
response.history,
|
||||
status=response.status,
|
||||
message=response.reason,
|
||||
headers=response.headers,
|
||||
body=body,
|
||||
)
|
||||
|
||||
def _base64url_encode(self, value: str) -> str:
|
||||
"""Uses base64url encoding on a given string"""
|
||||
return base64.urlsafe_b64encode(value).rstrip(b"=").decode("utf-8")
|
||||
@@ -173,42 +380,13 @@ class OIDCClient:
|
||||
)
|
||||
return self.http_session
|
||||
|
||||
async def _fetch_discovery_document(self):
|
||||
"""Fetches discovery document from the given URL."""
|
||||
try:
|
||||
session = await self._get_http_session()
|
||||
|
||||
async with session.get(self.discovery_url) as response:
|
||||
await self.http_raise_for_status(response)
|
||||
return await response.json()
|
||||
except HTTPClientError as e:
|
||||
if e.status == 404:
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document not found at %s", self.discovery_url
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning("Error fetching discovery: %s", e)
|
||||
raise OIDCDiscoveryInvalid from e
|
||||
|
||||
async def _get_jwks(self, jwks_uri):
|
||||
"""Fetches JWKS from the given URL."""
|
||||
try:
|
||||
session = await self._get_http_session()
|
||||
|
||||
async with session.get(jwks_uri) as response:
|
||||
await self.http_raise_for_status(response)
|
||||
return await response.json()
|
||||
except HTTPClientError as e:
|
||||
_LOGGER.warning("Error fetching JWKS: %s", e)
|
||||
raise OIDCJWKSInvalid from e
|
||||
|
||||
async def _make_token_request(self, token_endpoint, query_params):
|
||||
"""Performs the token POST call"""
|
||||
try:
|
||||
session = await self._get_http_session()
|
||||
|
||||
async with session.post(token_endpoint, data=query_params) as response:
|
||||
await self.http_raise_for_status(response)
|
||||
await http_raise_for_status(response)
|
||||
return await response.json()
|
||||
except HTTPClientError as e:
|
||||
if e.status == 400:
|
||||
@@ -231,12 +409,34 @@ class OIDCClient:
|
||||
headers = {"Authorization": "Bearer " + access_token}
|
||||
|
||||
async with session.get(userinfo_uri, headers=headers) as response:
|
||||
await self.http_raise_for_status(response)
|
||||
await http_raise_for_status(response)
|
||||
return await response.json()
|
||||
except HTTPClientError as e:
|
||||
_LOGGER.warning("Error fetching userinfo: %s", e)
|
||||
raise OIDCUserinfoInvalid from e
|
||||
|
||||
async def _fetch_discovery_document(self):
|
||||
"""Fetches discovery document."""
|
||||
if self.discovery_document is not None:
|
||||
return self.discovery_document
|
||||
|
||||
if self.discovery_class is None:
|
||||
session = await self._get_http_session()
|
||||
self.discovery_class = OIDCDiscoveryClient(
|
||||
discovery_url=self.discovery_url,
|
||||
http_session=session,
|
||||
verification_context={
|
||||
"id_token_signing_alg": self.id_token_signing_alg,
|
||||
},
|
||||
)
|
||||
|
||||
self.discovery_document = await self.discovery_class.fetch_discovery_document()
|
||||
return self.discovery_document
|
||||
|
||||
async def _fetch_jwks(self, jwks_uri: str):
|
||||
"""Fetches JWKS."""
|
||||
return await self.discovery_class.fetch_jwks(jwks_uri)
|
||||
|
||||
async def _parse_id_token(
|
||||
self, id_token: str, access_token: str | None
|
||||
) -> Optional[dict]:
|
||||
@@ -245,7 +445,7 @@ class OIDCClient:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
jwks_uri = self.discovery_document["jwks_uri"]
|
||||
jwks_data = await self._get_jwks(jwks_uri)
|
||||
jwks_data = await self._fetch_jwks(jwks_uri)
|
||||
|
||||
try:
|
||||
# Obtain the id_token header
|
||||
@@ -369,10 +569,8 @@ class OIDCClient:
|
||||
async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]:
|
||||
"""Generates the authorization URL for the OIDC flow."""
|
||||
try:
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
auth_endpoint = self.discovery_document["authorization_endpoint"]
|
||||
discovery_document = await self._fetch_discovery_document()
|
||||
auth_endpoint = discovery_document["authorization_endpoint"]
|
||||
|
||||
# Generate random nonce & state
|
||||
nonce = self._generate_random_url_string()
|
||||
@@ -417,8 +615,9 @@ class OIDCClient:
|
||||
|
||||
# Fetch userinfo if there is an userinfo_endpoint available
|
||||
# and use the data to supply the missing values in id_token
|
||||
if "userinfo_endpoint" in self.discovery_document:
|
||||
userinfo_endpoint = self.discovery_document["userinfo_endpoint"]
|
||||
discovery_document = await self._fetch_discovery_document()
|
||||
if "userinfo_endpoint" in discovery_document:
|
||||
userinfo_endpoint = discovery_document["userinfo_endpoint"]
|
||||
userinfo = await self._get_userinfo(userinfo_endpoint, access_token)
|
||||
|
||||
# Replace missing claims in the id_token with their userinfo version
|
||||
@@ -451,9 +650,7 @@ class OIDCClient:
|
||||
# Only unique per issuer, so we combine it with the issuer and hash it.
|
||||
# This might allow multiple OIDC providers to be used with this integration.
|
||||
"sub": hashlib.sha256(
|
||||
f"{self.discovery_document['issuer']}.{id_token.get('sub')}".encode(
|
||||
"utf-8"
|
||||
)
|
||||
f"{discovery_document['issuer']}.{id_token.get('sub')}".encode("utf-8")
|
||||
).hexdigest(),
|
||||
# Display name, configurable
|
||||
"display_name": id_token.get(self.display_name_claim),
|
||||
@@ -474,10 +671,8 @@ class OIDCClient:
|
||||
|
||||
flow = self.flows[state]
|
||||
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
token_endpoint = self.discovery_document["token_endpoint"]
|
||||
discovery_document = await self._fetch_discovery_document()
|
||||
token_endpoint = discovery_document["token_endpoint"]
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
@@ -532,21 +727,3 @@ class OIDCClient:
|
||||
except OIDCClientException as e:
|
||||
_LOGGER.warning("Failed to complete token flow, returning None. (%s)", e)
|
||||
return None
|
||||
|
||||
async def validate_discovery(self):
|
||||
"""Validate that the discovery document can be fetched and is valid.
|
||||
|
||||
Public method for configuration validation.
|
||||
Returns the discovery document if valid.
|
||||
Raises OIDCDiscoveryInvalid if invalid.
|
||||
"""
|
||||
return await self._fetch_discovery_document()
|
||||
|
||||
async def validate_jwks(self, jwks_uri: str):
|
||||
"""Validate that the JWKS can be fetched from the given URI.
|
||||
|
||||
Public method for configuration validation.
|
||||
Returns the JWKS if valid.
|
||||
Raises OIDCJWKSInvalid if invalid.
|
||||
"""
|
||||
return await self._get_jwks(jwks_uri)
|
||||
@@ -5,11 +5,24 @@ from __future__ import annotations
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def validate_url(url: str) -> bool:
|
||||
"""Validate that a URL is properly formatted."""
|
||||
try:
|
||||
parsed = urlparse(url.strip())
|
||||
return bool(parsed.scheme in ("http", "https") and parsed.netloc)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
def validate_discovery_url(url: str) -> bool:
|
||||
"""Validate that a URL is properly formatted for OIDC discovery."""
|
||||
try:
|
||||
parsed = urlparse(url.strip())
|
||||
return bool(parsed.scheme in ("http", "https") and parsed.netloc)
|
||||
return bool(
|
||||
parsed.scheme in ("http", "https")
|
||||
and parsed.netloc
|
||||
and parsed.path.endswith("/.well-known/openid-configuration")
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
@@ -3,18 +3,25 @@
|
||||
"step": {
|
||||
"user": {
|
||||
"title": "Choose OIDC Provider",
|
||||
"description": "Select your OpenID Connect identity provider to get started with the setup.",
|
||||
"description": "Select your OpenID Connect identity provider to get started with the setup.\n\nIf you want to use a provider that isn't listed, try the Generic OpenID Connect provider or use the advanced YAML configuration instead.",
|
||||
"data": {
|
||||
"provider": "Provider"
|
||||
}
|
||||
},
|
||||
"discovery_url": {
|
||||
"title": "Provider Configuration",
|
||||
"description": "Enter the discovery URL for {provider_name}. This is typically found in your provider's documentation and usually ends with '/.well-known/openid-configuration'.\n\nNeed detailed setup instructions? See the [provider guide]({documentation_url}).",
|
||||
"description": "Enter the discovery URL for {provider_name}. This is typically found in your provider's admin interface and ends with '/.well-known/openid-configuration'.\n\nNeed detailed setup instructions? See the [provider guide]({documentation_url}).",
|
||||
"data": {
|
||||
"discovery_url": "Discovery URL"
|
||||
}
|
||||
},
|
||||
"validate_connection": {
|
||||
"title": "Connection Validation",
|
||||
"description": "Testing connection to your {provider_name} OIDC provider...\n\n**Discovery URL:** {discovery_url}\n\n{discovery_details}\n\n**What to do next:**\n- **Continue Setup:** Proceed with the configuration (when validation succeeds)\n- **Retry Validation:** Test the connection again with current settings\n- **Modify Discovery URL:** Go back to change the discovery URL\n- **Change Provider:** Start over with a different provider\n\n**Need Help?** Check the [setup documentation]({documentation_url}) for detailed configuration instructions.",
|
||||
"data": {
|
||||
"action": "Choose an action"
|
||||
}
|
||||
},
|
||||
"client_config": {
|
||||
"title": "Client Configuration",
|
||||
"description": "Configure your OIDC client. You can find these details in your {provider_name} application settings.\n\n**Discovery URL:** {discovery_url}\n\n**Setup Instructions:**\n1. Register a new application in your OIDC provider\n2. Set the application type to 'Public Client' (recommended for most users)\n3. Add redirect URLs for Home Assistant\n4. Copy the Client ID below\n\n**Note:** If your provider requires a client secret, check 'Use Confidential Client' and provide your client secret below.\n\n**Need detailed setup instructions?** Check the [setup guide]({documentation_url}) for step-by-step instructions.",
|
||||
@@ -23,20 +30,6 @@
|
||||
"client_secret": "Client Secret (optional; required by some providers)"
|
||||
}
|
||||
},
|
||||
"client_secret": {
|
||||
"title": "Client Secret Configuration",
|
||||
"description": "Since you selected 'Confidential Client', please provide your client secret.\n\n**Provider:** {provider_name}\n**Client ID:** {client_id}\n**Discovery URL:** {discovery_url}\n\n**Security Note:** The client secret will be stored securely in Home Assistant's configuration. Never share your client secret with others.",
|
||||
"data": {
|
||||
"client_secret": "Client Secret"
|
||||
}
|
||||
},
|
||||
"validate_connection": {
|
||||
"title": "Connection Validation",
|
||||
"description": "Testing connection to your {provider_name} OIDC provider...\n\n**Discovery URL:** {discovery_url}\n**Client ID:** {client_id}\n\n{discovery_details}\n\n**What to do next:**\n- **Continue Setup:** Proceed with the configuration (when validation succeeds)\n- **Retry Validation:** Test the connection again with current settings\n- **Modify Client Settings:** Go back to change Client ID or secret\n- **Modify Discovery URL:** Go back to change the discovery URL\n- **Change Provider:** Start over with a different provider\n\n**Need Help?** Check the [setup documentation]({documentation_url}) for detailed configuration instructions.",
|
||||
"data": {
|
||||
"action": "Choose an action"
|
||||
}
|
||||
},
|
||||
"groups_config": {
|
||||
"title": "Groups & Role Configuration",
|
||||
"description": "Configure how user groups from {provider_name} should be mapped to Home Assistant roles.\n\n**Groups Support:** Groups allow you to automatically assign admin or user roles based on group membership in your identity provider.\n\n**Admin Group:** Users in this group will have administrator access\n**User Group:** Users in this group will have standard user access (leave empty to allow all authenticated users)",
|
||||
@@ -71,12 +64,8 @@
|
||||
"cannot_connect": "Failed to connect to the OIDC provider. Please check your network connection and discovery URL.",
|
||||
"discovery_invalid": "The discovery document could not be retrieved or is invalid. Please verify the discovery URL is correct.",
|
||||
"jwks_invalid": "Failed to retrieve or validate the JWKS (JSON Web Key Set). Please check your provider configuration.",
|
||||
"invalid_client_credentials": "The client ID or client secret appears to be invalid. Please check your OIDC application settings and ensure the credentials are correct.",
|
||||
"client_secret_required": "Client secret is required when using confidential client mode.",
|
||||
"invalid_url_format": "The discovery URL must be a valid HTTP or HTTPS URL.",
|
||||
"invalid_url_format": "The discovery URL must be a valid HTTP or HTTPS URL and should end with '/.well-known/openid-configuration'",
|
||||
"invalid_client_id": "Client ID cannot be empty and must contain valid characters.",
|
||||
"no_url_available": "Unable to determine Home Assistant URL for OAuth redirect. Please check your network configuration.",
|
||||
"auth_url_failed": "Failed to generate authorization URL for OAuth test.",
|
||||
"unknown": "An unexpected error occurred. Please check the logs for more details."
|
||||
},
|
||||
"abort": {
|
||||
@@ -84,7 +73,8 @@
|
||||
"cannot_connect": "Unable to connect to the OIDC provider.",
|
||||
"invalid_discovery": "Invalid discovery document received from the provider.",
|
||||
"reconfigure_successful": "OIDC Authentication has been successfully reconfigured with the updated client credentials.",
|
||||
"single_instance_allowed": "OIDC Authentication only supports a single configuration. You already have OIDC configured (either through YAML or the UI). To modify your existing configuration, go to Settings > Devices & Services > OIDC Authentication and click 'Configure'. To replace your configuration, first remove the existing one."
|
||||
"single_instance_allowed": "OIDC Authentication only supports a single configuration. You already have OIDC configured in the UI. To modify your existing configuration, go to Settings > Devices & Services > OIDC Authentication and click 'Configure'. To replace your configuration, first remove the existing one.",
|
||||
"yaml_configured": "You are currently using YAML configuration for this integration. To switch to UI configuration, please remove the YAML configuration first. Note that some advanced options configured via YAML may not be available in the UI."
|
||||
}
|
||||
},
|
||||
"options": {
|
||||
|
||||
0
custom_components/auth_oidc/views/__init__.py
Normal file
0
custom_components/auth_oidc/views/__init__.py
Normal file
Reference in New Issue
Block a user