Code quality improvements (v0.2.0-pre-alpha) (#5)
* Bumped version to 0.2.0 * Implemented Github Actions for HACS, Hassfest, Linting * Improved code quality (compliant with the linter now) * Added link to the finish page to automatically login on the same device/browser
This commit is contained in:
committed by
GitHub
parent
a30d42ffce
commit
b4a08b17ab
@@ -1,3 +1,5 @@
|
||||
"""OIDC Integration for Home Assistant."""
|
||||
|
||||
import logging
|
||||
from typing import OrderedDict
|
||||
|
||||
@@ -10,29 +12,31 @@ from .endpoints.finish import OIDCFinishView
|
||||
from .endpoints.callback import OIDCCallbackView
|
||||
|
||||
from .oidc_client import OIDCClient
|
||||
from .provider import OpenIDAuthProvider
|
||||
|
||||
DOMAIN = "auth_oidc"
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
from .provider import OpenIDAuthProvider
|
||||
|
||||
CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{
|
||||
vol.Required("client_id"): vol.Coerce(str),
|
||||
vol.Optional("client_secret"): vol.Coerce(str),
|
||||
vol.Required("discovery_url"): vol.Url(),
|
||||
vol.Required("discovery_url"): vol.Coerce(str),
|
||||
}
|
||||
)
|
||||
},
|
||||
extra=vol.ALLOW_EXTRA,
|
||||
)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config):
|
||||
"""Add the OIDC Auth Provider to the providers in Home Assistant"""
|
||||
providers = OrderedDict()
|
||||
|
||||
# Use private APIs until there is a real auth platform
|
||||
# pylint: disable=protected-access
|
||||
provider = OpenIDAuthProvider(
|
||||
hass,
|
||||
hass.auth._store,
|
||||
@@ -42,13 +46,14 @@ async def async_setup(hass: HomeAssistant, config):
|
||||
providers[(provider.type, provider.id)] = provider
|
||||
providers.update(hass.auth._providers)
|
||||
hass.auth._providers = providers
|
||||
# pylint: enable=protected-access
|
||||
|
||||
_LOGGER.debug("Added OIDC provider for Home Assistant")
|
||||
|
||||
# Define some fields
|
||||
discovery_url = config[DOMAIN]["discovery_url"]
|
||||
client_id = config[DOMAIN]["client_id"]
|
||||
scope = "openid profile email"
|
||||
discovery_url: str = config[DOMAIN]["discovery_url"]
|
||||
client_id: str = config[DOMAIN]["client_id"]
|
||||
scope: str = "openid profile email"
|
||||
|
||||
oidc_client = oidc_client = OIDCClient(discovery_url, client_id, scope)
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
from aiohttp import web
|
||||
"""Callback route to return the user to after external OIDC interaction."""
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
from aiohttp import web
|
||||
from ..oidc_client import OIDCClient
|
||||
from ..provider import OpenIDAuthProvider
|
||||
from ..helpers import get_url
|
||||
|
||||
PATH = "/auth/oidc/callback"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCCallbackView(HomeAssistantView):
|
||||
"""OIDC Plugin Callback View."""
|
||||
@@ -24,12 +25,9 @@ class OIDCCallbackView(HomeAssistantView):
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug("Callback view accessed")
|
||||
|
||||
params = request.rel_url.query
|
||||
code = params.get("code")
|
||||
state = params.get("state")
|
||||
base_uri = str(request.url).split('/auth', 2)[0]
|
||||
|
||||
if not (code and state):
|
||||
return web.Response(
|
||||
@@ -37,13 +35,16 @@ class OIDCCallbackView(HomeAssistantView):
|
||||
text="<h1>Error</h1><p>Missing code or state parameter</p>",
|
||||
)
|
||||
|
||||
user_details = await self.oidc_client.complete_token_flow(base_uri, code, state)
|
||||
redirect_uri = get_url("/auth/oidc/callback")
|
||||
user_details = await self.oidc_client.async_complete_token_flow(
|
||||
redirect_uri, code, state
|
||||
)
|
||||
if user_details is None:
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>Error</h1><p>Failed to get user details, see console.</p>",
|
||||
)
|
||||
|
||||
code = await self.oidc_provider.save_user_info(user_details)
|
||||
code = await self.oidc_provider.async_save_user_info(user_details)
|
||||
|
||||
return web.HTTPFound(base_uri + "/auth/oidc/finish?code=" + code)
|
||||
return web.HTTPFound(get_url("/auth/oidc/finish?code=" + code))
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
from aiohttp import web
|
||||
"""Finish route to allow the user to view their code."""
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
from aiohttp import web
|
||||
|
||||
from ..helpers import get_url
|
||||
|
||||
PATH = "/auth/oidc/finish"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCFinishView(HomeAssistantView):
|
||||
"""OIDC Plugin Finish View."""
|
||||
@@ -17,8 +19,20 @@ class OIDCFinishView(HomeAssistantView):
|
||||
"""Receive response."""
|
||||
|
||||
code = request.query.get("code", "FAIL")
|
||||
link = get_url("/")
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text=f"<h1>Done!</h1><p>Your code is: <b>{code}</b></p><p>Please return to the Home Assistant login screen (or your mobile app) and fill in this code into the single login field. It should be visible if you select 'Login with OpenID Connect (SSO)'.</p>",
|
||||
)
|
||||
headers={
|
||||
"content-type": "text/html",
|
||||
"set-cookie": "auth_oidc_code="
|
||||
+ code
|
||||
+ "; Path=/auth/login_flow; SameSite=Strict; HttpOnly; Max-Age=300",
|
||||
},
|
||||
text=f"<h1>Done!</h1><p>Your code is: <b>{code}</b></p>"
|
||||
+ "<p>Please return to the Home Assistant login "
|
||||
+ "screen (or your mobile app) and fill in this code into the single login field. "
|
||||
+ "It should be visible if you "
|
||||
+ "select 'Login with OpenID Connect (SSO)'.</p><p><a href='"
|
||||
+ link
|
||||
+ "'>Click here to login automatically (on desktop).</a></p>",
|
||||
)
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
"""Redirect route to redirect the user to the external OIDC server,
|
||||
can either be linked to directly or accessed through the welcome page."""
|
||||
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
|
||||
from ..oidc_client import OIDCClient
|
||||
from ..helpers import get_url
|
||||
|
||||
PATH = "/auth/oidc/redirect"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCRedirectView(HomeAssistantView):
|
||||
"""OIDC Plugin Redirect View."""
|
||||
@@ -15,32 +17,23 @@ class OIDCRedirectView(HomeAssistantView):
|
||||
url = PATH
|
||||
name = "auth:oidc:redirect"
|
||||
|
||||
def __init__(
|
||||
self, oidc_client: OIDCClient
|
||||
) -> None:
|
||||
def __init__(self, oidc_client: OIDCClient) -> None:
|
||||
self.oidc_client = oidc_client
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
async def get(self, _: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug("Redirect view accessed")
|
||||
|
||||
base_uri = str(request.url).split('/auth', 2)[0]
|
||||
_LOGGER.debug("Base URI: %s", base_uri)
|
||||
|
||||
auth_url = await self.oidc_client.get_authorization_url(base_uri)
|
||||
_LOGGER.debug("Auth URL: %s", auth_url)
|
||||
redirect_uri = get_url("/auth/oidc/callback")
|
||||
auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri)
|
||||
|
||||
if auth_url:
|
||||
return web.HTTPFound(auth_url)
|
||||
else:
|
||||
return web.Response(
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>Plugin is misconfigured, discovery could not be obtained</h1>",
|
||||
)
|
||||
|
||||
async def post(self, request: web.Request) -> web.Response:
|
||||
"""POST"""
|
||||
|
||||
_LOGGER.debug("Redirect POST view accessed")
|
||||
return await self.get(request)
|
||||
return await self.get(request)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
"""Welcome route to show the user the OIDC login button and give instructions."""
|
||||
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
|
||||
PATH = "/auth/oidc/welcome"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCWelcomeView(HomeAssistantView):
|
||||
"""OIDC Plugin Welcome View."""
|
||||
@@ -13,12 +13,10 @@ class OIDCWelcomeView(HomeAssistantView):
|
||||
url = PATH
|
||||
name = "auth:oidc:welcome"
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
async def get(self, _: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug("Welcome view accessed")
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>OIDC Login (beta)</h1><p><a href='/auth/oidc/redirect'>Login with OIDC</a></p>",
|
||||
)
|
||||
text="<h1>OIDC Login</h1><p><a href='/auth/oidc/redirect'>Login with OIDC</a></p>",
|
||||
)
|
||||
|
||||
12
custom_components/auth_oidc/helpers.py
Normal file
12
custom_components/auth_oidc/helpers.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Helper functions for the integration."""
|
||||
|
||||
from homeassistant.components import http
|
||||
|
||||
|
||||
def get_url(path: str) -> str:
|
||||
"""Returns the requested path appended to the current request base URL."""
|
||||
if (req := http.current_request.get()) is None:
|
||||
raise RuntimeError("No current request in context")
|
||||
|
||||
base_uri = str(req.url).split("/auth", 2)[0]
|
||||
return f"{base_uri}{path}"
|
||||
@@ -1,14 +1,20 @@
|
||||
{
|
||||
"domain": "auth_oidc",
|
||||
"name": "OIDC Authentication",
|
||||
"documentation": "",
|
||||
"requirements": [],
|
||||
"ssdp": [],
|
||||
"zeroconf": [],
|
||||
"homekit": {},
|
||||
"dependencies": [
|
||||
"auth"
|
||||
"codeowners": [
|
||||
"@christiaangoossens"
|
||||
],
|
||||
"codeowners": ["@christiaangoossens"],
|
||||
"version": "0.1"
|
||||
}
|
||||
"config_flow": false,
|
||||
"dependencies": [
|
||||
"auth",
|
||||
"http"
|
||||
],
|
||||
"documentation": "https://github.com/christiaangoossens/hass-oidc-auth",
|
||||
"integration_type": "service",
|
||||
"iot_class": "calculated",
|
||||
"issue_tracker": "https://github.com/christiaangoossens/hass-oidc-auth/issues",
|
||||
"requirements": [
|
||||
"python-jose>=3.3.0"
|
||||
],
|
||||
"version": "0.2.0"
|
||||
}
|
||||
@@ -1,25 +1,50 @@
|
||||
import aiohttp
|
||||
"""OIDC Client class"""
|
||||
|
||||
import urllib.parse
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
from jose import jwt
|
||||
|
||||
from jose import jwk, jwt
|
||||
from typing import Optional
|
||||
import aiohttp
|
||||
from jose import jwt, jwk
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OIDCClientException(Exception):
|
||||
"Raised when the OIDC Client encounters an error"
|
||||
|
||||
|
||||
class OIDCDiscoveryInvalid(OIDCClientException):
|
||||
"Raised when the discovery document is not found, invalid or otherwise malformed."
|
||||
|
||||
|
||||
class OIDCTokenResponseInvalid(OIDCClientException):
|
||||
"Raised when the token request returns invalid."
|
||||
|
||||
|
||||
class OIDCJWKSInvalid(OIDCClientException):
|
||||
"Raised when the JWKS is invalid or cannot be obtained."
|
||||
|
||||
|
||||
class OIDCStateInvalid(OIDCClientException):
|
||||
"Raised when the state for your request cannot be matched against a stored state."
|
||||
|
||||
|
||||
class OIDCClient:
|
||||
"""OIDC Client implementation for Python, including PKCE."""
|
||||
|
||||
# Flows stores the state, code_verifier and nonce of all current flows.
|
||||
flows = {}
|
||||
|
||||
def __init__(self, discovery_url, client_id, scope):
|
||||
def __init__(self, discovery_url: str, client_id: str, scope: str):
|
||||
self.discovery_url = discovery_url
|
||||
self.discovery_document = None
|
||||
self.client_id = client_id
|
||||
self.scope = scope
|
||||
|
||||
async def fetch_discovery_document(self):
|
||||
async def _fetch_discovery_document(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.discovery_url) as response:
|
||||
@@ -27,47 +52,60 @@ class OIDCClient:
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
if e.status == 404:
|
||||
_LOGGER.warning(f"Error: Discovery document not found at {self.discovery_url}")
|
||||
_LOGGER.warning(
|
||||
"Error: Discovery document not found at %s", self.discovery_url
|
||||
)
|
||||
else:
|
||||
_LOGGER.warning(f"Error: {e.status} - {e.message}")
|
||||
return None
|
||||
|
||||
async def get_authorization_url(self, base_uri):
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
_LOGGER.warning("Error: %s - %s", e.status, e.message)
|
||||
raise OIDCDiscoveryInvalid from e
|
||||
|
||||
if not self.discovery_document:
|
||||
async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]:
|
||||
"""Generates the authorization URL for the OIDC flow."""
|
||||
try:
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
auth_endpoint = self.discovery_document["authorization_endpoint"]
|
||||
|
||||
# Generate the necessary PKCE parameters, nonce & state
|
||||
code_verifier = (
|
||||
base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
code_challenge = (
|
||||
base64.urlsafe_b64encode(
|
||||
hashlib.sha256(code_verifier.encode("utf-8")).digest()
|
||||
)
|
||||
.rstrip(b"=")
|
||||
.decode("utf-8")
|
||||
)
|
||||
nonce = (
|
||||
base64.urlsafe_b64encode(os.urandom(16)).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
state = (
|
||||
base64.urlsafe_b64encode(os.urandom(16)).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
|
||||
# Save all of them for later verification
|
||||
self.flows[state] = {"code_verifier": code_verifier, "nonce": nonce}
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": redirect_uri,
|
||||
"scope": self.scope,
|
||||
"state": state,
|
||||
"nonce": nonce,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
}
|
||||
|
||||
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
|
||||
return url
|
||||
except OIDCClientException as e:
|
||||
_LOGGER.warning("Error generating authorization URL: %s", e)
|
||||
return None
|
||||
|
||||
auth_endpoint = self.discovery_document['authorization_endpoint']
|
||||
|
||||
# Generate the necessary PKCE parameters, nonce & state
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b'=').decode('utf-8')
|
||||
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8')
|
||||
nonce = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
||||
state = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
||||
|
||||
# Save all of them for later verification
|
||||
self.flows[state] = {
|
||||
'code_verifier': code_verifier,
|
||||
'nonce': nonce
|
||||
}
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': base_uri + '/auth/oidc/callback',
|
||||
'scope': self.scope,
|
||||
'state': state,
|
||||
'nonce': nonce,
|
||||
'code_challenge': code_challenge,
|
||||
'code_challenge_method': 'S256',
|
||||
}
|
||||
|
||||
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
|
||||
return url
|
||||
|
||||
async def _make_token_request(self, token_endpoint, query_params):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -75,12 +113,9 @@ class OIDCClient:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
response_json = await response.json()
|
||||
_LOGGER.warning(f"Error: {e.status} - {e.message}, Response: {response_json}")
|
||||
return None
|
||||
_LOGGER.warning("Error exchanging token: %s - %s", e.status, e.message)
|
||||
raise OIDCTokenResponseInvalid from e
|
||||
|
||||
return None
|
||||
|
||||
async def _get_jwks(self, jwks_uri):
|
||||
"""Fetches JWKS from the given URL."""
|
||||
try:
|
||||
@@ -89,23 +124,18 @@ class OIDCClient:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
_LOGGER.warning(f"Error fetching JWKS: {e.status} - {e.message}")
|
||||
return None
|
||||
_LOGGER.warning("Error fetching JWKS: %s - %s", e.status, e.message)
|
||||
raise OIDCJWKSInvalid from e
|
||||
|
||||
async def _parse_id_token(self, id_token: str):
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
async def _parse_id_token(self, id_token):
|
||||
# Parse the id token to obtain the relevant details
|
||||
# Use python-jose
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
jwks_uri = self.discovery_document['jwks_uri']
|
||||
|
||||
jwks_uri = self.discovery_document["jwks_uri"]
|
||||
jwks_data = await self._get_jwks(jwks_uri)
|
||||
if not jwks_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
unverified_header = jwt.get_unverified_header(id_token)
|
||||
@@ -113,12 +143,11 @@ class OIDCClient:
|
||||
print("Could not parse JWT Header")
|
||||
return None
|
||||
|
||||
kid = unverified_header.get('kid')
|
||||
kid = unverified_header.get("kid")
|
||||
if not kid:
|
||||
print("JWT does not have kid (Key ID)")
|
||||
return None
|
||||
|
||||
|
||||
# Get the correct key
|
||||
rsa_key = None
|
||||
for key in jwks_data["keys"]:
|
||||
@@ -139,66 +168,60 @@ class OIDCClient:
|
||||
jwk_obj,
|
||||
algorithms=["RS256"], # Adjust if your algorithm is different
|
||||
audience=self.client_id,
|
||||
issuer=self.discovery_document['issuer'],
|
||||
issuer=self.discovery_document["issuer"],
|
||||
)
|
||||
return decoded_token
|
||||
|
||||
except jwt.JWTError as e:
|
||||
print(f"JWT Verification failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
|
||||
|
||||
return None
|
||||
|
||||
async def complete_token_flow(self, base_uri, code, state):
|
||||
if state not in self.flows:
|
||||
|
||||
async def async_complete_token_flow(
|
||||
self, redirect_uri: str, code: str, state: str
|
||||
) -> dict[str, str | dict]:
|
||||
"""Completes the OIDC token flow to obtain a user's details."""
|
||||
|
||||
try:
|
||||
if state not in self.flows:
|
||||
raise OIDCStateInvalid
|
||||
|
||||
flow = self.flows[state]
|
||||
code_verifier = flow["code_verifier"]
|
||||
|
||||
if self.discovery_document is None:
|
||||
self.discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
token_endpoint = self.discovery_document["token_endpoint"]
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": self.client_id,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"code_verifier": code_verifier,
|
||||
}
|
||||
|
||||
token_response = await self._make_token_request(
|
||||
token_endpoint, query_params
|
||||
)
|
||||
id_token = token_response.get("id_token")
|
||||
|
||||
# Parse the id token to obtain the relevant details
|
||||
id_token = await self._parse_id_token(id_token)
|
||||
|
||||
# Verify nonce
|
||||
if id_token.get("nonce") != flow["nonce"]:
|
||||
_LOGGER.warning("Nonce mismatch!")
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": id_token.get("name"),
|
||||
"username": id_token.get("preferred_username"),
|
||||
"groups": id_token.get("groups"),
|
||||
}
|
||||
except OIDCClientException as e:
|
||||
_LOGGER.warning("Error completing token flow: %s", e)
|
||||
return None
|
||||
|
||||
flow = self.flows[state]
|
||||
code_verifier = flow['code_verifier']
|
||||
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
token_endpoint = self.discovery_document['token_endpoint']
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.client_id,
|
||||
'code': code,
|
||||
'redirect_uri': base_uri + '/auth/oidc/callback',
|
||||
'code_verifier': code_verifier,
|
||||
}
|
||||
|
||||
_LOGGER.debug(f"Token request params: {query_params}")
|
||||
|
||||
token_response = await self._make_token_request(token_endpoint, query_params)
|
||||
|
||||
if not token_response:
|
||||
return None
|
||||
|
||||
access_token = token_response.get('access_token')
|
||||
id_token = token_response.get('id_token')
|
||||
_LOGGER.debug(f"Access Token: {access_token}")
|
||||
_LOGGER.debug(f"ID Token: {id_token}")
|
||||
|
||||
# Parse the id token to obtain the relevant details
|
||||
id_token = await self._parse_id_token(id_token)
|
||||
|
||||
# Verify nonce
|
||||
if id_token.get('nonce') != flow['nonce']:
|
||||
_LOGGER.warning(f"Nonce mismatch!")
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": id_token.get("name"),
|
||||
"email": id_token.get("email"),
|
||||
"preferred_username": id_token.get("preferred_username"),
|
||||
"nickname": id_token.get("nickname"),
|
||||
"groups": id_token.get("groups"),
|
||||
}
|
||||
@@ -1,8 +1,11 @@
|
||||
"""OIDC Authentication provider.
|
||||
Allow access to users based on login with an external OpenID Connect Identity Provider (IdP).
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from typing import Dict, Optional
|
||||
import asyncio
|
||||
from homeassistant.auth.providers import (
|
||||
AUTH_PROVIDERS,
|
||||
AuthProvider,
|
||||
@@ -11,30 +14,26 @@ from homeassistant.auth.providers import (
|
||||
Credentials,
|
||||
UserMeta,
|
||||
)
|
||||
from homeassistant.components import http
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import voluptuous as vol
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
import string
|
||||
from homeassistant.helpers.storage import Store
|
||||
from collections.abc import Mapping
|
||||
|
||||
from .stores.code_store import CodeStore
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvalidAuthError(HomeAssistantError):
|
||||
"""Raised when submitting invalid authentication."""
|
||||
|
||||
|
||||
@AUTH_PROVIDERS.register("oidc")
|
||||
class OpenIDAuthProvider(AuthProvider):
|
||||
"""Allow access to users based on login with an external OpenID Connect Identity Provider (IdP)."""
|
||||
"""Allow access to users based on login with an external
|
||||
OpenID Connect Identity Provider (IdP)."""
|
||||
|
||||
DEFAULT_TITLE = "OpenID Connect (SSO)"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the OpenIDAuthProvider."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._user_meta = {}
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return "auth_oidc"
|
||||
@@ -42,13 +41,62 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
@property
|
||||
def support_mfa(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the OpenIDAuthProvider."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._user_meta = {}
|
||||
self._code_store: CodeStore | None = None
|
||||
self._init_lock = asyncio.Lock()
|
||||
|
||||
async def async_initialize(self) -> None:
|
||||
"""Initialize the auth provider."""
|
||||
|
||||
# Init the code store first
|
||||
# Use the same technique as the HomeAssistant auth provider for storage
|
||||
# (/auth/providers/homeassistant.py#L392)
|
||||
async with self._init_lock:
|
||||
if self._code_store is not None:
|
||||
return
|
||||
|
||||
store = CodeStore(self.hass)
|
||||
await store.async_load()
|
||||
self._code_store = store
|
||||
self._user_meta = {}
|
||||
|
||||
async def async_retrieve_username(self, code: str) -> Optional[str]:
|
||||
"""Retrieve user from the code, return username and save meta
|
||||
for later use with this provider instance."""
|
||||
if self._code_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._code_store is not None
|
||||
|
||||
user_data = await self._code_store.receive_userinfo_for_code(code)
|
||||
if user_data is None:
|
||||
return None
|
||||
|
||||
username = user_data["username"]
|
||||
self._user_meta[username] = user_data
|
||||
return username
|
||||
|
||||
async def async_save_user_info(self, user_info: dict[str, dict | str]) -> str:
|
||||
"""Save user info and return a code."""
|
||||
if self._code_store is None:
|
||||
await self.async_initialize()
|
||||
assert self._code_store is not None
|
||||
|
||||
return await self._code_store.async_generate_code_for_userinfo(user_info)
|
||||
|
||||
# ====
|
||||
# Required functions for Home Assistant Auth Providers
|
||||
# ====
|
||||
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
return OpenIdLoginFlow(self)
|
||||
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Mapping[str, str]
|
||||
self, flow_result: dict[str, str]
|
||||
) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
username = flow_result["username"]
|
||||
@@ -64,7 +112,7 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
) -> UserMeta:
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Currently, supports name, group and local_only.
|
||||
Currently, supports name, is_active, group and local_only.
|
||||
"""
|
||||
meta = self._user_meta.get(credentials.data["username"], {})
|
||||
groups = meta.get("groups", [])
|
||||
@@ -76,96 +124,70 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
group=group,
|
||||
local_only="true",
|
||||
)
|
||||
|
||||
async def save_user_info(self, user_info: dict) -> str:
|
||||
"""Save user info during login."""
|
||||
_LOGGER.info("User info to be saved: %s", user_info)
|
||||
|
||||
code = self._generate_code()
|
||||
expiration = datetime.utcnow() + timedelta(minutes=5)
|
||||
user_data = {
|
||||
"user_info": user_info,
|
||||
"code": code,
|
||||
"expiration": expiration.isoformat()
|
||||
}
|
||||
|
||||
await self._save_to_db(self._get_code_key(code), user_data)
|
||||
return code
|
||||
|
||||
async def async_retrieve_username(self, code: str) -> Optional[dict]:
|
||||
"""Retrieve user info based on the code."""
|
||||
user_data = await self._get_from_db(self._get_code_key(code))
|
||||
await self._wipe_from_db(self._get_code_key(code))
|
||||
|
||||
if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow():
|
||||
username = user_data["user_info"]["preferred_username"]
|
||||
self._user_meta[username] = user_data["user_info"]
|
||||
return username
|
||||
return None
|
||||
|
||||
def _generate_code(self) -> str:
|
||||
"""Generate a random six-digit code."""
|
||||
return ''.join(random.choices(string.digits, k=6))
|
||||
|
||||
def _get_code_key(self, code: str) -> str:
|
||||
return f"provider_oidc_auth_user_{code}"
|
||||
|
||||
async def _save_to_db(self, key: str, value: dict) -> None:
|
||||
"""Save key-value data to the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
await store.async_save(value)
|
||||
|
||||
async def _get_from_db(self, key: str) -> Optional[dict]:
|
||||
"""Retrieve key-value data from the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
return await store.async_load()
|
||||
|
||||
async def _wipe_from_db(self, key: str) -> None:
|
||||
"""Delete key-value data from the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
return await store.async_remove()
|
||||
|
||||
|
||||
class OpenIdLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
async def _finalize_user(self, code: str) -> AuthFlowResult:
|
||||
username = await self._auth_provider.async_retrieve_username(code)
|
||||
if username:
|
||||
_LOGGER.info("Logged in user: %s", username)
|
||||
return await self.async_finish(
|
||||
{
|
||||
"username": username,
|
||||
}
|
||||
)
|
||||
|
||||
raise InvalidAuthError
|
||||
|
||||
def _show_login_form(
|
||||
self, errors: Optional[dict[str, str]] = None
|
||||
) -> AuthFlowResult:
|
||||
if errors is None:
|
||||
errors = {}
|
||||
|
||||
# Show the login form
|
||||
# Abuses the MFA form, as it works better for our usecase
|
||||
# UI suggestions are welcome (make a PR!)
|
||||
return self.async_show_form(
|
||||
step_id="mfa",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required("code"): str,
|
||||
}
|
||||
),
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
|
||||
# Show the login form
|
||||
# Currently, this form looks bad because the frontend gives no options to make it look better
|
||||
# We will investigate options to make it look better in the future
|
||||
return self.async_show_form(
|
||||
step_id="mfa",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required("code"): str,
|
||||
}
|
||||
),
|
||||
errors={},
|
||||
)
|
||||
|
||||
# Try to use the user input first
|
||||
if user_input is not None:
|
||||
try:
|
||||
return await self._finalize_user(user_input["code"])
|
||||
except InvalidAuthError:
|
||||
return self._show_login_form({"base": "invalid_auth"})
|
||||
|
||||
# If not available, check the cookie
|
||||
req = http.current_request.get()
|
||||
code_cookie = req.cookies.get("auth_oidc_code")
|
||||
|
||||
if code_cookie:
|
||||
_LOGGER.debug("Code cookie found on login: %s", code_cookie)
|
||||
try:
|
||||
return await self._finalize_user(code_cookie)
|
||||
except InvalidAuthError:
|
||||
pass
|
||||
|
||||
# If none are available, just show the form
|
||||
return self._show_login_form()
|
||||
|
||||
async def async_step_mfa(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the result of the form."""
|
||||
|
||||
if user_input is None:
|
||||
return self.async_abort(reason="no_code_given")
|
||||
|
||||
# Log
|
||||
_LOGGER.info("User input %s", user_input)
|
||||
_LOGGER.info("Code %s was entered", user_input["code"])
|
||||
|
||||
username = await self._auth_provider.async_retrieve_username(user_input["code"])
|
||||
if username:
|
||||
_LOGGER.info("Logged in user: %s", username)
|
||||
|
||||
return await self.async_finish({
|
||||
"username": username,
|
||||
})
|
||||
|
||||
return self.async_abort(reason="invalid_code")
|
||||
# This is a dummy step function just to use the nicer MFA UI instead
|
||||
return await self.async_step_init(user_input)
|
||||
|
||||
80
custom_components/auth_oidc/stores/code_store.py
Normal file
80
custom_components/auth_oidc/stores/code_store.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Code Store, stores the codes and their associated authenticated user temporarily."""
|
||||
|
||||
import random
|
||||
import string
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import cast, Optional
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
STORAGE_VERSION = 1
|
||||
STORAGE_KEY = "auth_provider.auth_oidc.codes"
|
||||
|
||||
|
||||
class CodeStore:
|
||||
"""Holds the codes and associated data"""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize the user data store."""
|
||||
self.hass = hass
|
||||
self._store = Store[dict[str, dict[str, dict | str]]](
|
||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||
)
|
||||
self._data: dict[str, dict[str, dict | str]] | None = None
|
||||
|
||||
async def async_load(self) -> None:
|
||||
"""Load stored data."""
|
||||
if (data := await self._store.async_load()) is None:
|
||||
data = cast(dict[str, dict[str, dict | str]], {})
|
||||
self._data = data
|
||||
|
||||
async def async_save(self) -> None:
|
||||
"""Save data."""
|
||||
if self._data is not None:
|
||||
await self._store.async_save(self._data)
|
||||
|
||||
def _generate_code(self) -> str:
|
||||
"""Generate a random six-digit code."""
|
||||
return "".join(random.choices(string.digits, k=6))
|
||||
|
||||
async def async_generate_code_for_userinfo(
|
||||
self, user_info: dict[str, dict | str]
|
||||
) -> str:
|
||||
"""Generates a one time code and adds it to the database for 5 minutes."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
code = self._generate_code()
|
||||
expiration = datetime.utcnow() + timedelta(minutes=5)
|
||||
|
||||
self._data[code] = {
|
||||
"user_info": user_info,
|
||||
"code": code,
|
||||
"expiration": expiration.isoformat(),
|
||||
}
|
||||
|
||||
await self.async_save()
|
||||
return code
|
||||
|
||||
async def receive_userinfo_for_code(
|
||||
self, code: str
|
||||
) -> Optional[dict[str, dict | str]]:
|
||||
"""Retrieve user info based on the code."""
|
||||
if self._data is None:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
user_data = self._data.get(code)
|
||||
|
||||
if user_data:
|
||||
# We should now wipe it from the database, as it's one time use code
|
||||
self._data.pop(code)
|
||||
await self.async_save()
|
||||
|
||||
if (
|
||||
user_data
|
||||
and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow()
|
||||
):
|
||||
return user_data["user_info"]
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user