Implement initial flow (#2)
This commit is contained in:
committed by
GitHub
parent
1c8c7ed14a
commit
8ba494c49c
@@ -4,6 +4,13 @@ from typing import OrderedDict
|
||||
import voluptuous as vol
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .endpoints.welcome import OIDCWelcomeView
|
||||
from .endpoints.redirect import OIDCRedirectView
|
||||
from .endpoints.finish import OIDCFinishView
|
||||
from .endpoints.callback import OIDCCallbackView
|
||||
|
||||
from .oidc_client import OIDCClient
|
||||
|
||||
DOMAIN = "auth_oidc"
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -13,7 +20,9 @@ CONFIG_SCHEMA = vol.Schema(
|
||||
{
|
||||
DOMAIN: vol.Schema(
|
||||
{
|
||||
|
||||
vol.Required("client_id"): vol.Coerce(str),
|
||||
vol.Optional("client_secret"): vol.Coerce(str),
|
||||
vol.Required("discovery_url"): vol.Url(),
|
||||
}
|
||||
)
|
||||
},
|
||||
@@ -34,5 +43,18 @@ async def async_setup(hass: HomeAssistant, config):
|
||||
providers.update(hass.auth._providers)
|
||||
hass.auth._providers = providers
|
||||
|
||||
_LOGGER.debug("Added OIDC provider")
|
||||
_LOGGER.debug("Added OIDC provider for Home Assistant")
|
||||
|
||||
# Define some fields
|
||||
discovery_url = config[DOMAIN]["discovery_url"]
|
||||
client_id = config[DOMAIN]["client_id"]
|
||||
scope = "openid profile email"
|
||||
|
||||
oidc_client = oidc_client = OIDCClient(discovery_url, client_id, scope)
|
||||
|
||||
hass.http.register_view(OIDCWelcomeView())
|
||||
hass.http.register_view(OIDCRedirectView(oidc_client))
|
||||
hass.http.register_view(OIDCCallbackView(oidc_client, provider))
|
||||
hass.http.register_view(OIDCFinishView())
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
import logging
|
||||
|
||||
DATA_VIEW_REGISTERED = "oauth2_view_reg"
|
||||
AUTH_CALLBACK_PATH = "/auth/oidc/callback"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@callback
|
||||
def async_register_view(hass: HomeAssistant) -> None:
|
||||
"""Make sure callback view is registered."""
|
||||
if not hass.data.get(DATA_VIEW_REGISTERED, False):
|
||||
hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore
|
||||
hass.data[DATA_VIEW_REGISTERED] = True
|
||||
|
||||
|
||||
class OAuth2AuthorizeCallbackView(HomeAssistantView):
|
||||
"""OAuth2 Authorization Callback View."""
|
||||
|
||||
requires_auth = False
|
||||
url = AUTH_CALLBACK_PATH
|
||||
name = "auth:oidc:callback"
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug(request.query)
|
||||
|
||||
hass = request.app["hass"]
|
||||
flow_mgr = hass.auth.login_flow
|
||||
|
||||
await flow_mgr.async_configure(
|
||||
flow_id=request.query["flow_id"], user_input=request.query["test"]
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<script>if (window.opener) { window.opener.postMessage({type: 'externalCallback'}); } window.close();</script>",
|
||||
)
|
||||
49
custom_components/auth_oidc/endpoints/callback.py
Normal file
49
custom_components/auth_oidc/endpoints/callback.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
from ..oidc_client import OIDCClient
|
||||
from ..provider import OpenIDAuthProvider
|
||||
|
||||
PATH = "/auth/oidc/callback"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCCallbackView(HomeAssistantView):
|
||||
"""OIDC Plugin Callback View."""
|
||||
|
||||
requires_auth = False
|
||||
url = PATH
|
||||
name = "auth:oidc:callback"
|
||||
|
||||
def __init__(
|
||||
self, oidc_client: OIDCClient, oidc_provider: OpenIDAuthProvider
|
||||
) -> None:
|
||||
self.oidc_client = oidc_client
|
||||
self.oidc_provider = oidc_provider
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug("Callback view accessed")
|
||||
|
||||
params = request.rel_url.query
|
||||
code = params.get("code")
|
||||
state = params.get("state")
|
||||
base_uri = str(request.url).split('/auth', 2)[0]
|
||||
|
||||
if not (code and state):
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>Error</h1><p>Missing code or state parameter</p>",
|
||||
)
|
||||
|
||||
user_details = await self.oidc_client.complete_token_flow(base_uri, code, state)
|
||||
if user_details is None:
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>Error</h1><p>Failed to get user details, see console.</p>",
|
||||
)
|
||||
|
||||
code = await self.oidc_provider.save_user_info(user_details)
|
||||
|
||||
return web.HTTPFound(base_uri + "/auth/oidc/finish?code=" + code)
|
||||
24
custom_components/auth_oidc/endpoints/finish.py
Normal file
24
custom_components/auth_oidc/endpoints/finish.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
|
||||
PATH = "/auth/oidc/finish"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCFinishView(HomeAssistantView):
|
||||
"""OIDC Plugin Finish View."""
|
||||
|
||||
requires_auth = False
|
||||
url = PATH
|
||||
name = "auth:oidc:finish"
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
code = request.query.get("code", "FAIL")
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text=f"<h1>Done!</h1><p>Your code is: <b>{code}</b></p><p>Please return to the Home Assistant login screen (or your mobile app) and fill in this code into the single login field. It should be visible if you select 'Login with OpenID Connect (SSO)'.</p>",
|
||||
)
|
||||
46
custom_components/auth_oidc/endpoints/redirect.py
Normal file
46
custom_components/auth_oidc/endpoints/redirect.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
|
||||
from ..oidc_client import OIDCClient
|
||||
|
||||
PATH = "/auth/oidc/redirect"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCRedirectView(HomeAssistantView):
|
||||
"""OIDC Plugin Redirect View."""
|
||||
|
||||
requires_auth = False
|
||||
url = PATH
|
||||
name = "auth:oidc:redirect"
|
||||
|
||||
def __init__(
|
||||
self, oidc_client: OIDCClient
|
||||
) -> None:
|
||||
self.oidc_client = oidc_client
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug("Redirect view accessed")
|
||||
|
||||
base_uri = str(request.url).split('/auth', 2)[0]
|
||||
_LOGGER.debug("Base URI: %s", base_uri)
|
||||
|
||||
auth_url = await self.oidc_client.get_authorization_url(base_uri)
|
||||
_LOGGER.debug("Auth URL: %s", auth_url)
|
||||
|
||||
if auth_url:
|
||||
return web.HTTPFound(auth_url)
|
||||
else:
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>Plugin is misconfigured, discovery could not be obtained</h1>",
|
||||
)
|
||||
|
||||
async def post(self, request: web.Request) -> web.Response:
|
||||
"""POST"""
|
||||
|
||||
_LOGGER.debug("Redirect POST view accessed")
|
||||
return await self.get(request)
|
||||
24
custom_components/auth_oidc/endpoints/welcome.py
Normal file
24
custom_components/auth_oidc/endpoints/welcome.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from aiohttp import web
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
import logging
|
||||
|
||||
PATH = "/auth/oidc/welcome"
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCWelcomeView(HomeAssistantView):
|
||||
"""OIDC Plugin Welcome View."""
|
||||
|
||||
requires_auth = False
|
||||
url = PATH
|
||||
name = "auth:oidc:welcome"
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive response."""
|
||||
|
||||
_LOGGER.debug("Welcome view accessed")
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<h1>OIDC Login (beta)</h1><p><a href='/auth/oidc/redirect'>Login with OIDC</a></p>",
|
||||
)
|
||||
204
custom_components/auth_oidc/oidc_client.py
Normal file
204
custom_components/auth_oidc/oidc_client.py
Normal file
@@ -0,0 +1,204 @@
|
||||
import aiohttp
|
||||
|
||||
import urllib.parse
|
||||
import logging
|
||||
import os
|
||||
import base64
|
||||
import hashlib
|
||||
from jose import jwt
|
||||
|
||||
from jose import jwk, jwt
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
class OIDCClient:
|
||||
flows = {}
|
||||
|
||||
def __init__(self, discovery_url, client_id, scope):
|
||||
self.discovery_url = discovery_url
|
||||
self.client_id = client_id
|
||||
self.scope = scope
|
||||
|
||||
async def fetch_discovery_document(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(self.discovery_url) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
if e.status == 404:
|
||||
_LOGGER.warning(f"Error: Discovery document not found at {self.discovery_url}")
|
||||
else:
|
||||
_LOGGER.warning(f"Error: {e.status} - {e.message}")
|
||||
return None
|
||||
|
||||
async def get_authorization_url(self, base_uri):
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
auth_endpoint = self.discovery_document['authorization_endpoint']
|
||||
|
||||
# Generate the necessary PKCE parameters, nonce & state
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b'=').decode('utf-8')
|
||||
code_challenge = base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode('utf-8')).digest()).rstrip(b'=').decode('utf-8')
|
||||
nonce = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
||||
state = base64.urlsafe_b64encode(os.urandom(16)).rstrip(b'=').decode('utf-8')
|
||||
|
||||
# Save all of them for later verification
|
||||
self.flows[state] = {
|
||||
'code_verifier': code_verifier,
|
||||
'nonce': nonce
|
||||
}
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
'response_type': 'code',
|
||||
'client_id': self.client_id,
|
||||
'redirect_uri': base_uri + '/auth/oidc/callback',
|
||||
'scope': self.scope,
|
||||
'state': state,
|
||||
'nonce': nonce,
|
||||
'code_challenge': code_challenge,
|
||||
'code_challenge_method': 'S256',
|
||||
}
|
||||
|
||||
url = f"{auth_endpoint}?{urllib.parse.urlencode(query_params)}"
|
||||
return url
|
||||
|
||||
async def _make_token_request(self, token_endpoint, query_params):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(token_endpoint, data=query_params) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
response_json = await response.json()
|
||||
_LOGGER.warning(f"Error: {e.status} - {e.message}, Response: {response_json}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
||||
async def _get_jwks(self, jwks_uri):
|
||||
"""Fetches JWKS from the given URL."""
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(jwks_uri) as response:
|
||||
response.raise_for_status()
|
||||
return await response.json()
|
||||
except aiohttp.ClientResponseError as e:
|
||||
_LOGGER.warning(f"Error fetching JWKS: {e.status} - {e.message}")
|
||||
return None
|
||||
|
||||
async def _parse_id_token(self, id_token):
|
||||
# Parse the id token to obtain the relevant details
|
||||
# Use python-jose
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
jwks_uri = self.discovery_document['jwks_uri']
|
||||
|
||||
jwks_data = await self._get_jwks(jwks_uri)
|
||||
if not jwks_data:
|
||||
return None
|
||||
|
||||
try:
|
||||
unverified_header = jwt.get_unverified_header(id_token)
|
||||
if not unverified_header:
|
||||
print("Could not parse JWT Header")
|
||||
return None
|
||||
|
||||
kid = unverified_header.get('kid')
|
||||
if not kid:
|
||||
print("JWT does not have kid (Key ID)")
|
||||
return None
|
||||
|
||||
|
||||
# Get the correct key
|
||||
rsa_key = None
|
||||
for key in jwks_data["keys"]:
|
||||
if key["kid"] == kid:
|
||||
rsa_key = key
|
||||
break
|
||||
|
||||
if not rsa_key:
|
||||
print(f"Could not find matching key with kid:{kid}")
|
||||
return None
|
||||
|
||||
# Construct the JWK
|
||||
jwk_obj = jwk.construct(rsa_key)
|
||||
|
||||
# Verify the token
|
||||
decoded_token = jwt.decode(
|
||||
id_token,
|
||||
jwk_obj,
|
||||
algorithms=["RS256"], # Adjust if your algorithm is different
|
||||
audience=self.client_id,
|
||||
issuer=self.discovery_document['issuer'],
|
||||
)
|
||||
return decoded_token
|
||||
|
||||
except jwt.JWTError as e:
|
||||
print(f"JWT Verification failed: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Unexpected error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def complete_token_flow(self, base_uri, code, state):
|
||||
if state not in self.flows:
|
||||
return None
|
||||
|
||||
flow = self.flows[state]
|
||||
code_verifier = flow['code_verifier']
|
||||
|
||||
if not hasattr(self, 'discovery_document'):
|
||||
self.discovery_document = await self.fetch_discovery_document()
|
||||
|
||||
if not self.discovery_document:
|
||||
return None
|
||||
|
||||
token_endpoint = self.discovery_document['token_endpoint']
|
||||
|
||||
# Construct the params
|
||||
query_params = {
|
||||
'grant_type': 'authorization_code',
|
||||
'client_id': self.client_id,
|
||||
'code': code,
|
||||
'redirect_uri': base_uri + '/auth/oidc/callback',
|
||||
'code_verifier': code_verifier,
|
||||
}
|
||||
|
||||
_LOGGER.debug(f"Token request params: {query_params}")
|
||||
|
||||
token_response = await self._make_token_request(token_endpoint, query_params)
|
||||
|
||||
if not token_response:
|
||||
return None
|
||||
|
||||
access_token = token_response.get('access_token')
|
||||
id_token = token_response.get('id_token')
|
||||
_LOGGER.debug(f"Access Token: {access_token}")
|
||||
_LOGGER.debug(f"ID Token: {id_token}")
|
||||
|
||||
# Parse the id token to obtain the relevant details
|
||||
id_token = await self._parse_id_token(id_token)
|
||||
|
||||
# Verify nonce
|
||||
if id_token.get('nonce') != flow['nonce']:
|
||||
_LOGGER.warning(f"Nonce mismatch!")
|
||||
return None
|
||||
|
||||
return {
|
||||
"name": id_token.get("name"),
|
||||
"email": id_token.get("email"),
|
||||
"preferred_username": id_token.get("preferred_username"),
|
||||
"nickname": id_token.get("nickname"),
|
||||
"groups": id_token.get("groups"),
|
||||
}
|
||||
@@ -2,18 +2,22 @@
|
||||
Allow access to users based on login with an external OpenID Connect Identity Provider (IdP).
|
||||
"""
|
||||
import logging
|
||||
from secrets import token_hex
|
||||
from typing import Any, Dict, Optional, cast
|
||||
from typing import Dict, Optional
|
||||
from homeassistant.auth.providers import (
|
||||
AUTH_PROVIDERS,
|
||||
AuthProvider,
|
||||
LoginFlow,
|
||||
AuthFlowResult,
|
||||
Credentials,
|
||||
UserMeta,
|
||||
)
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
import voluptuous as vol
|
||||
from homeassistant.helpers.network import get_url
|
||||
|
||||
from .callback import async_register_view, AUTH_CALLBACK_PATH
|
||||
from datetime import datetime, timedelta
|
||||
import random
|
||||
import string
|
||||
from homeassistant.helpers.storage import Store
|
||||
from collections.abc import Mapping
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@@ -24,7 +28,12 @@ class InvalidAuthError(HomeAssistantError):
|
||||
class OpenIDAuthProvider(AuthProvider):
|
||||
"""Allow access to users based on login with an external OpenID Connect Identity Provider (IdP)."""
|
||||
|
||||
DEFAULT_TITLE = "OpenID Connect"
|
||||
DEFAULT_TITLE = "OpenID Connect (SSO)"
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""Initialize the OpenIDAuthProvider."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._user_meta = {}
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
@@ -33,43 +42,130 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
@property
|
||||
def support_mfa(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def async_login_flow(self, context: Optional[Dict]) -> LoginFlow:
|
||||
"""Return a flow to login."""
|
||||
|
||||
async_register_view(self.hass)
|
||||
return OpenIdLoginFlow(self)
|
||||
|
||||
async def async_get_or_create_credentials(
|
||||
self, flow_result: Mapping[str, str]
|
||||
) -> Credentials:
|
||||
"""Get credentials based on the flow result."""
|
||||
username = flow_result["username"]
|
||||
for credential in await self.async_credentials():
|
||||
if credential.data["username"] == username:
|
||||
return credential
|
||||
|
||||
# Create new credentials.
|
||||
return self.async_create_credentials({"username": username})
|
||||
|
||||
async def async_user_meta_for_credentials(
|
||||
self, credentials: Credentials
|
||||
) -> UserMeta:
|
||||
"""Return extra user metadata for credentials.
|
||||
|
||||
Currently, supports name, group and local_only.
|
||||
"""
|
||||
meta = self._user_meta.get(credentials.data["username"], {})
|
||||
groups = meta.get("groups", [])
|
||||
|
||||
group = "system-admin" if "admins" in groups else "system-users"
|
||||
return UserMeta(
|
||||
name=meta.get("name"),
|
||||
is_active=True,
|
||||
group=group,
|
||||
local_only="true",
|
||||
)
|
||||
|
||||
async def save_user_info(self, user_info: dict) -> str:
|
||||
"""Save user info during login."""
|
||||
_LOGGER.info("User info to be saved: %s", user_info)
|
||||
|
||||
code = self._generate_code()
|
||||
expiration = datetime.utcnow() + timedelta(minutes=5)
|
||||
user_data = {
|
||||
"user_info": user_info,
|
||||
"code": code,
|
||||
"expiration": expiration.isoformat()
|
||||
}
|
||||
|
||||
await self._save_to_db(self._get_code_key(code), user_data)
|
||||
return code
|
||||
|
||||
async def async_retrieve_username(self, code: str) -> Optional[dict]:
|
||||
"""Retrieve user info based on the code."""
|
||||
user_data = await self._get_from_db(self._get_code_key(code))
|
||||
await self._wipe_from_db(self._get_code_key(code))
|
||||
|
||||
if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow():
|
||||
username = user_data["user_info"]["preferred_username"]
|
||||
self._user_meta[username] = user_data["user_info"]
|
||||
return username
|
||||
return None
|
||||
|
||||
def _generate_code(self) -> str:
|
||||
"""Generate a random six-digit code."""
|
||||
return ''.join(random.choices(string.digits, k=6))
|
||||
|
||||
def _get_code_key(self, code: str) -> str:
|
||||
return f"provider_oidc_auth_user_{code}"
|
||||
|
||||
async def _save_to_db(self, key: str, value: dict) -> None:
|
||||
"""Save key-value data to the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
await store.async_save(value)
|
||||
|
||||
async def _get_from_db(self, key: str) -> Optional[dict]:
|
||||
"""Retrieve key-value data from the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
return await store.async_load()
|
||||
|
||||
async def _wipe_from_db(self, key: str) -> None:
|
||||
"""Delete key-value data from the Home Assistant storage."""
|
||||
store = Store(self.hass, 1, key)
|
||||
return await store.async_remove()
|
||||
|
||||
|
||||
class OpenIdLoginFlow(LoginFlow):
|
||||
"""Handler for the login flow."""
|
||||
|
||||
external_data: Any
|
||||
|
||||
async def async_step_init(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the step of the form."""
|
||||
return await self.async_step_authenticate()
|
||||
|
||||
def redirect_uri(self) -> str:
|
||||
"""Return the redirect uri."""
|
||||
return f"{get_url(self.hass, allow_external=True, require_current_request=True)}{AUTH_CALLBACK_PATH}?test=value&flow_id={self.flow_id}"
|
||||
# Show the login form
|
||||
# Currently, this form looks bad because the frontend gives no options to make it look better
|
||||
# We will investigate options to make it look better in the future
|
||||
return self.async_show_form(
|
||||
step_id="mfa",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required("code"): str,
|
||||
}
|
||||
),
|
||||
errors={},
|
||||
)
|
||||
|
||||
|
||||
async def async_step_authenticate(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Authenticate user using external step."""
|
||||
async def async_step_mfa(
|
||||
self, user_input: dict[str, str] | None = None
|
||||
) -> AuthFlowResult:
|
||||
"""Handle the result of the form."""
|
||||
|
||||
if user_input:
|
||||
self.external_data = str(user_input)
|
||||
return self.async_external_step_done(next_step_id="authorize")
|
||||
if user_input is None:
|
||||
return self.async_abort(reason="no_code_given")
|
||||
|
||||
return self.async_external_step(step_id="authenticate", url=self.redirect_uri())
|
||||
# Log
|
||||
_LOGGER.info("User input %s", user_input)
|
||||
_LOGGER.info("Code %s was entered", user_input["code"])
|
||||
|
||||
async def async_step_authorize(
|
||||
self, user_input: Optional[Dict[str, str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Authorize user received from external step."""
|
||||
_LOGGER.debug(self.external_data)
|
||||
return self.async_abort(reason="invalid_auth")
|
||||
username = await self._auth_provider.async_retrieve_username(user_input["code"])
|
||||
if username:
|
||||
_LOGGER.info("Logged in user: %s", username)
|
||||
|
||||
return await self.async_finish({
|
||||
"username": username,
|
||||
})
|
||||
|
||||
return self.async_abort(reason="invalid_code")
|
||||
Reference in New Issue
Block a user