Add unit tests (#133)
* Add initial test & add pipeline * Add very basic YAML config tests * Add coverage reporting * Add some webserver & template loading tests * Add test cases for the helpers * Implement initial OIDC server tests * Test codestore & discovery checker * Test basics of the config flow * Add test for the HA auth provider * Cleaned up tests & test injection
This commit is contained in:
committed by
GitHub
parent
5714e844a7
commit
404d2451df
@@ -67,6 +67,4 @@ class OIDCCallbackView(HomeAssistantView):
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
|
||||
code = await self.oidc_provider.async_save_user_info(user_details)
|
||||
return web.HTTPFound(
|
||||
get_url("/auth/oidc/finish?code=" + code, self.force_https)
|
||||
)
|
||||
raise web.HTTPFound(get_url("/auth/oidc/finish?code=" + code, self.force_https))
|
||||
|
||||
@@ -40,7 +40,7 @@ class OIDCFinishView(HomeAssistantView):
|
||||
return web.Response(text="No code received", status=500)
|
||||
|
||||
# Return redirect to the main page for sign in with a cookie
|
||||
return web.HTTPFound(
|
||||
raise web.HTTPFound(
|
||||
location="/?storeToken=true",
|
||||
headers={
|
||||
# Set a cookie to enable autologin on only the specific path used
|
||||
|
||||
@@ -25,10 +25,14 @@ class OIDCRedirectView(HomeAssistantView):
|
||||
"""Receive response."""
|
||||
|
||||
redirect_uri = get_url("/auth/oidc/callback", self.force_https)
|
||||
auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri)
|
||||
|
||||
if auth_url:
|
||||
return web.HTTPFound(auth_url)
|
||||
try:
|
||||
auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri)
|
||||
|
||||
if auth_url:
|
||||
raise web.HTTPFound(auth_url)
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
view_html = await get_view(
|
||||
"error",
|
||||
|
||||
@@ -23,7 +23,7 @@ class OIDCWelcomeView(HomeAssistantView):
|
||||
"""Receive response."""
|
||||
|
||||
if not self.is_enabled:
|
||||
return web.HTTPTemporaryRedirect(get_url("/", self.force_https))
|
||||
raise web.HTTPTemporaryRedirect(get_url("/", self.force_https))
|
||||
|
||||
view_html = await get_view("welcome", {"name": self.name})
|
||||
return web.Response(text=view_html, content_type="text/html")
|
||||
|
||||
@@ -177,9 +177,9 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
# If person creation is enabled, add a person for this user
|
||||
if self.create_persons:
|
||||
user_meta = await self.async_user_meta_for_credentials(credential)
|
||||
await self.async_create_person(user, user_meta.name)
|
||||
await self._async_create_person(user, user_meta.name)
|
||||
|
||||
async def async_create_person(self, user: User, name: str) -> None:
|
||||
async def _async_create_person(self, user: User, name: str) -> None:
|
||||
"""Create a person for the user."""
|
||||
_LOGGER.info("Automatically creating person for new user %s", user.id)
|
||||
|
||||
@@ -194,7 +194,7 @@ class OpenIDAuthProvider(AuthProvider):
|
||||
# pylint: disable=broad-exception-caught
|
||||
except Exception:
|
||||
_LOGGER.warning(
|
||||
"Requested automatic person creation, but person creation failed."
|
||||
"Requested automatic person creation, but person creation failed"
|
||||
)
|
||||
# pylint: enable=broad-exception-caught
|
||||
|
||||
@@ -315,7 +315,7 @@ class OpenIdLoginFlow(LoginFlow):
|
||||
"""Handle the step of the form."""
|
||||
|
||||
# Try to use the user input first
|
||||
if user_input is not None:
|
||||
if user_input is not None and "code" in user_input:
|
||||
try:
|
||||
return await self._finalize_user(user_input["code"])
|
||||
except InvalidAuthError:
|
||||
@@ -323,14 +323,15 @@ class OpenIdLoginFlow(LoginFlow):
|
||||
|
||||
# If not available, check the cookie
|
||||
req = http.current_request.get()
|
||||
code_cookie = req.cookies.get("auth_oidc_code")
|
||||
if req and req.cookies:
|
||||
code_cookie = req.cookies.get("auth_oidc_code")
|
||||
|
||||
if code_cookie:
|
||||
_LOGGER.debug("Code cookie found on login: %s", code_cookie)
|
||||
try:
|
||||
return await self._finalize_user(code_cookie)
|
||||
except InvalidAuthError:
|
||||
pass
|
||||
if code_cookie:
|
||||
_LOGGER.debug("Code cookie found on login: %s", code_cookie)
|
||||
try:
|
||||
return await self._finalize_user(code_cookie)
|
||||
except InvalidAuthError:
|
||||
pass
|
||||
|
||||
# If none are available, just show the form
|
||||
return self._show_login_form()
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import cast, Optional
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.core import HomeAssistant
|
||||
@@ -31,7 +31,7 @@ class CodeStore:
|
||||
data = cast(dict[str, UserDetails], {})
|
||||
self._data = data
|
||||
|
||||
async def async_save(self) -> None:
|
||||
async def _async_save(self) -> None:
|
||||
"""Save data."""
|
||||
if self._data is not None:
|
||||
await self._store.async_save(self._data)
|
||||
@@ -46,7 +46,7 @@ class CodeStore:
|
||||
raise RuntimeError("Data not loaded")
|
||||
|
||||
code = self._generate_code()
|
||||
expiration = datetime.utcnow() + timedelta(minutes=5)
|
||||
expiration = datetime.now(timezone.utc) + timedelta(minutes=5)
|
||||
|
||||
self._data[code] = {
|
||||
"user_info": user_info,
|
||||
@@ -54,7 +54,7 @@ class CodeStore:
|
||||
"expiration": expiration.isoformat(),
|
||||
}
|
||||
|
||||
await self.async_save()
|
||||
await self._async_save()
|
||||
return code
|
||||
|
||||
async def receive_userinfo_for_code(self, code: str) -> Optional[UserDetails]:
|
||||
@@ -67,12 +67,15 @@ class CodeStore:
|
||||
if user_data:
|
||||
# We should now wipe it from the database, as it's one time use code
|
||||
self._data.pop(code)
|
||||
await self.async_save()
|
||||
await self._async_save()
|
||||
|
||||
if (
|
||||
user_data
|
||||
and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow()
|
||||
if user_data and datetime.fromisoformat(user_data["expiration"]) > datetime.now(
|
||||
timezone.utc
|
||||
):
|
||||
return user_data["user_info"]
|
||||
|
||||
return None
|
||||
|
||||
def get_data(self):
|
||||
"""Get the internal data for testing purposes."""
|
||||
return self._data
|
||||
|
||||
@@ -39,12 +39,8 @@ class OIDCDiscoveryInvalid(OIDCClientException):
|
||||
type: Optional[str]
|
||||
details: Optional[dict]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if args:
|
||||
self.message = args[0]
|
||||
else:
|
||||
self.message = "OIDC Discovery document is invalid"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.message = "OIDC Discovery document is invalid"
|
||||
self.type = kwargs.pop("type", None)
|
||||
self.details = kwargs.pop("details", None)
|
||||
super().__init__(self.message)
|
||||
@@ -196,7 +192,7 @@ class OIDCDiscoveryClient:
|
||||
)
|
||||
raise OIDCDiscoveryInvalid(
|
||||
type="does_not_support_response_mode",
|
||||
modes=document["response_modes_supported"],
|
||||
details={"modes": document["response_modes_supported"]},
|
||||
)
|
||||
|
||||
# If grant_types_supported is set, should support 'authorization_code'
|
||||
@@ -281,7 +277,7 @@ class OIDCDiscoveryClient:
|
||||
await self._validate_discovery_document(document)
|
||||
return document
|
||||
|
||||
async def fetch_jwks(self, jwks_uri: str | None):
|
||||
async def fetch_jwks(self, jwks_uri: str | None = None):
|
||||
"""Fetches JWKS."""
|
||||
if jwks_uri is None:
|
||||
discovery_document = await self._fetch_discovery_document()
|
||||
|
||||
@@ -10,7 +10,7 @@ def validate_url(url: str) -> bool:
|
||||
try:
|
||||
parsed = urlparse(url.strip())
|
||||
return bool(parsed.scheme in ("http", "https") and parsed.netloc)
|
||||
except (ValueError, TypeError):
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ def validate_discovery_url(url: str) -> bool:
|
||||
and parsed.netloc
|
||||
and parsed.path.endswith("/.well-known/openid-configuration")
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
except (ValueError, TypeError, AttributeError):
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class AsyncTemplateRenderer:
|
||||
) as f:
|
||||
content = await f.read()
|
||||
templates[filename] = content
|
||||
except (OSError, IOError) as e:
|
||||
except (OSError, IOError) as e: # pragma: no cover
|
||||
_LOGGER.warning("Error reading template file %s: %s", filename, e)
|
||||
|
||||
async def render_template(self, template_name: str, **kwargs: Any) -> str:
|
||||
|
||||
Reference in New Issue
Block a user