Improved config options for OIDC (#9)
Added many new configuration options, including claim configuration and client_secret/confidential client support. Also enables user linking & creates person entries upon first sign in.
This commit is contained in:
committed by
GitHub
parent
ca83e86acb
commit
db4c6bcade
@@ -3,9 +3,25 @@
|
||||
import logging
|
||||
from typing import OrderedDict
|
||||
|
||||
import voluptuous as vol
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
# Import and re-export config schema explictly
|
||||
# pylint: disable=useless-import-alias
|
||||
from .config import (
|
||||
CONFIG_SCHEMA as CONFIG_SCHEMA,
|
||||
DOMAIN,
|
||||
DEFAULT_TITLE,
|
||||
CLIENT_ID,
|
||||
CLIENT_SECRET,
|
||||
DISCOVERY_URL,
|
||||
DISPLAY_NAME,
|
||||
ID_TOKEN_SIGNING_ALGORITHM,
|
||||
FEATURES,
|
||||
CLAIMS,
|
||||
)
|
||||
|
||||
# pylint: enable=useless-import-alias
|
||||
|
||||
from .endpoints.welcome import OIDCWelcomeView
|
||||
from .endpoints.redirect import OIDCRedirectView
|
||||
from .endpoints.finish import OIDCFinishView
|
||||
@@ -14,52 +30,47 @@ from .endpoints.callback import OIDCCallbackView
|
||||
from .oidc_client import OIDCClient
|
||||
from .provider import OpenIDAuthProvider
|
||||
|
||||
DOMAIN = "auth_oidc"
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{
|
||||
vol.Required("client_id"): vol.Coerce(str),
|
||||
vol.Optional("client_secret"): vol.Coerce(str),
|
||||
vol.Required("discovery_url"): vol.Coerce(str),
|
||||
}
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config):
|
||||
"""Add the OIDC Auth Provider to the providers in Home Assistant"""
|
||||
my_config = config[DOMAIN]
|
||||
|
||||
providers = OrderedDict()
|
||||
|
||||
# Use private APIs until there is a real auth platform
|
||||
# pylint: disable=protected-access
|
||||
provider = OpenIDAuthProvider(
|
||||
hass,
|
||||
hass.auth._store,
|
||||
config[DOMAIN],
|
||||
)
|
||||
provider = OpenIDAuthProvider(hass, hass.auth._store, my_config)
|
||||
|
||||
providers[(provider.type, provider.id)] = provider
|
||||
providers.update(hass.auth._providers)
|
||||
hass.auth._providers = providers
|
||||
# pylint: enable=protected-access
|
||||
|
||||
_LOGGER.debug("Added OIDC provider for Home Assistant")
|
||||
_LOGGER.info("Registered OIDC provider")
|
||||
|
||||
# Define some fields
|
||||
discovery_url: str = config[DOMAIN]["discovery_url"]
|
||||
client_id: str = config[DOMAIN]["client_id"]
|
||||
scope: str = "openid profile email"
|
||||
# We only use openid & profile, never email
|
||||
scope = "openid profile"
|
||||
|
||||
oidc_client = oidc_client = OIDCClient(discovery_url, client_id, scope)
|
||||
oidc_client = oidc_client = OIDCClient(
|
||||
discovery_url=my_config.get(DISCOVERY_URL),
|
||||
client_id=my_config.get(CLIENT_ID),
|
||||
scope=scope,
|
||||
client_secret=my_config.get(CLIENT_SECRET),
|
||||
id_token_signing_alg=my_config.get(ID_TOKEN_SIGNING_ALGORITHM),
|
||||
features=my_config.get(FEATURES, {}),
|
||||
claims=my_config.get(CLAIMS, {}),
|
||||
)
|
||||
|
||||
hass.http.register_view(OIDCWelcomeView())
|
||||
# Register the views
|
||||
name = config[DOMAIN].get(DISPLAY_NAME, DEFAULT_TITLE)
|
||||
|
||||
hass.http.register_view(OIDCWelcomeView(name))
|
||||
hass.http.register_view(OIDCRedirectView(oidc_client))
|
||||
hass.http.register_view(OIDCCallbackView(oidc_client, provider))
|
||||
hass.http.register_view(OIDCFinishView())
|
||||
|
||||
_LOGGER.info("Registered OIDC views")
|
||||
|
||||
return True
|
||||
|
||||
72
custom_components/auth_oidc/config.py
Normal file
72
custom_components/auth_oidc/config.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Config schema and constants."""
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
CLIENT_ID = "client_id"
|
||||
CLIENT_SECRET = "client_secret"
|
||||
DISCOVERY_URL = "discovery_url"
|
||||
DISPLAY_NAME = "display_name"
|
||||
ID_TOKEN_SIGNING_ALGORITHM = "id_token_signing_alg"
|
||||
FEATURES = "features"
|
||||
FEATURES_AUTOMATIC_USER_LINKING = "automatic_user_linking"
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION = "automatic_person_creation"
|
||||
FEATURES_DISABLE_PKCE = "disable_rfc7636"
|
||||
CLAIMS = "claims"
|
||||
CLAIMS_DISPLAY_NAME = "display_name"
|
||||
CLAIMS_USERNAME = "username"
|
||||
CLAIMS_GROUPS = "groups"
|
||||
|
||||
DEFAULT_TITLE = "OpenID Connect (SSO)"
|
||||
|
||||
DOMAIN = "auth_oidc"
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{
|
||||
# Required client ID as registered with the OIDC provider
|
||||
vol.Required(CLIENT_ID): vol.Coerce(str),
|
||||
# Optional Client Secret to enable confidential client mode
|
||||
vol.Optional(CLIENT_SECRET): vol.Coerce(str),
|
||||
# Which OIDC well-known URL should we use?
|
||||
vol.Required(DISCOVERY_URL): vol.Coerce(str),
|
||||
# Which name should be shown on the login screens?
|
||||
vol.Optional(DISPLAY_NAME): vol.Coerce(str),
|
||||
# Should we enforce a specific signing algorithm on the id tokens?
|
||||
# Defaults to RS256/RSA-pubkey
|
||||
vol.Optional(ID_TOKEN_SIGNING_ALGORITHM): vol.Coerce(str),
|
||||
# Which features should be enabled/disabled?
|
||||
# Optional, defaults to sane/secure defaults
|
||||
vol.Optional(FEATURES): vol.Schema(
|
||||
{
|
||||
# Automatically links users to the HA user based on OIDC username claim
|
||||
# See provider.py for explanation
|
||||
vol.Optional(FEATURES_AUTOMATIC_USER_LINKING): vol.Coerce(bool),
|
||||
# Automatically creates a person entry for your new OIDC user
|
||||
# See provider.py for explanation
|
||||
vol.Optional(FEATURES_AUTOMATIC_PERSON_CREATION): vol.Coerce(
|
||||
bool
|
||||
),
|
||||
# Feature flag to disable PKCE to support OIDC servers that do not
|
||||
# allow additional parameters and don't support RFC 7636
|
||||
vol.Optional(FEATURES_DISABLE_PKCE): vol.Coerce(bool),
|
||||
}
|
||||
),
|
||||
# Determine which specific claims will be used from the id_token
|
||||
# Optional, defaults to most common claims
|
||||
vol.Optional(CLAIMS): vol.Schema(
|
||||
{
|
||||
# Which claim should we use to obtain the display name from OIDC?
|
||||
vol.Optional(CLAIMS_DISPLAY_NAME): vol.Coerce(str),
|
||||
# Which claim should we use to obtain the username from OIDC?
|
||||
vol.Optional(CLAIMS_USERNAME): vol.Coerce(str),
|
||||
# Which claim should we use to obtain the group(s) from OIDC?
|
||||
vol.Optional(CLAIMS_GROUPS): vol.Coerce(str),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
},
|
||||
# Any extra fields should not go into our config right now
|
||||
# You may set them for upgrading etc
|
||||
extra=vol.REMOVE_EXTRA,
|
||||
)
|
||||
@@ -14,7 +14,10 @@ class OIDCWelcomeView(HomeAssistantView):
|
||||
url = PATH
|
||||
name = "auth:oidc:welcome"
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
async def get(self, _: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
view_html = await get_view("welcome")
|
||||
view_html = await get_view("welcome", {"name": self.name})
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
|
||||
@@ -18,5 +18,5 @@
|
||||
"aiofiles>=24.1.0",
|
||||
"jinja2>=3.1.4"
|
||||
],
|
||||
"version": "0.3.0"
|
||||
"version": "0.4.0"
|
||||
}
|
||||
@@ -9,6 +9,14 @@ from typing import Optional
|
||||
import aiohttp
|
||||
from jose import jwt, jwk
|
||||
|
||||
from .types import UserDetails
|
||||
from .config import (
|
||||
FEATURES_DISABLE_PKCE,
|
||||
CLAIMS_DISPLAY_NAME,
|
||||
CLAIMS_USERNAME,
|
||||
CLAIMS_GROUPS,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -32,19 +40,49 @@ class OIDCStateInvalid(OIDCClientException):
|
||||
"Raised when the state for your request cannot be matched against a stored state."
|
||||
|
||||
|
||||
class OIDCIdTokenSigningAlgorithmInvalid(OIDCTokenResponseInvalid):
|
||||
"Raised when the id_token is signed with the wrong algorithm, adjust your config accordingly."
|
||||
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
class OIDCClient:
|
||||
"""OIDC Client implementation for Python, including PKCE."""
|
||||
|
||||
# Flows stores the state, code_verifier and nonce of all current flows.
|
||||
flows = {}
|
||||
|
||||
def __init__(self, discovery_url: str, client_id: str, scope: str):
|
||||
def __init__(self, discovery_url: str, client_id: str, scope: str, **kwargs: str):
|
||||
self.discovery_url = discovery_url
|
||||
self.discovery_document = None
|
||||
self.client_id = client_id
|
||||
self.scope = scope
|
||||
|
||||
# Optional parameters
|
||||
self.client_secret = kwargs.get("client_secret")
|
||||
|
||||
# Default id_token_signing_alg to RS256 if not specified
|
||||
self.id_token_signing_alg = kwargs.get("id_token_signing_alg")
|
||||
if self.id_token_signing_alg is None:
|
||||
self.id_token_signing_alg = "RS256"
|
||||
|
||||
features = kwargs.get("features")
|
||||
claims = kwargs.get("claims")
|
||||
|
||||
self.disable_pkce: bool = features.get(FEATURES_DISABLE_PKCE)
|
||||
self.display_name_claim = claims.get(CLAIMS_DISPLAY_NAME, "name")
|
||||
self.username_claim = claims.get(CLAIMS_USERNAME, "preferred_username")
|
||||
self.groups_claim = claims.get(CLAIMS_GROUPS, "groups")
|
||||
|
||||
def _base64url_encode(self, value: str) -> str:
|
||||
"""Uses base64url encoding on a given string"""
|
||||
return base64.urlsafe_b64encode(value).rstrip(b"=").decode("utf-8")
|
||||
|
||||
def _generate_random_url_string(self, length: int = 16) -> str:
|
||||
"""Generates a random URL safe string (base64_url encoded)"""
|
||||
return self._base64url_encode(os.urandom(length))
|
||||
|
||||
async def _fetch_discovery_document(self):
|
||||
"""Fetches discovery document from the given URL."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.discovery_url) as response:
|
||||
@@ -59,6 +97,161 @@ class OIDCClient:
|
||||
_LOGGER.warning("Error: %s - %s", e.status, e.message)
|
||||
raise OIDCDiscoveryInvalid from e
|
||||
|
||||
async def _get_jwks(self, jwks_uri):
|
||||
"""Fetches JWKS from the given URL."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(jwks_uri) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
_LOGGER.warning("Error fetching JWKS: %s - %s", e.status, e.message)
|
||||
raise OIDCJWKSInvalid from e
|
||||
|
||||
async def _make_token_request(self, token_endpoint, query_params):
|
||||
"""Performs the token POST call"""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(token_endpoint, data=query_params) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
if e.status == 400:
|
||||
_LOGGER.warning(
|
||||
"Error: Token could not be obtained (Bad Request), "
|
||||
+ "did you forget the client_secret?"
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning(
|
||||
"Unexpected error exchanging token: %s - %s", e.status, e.message
|
||||
)
|
||||
raise OIDCTokenResponseInvalid from e
|
||||
|
||||
async def _parse_id_token(
|
||||
self, id_token: str, access_token: str | None
|
||||
) -> Optional[dict]:
|
||||
"""Parses the ID token into a dict containing token contents."""
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
jwks_uri = self.discovery_document["jwks_uri"]
|
||||
jwks_data = await self._get_jwks(jwks_uri)
|
||||
|
||||
try:
|
||||
# Obtain the id_token header
|
||||
unverified_header = jwt.get_unverified_header(id_token)
|
||||
if not unverified_header:
|
||||
_LOGGER.warning("Could not get header from received id_token.")
|
||||
return None
|
||||
|
||||
# Obtain the signing algorithm from the header of the id_token
|
||||
alg = unverified_header.get("alg")
|
||||
if alg != self.id_token_signing_alg:
|
||||
# Verify that it matches our requested algorithm
|
||||
_LOGGER.warning(
|
||||
"ID Token received signed with the wrong algorithm: %s, expected %s",
|
||||
alg,
|
||||
self.id_token_signing_alg,
|
||||
)
|
||||
raise OIDCIdTokenSigningAlgorithmInvalid()
|
||||
|
||||
# OpenID Connect Core 1.0 Section 3.1.3.7.8
|
||||
# If the JWT alg Header Parameter uses a MAC based algorithm
|
||||
# such as HS256, HS384, or HS512, the octets of the UTF-8 [RFC3629]
|
||||
# representation of the client_secret corresponding to the client_id
|
||||
# contained in the aud (audience) Claim are used as the key to
|
||||
# validate the signature.
|
||||
if alg.startswith("HS"):
|
||||
if not self.client_secret:
|
||||
_LOGGER.warning(
|
||||
"ID Token signed with HMAC algorithm, but no client_secret provided."
|
||||
)
|
||||
raise OIDCIdTokenSigningAlgorithmInvalid()
|
||||
|
||||
jwk_obj = jwk.construct(
|
||||
{
|
||||
"kty": "oct",
|
||||
"k": base64.urlsafe_b64encode(
|
||||
self.client_secret.encode()
|
||||
).decode(),
|
||||
"alg": alg,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# TODO: Deal with cases where kid is not specified (just take the first key?)
|
||||
# Obtain the kid (Key ID) from the header of the id_token
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
_LOGGER.warning("JWT does not have kid (Key ID)")
|
||||
return None
|
||||
|
||||
# Get the correct key
|
||||
signing_key = None
|
||||
for key in jwks_data["keys"]:
|
||||
if key["kid"] == kid:
|
||||
signing_key = key
|
||||
break
|
||||
|
||||
if not signing_key:
|
||||
_LOGGER.warning("Could not find matching key with kid: %s", kid)
|
||||
return None
|
||||
|
||||
# Construct the JWK from the RSA key
|
||||
jwk_obj = jwk.construct(signing_key)
|
||||
|
||||
# Verify the token
|
||||
decoded_token = jwt.decode(
|
||||
id_token,
|
||||
jwk_obj,
|
||||
# OpenID Connect Core 1.0 Section 3.1.3.7.6
|
||||
# The Client MUST validate the signature of all other ID Tokens
|
||||
# according to JWS [JWS] using the algorithm specified in the JWT
|
||||
# alg Header Parameter.
|
||||
algorithms=[self.id_token_signing_alg],
|
||||
# OpenID Connect Core 1.0 Section 3.1.3.7.3
|
||||
# The Client MUST validate that the aud (audience) Claim contains
|
||||
# its client_id value registered at the Issuer identified by the
|
||||
# iss (issuer) Claim as an audience.
|
||||
audience=self.client_id,
|
||||
# OpenID Connect Core 1.0 Section 3.1.3.7.2
|
||||
# The Issuer Identifier for the OpenID Provider MUST exactly
|
||||
# match the value of the iss (issuer) Claim.
|
||||
issuer=self.discovery_document["issuer"],
|
||||
access_token=access_token,
|
||||
options={
|
||||
# Verify everything if present
|
||||
"verify_signature": True,
|
||||
"verify_aud": True,
|
||||
"verify_iat": True,
|
||||
"verify_exp": True,
|
||||
"verify_nbf": True,
|
||||
"verify_iss": True,
|
||||
"verify_sub": True,
|
||||
"verify_jti": True,
|
||||
"verify_at_hash": True,
|
||||
# OpenID Connect Core 1.0 Section 3.1.3.7.3
|
||||
"require_aud": True,
|
||||
# OpenID Connect Core 1.0 Section 3.1.3.7.10
|
||||
"require_iat": True,
|
||||
# OpenID Connect Core 1.0 Section 3.1.3.7.9
|
||||
"require_exp": True,
|
||||
# OpenID Connect Core 1.0 Section 3.1.3.7.2
|
||||
"require_iss": True,
|
||||
# We need the sub as it's used to identify the user
|
||||
"require_sub": True,
|
||||
# Other values, not required.
|
||||
"require_nbf": False,
|
||||
"require_jti": False,
|
||||
"require_at_hash": False,
|
||||
"leeway": 5,
|
||||
},
|
||||
)
|
||||
return decoded_token
|
||||
|
||||
except jwt.JWTError as e:
|
||||
_LOGGER.warning("JWT Verification failed: %s", e)
|
||||
return None
|
||||
|
||||
async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]:
|
||||
"""Generates the authorization URL for the OIDC flow."""
|
||||
try:
|
||||
@@ -67,22 +260,14 @@ class OIDCClient:
|
||||
|
||||
auth_endpoint = self.discovery_document["authorization_endpoint"]
|
||||
|
||||
# Generate the necessary PKCE parameters, nonce & state
|
||||
code_verifier = (
|
||||
base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(
|
||||
hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
)
|
||||
.rstrip(b"=")
|
||||
.decode("utf-8")
|
||||
)
|
||||
nonce = (
|
||||
base64.urlsafe_b64encode(os.urandom(16)).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
state = (
|
||||
base64.urlsafe_b64encode(os.urandom(16)).rstrip(b"=").decode("utf-8")
|
||||
# Generate random nonce & state
|
||||
nonce = self._generate_random_url_string()
|
||||
state = self._generate_random_url_string()
|
||||
|
||||
# Generate PKCE (RFC 7636) parameters
|
||||
code_verifier = self._generate_random_url_string(32)
|
||||
code_challenge = self._base64url_encode(
|
||||
hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
)
|
||||
|
||||
# Save all of them for later verification
|
||||
@@ -95,92 +280,27 @@ class OIDCClient:
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": self.scope,
|
||||
"state": state,
|
||||
# Nonce is always set in accordance with OpenID Connect Core 1.0
|
||||
"nonce": nonce,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
# We always want to use PKCE (RFC 7636), unless it's disabled for compatibility.
|
||||
# PKCE is the recommended method of securing the authorization code grant
|
||||
# for public clients as much as possible.
|
||||
# (see https://datatracker.ietf.org/doc/html/draft-ietf-oauth-v2-1-11#section-7.5.1)
|
||||
if not self.disable_pkce:
|
||||
query_params["code_challenge"] = code_challenge
|
||||
query_params["code_challenge_method"] = "S256"
|
||||
|
||||
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
|
||||
return url
|
||||
except OIDCClientException as e:
|
||||
_LOGGER.warning("Error generating authorization URL: %s", e)
|
||||
return None
|
||||
|
||||
async def _make_token_request(self, token_endpoint, query_params):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(token_endpoint, data=query_params) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
_LOGGER.warning("Error exchanging token: %s - %s", e.status, e.message)
|
||||
raise OIDCTokenResponseInvalid from e
|
||||
|
||||
async def _get_jwks(self, jwks_uri):
|
||||
"""Fetches JWKS from the given URL."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(jwks_uri) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
_LOGGER.warning("Error fetching JWKS: %s - %s", e.status, e.message)
|
||||
raise OIDCJWKSInvalid from e
|
||||
|
||||
async def _parse_id_token(self, id_token: str):
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
# Parse the id token to obtain the relevant details
|
||||
# Use python-jose
|
||||
|
||||
jwks_uri = self.discovery_document["jwks_uri"]
|
||||
jwks_data = await self._get_jwks(jwks_uri)
|
||||
|
||||
try:
|
||||
unverified_header = jwt.get_unverified_header(id_token)
|
||||
if not unverified_header:
|
||||
print("Could not parse JWT Header")
|
||||
return None
|
||||
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
print("JWT does not have kid (Key ID)")
|
||||
return None
|
||||
|
||||
# Get the correct key
|
||||
rsa_key = None
|
||||
for key in jwks_data["keys"]:
|
||||
if key["kid"] == kid:
|
||||
rsa_key = key
|
||||
break
|
||||
|
||||
if not rsa_key:
|
||||
print(f"Could not find matching key with kid:{kid}")
|
||||
return None
|
||||
|
||||
# Construct the JWK
|
||||
jwk_obj = jwk.construct(rsa_key)
|
||||
|
||||
# Verify the token
|
||||
decoded_token = jwt.decode(
|
||||
id_token,
|
||||
jwk_obj,
|
||||
algorithms=["RS256"], # Adjust if your algorithm is different
|
||||
audience=self.client_id,
|
||||
issuer=self.discovery_document["issuer"],
|
||||
)
|
||||
return decoded_token
|
||||
|
||||
except jwt.JWTError as e:
|
||||
print(f"JWT Verification failed: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def async_complete_token_flow(
|
||||
self, redirect_uri: str, code: str, state: str
|
||||
) -> dict[str, str | dict]:
|
||||
) -> Optional[UserDetails]:
|
||||
"""Completes the OIDC token flow to obtain a user's details."""
|
||||
|
||||
try:
|
||||
@@ -188,7 +308,6 @@ class OIDCClient:
|
||||
raise OIDCStateInvalid
|
||||
|
||||
flow = self.flows[state]
|
||||
code_verifier = flow["code_verifier"]
|
||||
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
@@ -201,27 +320,69 @@ class OIDCClient:
|
||||
"client_id": self.client_id,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
|
||||
# Send the client secret if we have one
|
||||
if self.client_secret is not None:
|
||||
query_params["client_secret"] = self.client_secret
|
||||
|
||||
# If we disable PKCE, don't send the code verifier
|
||||
if not self.disable_pkce:
|
||||
query_params["code_verifier"] = flow["code_verifier"]
|
||||
|
||||
# Exchange the code for a token
|
||||
token_response = await self._make_token_request(
|
||||
token_endpoint, query_params
|
||||
)
|
||||
|
||||
id_token = token_response.get("id_token")
|
||||
access_token = token_response.get("access_token")
|
||||
|
||||
# Parse the id token to obtain the relevant details
|
||||
id_token = await self._parse_id_token(id_token)
|
||||
# Access token is supplied to check at_hash if present
|
||||
id_token = await self._parse_id_token(id_token, access_token)
|
||||
|
||||
# Verify nonce
|
||||
if id_token is None:
|
||||
_LOGGER.warning("ID token could not be parsed!")
|
||||
return None
|
||||
|
||||
# OpenID Connect Core 1.0 Section 3.1.3.7.11
|
||||
# If a nonce value was sent in the Authentication Request,
|
||||
# a nonce Claim MUST be present and its value checked to verify
|
||||
# that it is the same value as the one that was sent in the Authentication Request.
|
||||
if id_token.get("nonce") != flow["nonce"]:
|
||||
_LOGGER.warning("Nonce mismatch!")
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": id_token.get("name"),
|
||||
"username": id_token.get("preferred_username"),
|
||||
"groups": id_token.get("groups"),
|
||||
# TODO: If the configured claims are not present in id_token, we should fetch userinfo
|
||||
|
||||
# Create a user details dict based on the contents of the id_token & userinfo
|
||||
data: UserDetails = {
|
||||
# Subject Identifier. A locally unique and never reassigned identifier within the
|
||||
# Issuer for the End-User, which is intended to be consumed by the Client
|
||||
# Only unique per issuer, so we combine it with the issuer and hash it.
|
||||
# This might allow multiple OIDC providers to be used with this integration.
|
||||
"sub": hashlib.sha256(
|
||||
f"{self.discovery_document['issuer']}.{id_token.get('sub')}".encode(
|
||||
"utf-8"
|
||||
)
|
||||
).hexdigest(),
|
||||
# Display name, configurable
|
||||
"display_name": id_token.get(self.display_name_claim),
|
||||
# Username, configurable
|
||||
"username": id_token.get(self.username_claim),
|
||||
# Groups, configurable
|
||||
"groups": id_token.get(self.groups_claim),
|
||||
}
|
||||
|
||||
# Log which details were obtained for debugging
|
||||
# Also log the original subject identifier such that you can look it up in your provider
|
||||
_LOGGER.debug(
|
||||
"Obtained user details from OIDC provider: %s (issuer subject: %s)",
|
||||
data,
|
||||
id_token.get("sub"),
|
||||
)
|
||||
return data
|
||||
except OIDCClientException as e:
|
||||
_LOGGER.warning("Error completing token flow: %s", e)
|
||||
return None
|
||||
|
||||
@@ -6,6 +6,7 @@ import logging
|
||||
|
||||
from typing import Dict, Optional
|
||||
import asyncio
|
||||
from homeassistant.auth import EVENT_USER_ADDED
|
||||
from homeassistant.auth.providers import (
|
||||
AUTH_PROVIDERS,
|
||||
AuthProvider,
|
||||
@@ -13,15 +14,29 @@ from homeassistant.auth.providers import (
|
||||
AuthFlowResult,
|
||||
Credentials,
|
||||
UserMeta,
|
||||
User,
|
||||
AuthStore,
|
||||
)
|
||||
from homeassistant.components import http
|
||||
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.components import http, person
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import voluptuous as vol
|
||||
|
||||
from .config import (
|
||||
FEATURES,
|
||||
FEATURES_AUTOMATIC_USER_LINKING,
|
||||
FEATURES_AUTOMATIC_PERSON_CREATION,
|
||||
DEFAULT_TITLE,
|
||||
)
|
||||
from .stores.code_store import CodeStore
|
||||
from .types import UserDetails
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
PROVIDER_TYPE = "auth_oidc"
|
||||
HASS_PROVIDER_TYPE = "homeassistant"
|
||||
|
||||
|
||||
class InvalidAuthError(HomeAssistantError):
|
||||
"""Raised when submitting invalid authentication."""
|
||||
@@ -32,23 +47,44 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
"""Allow access to users based on login with an external
|
||||
OpenID Connect Identity Provider (IdP)."""
|
||||
|
||||
DEFAULT_TITLE = "OpenID Connect (SSO)"
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "auth_oidc"
|
||||
|
||||
@property
|
||||
def support_mfa(self) -> bool:
|
||||
return False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, hass: HomeAssistant, store: AuthStore, config: dict[str, str]):
|
||||
"""Initialize the OpenIDAuthProvider."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._user_meta = {}
|
||||
super().__init__(
|
||||
hass,
|
||||
store,
|
||||
{
|
||||
# Currently register as default, might be used when we have multiple OIDC providers
|
||||
CONF_ID: "default",
|
||||
# Name displayed in the UI
|
||||
CONF_NAME: config.get("display_name", DEFAULT_TITLE),
|
||||
# Type
|
||||
CONF_TYPE: PROVIDER_TYPE,
|
||||
},
|
||||
)
|
||||
|
||||
self._user_meta: dict[UserDetails] = {}
|
||||
self._code_store: CodeStore | None = None
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
features = config.get(
|
||||
FEATURES,
|
||||
{},
|
||||
)
|
||||
|
||||
# Link users automatically?
|
||||
# False by default to always make new accounts for OIDC users
|
||||
# Turn this on to migrate from HA accounts to OIDC
|
||||
self.user_linking = features.get(FEATURES_AUTOMATIC_USER_LINKING, False)
|
||||
|
||||
# Create person entries automatically?
|
||||
# True by default to create a person for each new user (just like normal HA)
|
||||
# Turn this off if you don't want OIDC to interfere more than necessary
|
||||
self.create_persons = features.get(FEATURES_AUTOMATIC_PERSON_CREATION, True)
|
||||
|
||||
async def async_initialize(self) -> None:
|
||||
"""Initialize the auth provider."""
|
||||
|
||||
@@ -64,8 +100,11 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
self._code_store = store
|
||||
self._user_meta = {}
|
||||
|
||||
async def async_retrieve_username(self, code: str) -> Optional[str]:
|
||||
"""Retrieve user from the code, return username and save meta
|
||||
# Listen for user creation events
|
||||
self.hass.bus.async_listen(EVENT_USER_ADDED, self.async_user_created)
|
||||
|
||||
async def async_get_subject(self, code: str) -> Optional[str]:
|
||||
"""Retrieve user from the code, return subject and save meta
|
||||
for later use with this provider instance."""
|
||||
if self._code_store is None:
|
||||
await self.async_initialize()
|
||||
@@ -75,9 +114,9 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
if user_data is None:
|
||||
return None
|
||||
|
||||
username = user_data["username"]
|
||||
self._user_meta[username] = user_data
|
||||
return username
|
||||
sub = user_data["sub"]
|
||||
self._user_meta[sub] = user_data
|
||||
return sub
|
||||
|
||||
async def async_save_user_info(self, user_info: dict[str, dict | str]) -> str:
|
||||
"""Save user info and return a code."""
|
||||
@@ -87,6 +126,77 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
|
||||
return await self._code_store.async_generate_code_for_userinfo(user_info)
|
||||
|
||||
async def _async_find_user_by_username(self, username: str) -> Optional[User]:
|
||||
"""Find a user by username."""
|
||||
users = await self.store.async_get_users()
|
||||
for user in users:
|
||||
# System generated users don't have usernames and aren't our target here
|
||||
if user.system_generated:
|
||||
continue
|
||||
|
||||
# Check if we have a homeassistant credential with the provided username
|
||||
for credential in user.credentials:
|
||||
if (
|
||||
credential.auth_provider_type == HASS_PROVIDER_TYPE
|
||||
and credential.data.get("username") == username
|
||||
):
|
||||
return user
|
||||
|
||||
return None
|
||||
|
||||
# ====
|
||||
# Handler for user created and related functions (person creation)
|
||||
# ====
|
||||
|
||||
@callback
|
||||
async def async_user_created(self, event) -> None:
|
||||
"""Handle the user created event."""
|
||||
user_id = event.data["user_id"]
|
||||
user = await self.store.async_get_user(user_id)
|
||||
|
||||
# Get the first credential, if it's not ours, return
|
||||
if not user.credentials or len(user.credentials) == 0:
|
||||
return
|
||||
|
||||
credential = user.credentials[0]
|
||||
if not (
|
||||
credential.auth_provider_type == self.type
|
||||
and credential.auth_provider_id == self.id
|
||||
):
|
||||
# Not mine, return
|
||||
return
|
||||
|
||||
# Audit log the user creation
|
||||
_LOGGER.info(
|
||||
"User was created for first OIDC sign in: %s from subject %s",
|
||||
user.id,
|
||||
credential.data["sub"],
|
||||
)
|
||||
|
||||
# If person creation is enabled, add a person for this user
|
||||
if self.create_persons:
|
||||
user_meta = await self.async_user_meta_for_credentials(credential)
|
||||
await self.async_create_person(user, user_meta.name)
|
||||
|
||||
async def async_create_person(self, user: User, name: str) -> None:
|
||||
"""Create a person for the user."""
|
||||
_LOGGER.info("Automatically creating person for new user %s", user.id)
|
||||
|
||||
# Create a person for the user
|
||||
try:
|
||||
await person.async_create_person(
|
||||
hass=self.hass,
|
||||
name=name,
|
||||
user_id=user.id,
|
||||
)
|
||||
# Catch all, we don't want to fail here
|
||||
# pylint: disable=broad-exception-caught
|
||||
except Exception:
|
||||
_LOGGER.warning(
|
||||
"Requested automatic person creation, but person creation failed."
|
||||
)
|
||||
# pylint: enable=broad-exception-caught
|
||||
|
||||
# ====
|
||||
# Required functions for Home Assistant Auth Providers
|
||||
# ====
|
||||
@@ -99,13 +209,43 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
self, flow_result: dict[str, str]
|
||||
) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
username = flow_result["username"]
|
||||
sub = flow_result["sub"]
|
||||
meta = self._user_meta.get(sub)
|
||||
|
||||
# Audit logging for the login that is about to occur
|
||||
_LOGGER.info(
|
||||
"Logged in user through OIDC: %s, %s", meta["sub"], meta["display_name"]
|
||||
)
|
||||
|
||||
# Iterate over previously created credentials to find one with the same sub
|
||||
for credential in await self.async_credentials():
|
||||
if credential.data["username"] == username:
|
||||
# When logging in again, use the subject to check if the credential exist
|
||||
# OpenID spec says that sub is the only claim we can rely on, as username
|
||||
# might change over time.
|
||||
if credential.data.get("sub") == sub:
|
||||
return credential
|
||||
|
||||
# Create new credentials.
|
||||
return self.async_create_credentials({"username": username})
|
||||
# If no credential was found, create a new one
|
||||
# Username cannot be supplied here as it won't be shown by Home Assistant regardless
|
||||
# Source: homeassistant/components/config/auth.py, line 162
|
||||
credential = self.async_create_credentials({"sub": sub})
|
||||
|
||||
# If we have user linking enabled, try to link the user here
|
||||
if self.user_linking:
|
||||
user = await self._async_find_user_by_username(meta["username"])
|
||||
if user is not None:
|
||||
_LOGGER.info(
|
||||
"User already exists, adding credential for "
|
||||
+ "OIDC to existing user with username '%s'.",
|
||||
meta["username"],
|
||||
)
|
||||
|
||||
# Link the credential to the existing user
|
||||
# Will set the credential isNew = false
|
||||
await self.store.async_link_user(user, credential)
|
||||
|
||||
# If the credential is new, HA will automatically create a new user for us
|
||||
return credential
|
||||
|
||||
async def async_user_meta_for_credentials(
|
||||
self, credentials: Credentials
|
||||
@@ -114,15 +254,19 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
|
||||
Currently, supports name, is_active, group and local_only.
|
||||
"""
|
||||
meta = self._user_meta.get(credentials.data["username"], {})
|
||||
|
||||
sub = credentials.data["sub"]
|
||||
meta = self._user_meta.get(sub, {})
|
||||
|
||||
groups = meta.get("groups", [])
|
||||
|
||||
# TODO: Allow setting which group is for admins
|
||||
group = "system-admin" if "admins" in groups else "system-users"
|
||||
return UserMeta(
|
||||
name=meta.get("name"),
|
||||
name=meta.get("display_name"),
|
||||
is_active=True,
|
||||
group=group,
|
||||
local_only="true",
|
||||
local_only=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -130,12 +274,11 @@ class OpenIdLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
async def _finalize_user(self, code: str) -> AuthFlowResult:
|
||||
username = await self._auth_provider.async_retrieve_username(code)
|
||||
if username:
|
||||
_LOGGER.info("Logged in user: %s", username)
|
||||
sub = await self._auth_provider.async_get_subject(code)
|
||||
if sub:
|
||||
return await self.async_finish(
|
||||
{
|
||||
"username": username,
|
||||
"sub": sub,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from typing import cast, Optional
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from ..types import UserDetails
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "auth_provider.auth_oidc.codes"
|
||||
|
||||
@@ -18,7 +20,7 @@ class CodeStore:
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the user data store."""
|
||||
self.hass = hass
|
||||
self._store = Store[dict[str, dict[str, dict | str]]](
|
||||
self._store = Store[dict[str, UserDetails]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._data: dict[str, dict[str, dict | str]] | None = None
|
||||
@@ -26,7 +28,7 @@ class CodeStore:
|
||||
async def async_load(self) -> None:
|
||||
"""Load stored data."""
|
||||
if (data := await self._store.async_load()) is None:
|
||||
data = cast(dict[str, dict[str, dict | str]], {})
|
||||
data = cast(dict[str, UserDetails], {})
|
||||
self._data = data
|
||||
|
||||
async def async_save(self) -> None:
|
||||
@@ -38,9 +40,7 @@ class CodeStore:
|
||||
"""Generate a random six-digit code."""
|
||||
return "".join(random.choices(string.digits, k=6))
|
||||
|
||||
async def async_generate_code_for_userinfo(
|
||||
self, user_info: dict[str, dict | str]
|
||||
) -> str:
|
||||
async def async_generate_code_for_userinfo(self, user_info: UserDetails) -> str:
|
||||
"""Generates a one time code and adds it to the database for 5 minutes."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
@@ -57,9 +57,7 @@ class CodeStore:
|
||||
await self.async_save()
|
||||
return code
|
||||
|
||||
async def receive_userinfo_for_code(
|
||||
self, code: str
|
||||
) -> Optional[dict[str, dict | str]]:
|
||||
async def receive_userinfo_for_code(self, code: str) -> Optional[UserDetails]:
|
||||
"""Retrieve user info based on the code."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
16
custom_components/auth_oidc/types.py
Normal file
16
custom_components/auth_oidc/types.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Generic data types"""
|
||||
|
||||
|
||||
# Dict class to give a type to the user details
|
||||
class UserDetails(dict):
|
||||
"""User details representation"""
|
||||
|
||||
# User subject, persistent identifier
|
||||
sub: str
|
||||
# Full name of the user for display purposes
|
||||
display_name: str
|
||||
# Preferred username for the user, will be used when first generating the account
|
||||
# or to link the account on first login
|
||||
username: str
|
||||
# Groups that the user has, if any are sent from the OIDC provider
|
||||
groups: list[str]
|
||||
@@ -16,7 +16,9 @@ class AsyncTemplateRenderer:
|
||||
"""An asynchronous template renderer that caches rendered templates."""
|
||||
|
||||
def __init__(self, template_dir: str = None):
|
||||
self.template_dir = template_dir or path.dirname(path.abspath(__file__))
|
||||
self.template_dir = template_dir or path.join(
|
||||
path.dirname(path.abspath(__file__)), "templates"
|
||||
)
|
||||
|
||||
async def fetch_templates(self) -> None:
|
||||
"""Fetches all HTML files from the template directory."""
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
<div>
|
||||
<button id="oidc-login-btn"
|
||||
class="w-full py-2 px-4 bg-blue-500 text-white font-semibold rounded-lg shadow-md hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-400 focus:ring-opacity-75">
|
||||
Login with OpenID Connect (SSO)
|
||||
Login with {{ name }}
|
||||
</button>
|
||||
|
||||
<div role="status" id="loader" class="items-center justify-center flex hidden">
|
||||
Reference in New Issue
Block a user