Compare commits
92 Commits
v0.6.5-alp
...
v1.0.0-rc3
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7cc960e4db | ||
|
|
07c1e3a4c4 | ||
|
|
0ca300c385 | ||
|
|
a9483e2038 | ||
|
|
6f1d2bcb3f | ||
|
|
67f58a39aa | ||
|
|
ddb2952e64 | ||
|
|
baf3ac6b5a | ||
|
|
c7672f65d9 | ||
|
|
fd3643685d | ||
|
|
fdc93e2719 | ||
|
|
3d789be33b | ||
|
|
58f52e4720 | ||
|
|
a135be2d81 | ||
|
|
26f3590a69 | ||
|
|
a3e839d0a8 | ||
|
|
52c9bd2a50 | ||
|
|
94b0fe80fe | ||
|
|
51b5e5662a | ||
|
|
0e173cfe06 | ||
|
|
55375395ea | ||
|
|
a0e58448e7 | ||
|
|
c85002167f | ||
|
|
5050cb4d5e | ||
|
|
52ff6a899d | ||
|
|
620df3006b | ||
|
|
fb4b5785de | ||
|
|
dfd7f83d88 | ||
|
|
b146441a23 | ||
|
|
624b5ec189 | ||
|
|
9a8bca4c1f | ||
|
|
0eef84d092 | ||
|
|
5187ceffbd | ||
|
|
1a35f953da | ||
|
|
a29e0e6730 | ||
|
|
0f0679d46d | ||
|
|
d6b8f6bbb1 | ||
|
|
6f93a22c37 | ||
|
|
b2d07c28f0 | ||
|
|
eaed91016a | ||
|
|
307af07c0c | ||
|
|
1f95efd0aa | ||
|
|
04a693ccfe | ||
|
|
d663d28ece | ||
|
|
e2bcbf4aa7 | ||
|
|
50e78a5189 | ||
|
|
36d3255b0b | ||
|
|
bc5039b951 | ||
|
|
d197ebade1 | ||
|
|
ff13b1db8f | ||
|
|
6f45858909 | ||
|
|
da39443d95 | ||
|
|
8914dba95c | ||
|
|
0133446975 | ||
|
|
674c342a81 | ||
|
|
a8e0162d25 | ||
|
|
2b224b3002 | ||
|
|
6fc1cd2944 | ||
|
|
dde2a70a39 | ||
|
|
4e898087d4 | ||
|
|
75eb5378c8 | ||
|
|
979a55dd12 | ||
|
|
2aa083a88d | ||
|
|
4bf1923b8b | ||
|
|
92df670c3f | ||
|
|
744be2a4d8 | ||
|
|
404d2451df | ||
|
|
5714e844a7 | ||
|
|
d1da841e1f | ||
|
|
3b481cd282 | ||
|
|
9fb65084b1 | ||
|
|
19cbc6e3ad | ||
|
|
daa13fa94f | ||
|
|
db8a5575f8 | ||
|
|
513e8e33f7 | ||
|
|
b87dd35577 | ||
|
|
44794f7cfd | ||
|
|
b0532ec2ec | ||
|
|
4fca5d8de4 | ||
|
|
02a4937ce4 | ||
|
|
86663dd5e4 | ||
|
|
c13eb7c438 | ||
|
|
764326a9e1 | ||
|
|
601127d6b3 | ||
|
|
e22f960d69 | ||
|
|
054f0e4bca | ||
|
|
1e0310afc9 | ||
|
|
0888ea0400 | ||
|
|
27de2bcf71 | ||
|
|
2e85f4bd16 | ||
|
|
5651e9bff3 | ||
|
|
86c663700c |
21
.github/workflows/build-pr.yaml
vendored
Normal file
21
.github/workflows/build-pr.yaml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: Build pull request artifact
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: scripts/build
|
||||||
|
|
||||||
|
- name: Upload Artifact
|
||||||
|
uses: actions/upload-artifact@v7
|
||||||
|
with:
|
||||||
|
path: ./hass-oidc-auth.zip
|
||||||
|
archive: false
|
||||||
3
.github/workflows/hacs.yaml
vendored
3
.github/workflows/hacs.yaml
vendored
@@ -13,10 +13,9 @@ jobs:
|
|||||||
validate:
|
validate:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
- name: HACS validation
|
- name: HACS validation
|
||||||
uses: hacs/action@22.5.0
|
uses: hacs/action@22.5.0
|
||||||
with:
|
with:
|
||||||
category: "integration"
|
category: "integration"
|
||||||
ignore: brands
|
|
||||||
|
|
||||||
2
.github/workflows/hassfest.yaml
vendored
2
.github/workflows/hassfest.yaml
vendored
@@ -13,5 +13,5 @@ jobs:
|
|||||||
validate:
|
validate:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
- uses: home-assistant/actions/hassfest@master
|
- uses: home-assistant/actions/hassfest@master
|
||||||
16
.github/workflows/lint.yaml
vendored
16
.github/workflows/lint.yaml
vendored
@@ -9,12 +9,16 @@ jobs:
|
|||||||
build:
|
build:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v6
|
||||||
- name: Install the latest version of rye
|
- name: "Set up Python"
|
||||||
uses: eifinger/setup-rye@v4
|
uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version-file: ".python-version"
|
||||||
|
- name: Install the latest version of uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
with:
|
with:
|
||||||
enable-cache: true
|
enable-cache: true
|
||||||
- name: Sync dependencies
|
- name: Sync dependencies
|
||||||
run: rye sync
|
run: scripts/sync
|
||||||
- name: Lint (pylint/rye lint)
|
- name: Lint (pylint/ruff lint)
|
||||||
run: rye run check
|
run: scripts/check
|
||||||
|
|||||||
27
.github/workflows/release.yaml
vendored
Normal file
27
.github/workflows/release.yaml
vendored
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
name: Build and create draft release
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- v*.*.*
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
release:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
run: scripts/build
|
||||||
|
|
||||||
|
- name: Create or update draft release with ZIP
|
||||||
|
uses: softprops/action-gh-release@v3
|
||||||
|
with:
|
||||||
|
draft: true
|
||||||
|
fail_on_unmatched_files: true
|
||||||
|
generate_release_notes: true
|
||||||
|
files: ./hass-oidc-auth.zip
|
||||||
26
.github/workflows/security.yaml
vendored
Normal file
26
.github/workflows/security.yaml
vendored
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
---
|
||||||
|
name: Security (pysentry)
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- main
|
||||||
|
pull_request:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 8 */3 * *"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
vulnerability-scan:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- name: "Set up Python"
|
||||||
|
uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version-file: ".python-version"
|
||||||
|
- name: Install the latest version of uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
with:
|
||||||
|
enable-cache: true
|
||||||
|
- name: Scan dependencies for vulnerabilities
|
||||||
|
run: uvx pysentry-rs .
|
||||||
24
.github/workflows/test.yaml
vendored
Normal file
24
.github/workflows/test.yaml
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
---
|
||||||
|
name: Tests (pytest)
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
pull_request:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v6
|
||||||
|
- name: "Set up Python"
|
||||||
|
uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version-file: ".python-version"
|
||||||
|
- name: Install the latest version of uv
|
||||||
|
uses: astral-sh/setup-uv@v7
|
||||||
|
with:
|
||||||
|
enable-cache: true
|
||||||
|
- name: Sync dependencies
|
||||||
|
run: scripts/sync
|
||||||
|
- name: Test
|
||||||
|
run: scripts/test
|
||||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -105,6 +105,12 @@ dmypy.json
|
|||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
# End of https://www.gitignore.io/api/python
|
# End of https://www.gitignore.io/api/python
|
||||||
config/
|
/config/
|
||||||
|
|
||||||
.venv
|
.venv
|
||||||
|
|
||||||
|
.pytest_logs.log
|
||||||
|
|
||||||
|
# Build NPM
|
||||||
|
node_modules
|
||||||
|
custom_components/auth_oidc/static/style.css
|
||||||
@@ -106,7 +106,7 @@ source-roots=
|
|||||||
|
|
||||||
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
# When enabled, pylint would attempt to guess common misconfiguration and emit
|
||||||
# user-friendly hints instead of false-positive error messages.
|
# user-friendly hints instead of false-positive error messages.
|
||||||
suggestion-mode=yes
|
#suggestion-mode=yes
|
||||||
|
|
||||||
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
# Allow loading of arbitrary C extensions. Extensions are imported into the
|
||||||
# active Python interpreter and may run arbitrary code.
|
# active Python interpreter and may run arbitrary code.
|
||||||
|
|||||||
50
.pysentry.toml
Normal file
50
.pysentry.toml
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
version = 1
|
||||||
|
|
||||||
|
[defaults]
|
||||||
|
format = "human"
|
||||||
|
severity = "low"
|
||||||
|
fail_on = "medium"
|
||||||
|
scope = "main"
|
||||||
|
direct_only = false
|
||||||
|
detailed = false
|
||||||
|
include_withdrawn = false
|
||||||
|
no_ci_detect = false
|
||||||
|
|
||||||
|
[sources]
|
||||||
|
enabled = [
|
||||||
|
"pypa",
|
||||||
|
"pypi",
|
||||||
|
"osv",
|
||||||
|
]
|
||||||
|
|
||||||
|
[resolver]
|
||||||
|
type = "uv"
|
||||||
|
|
||||||
|
[cache]
|
||||||
|
enabled = true
|
||||||
|
resolution_ttl = 24
|
||||||
|
vulnerability_ttl = 48
|
||||||
|
|
||||||
|
[ignore]
|
||||||
|
ids = []
|
||||||
|
while_no_fix = []
|
||||||
|
|
||||||
|
[http]
|
||||||
|
timeout = 120
|
||||||
|
connect_timeout = 30
|
||||||
|
max_retries = 3
|
||||||
|
retry_initial_backoff = 1
|
||||||
|
retry_max_backoff = 60
|
||||||
|
show_progress = true
|
||||||
|
|
||||||
|
[maintenance]
|
||||||
|
enabled = true
|
||||||
|
forbid_archived = false
|
||||||
|
forbid_deprecated = false
|
||||||
|
forbid_quarantined = false
|
||||||
|
forbid_unmaintained = false
|
||||||
|
check_direct_only = false
|
||||||
|
cache_ttl = 1
|
||||||
|
|
||||||
|
[notifications]
|
||||||
|
enabled = true
|
||||||
@@ -1 +1 @@
|
|||||||
3.13.1
|
3.14.4
|
||||||
@@ -13,9 +13,38 @@ If you are not a programmer, you can still contribute by:
|
|||||||
You may also submit Pull Requests (PRs) to add features yourself! You can find a list that we are currently working on below. Please note that workflows will be run on your pull request and a pull request will only be merged when all checks pass and a review has been conducted (together with a manual test).
|
You may also submit Pull Requests (PRs) to add features yourself! You can find a list that we are currently working on below. Please note that workflows will be run on your pull request and a pull request will only be merged when all checks pass and a review has been conducted (together with a manual test).
|
||||||
|
|
||||||
### Development
|
### Development
|
||||||
This project uses the Rye package manager for development. You can find installation instructions here: https://rye.astral.sh/guide/installation/. Start by installing the dependencies using rye sync and then point your editor towards the environment created in the .venv directory.
|
This project uses the uv package manager for development. You can find installation instructions here: https://docs.astral.sh/uv/getting-started/installation/. Start by installing the dependencies using `uv sync` and then point your editor towards the environment created in the .venv directory.
|
||||||
You can then run Home Assistant and put the `custom_components/auth_oidc` directory in your HA `config` folder.
|
You can then run Home Assistant and put the `custom_components/auth_oidc` directory in your HA `config` folder.
|
||||||
|
|
||||||
|
#### Other useful commands
|
||||||
|
Some useful scripts are in the `scripts` directory. If you run Linux (or WSL under Windows), you can run these directly:
|
||||||
|
|
||||||
|
- `scripts/check` will check your Python files for linting errors
|
||||||
|
- `scripts/fix` will fix some formatting mistakes automatically
|
||||||
|
- `scripts/test` will run the testing suite
|
||||||
|
- `scripts/coverage-report` will run the testing suite and generate a code coverage report (and runs a webserver to serve the report)
|
||||||
|
|
||||||
|
You can also run these commands manually on Windows:
|
||||||
|
|
||||||
|
##### Compiling css
|
||||||
|
|
||||||
|
To compile tailwind css styles for the pages you need the NodeJS and NPM installed.
|
||||||
|
|
||||||
|
You can run the `npm run css` script to generate the css once and you can run the `npm run css:watch` to recompile the css every time the templates change
|
||||||
|
|
||||||
|
##### Check
|
||||||
|
```
|
||||||
|
uv run ruff check
|
||||||
|
uv run ruff format --check
|
||||||
|
uv run pylint custom_components
|
||||||
|
```
|
||||||
|
|
||||||
|
##### Fix
|
||||||
|
```
|
||||||
|
uv run ruff check --fix
|
||||||
|
uv run ruff format
|
||||||
|
```
|
||||||
|
|
||||||
### Docker Compose Development Environment
|
### Docker Compose Development Environment
|
||||||
You can also use the following Docker Compose configuration to automatically start up the latest HA release with the `auth_oidc` integration:
|
You can also use the following Docker Compose configuration to automatically start up the latest HA release with the `auth_oidc` integration:
|
||||||
|
|
||||||
@@ -34,59 +63,3 @@ services:
|
|||||||
|
|
||||||
# Found a security issue?
|
# Found a security issue?
|
||||||
Please see [SECURITY.md](./SECURITY.md) for more information on how to submit your security issue securely. You can find previously found vulnerablities and their corresponding security advisories at the [Security Advisories page](https://github.com/christiaangoossens/hass-oidc-auth/security/advisories).
|
Please see [SECURITY.md](./SECURITY.md) for more information on how to submit your security issue securely. You can find previously found vulnerablities and their corresponding security advisories at the [Security Advisories page](https://github.com/christiaangoossens/hass-oidc-auth/security/advisories).
|
||||||
|
|
||||||
# Roadmap
|
|
||||||
The following features are on the roadmap:
|
|
||||||
|
|
||||||
## Better user experience
|
|
||||||
*Copied from https://github.com/christiaangoossens/hass-oidc-auth/issues/19*
|
|
||||||
|
|
||||||
Current status on the user experience:
|
|
||||||
|
|
||||||
- I cannot change the login screen as all of this is hard coded in the frontend code. So, I am stuck with the title of "Just checking" and without any description or even a title for the input box. Changing this would require a PR on the Home Assistant frontend repository.
|
|
||||||
- If anyone can refactor their code to allow integrations (Auth Providers) to send custom translations to the frontend when sending the form (here: [custom_components/auth_oidc/provider.py, line 302](https://github.com/christiaangoossens/hass-oidc-auth/blob/main/custom_components/auth_oidc/provider.py#L302)), such that I can send custom translation keys for the title (instead of just using the `mfa` version), description and input label, I would be very happy to accept a PR here as well that accomplishes that.
|
|
||||||
- Bonus points if it uses the same translation system you would use for any normal setup/config flow in the UI.
|
|
||||||
- Extra bonus points if we can add a button or link besides it that allows for opening the start of the OIDC flow there too, within the description for instance.
|
|
||||||
|
|
||||||
- I cannot redirect you to the start of the OIDC process yet, both on mobile and on desktop. Whenever [the PR](https://github.com/home-assistant/frontend/pull/23204) gets merged and a Home Assistant version that's includes the PR is released (or planned), I will hopefully be able to get something like that to work on desktop.
|
|
||||||
- It likely will not work on mobile, as the PR that's now approved only does it for desktop, I tested mobile with that code 2 years ago and it didn't work. I will contact someone on the Android team to see if we can make that happen too at some point.
|
|
||||||
- Mobile will need to open the `window.open` call using Android Custom Tab (Android) / SFSafariViewController (iOS) instead of the normal webview. It seems that external links didn't work at all when I tried it.
|
|
||||||
|
|
||||||
PR's that improve the user experience are welcome, but they should be stable and preferably hack as little as possible.
|
|
||||||
|
|
||||||
## Tests
|
|
||||||
The project still needs the following automated tests on every PR:
|
|
||||||
|
|
||||||
- Spin up Home Assistant (both the required version from the `hacs.json` and the latest version) and verify that it starts up with no warnings or errors
|
|
||||||
- Normal pytest unit testing (https://developers.home-assistant.io/docs/development_testing/)
|
|
||||||
- You might be able to re-use some unit tests from the original implementation by @elupus: https://github.com/home-assistant/core/pull/32926 or from it's inspired work by @allenporter: https://github.com/allenporter/home-assistant-openid-auth-provider/tree/main/tests
|
|
||||||
- Integration test that performs an automatic run-through of an entire flow with an example/mocked OIDC provider, either in Python code or using an external tool (such as Playwright)
|
|
||||||
|
|
||||||
Together, these should test the following:
|
|
||||||
- The integration registers correctly without any errors (spin-up test)
|
|
||||||
- The integration works with both the minimum HA version as well as the latest HA version (spin-up test)
|
|
||||||
- Configuration can be set without any errors (unit test)
|
|
||||||
- Configuration has the correct effects (unit test)
|
|
||||||
- Code works correctly on its own (unit test)
|
|
||||||
- Full flow is functional and displays as expected, including integration with an external OIDC provider (integration test)
|
|
||||||
|
|
||||||
Preferably, we run all tests on every PR to make manual testing unnecessary.
|
|
||||||
|
|
||||||
## Better configuration experience
|
|
||||||
As a conclusion to the poll (https://github.com/christiaangoossens/hass-oidc-auth/discussions/6), it seems that the best option would be to keep the current YAML configuration for advanced uses and add a UI configuration for the common providers.
|
|
||||||
|
|
||||||
I planned for the following user flow:
|
|
||||||
|
|
||||||
1. Add integration in the HA UI
|
|
||||||
2. Get config dialog with a selector for which OIDC provider you are using
|
|
||||||
3. Preconfigure claim configuration using the chosen provider
|
|
||||||
4. Have user input client id & discovery URL with an instruction to configure as public client
|
|
||||||
5. (Optionally) allow users to choose confidential client and input client secret
|
|
||||||
6. Check these fields by requesting the discovery, JWKS
|
|
||||||
7. Ask user if they want to enable groups and allow them to input the correct group name for both roles
|
|
||||||
8. (Optionally) allow users to enable user linking, explain the issues to them with leaving it enabled and allow disabling later
|
|
||||||
9. Inform users that advanced options are only available in YAML, such as networking settings or specific claim configurations
|
|
||||||
10. Have the user perform one login to check that all the fields are correct, just as any OAuth2 integration would, preferably using our oidc_provider
|
|
||||||
11. Save the integration and request restart to enable it (if necessary)
|
|
||||||
|
|
||||||
While I welcome adding configuration by UI, it's not at the top of my priority list. Ask me in the PR if you have any other suggestions and don't forget to add tests for this too. Existing YAML configuration should also remain unaffected, whenever possible.
|
|
||||||
@@ -41,9 +41,6 @@
|
|||||||
</p>
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
> [!CAUTION]
|
|
||||||
> This is an alpha release. I give no guarantees about code quality, error handling or security at this stage. Use at your own risk.
|
|
||||||
|
|
||||||
Provides an OpenID Connect (OIDC) implementation for Home Assistant through a custom component/integration. Through this integration, you can create an SSO (single-sign-on) environment within your self-hosted application stack / homelab.
|
Provides an OpenID Connect (OIDC) implementation for Home Assistant through a custom component/integration. Through this integration, you can create an SSO (single-sign-on) environment within your self-hosted application stack / homelab.
|
||||||
|
|
||||||
### Background
|
### Background
|
||||||
@@ -54,7 +51,7 @@ If you would like to read the background/open letter that lead to this component
|
|||||||
|
|
||||||
## Installation guide
|
## Installation guide
|
||||||
|
|
||||||
1. Add this repository to [HACS](https://hacs.xyz/).
|
1. Add this repository to [HACS](https://hacs.xyz/) (or search for "OpenID Connect" in HACS).
|
||||||
|
|
||||||
[](https://my.home-assistant.io/redirect/hacs_repository/?owner=christiaangoossens&repository=hass-oidc-auth&category=Integration)
|
[](https://my.home-assistant.io/redirect/hacs_repository/?owner=christiaangoossens&repository=hass-oidc-auth&category=Integration)
|
||||||
|
|
||||||
|
|||||||
@@ -1,14 +1,18 @@
|
|||||||
"""OIDC Integration for Home Assistant."""
|
"""OIDC Integration for Home Assistant."""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from typing import OrderedDict
|
from typing import OrderedDict
|
||||||
|
|
||||||
|
from homeassistant.config_entries import ConfigEntry
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
# Import and re-export config schema explictly
|
# Import and re-export config schema explictly
|
||||||
# pylint: disable=useless-import-alias
|
# pylint: disable=useless-import-alias
|
||||||
|
from .config import CONFIG_SCHEMA as CONFIG_SCHEMA
|
||||||
|
|
||||||
|
# Get all the constants for the config
|
||||||
from .config import (
|
from .config import (
|
||||||
CONFIG_SCHEMA as CONFIG_SCHEMA,
|
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
DEFAULT_TITLE,
|
DEFAULT_TITLE,
|
||||||
CLIENT_ID,
|
CLIENT_ID,
|
||||||
@@ -23,25 +27,68 @@ from .config import (
|
|||||||
ROLES,
|
ROLES,
|
||||||
NETWORK,
|
NETWORK,
|
||||||
FEATURES_INCLUDE_GROUPS_SCOPE,
|
FEATURES_INCLUDE_GROUPS_SCOPE,
|
||||||
|
FEATURES_FORCE_HTTPS,
|
||||||
|
REQUIRED_SCOPES,
|
||||||
)
|
)
|
||||||
|
|
||||||
# pylint: enable=useless-import-alias
|
from .config import convert_ui_config_entry_to_internal_format
|
||||||
|
|
||||||
from .endpoints.welcome import OIDCWelcomeView
|
from .endpoints import (
|
||||||
from .endpoints.redirect import OIDCRedirectView
|
OIDCWelcomeView,
|
||||||
from .endpoints.finish import OIDCFinishView
|
OIDCRedirectView,
|
||||||
from .endpoints.callback import OIDCCallbackView
|
OIDCFinishView,
|
||||||
|
OIDCCallbackView,
|
||||||
from .oidc_client import OIDCClient
|
OIDCInjectedAuthPage,
|
||||||
|
OIDCDeviceSSE,
|
||||||
|
)
|
||||||
|
from .tools.oidc_client import OIDCClient
|
||||||
from .provider import OpenIDAuthProvider
|
from .provider import OpenIDAuthProvider
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def async_setup(hass: HomeAssistant, config):
|
async def async_setup(hass: HomeAssistant, config):
|
||||||
"""Add the OIDC Auth Provider to the providers in Home Assistant"""
|
"""Add the OIDC Auth Provider to the providers in Home Assistant (YAML config)."""
|
||||||
|
if DOMAIN not in config:
|
||||||
|
return True
|
||||||
|
|
||||||
my_config = config[DOMAIN]
|
my_config = config[DOMAIN]
|
||||||
|
|
||||||
|
# Store YAML config for later access by config flow
|
||||||
|
if DOMAIN not in hass.data:
|
||||||
|
hass.data[DOMAIN] = {}
|
||||||
|
hass.data[DOMAIN]["yaml_config"] = my_config
|
||||||
|
|
||||||
|
await _setup_oidc_provider(
|
||||||
|
hass, my_config, config[DOMAIN].get(DISPLAY_NAME, DEFAULT_TITLE)
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
|
||||||
|
"""Set up OIDC Authentication from a config entry (UI config)."""
|
||||||
|
# Convert config entry data to the format expected by the existing setup
|
||||||
|
config_data = entry.data.copy()
|
||||||
|
|
||||||
|
# Convert config entry format to internal format
|
||||||
|
my_config = convert_ui_config_entry_to_internal_format(config_data)
|
||||||
|
|
||||||
|
# Get display name from config entry
|
||||||
|
display_name = config_data.get("display_name", DEFAULT_TITLE)
|
||||||
|
|
||||||
|
await _setup_oidc_provider(hass, my_config, display_name)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
async def async_unload_entry(_hass: HomeAssistant, _entry: ConfigEntry):
|
||||||
|
"""Unload a config entry."""
|
||||||
|
# OIDC auth providers cannot be easily unloaded as they are integrated
|
||||||
|
# into Home Assistant's auth system. A restart is required.
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
async def _setup_oidc_provider(hass: HomeAssistant, my_config: dict, display_name: str):
|
||||||
|
"""Set up the OIDC provider with the given configuration."""
|
||||||
providers = OrderedDict()
|
providers = OrderedDict()
|
||||||
|
|
||||||
# Use private APIs until there is a real auth platform
|
# Use private APIs until there is a real auth platform
|
||||||
@@ -49,6 +96,10 @@ async def async_setup(hass: HomeAssistant, config):
|
|||||||
provider = OpenIDAuthProvider(hass, hass.auth._store, my_config)
|
provider = OpenIDAuthProvider(hass, hass.auth._store, my_config)
|
||||||
|
|
||||||
providers[(provider.type, provider.id)] = provider
|
providers[(provider.type, provider.id)] = provider
|
||||||
|
|
||||||
|
# Get current provider count
|
||||||
|
has_other_auth_providers = len(hass.auth._providers) > 0
|
||||||
|
|
||||||
providers.update(hass.auth._providers)
|
providers.update(hass.auth._providers)
|
||||||
hass.auth._providers = providers
|
hass.auth._providers = providers
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
@@ -58,7 +109,7 @@ async def async_setup(hass: HomeAssistant, config):
|
|||||||
# Set the correct scopes
|
# Set the correct scopes
|
||||||
# Always use 'openid' & 'profile' as they are specified in the OIDC spec
|
# Always use 'openid' & 'profile' as they are specified in the OIDC spec
|
||||||
# All servers should support this
|
# All servers should support this
|
||||||
scope = "openid profile"
|
scope = REQUIRED_SCOPES
|
||||||
|
|
||||||
# Include groups if requested (default is to include 'groups'
|
# Include groups if requested (default is to include 'groups'
|
||||||
# as a scope for Authelia & Authentik)
|
# as a scope for Authelia & Authentik)
|
||||||
@@ -76,7 +127,7 @@ async def async_setup(hass: HomeAssistant, config):
|
|||||||
scope += " ".join(additional_scopes)
|
scope += " ".join(additional_scopes)
|
||||||
|
|
||||||
# Create the OIDC client
|
# Create the OIDC client
|
||||||
oidc_client = oidc_client = OIDCClient(
|
oidc_client = OIDCClient(
|
||||||
hass=hass,
|
hass=hass,
|
||||||
discovery_url=my_config.get(DISCOVERY_URL),
|
discovery_url=my_config.get(DISCOVERY_URL),
|
||||||
client_id=my_config.get(CLIENT_ID),
|
client_id=my_config.get(CLIENT_ID),
|
||||||
@@ -90,13 +141,22 @@ async def async_setup(hass: HomeAssistant, config):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Register the views
|
# Register the views
|
||||||
name = config[DOMAIN].get(DISPLAY_NAME, DEFAULT_TITLE)
|
name = display_name
|
||||||
|
name = re.sub(r"[^A-Za-z0-9 _\-\(\)]", "", name)
|
||||||
|
|
||||||
hass.http.register_view(OIDCWelcomeView(name))
|
force_https = features_config.get(FEATURES_FORCE_HTTPS, False)
|
||||||
hass.http.register_view(OIDCRedirectView(oidc_client))
|
|
||||||
hass.http.register_view(OIDCCallbackView(oidc_client, provider))
|
hass.http.register_view(
|
||||||
hass.http.register_view(OIDCFinishView())
|
OIDCWelcomeView(provider, name, force_https, has_other_auth_providers)
|
||||||
|
)
|
||||||
|
hass.http.register_view(OIDCDeviceSSE(provider))
|
||||||
|
hass.http.register_view(OIDCRedirectView(oidc_client, provider, force_https))
|
||||||
|
hass.http.register_view(OIDCCallbackView(oidc_client, provider, force_https))
|
||||||
|
hass.http.register_view(OIDCFinishView(provider))
|
||||||
|
|
||||||
_LOGGER.info("Registered OIDC views")
|
_LOGGER.info("Registered OIDC views")
|
||||||
|
|
||||||
|
# Inject OIDC code into the frontend for /auth/authorize for automatic redirect
|
||||||
|
await OIDCInjectedAuthPage.inject(hass, force_https)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|||||||
8
custom_components/auth_oidc/config/__init__.py
Normal file
8
custom_components/auth_oidc/config/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Imports manager"""
|
||||||
|
|
||||||
|
from .const import * # noqa: F403
|
||||||
|
from .schema import CONFIG_SCHEMA as CONFIG_SCHEMA
|
||||||
|
from .ui_flow import (
|
||||||
|
OIDCConfigFlow as OIDCConfigFlow,
|
||||||
|
convert_ui_config_entry_to_internal_format as convert_ui_config_entry_to_internal_format,
|
||||||
|
)
|
||||||
91
custom_components/auth_oidc/config/const.py
Normal file
91
custom_components/auth_oidc/config/const.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
"""Config constants."""
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
## ===
|
||||||
|
## General integration constants
|
||||||
|
## ===
|
||||||
|
|
||||||
|
DEFAULT_TITLE = "OpenID Connect (SSO)"
|
||||||
|
DOMAIN = "auth_oidc"
|
||||||
|
REPO_ROOT_URL = "https://github.com/christiaangoossens/hass-oidc-auth/tree/v1.0.0-rc3"
|
||||||
|
|
||||||
|
## ===
|
||||||
|
## Config keys
|
||||||
|
## ===
|
||||||
|
|
||||||
|
CLIENT_ID = "client_id"
|
||||||
|
CLIENT_SECRET = "client_secret"
|
||||||
|
DISCOVERY_URL = "discovery_url"
|
||||||
|
DISPLAY_NAME = "display_name"
|
||||||
|
ID_TOKEN_SIGNING_ALGORITHM = "id_token_signing_alg"
|
||||||
|
GROUPS_SCOPE = "groups_scope"
|
||||||
|
ADDITIONAL_SCOPES = "additional_scopes"
|
||||||
|
FEATURES = "features"
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING = "automatic_user_linking"
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION = "automatic_person_creation"
|
||||||
|
FEATURES_DISABLE_PKCE = "disable_rfc7636"
|
||||||
|
FEATURES_INCLUDE_GROUPS_SCOPE = "include_groups_scope"
|
||||||
|
FEATURES_FORCE_HTTPS = "force_https"
|
||||||
|
CLAIMS = "claims"
|
||||||
|
CLAIMS_DISPLAY_NAME = "display_name"
|
||||||
|
CLAIMS_USERNAME = "username"
|
||||||
|
CLAIMS_GROUPS = "groups"
|
||||||
|
ROLES = "roles"
|
||||||
|
ROLE_ADMINS = "admin"
|
||||||
|
ROLE_USERS = "user"
|
||||||
|
NETWORK = "network"
|
||||||
|
NETWORK_TLS_VERIFY = "tls_verify"
|
||||||
|
NETWORK_TLS_CA_PATH = "tls_ca_path"
|
||||||
|
|
||||||
|
## ===
|
||||||
|
## Default configurations for providers
|
||||||
|
## ===
|
||||||
|
|
||||||
|
REQUIRED_SCOPES = "openid profile"
|
||||||
|
DEFAULT_ID_TOKEN_SIGNING_ALGORITHM = "RS256"
|
||||||
|
|
||||||
|
DEFAULT_GROUPS_SCOPE = "groups"
|
||||||
|
DEFAULT_ADMIN_GROUP = "admins"
|
||||||
|
|
||||||
|
OIDC_PROVIDERS: Dict[str, Dict[str, Any]] = {
|
||||||
|
"authentik": {
|
||||||
|
"name": "Authentik",
|
||||||
|
"discovery_url": "",
|
||||||
|
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||||
|
"supports_groups": True,
|
||||||
|
"claims": {
|
||||||
|
"display_name": "name",
|
||||||
|
"username": "preferred_username",
|
||||||
|
"groups": "groups",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"authelia": {
|
||||||
|
"name": "Authelia",
|
||||||
|
"discovery_url": "",
|
||||||
|
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||||
|
"supports_groups": True,
|
||||||
|
"claims": {
|
||||||
|
"display_name": "name",
|
||||||
|
"username": "preferred_username",
|
||||||
|
"groups": "groups",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"pocketid": {
|
||||||
|
"name": "Pocket ID",
|
||||||
|
"discovery_url": "",
|
||||||
|
"default_admin_group": DEFAULT_ADMIN_GROUP,
|
||||||
|
"supports_groups": True,
|
||||||
|
"claims": {
|
||||||
|
"display_name": "name",
|
||||||
|
"username": "preferred_username",
|
||||||
|
"groups": "groups",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"generic": {
|
||||||
|
"name": "OpenID Connect (SSO)",
|
||||||
|
"discovery_url": "",
|
||||||
|
"supports_groups": False,
|
||||||
|
"claims": {"display_name": "name", "username": "preferred_username"},
|
||||||
|
},
|
||||||
|
}
|
||||||
35
custom_components/auth_oidc/config/provider_catalog.py
Normal file
35
custom_components/auth_oidc/config/provider_catalog.py
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
"""Provider catalog and helpers for OIDC providers."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from .const import OIDC_PROVIDERS, REPO_ROOT_URL
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_config(key: str) -> Dict[str, Any]:
|
||||||
|
"""Return provider configuration by key."""
|
||||||
|
return OIDC_PROVIDERS.get(key, {})
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_name(key: str | None) -> str:
|
||||||
|
"""Return provider display name by key."""
|
||||||
|
if not key:
|
||||||
|
return "Unknown Provider"
|
||||||
|
return OIDC_PROVIDERS.get(key, {}).get("name", "Unknown Provider")
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_docs_url(key: str | None) -> str:
|
||||||
|
"""Return documentation URL for a provider key."""
|
||||||
|
base_url = REPO_ROOT_URL + "/docs/provider-configurations"
|
||||||
|
|
||||||
|
provider_docs = {
|
||||||
|
"authentik": f"{base_url}/authentik.md",
|
||||||
|
"authelia": f"{base_url}/authelia.md",
|
||||||
|
"pocketid": f"{base_url}/pocket-id.md",
|
||||||
|
"kanidm": f"{base_url}/kanidm.md",
|
||||||
|
"microsoft": f"{base_url}/microsoft-entra.md",
|
||||||
|
}
|
||||||
|
|
||||||
|
if key in provider_docs:
|
||||||
|
return provider_docs[key]
|
||||||
|
return REPO_ROOT_URL + "/docs/configuration.md"
|
||||||
@@ -1,34 +1,34 @@
|
|||||||
"""Config schema and constants."""
|
"""Config schema"""
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
from .const import (
|
||||||
|
CLIENT_ID,
|
||||||
|
CLIENT_SECRET,
|
||||||
|
DISCOVERY_URL,
|
||||||
|
DISPLAY_NAME,
|
||||||
|
ID_TOKEN_SIGNING_ALGORITHM,
|
||||||
|
GROUPS_SCOPE,
|
||||||
|
ADDITIONAL_SCOPES,
|
||||||
|
FEATURES,
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING,
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION,
|
||||||
|
FEATURES_DISABLE_PKCE,
|
||||||
|
FEATURES_INCLUDE_GROUPS_SCOPE,
|
||||||
|
FEATURES_FORCE_HTTPS,
|
||||||
|
CLAIMS,
|
||||||
|
CLAIMS_DISPLAY_NAME,
|
||||||
|
CLAIMS_USERNAME,
|
||||||
|
CLAIMS_GROUPS,
|
||||||
|
ROLES,
|
||||||
|
ROLE_ADMINS,
|
||||||
|
ROLE_USERS,
|
||||||
|
NETWORK,
|
||||||
|
NETWORK_TLS_VERIFY,
|
||||||
|
NETWORK_TLS_CA_PATH,
|
||||||
|
DOMAIN,
|
||||||
|
DEFAULT_GROUPS_SCOPE,
|
||||||
|
)
|
||||||
|
|
||||||
CLIENT_ID = "client_id"
|
|
||||||
CLIENT_SECRET = "client_secret"
|
|
||||||
DISCOVERY_URL = "discovery_url"
|
|
||||||
DISPLAY_NAME = "display_name"
|
|
||||||
ID_TOKEN_SIGNING_ALGORITHM = "id_token_signing_alg"
|
|
||||||
GROUPS_SCOPE = "groups_scope"
|
|
||||||
ADDITIONAL_SCOPES = "additional_scopes"
|
|
||||||
FEATURES = "features"
|
|
||||||
FEATURES_AUTOMATIC_USER_LINKING = "automatic_user_linking"
|
|
||||||
FEATURES_AUTOMATIC_PERSON_CREATION = "automatic_person_creation"
|
|
||||||
FEATURES_DISABLE_PKCE = "disable_rfc7636"
|
|
||||||
FEATURES_INCLUDE_GROUPS_SCOPE = "include_groups_scope"
|
|
||||||
CLAIMS = "claims"
|
|
||||||
CLAIMS_DISPLAY_NAME = "display_name"
|
|
||||||
CLAIMS_USERNAME = "username"
|
|
||||||
CLAIMS_GROUPS = "groups"
|
|
||||||
ROLES = "roles"
|
|
||||||
ROLE_ADMINS = "admin"
|
|
||||||
ROLE_USERS = "user"
|
|
||||||
|
|
||||||
NETWORK = "network"
|
|
||||||
NETWORK_TLS_VERIFY = "tls_verify"
|
|
||||||
NETWORK_TLS_CA_PATH = "tls_ca_path"
|
|
||||||
|
|
||||||
DEFAULT_TITLE = "OpenID Connect (SSO)"
|
|
||||||
|
|
||||||
DOMAIN = "auth_oidc"
|
|
||||||
CONFIG_SCHEMA = vol.Schema(
|
CONFIG_SCHEMA = vol.Schema(
|
||||||
{
|
{
|
||||||
DOMAIN: vol.Schema(
|
DOMAIN: vol.Schema(
|
||||||
@@ -46,7 +46,9 @@ CONFIG_SCHEMA = vol.Schema(
|
|||||||
vol.Optional(ID_TOKEN_SIGNING_ALGORITHM): vol.Coerce(str),
|
vol.Optional(ID_TOKEN_SIGNING_ALGORITHM): vol.Coerce(str),
|
||||||
# String value to allow changing the groups scope
|
# String value to allow changing the groups scope
|
||||||
# Defaults to 'groups' which is used by Authelia and Authentik
|
# Defaults to 'groups' which is used by Authelia and Authentik
|
||||||
vol.Optional(GROUPS_SCOPE, default="groups"): vol.Coerce(str),
|
vol.Optional(GROUPS_SCOPE, default=DEFAULT_GROUPS_SCOPE): vol.Coerce(
|
||||||
|
str
|
||||||
|
),
|
||||||
# Additional scopes to request from the OIDC provider
|
# Additional scopes to request from the OIDC provider
|
||||||
# Optional, this field is unnecessary if you only use the openid and profile scopes.
|
# Optional, this field is unnecessary if you only use the openid and profile scopes.
|
||||||
vol.Optional(ADDITIONAL_SCOPES, default=[]): vol.Coerce(list[str]),
|
vol.Optional(ADDITIONAL_SCOPES, default=[]): vol.Coerce(list[str]),
|
||||||
@@ -69,6 +71,10 @@ CONFIG_SCHEMA = vol.Schema(
|
|||||||
vol.Optional(
|
vol.Optional(
|
||||||
FEATURES_INCLUDE_GROUPS_SCOPE, default=True
|
FEATURES_INCLUDE_GROUPS_SCOPE, default=True
|
||||||
): vol.Coerce(bool),
|
): vol.Coerce(bool),
|
||||||
|
# Force HTTPS on all generated URLs (like redirect_uri)
|
||||||
|
vol.Optional(FEATURES_FORCE_HTTPS, default=False): vol.Coerce(
|
||||||
|
bool
|
||||||
|
),
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
# Determine which specific claims will be used from the id_token
|
# Determine which specific claims will be used from the id_token
|
||||||
839
custom_components/auth_oidc/config/ui_flow.py
Normal file
839
custom_components/auth_oidc/config/ui_flow.py
Normal file
@@ -0,0 +1,839 @@
|
|||||||
|
"""Config flow for OIDC Authentication integration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant import config_entries
|
||||||
|
from homeassistant.core import callback
|
||||||
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
|
|
||||||
|
from .const import (
|
||||||
|
DOMAIN,
|
||||||
|
DEFAULT_ADMIN_GROUP,
|
||||||
|
CLIENT_ID,
|
||||||
|
CLIENT_SECRET,
|
||||||
|
DISCOVERY_URL,
|
||||||
|
DISPLAY_NAME,
|
||||||
|
FEATURES,
|
||||||
|
CLAIMS,
|
||||||
|
ROLES,
|
||||||
|
DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..tools.oidc_client import (
|
||||||
|
OIDCDiscoveryClient,
|
||||||
|
OIDCDiscoveryInvalid,
|
||||||
|
OIDCJWKSInvalid,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .provider_catalog import (
|
||||||
|
OIDC_PROVIDERS,
|
||||||
|
get_provider_name,
|
||||||
|
get_provider_docs_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
from ..tools.validation import (
|
||||||
|
validate_discovery_url,
|
||||||
|
sanitize_client_secret,
|
||||||
|
validate_client_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration field names
|
||||||
|
CONF_PROVIDER = "provider"
|
||||||
|
CONF_CLIENT_ID = "client_id"
|
||||||
|
CONF_CLIENT_SECRET = "client_secret"
|
||||||
|
CONF_DISCOVERY_URL = "discovery_url"
|
||||||
|
CONF_ENABLE_GROUPS = "enable_groups"
|
||||||
|
CONF_ADMIN_GROUP = "admin_group"
|
||||||
|
CONF_USER_GROUP = "user_group"
|
||||||
|
CONF_ENABLE_USER_LINKING = "enable_user_linking"
|
||||||
|
|
||||||
|
# Cache settings
|
||||||
|
DISCOVERY_CACHE_TTL = 300 # 5 minutes
|
||||||
|
MAX_CACHE_SIZE = 10
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FlowState:
|
||||||
|
"""State tracking for the configuration flow."""
|
||||||
|
|
||||||
|
provider: str | None = None
|
||||||
|
discovery_url: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClientConfig:
|
||||||
|
"""Client configuration settings."""
|
||||||
|
|
||||||
|
client_id: str | None = None
|
||||||
|
client_secret: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FeatureConfig:
|
||||||
|
"""Feature configuration settings."""
|
||||||
|
|
||||||
|
enable_groups: bool = False
|
||||||
|
admin_group: str = DEFAULT_ADMIN_GROUP
|
||||||
|
user_group: str | None = None
|
||||||
|
enable_user_linking: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class OIDCConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
|
||||||
|
"""Handle a config flow for OIDC Authentication."""
|
||||||
|
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
def is_matching(self, other_flow):
|
||||||
|
"""Check if this flow is the same as another flow."""
|
||||||
|
self_state = getattr(self, "_flow_state", None)
|
||||||
|
other_state = getattr(other_flow, "_flow_state", None)
|
||||||
|
|
||||||
|
if not self_state or not other_state:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self_discovery_url = self_state.discovery_url
|
||||||
|
other_discovery_url = other_state.discovery_url
|
||||||
|
|
||||||
|
return (
|
||||||
|
self_discovery_url
|
||||||
|
and other_discovery_url
|
||||||
|
and self_discovery_url.rstrip("/").lower()
|
||||||
|
== other_discovery_url.rstrip("/").lower()
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the config flow."""
|
||||||
|
self._flow_state = FlowState()
|
||||||
|
self._client_config = ClientConfig()
|
||||||
|
self._feature_config = FeatureConfig()
|
||||||
|
self._discovery_cache = {}
|
||||||
|
self._cache_timestamps = {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_provider_config(self) -> dict[str, Any]:
|
||||||
|
"""Get the configuration for the currently selected provider."""
|
||||||
|
if not self._flow_state.provider:
|
||||||
|
return {}
|
||||||
|
return OIDC_PROVIDERS.get(self._flow_state.provider, {})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def current_provider_name(self) -> str:
|
||||||
|
"""Get the name of the currently selected provider."""
|
||||||
|
return get_provider_name(self._flow_state.provider)
|
||||||
|
|
||||||
|
def _cleanup_discovery_cache(self) -> None:
|
||||||
|
"""Remove expired and excess cache entries."""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Remove expired entries
|
||||||
|
expired_keys = [
|
||||||
|
key
|
||||||
|
for key, timestamp in self._cache_timestamps.items()
|
||||||
|
if current_time - timestamp > DISCOVERY_CACHE_TTL
|
||||||
|
]
|
||||||
|
for key in expired_keys:
|
||||||
|
self._discovery_cache.pop(key, None)
|
||||||
|
self._cache_timestamps.pop(key, None)
|
||||||
|
|
||||||
|
# Remove oldest entries if cache is too large
|
||||||
|
if len(self._discovery_cache) > MAX_CACHE_SIZE:
|
||||||
|
sorted_items = sorted(self._cache_timestamps.items(), key=lambda x: x[1])
|
||||||
|
excess_count = len(self._discovery_cache) - MAX_CACHE_SIZE
|
||||||
|
for key, _ in sorted_items[:excess_count]:
|
||||||
|
self._discovery_cache.pop(key, None)
|
||||||
|
self._cache_timestamps.pop(key, None)
|
||||||
|
|
||||||
|
def _is_cache_valid(self, cache_key: str) -> bool:
|
||||||
|
"""Check if a cache entry is still valid."""
|
||||||
|
if cache_key not in self._cache_timestamps:
|
||||||
|
return False
|
||||||
|
|
||||||
|
age = time.time() - self._cache_timestamps[cache_key]
|
||||||
|
return age <= DISCOVERY_CACHE_TTL
|
||||||
|
|
||||||
|
# =================
|
||||||
|
# Step 1: Provider selection
|
||||||
|
# =================
|
||||||
|
|
||||||
|
async def async_step_user(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Handle the initial step - provider selection."""
|
||||||
|
# Check if OIDC is already configured (only one instance allowed)
|
||||||
|
if self._async_current_entries():
|
||||||
|
return self.async_abort(reason="single_instance_allowed")
|
||||||
|
|
||||||
|
# Check if YAML configuration exists
|
||||||
|
if self.hass.data.get(DOMAIN, {}).get("yaml_config"):
|
||||||
|
return self.async_abort(reason="yaml_configured")
|
||||||
|
|
||||||
|
errors = {}
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
self._flow_state.provider = user_input[CONF_PROVIDER]
|
||||||
|
|
||||||
|
# If provider has a predefined discovery URL, prefill it but still
|
||||||
|
# show the discovery URL step so the user can customize it.
|
||||||
|
predefined = self.current_provider_config.get("discovery_url")
|
||||||
|
if predefined:
|
||||||
|
self._flow_state.discovery_url = predefined
|
||||||
|
|
||||||
|
# Always request discovery URL next (prefilled when available)
|
||||||
|
return await self.async_step_discovery_url()
|
||||||
|
|
||||||
|
data_schema = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required(CONF_PROVIDER): vol.In(
|
||||||
|
{key: provider["name"] for key, provider in OIDC_PROVIDERS.items()}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="user",
|
||||||
|
data_schema=data_schema,
|
||||||
|
errors=errors,
|
||||||
|
description_placeholders={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# =================
|
||||||
|
# Step 2: Discovery URL
|
||||||
|
# =================
|
||||||
|
|
||||||
|
async def async_step_discovery_url(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Handle discovery URL input for providers requiring URL configuration."""
|
||||||
|
errors = {}
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
discovery_url = user_input[CONF_DISCOVERY_URL].rstrip("/")
|
||||||
|
|
||||||
|
# Validate discovery URL format
|
||||||
|
if not validate_discovery_url(discovery_url):
|
||||||
|
errors["discovery_url"] = "invalid_url_format"
|
||||||
|
else:
|
||||||
|
self._flow_state.discovery_url = discovery_url
|
||||||
|
return await self.async_step_validate_connection()
|
||||||
|
|
||||||
|
provider_name = self.current_provider_name
|
||||||
|
provider_key = self._flow_state.provider
|
||||||
|
|
||||||
|
# Pre-populate with existing discovery URL if available
|
||||||
|
default_url = (
|
||||||
|
self._flow_state.discovery_url
|
||||||
|
if self._flow_state.discovery_url
|
||||||
|
else vol.UNDEFINED
|
||||||
|
)
|
||||||
|
|
||||||
|
data_schema = vol.Schema(
|
||||||
|
{vol.Required(CONF_DISCOVERY_URL, default=default_url): str}
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="discovery_url",
|
||||||
|
data_schema=data_schema,
|
||||||
|
errors=errors,
|
||||||
|
description_placeholders={
|
||||||
|
"provider_name": provider_name,
|
||||||
|
"documentation_url": get_provider_docs_url(provider_key),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# =================
|
||||||
|
# Step 3: Discovery Validation
|
||||||
|
# =================
|
||||||
|
|
||||||
|
async def _handle_validation_actions(
|
||||||
|
self, user_input: dict[str, Any]
|
||||||
|
) -> FlowResult | None:
|
||||||
|
"""Handle user actions from the validation form so they can fix errors."""
|
||||||
|
action = user_input.get("action")
|
||||||
|
|
||||||
|
# Handle special actions first
|
||||||
|
if action == "retry":
|
||||||
|
return None # Continue with validation
|
||||||
|
if action == "continue":
|
||||||
|
return await self.async_step_client_config()
|
||||||
|
|
||||||
|
# Handle redirect actions
|
||||||
|
action_handlers = {
|
||||||
|
"fix_discovery": self.async_step_discovery_url,
|
||||||
|
"change_provider": self.async_step_user,
|
||||||
|
}
|
||||||
|
|
||||||
|
handler = action_handlers.get(action)
|
||||||
|
return await handler() if handler else None
|
||||||
|
|
||||||
|
async def _perform_oidc_validation(self) -> tuple[dict, dict]:
|
||||||
|
"""Perform the actual OIDC validation and return discovery doc and errors."""
|
||||||
|
errors = {}
|
||||||
|
discovery_doc = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
http_session = aiohttp.ClientSession()
|
||||||
|
discovery_client = OIDCDiscoveryClient(
|
||||||
|
discovery_url=self._flow_state.discovery_url,
|
||||||
|
http_session=http_session,
|
||||||
|
verification_context={
|
||||||
|
# Cannot be changed from the UI config currently
|
||||||
|
"id_token_signing_alg": DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Clean up expired cache entries first
|
||||||
|
self._cleanup_discovery_cache()
|
||||||
|
|
||||||
|
# Check if discovery document is already cached and valid
|
||||||
|
cache_key = self._flow_state.discovery_url
|
||||||
|
if cache_key in self._discovery_cache and self._is_cache_valid(cache_key):
|
||||||
|
discovery_doc = self._discovery_cache[cache_key]
|
||||||
|
|
||||||
|
# Still validate JWKS if available since this might be a retry
|
||||||
|
if "jwks_uri" in discovery_doc:
|
||||||
|
await discovery_client.fetch_jwks(discovery_doc["jwks_uri"])
|
||||||
|
else:
|
||||||
|
# Perform discovery and JWKS validation
|
||||||
|
discovery_doc = await discovery_client.fetch_discovery_document()
|
||||||
|
|
||||||
|
# Cache the discovery document with timestamp
|
||||||
|
self._discovery_cache[cache_key] = discovery_doc
|
||||||
|
self._cache_timestamps[cache_key] = time.time()
|
||||||
|
|
||||||
|
# Validate JWKS if available
|
||||||
|
if "jwks_uri" in discovery_doc:
|
||||||
|
await discovery_client.fetch_jwks(discovery_doc["jwks_uri"])
|
||||||
|
|
||||||
|
except OIDCDiscoveryInvalid as e:
|
||||||
|
errors["base"] = "discovery_invalid"
|
||||||
|
errors["detail_string"] = e.get_detail_string()
|
||||||
|
except OIDCJWKSInvalid:
|
||||||
|
errors["base"] = "jwks_invalid"
|
||||||
|
except aiohttp.ClientError:
|
||||||
|
errors["base"] = "cannot_connect"
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
_LOGGER.exception("Unexpected error during validation")
|
||||||
|
errors["base"] = "unknown"
|
||||||
|
|
||||||
|
await http_session.close()
|
||||||
|
return discovery_doc, errors
|
||||||
|
|
||||||
|
def _get_action_options(self, has_errors: bool) -> dict[str, str]:
|
||||||
|
"""Get action options based on validation state."""
|
||||||
|
if has_errors:
|
||||||
|
return {
|
||||||
|
"retry": "Retry Validation",
|
||||||
|
"fix_discovery": "Change Discovery URL",
|
||||||
|
"change_provider": "Change Provider",
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"continue": "Continue Setup",
|
||||||
|
"fix_discovery": "Change Discovery URL",
|
||||||
|
"change_provider": "Change Provider",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_discovery_success_details(self, discovery_doc: dict) -> str:
|
||||||
|
"""Build success details from discovery document."""
|
||||||
|
return (
|
||||||
|
f"✅ Connected and verified succesfully!\n"
|
||||||
|
f"_Discovered valid OIDC issuer: {discovery_doc['issuer']}_\n\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _build_error_details(self, errors: dict[str, str]) -> str:
|
||||||
|
"""Build error details from validation errors."""
|
||||||
|
|
||||||
|
base = errors.get("base", "")
|
||||||
|
detail_string = errors.get("detail_string", "")
|
||||||
|
|
||||||
|
error_messages = {
|
||||||
|
"discovery_invalid": (
|
||||||
|
"❌ **Discovery document could not be validated.**\n"
|
||||||
|
"Please verify the discovery URL is correct and accessible.\n\n"
|
||||||
|
f"_({detail_string})_"
|
||||||
|
),
|
||||||
|
"jwks_invalid": (
|
||||||
|
"❌ **JWKS validation failed**\n"
|
||||||
|
"The JSON Web Key Set could not be retrieved or validated."
|
||||||
|
),
|
||||||
|
"cannot_connect": (
|
||||||
|
"❌ **Connection failed**\n"
|
||||||
|
"Unable to connect to the OIDC provider. Check your network and URL."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
return error_messages.get(base, "")
|
||||||
|
|
||||||
|
async def _build_validation_form(
|
||||||
|
self, errors: dict[str, str], discovery_doc: dict | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Build the validation form with errors and action options."""
|
||||||
|
action_options = self._get_action_options(bool(errors))
|
||||||
|
data_schema = vol.Schema({vol.Required("action"): vol.In(action_options)})
|
||||||
|
|
||||||
|
# Build description with discovery details
|
||||||
|
description_placeholders = {
|
||||||
|
"discovery_url": self._flow_state.discovery_url,
|
||||||
|
"provider_name": self.current_provider_name,
|
||||||
|
"discovery_details": "",
|
||||||
|
"documentation_url": get_provider_docs_url(self._flow_state.provider),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add appropriate details based on validation state
|
||||||
|
if discovery_doc and not errors:
|
||||||
|
description_placeholders["discovery_details"] = (
|
||||||
|
self._build_discovery_success_details(discovery_doc)
|
||||||
|
)
|
||||||
|
elif errors:
|
||||||
|
description_placeholders["discovery_details"] = self._build_error_details(
|
||||||
|
errors
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="validate_connection",
|
||||||
|
data_schema=data_schema,
|
||||||
|
errors=errors,
|
||||||
|
description_placeholders=description_placeholders,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_step_validate_connection(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Validate the OIDC configuration by testing discovery and JWKS."""
|
||||||
|
# Handle user actions from validation form
|
||||||
|
if user_input is not None:
|
||||||
|
action_result = await self._handle_validation_actions(user_input)
|
||||||
|
if action_result is not None:
|
||||||
|
return action_result
|
||||||
|
|
||||||
|
# Perform validation (either initial attempt or retry)
|
||||||
|
discovery_doc, errors = await self._perform_oidc_validation()
|
||||||
|
|
||||||
|
# Always show validation form with results (success or error)
|
||||||
|
return await self._build_validation_form(errors, discovery_doc)
|
||||||
|
|
||||||
|
# =================
|
||||||
|
# Step 4: Configure client details (client_id & client_secret)
|
||||||
|
# =================
|
||||||
|
|
||||||
|
async def _proceed_to_next_step_after_client_config(self) -> FlowResult:
|
||||||
|
"""Proceed to next step after client config."""
|
||||||
|
if self.current_provider_config.get("supports_groups", True):
|
||||||
|
return await self.async_step_groups_config()
|
||||||
|
return await self.async_step_user_linking()
|
||||||
|
|
||||||
|
async def async_step_client_config(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Handle client ID and client type selection."""
|
||||||
|
errors = {}
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
client_id = user_input[CONF_CLIENT_ID]
|
||||||
|
|
||||||
|
# Validate client ID
|
||||||
|
if not validate_client_id(client_id):
|
||||||
|
errors["client_id"] = "invalid_client_id"
|
||||||
|
if not errors:
|
||||||
|
self._client_config.client_id = client_id.strip()
|
||||||
|
# Optional client secret determines confidential/public
|
||||||
|
provided_secret = sanitize_client_secret(
|
||||||
|
user_input.get(CONF_CLIENT_SECRET, "")
|
||||||
|
)
|
||||||
|
self._client_config.client_secret = provided_secret or None
|
||||||
|
|
||||||
|
if not errors:
|
||||||
|
return await self._proceed_to_next_step_after_client_config()
|
||||||
|
|
||||||
|
provider_name = self.current_provider_name
|
||||||
|
|
||||||
|
# Pre-populate with existing values if available
|
||||||
|
default_client_id = (
|
||||||
|
self._client_config.client_id
|
||||||
|
if self._client_config.client_id
|
||||||
|
else vol.UNDEFINED
|
||||||
|
)
|
||||||
|
default_client_secret = (
|
||||||
|
self._client_config.client_secret
|
||||||
|
if self._client_config.client_secret
|
||||||
|
else vol.UNDEFINED
|
||||||
|
)
|
||||||
|
|
||||||
|
data_schema = vol.Schema(
|
||||||
|
{
|
||||||
|
vol.Required(CONF_CLIENT_ID, default=default_client_id): str,
|
||||||
|
vol.Optional(CONF_CLIENT_SECRET, default=default_client_secret): str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="client_config",
|
||||||
|
data_schema=data_schema,
|
||||||
|
errors=errors,
|
||||||
|
description_placeholders={
|
||||||
|
"provider_name": provider_name,
|
||||||
|
"discovery_url": self._flow_state.discovery_url,
|
||||||
|
"documentation_url": get_provider_docs_url(self._flow_state.provider),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# =================
|
||||||
|
# Step 5: Configure groups settings
|
||||||
|
# =================
|
||||||
|
|
||||||
|
async def async_step_groups_config(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Configure groups and roles."""
|
||||||
|
errors = {}
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
self._feature_config.enable_groups = user_input.get(
|
||||||
|
CONF_ENABLE_GROUPS, False
|
||||||
|
)
|
||||||
|
if self._feature_config.enable_groups:
|
||||||
|
self._feature_config.admin_group = user_input.get(
|
||||||
|
CONF_ADMIN_GROUP, "admins"
|
||||||
|
)
|
||||||
|
self._feature_config.user_group = user_input.get(CONF_USER_GROUP)
|
||||||
|
|
||||||
|
return await self.async_step_user_linking()
|
||||||
|
|
||||||
|
default_admin_group = self.current_provider_config.get(
|
||||||
|
"default_admin_group", "admins"
|
||||||
|
)
|
||||||
|
|
||||||
|
data_schema_dict = {vol.Optional(CONF_ENABLE_GROUPS, default=True): bool}
|
||||||
|
|
||||||
|
# Add group configuration fields if groups are enabled
|
||||||
|
if user_input is None or user_input.get(CONF_ENABLE_GROUPS, True):
|
||||||
|
data_schema_dict.update(
|
||||||
|
{
|
||||||
|
vol.Optional(CONF_ADMIN_GROUP, default=default_admin_group): str,
|
||||||
|
vol.Optional(CONF_USER_GROUP): str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
data_schema = vol.Schema(data_schema_dict)
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="groups_config",
|
||||||
|
data_schema=data_schema,
|
||||||
|
errors=errors,
|
||||||
|
description_placeholders={"provider_name": self.current_provider_name},
|
||||||
|
)
|
||||||
|
|
||||||
|
# =================
|
||||||
|
# Step 6: Configure user linking
|
||||||
|
# =================
|
||||||
|
|
||||||
|
async def async_step_user_linking(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Configure user linking options."""
|
||||||
|
errors = {}
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
self._feature_config.enable_user_linking = user_input.get(
|
||||||
|
CONF_ENABLE_USER_LINKING, False
|
||||||
|
)
|
||||||
|
return await self.async_step_finalize()
|
||||||
|
|
||||||
|
data_schema = vol.Schema(
|
||||||
|
{vol.Optional(CONF_ENABLE_USER_LINKING, default=False): bool}
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="user_linking",
|
||||||
|
data_schema=data_schema,
|
||||||
|
errors=errors,
|
||||||
|
description_placeholders={},
|
||||||
|
)
|
||||||
|
|
||||||
|
# =================
|
||||||
|
# Step 7: Finalize and create entry
|
||||||
|
# =================
|
||||||
|
|
||||||
|
async def async_step_finalize(self) -> FlowResult:
|
||||||
|
"""Finalize the configuration and create the config entry."""
|
||||||
|
await self.async_set_unique_id(DOMAIN)
|
||||||
|
self._abort_if_unique_id_configured()
|
||||||
|
|
||||||
|
# Build the configuration
|
||||||
|
config_data = {
|
||||||
|
"provider": self._flow_state.provider,
|
||||||
|
"client_id": self._client_config.client_id,
|
||||||
|
"discovery_url": self._flow_state.discovery_url,
|
||||||
|
"display_name": f"{self.current_provider_name}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add optional fields
|
||||||
|
if self._client_config.client_secret:
|
||||||
|
config_data["client_secret"] = self._client_config.client_secret
|
||||||
|
|
||||||
|
# Configure features
|
||||||
|
features = {
|
||||||
|
"automatic_user_linking": self._feature_config.enable_user_linking,
|
||||||
|
"automatic_person_creation": True,
|
||||||
|
"include_groups_scope": self._feature_config.enable_groups,
|
||||||
|
}
|
||||||
|
config_data["features"] = features
|
||||||
|
|
||||||
|
# Configure claims using provider defaults
|
||||||
|
claims = self.current_provider_config["claims"].copy()
|
||||||
|
config_data["claims"] = claims
|
||||||
|
|
||||||
|
# Configure roles if groups are enabled
|
||||||
|
if self._feature_config.enable_groups:
|
||||||
|
roles = {}
|
||||||
|
if self._feature_config.admin_group:
|
||||||
|
roles["admin"] = self._feature_config.admin_group
|
||||||
|
if self._feature_config.user_group:
|
||||||
|
roles["user"] = self._feature_config.user_group
|
||||||
|
config_data["roles"] = roles
|
||||||
|
|
||||||
|
title = f"{self.current_provider_name}"
|
||||||
|
|
||||||
|
return self.async_create_entry(title=title, data=config_data)
|
||||||
|
|
||||||
|
# =================
|
||||||
|
# Allow reconfiguration of client ID and secret
|
||||||
|
# =================
|
||||||
|
|
||||||
|
async def _validate_reconfigure_input(
|
||||||
|
self, entry, user_input: dict[str, Any]
|
||||||
|
) -> tuple[dict[str, str], dict[str, Any] | None]:
|
||||||
|
"""Validate reconfigure input and return errors and data updates."""
|
||||||
|
errors = {}
|
||||||
|
|
||||||
|
# Validate client ID
|
||||||
|
client_id = user_input[CONF_CLIENT_ID].strip()
|
||||||
|
if not validate_client_id(client_id):
|
||||||
|
errors["client_id"] = "invalid_client_id"
|
||||||
|
return errors, None
|
||||||
|
|
||||||
|
# Build updated data
|
||||||
|
data_updates = {"client_id": client_id}
|
||||||
|
|
||||||
|
# The optional secret field is submitted explicitly when the form is used.
|
||||||
|
# An empty value means the user wants to keep the existing secret.
|
||||||
|
if CONF_CLIENT_SECRET in user_input:
|
||||||
|
client_secret = user_input.get(CONF_CLIENT_SECRET, "").strip()
|
||||||
|
|
||||||
|
if client_secret:
|
||||||
|
data_updates["client_secret"] = client_secret
|
||||||
|
elif "client_secret" in entry.data:
|
||||||
|
data_updates["client_secret"] = entry.data["client_secret"]
|
||||||
|
|
||||||
|
return errors, data_updates
|
||||||
|
|
||||||
|
def _build_reconfigure_schema(
|
||||||
|
self, current_data: dict[str, Any], _user_input: dict[str, Any] | None
|
||||||
|
) -> vol.Schema:
|
||||||
|
"""Build the reconfigure form schema."""
|
||||||
|
schema_dict = {
|
||||||
|
vol.Required(
|
||||||
|
CONF_CLIENT_ID, default=current_data.get("client_id", vol.UNDEFINED)
|
||||||
|
): str,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Always allow updating or clearing the client secret
|
||||||
|
schema_dict[vol.Optional(CONF_CLIENT_SECRET)] = str
|
||||||
|
|
||||||
|
return vol.Schema(schema_dict)
|
||||||
|
|
||||||
|
async def async_step_reconfigure(
|
||||||
|
self, user_input: dict[str, Any] | None = None
|
||||||
|
) -> FlowResult:
|
||||||
|
"""Handle reconfiguration of OIDC client credentials."""
|
||||||
|
errors = {}
|
||||||
|
entry = self._get_reconfigure_entry()
|
||||||
|
if entry is None:
|
||||||
|
return self.async_abort(reason="unknown")
|
||||||
|
|
||||||
|
if user_input is not None:
|
||||||
|
try:
|
||||||
|
errors, data_updates = await self._validate_reconfigure_input(
|
||||||
|
entry, user_input
|
||||||
|
)
|
||||||
|
|
||||||
|
if not errors:
|
||||||
|
# Update the config entry
|
||||||
|
await self.async_set_unique_id(entry.unique_id)
|
||||||
|
self._abort_if_unique_id_mismatch()
|
||||||
|
|
||||||
|
return self.async_update_reload_and_abort(
|
||||||
|
entry, data_updates=data_updates
|
||||||
|
)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
_LOGGER.exception("Unexpected error during reconfiguration")
|
||||||
|
errors["base"] = "unknown"
|
||||||
|
|
||||||
|
# Show form
|
||||||
|
current_data = entry.data
|
||||||
|
data_schema = self._build_reconfigure_schema(current_data, user_input)
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="reconfigure",
|
||||||
|
data_schema=data_schema,
|
||||||
|
errors=errors,
|
||||||
|
description_placeholders={
|
||||||
|
"provider_name": get_provider_name(current_data.get("provider")),
|
||||||
|
"discovery_url": current_data.get("discovery_url", ""),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_reconfigure_entry(self):
|
||||||
|
"""Return the config entry being reconfigured if available.
|
||||||
|
|
||||||
|
Prefer the entry referenced by the flow context's entry_id. Fall back to the
|
||||||
|
first existing entry for this domain when only a single instance is allowed.
|
||||||
|
"""
|
||||||
|
# Try from flow context (preferred)
|
||||||
|
entry_id = None
|
||||||
|
context = getattr(self, "context", None)
|
||||||
|
if context and hasattr(context, "get"):
|
||||||
|
entry_id = context.get("entry_id")
|
||||||
|
|
||||||
|
if entry_id:
|
||||||
|
entry = self.hass.config_entries.async_get_entry(entry_id)
|
||||||
|
if entry and entry.domain == DOMAIN:
|
||||||
|
return entry
|
||||||
|
|
||||||
|
# Fallback: this integration allows a single instance; use the first
|
||||||
|
current = self._async_current_entries()
|
||||||
|
if current:
|
||||||
|
return current[0]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@callback
|
||||||
|
def async_get_options_flow(config_entry):
|
||||||
|
"""Get the options flow for this handler."""
|
||||||
|
return OIDCOptionsFlowHandler()
|
||||||
|
|
||||||
|
|
||||||
|
class OIDCOptionsFlowHandler(config_entries.OptionsFlow):
|
||||||
|
"""Handle options flow for OIDC Authentication."""
|
||||||
|
|
||||||
|
async def async_step_init(self, user_input=None):
|
||||||
|
"""Handle options flow."""
|
||||||
|
if user_input is not None:
|
||||||
|
# Process the updated configuration
|
||||||
|
updated_features = {
|
||||||
|
"automatic_user_linking": user_input.get("enable_user_linking", False),
|
||||||
|
"include_groups_scope": user_input.get("enable_groups", False),
|
||||||
|
}
|
||||||
|
|
||||||
|
updated_roles = {}
|
||||||
|
if user_input.get("enable_groups", False):
|
||||||
|
if user_input.get("admin_group"):
|
||||||
|
updated_roles["admin"] = user_input["admin_group"]
|
||||||
|
if user_input.get("user_group"):
|
||||||
|
updated_roles["user"] = user_input["user_group"]
|
||||||
|
|
||||||
|
# Update the config entry data
|
||||||
|
new_data = self.config_entry.data.copy()
|
||||||
|
new_data["features"] = {**new_data.get("features", {}), **updated_features}
|
||||||
|
if updated_roles:
|
||||||
|
new_data["roles"] = updated_roles
|
||||||
|
elif "roles" in new_data:
|
||||||
|
# Remove roles if groups are disabled
|
||||||
|
if not user_input.get("enable_groups", False):
|
||||||
|
del new_data["roles"]
|
||||||
|
|
||||||
|
# Update the config entry
|
||||||
|
self.hass.config_entries.async_update_entry(
|
||||||
|
self.config_entry, data=new_data
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_create_entry(title="", data={})
|
||||||
|
|
||||||
|
current_config = self.config_entry.data
|
||||||
|
current_features = current_config.get("features", {})
|
||||||
|
current_roles = current_config.get("roles", {})
|
||||||
|
|
||||||
|
# Determine if this provider supports groups
|
||||||
|
provider = current_config.get("provider", "authentik")
|
||||||
|
provider_supports_groups = OIDC_PROVIDERS.get(provider, {}).get(
|
||||||
|
"supports_groups", True
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build schema based on provider capabilities
|
||||||
|
schema_dict = {
|
||||||
|
vol.Optional(
|
||||||
|
"enable_user_linking",
|
||||||
|
default=current_features.get("automatic_user_linking", False),
|
||||||
|
): bool
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add groups options if provider supports them
|
||||||
|
if provider_supports_groups:
|
||||||
|
enable_groups_default = current_features.get("include_groups_scope", False)
|
||||||
|
schema_dict[
|
||||||
|
vol.Optional("enable_groups", default=enable_groups_default)
|
||||||
|
] = bool
|
||||||
|
|
||||||
|
# Add group name fields if groups are currently enabled or being enabled
|
||||||
|
if enable_groups_default or (
|
||||||
|
user_input and user_input.get("enable_groups", False)
|
||||||
|
):
|
||||||
|
schema_dict.update(
|
||||||
|
{
|
||||||
|
vol.Optional(
|
||||||
|
"admin_group",
|
||||||
|
default=current_roles.get("admin", DEFAULT_ADMIN_GROUP),
|
||||||
|
): str,
|
||||||
|
vol.Optional(
|
||||||
|
"user_group", default=current_roles.get("user", "")
|
||||||
|
): str,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.async_show_form(
|
||||||
|
step_id="init",
|
||||||
|
data_schema=vol.Schema(schema_dict),
|
||||||
|
description_placeholders={
|
||||||
|
"provider_name": get_provider_name(provider),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_ui_config_entry_to_internal_format(config_data: dict) -> dict:
|
||||||
|
"""Convert config entry data to internal configuration format."""
|
||||||
|
my_config = {}
|
||||||
|
|
||||||
|
# Required fields
|
||||||
|
my_config[CLIENT_ID] = config_data["client_id"]
|
||||||
|
my_config[DISCOVERY_URL] = config_data["discovery_url"]
|
||||||
|
|
||||||
|
# Optional fields
|
||||||
|
if "client_secret" in config_data:
|
||||||
|
my_config[CLIENT_SECRET] = config_data["client_secret"]
|
||||||
|
|
||||||
|
if "display_name" in config_data:
|
||||||
|
my_config[DISPLAY_NAME] = config_data["display_name"]
|
||||||
|
|
||||||
|
# Features configuration
|
||||||
|
if "features" in config_data:
|
||||||
|
my_config[FEATURES] = config_data["features"]
|
||||||
|
|
||||||
|
# Claims configuration
|
||||||
|
if "claims" in config_data:
|
||||||
|
my_config[CLAIMS] = config_data["claims"]
|
||||||
|
|
||||||
|
# Roles configuration
|
||||||
|
if "roles" in config_data:
|
||||||
|
my_config[ROLES] = config_data["roles"]
|
||||||
|
|
||||||
|
return my_config
|
||||||
5
custom_components/auth_oidc/config_flow.py
Normal file
5
custom_components/auth_oidc/config_flow.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""UI config flow re-export"""
|
||||||
|
|
||||||
|
# pylint: disable=useless-import-alias
|
||||||
|
# pylint: disable=unused-import
|
||||||
|
from .config.ui_flow import OIDCConfigFlow as OIDCConfigFlow
|
||||||
8
custom_components/auth_oidc/endpoints/__init__.py
Normal file
8
custom_components/auth_oidc/endpoints/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Imports manager"""
|
||||||
|
|
||||||
|
from .callback import OIDCCallbackView as OIDCCallbackView
|
||||||
|
from .finish import OIDCFinishView as OIDCFinishView
|
||||||
|
from .injected_auth_page import OIDCInjectedAuthPage as OIDCInjectedAuthPage
|
||||||
|
from .redirect import OIDCRedirectView as OIDCRedirectView
|
||||||
|
from .welcome import OIDCWelcomeView as OIDCWelcomeView
|
||||||
|
from .device_sse import OIDCDeviceSSE as OIDCDeviceSSE
|
||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from ..oidc_client import OIDCClient
|
from ..tools.oidc_client import OIDCClient
|
||||||
from ..provider import OpenIDAuthProvider
|
from ..provider import OpenIDAuthProvider
|
||||||
from ..helpers import get_url, get_view
|
from ..tools.helpers import error_response, get_url, get_valid_state_id
|
||||||
|
|
||||||
PATH = "/auth/oidc/callback"
|
PATH = "/auth/oidc/callback"
|
||||||
|
|
||||||
@@ -17,50 +17,61 @@ class OIDCCallbackView(HomeAssistantView):
|
|||||||
name = "auth:oidc:callback"
|
name = "auth:oidc:callback"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, oidc_client: OIDCClient, oidc_provider: OpenIDAuthProvider
|
self,
|
||||||
|
oidc_client: OIDCClient,
|
||||||
|
oidc_provider: OpenIDAuthProvider,
|
||||||
|
force_https: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.oidc_client = oidc_client
|
self.oidc_client = oidc_client
|
||||||
self.oidc_provider = oidc_provider
|
self.oidc_provider = oidc_provider
|
||||||
|
self.force_https = force_https
|
||||||
|
|
||||||
async def get(self, request: web.Request) -> web.Response:
|
async def get(self, request: web.Request) -> web.Response:
|
||||||
"""Receive response."""
|
"""Receive response."""
|
||||||
|
|
||||||
|
# Get cookie to get the state_id
|
||||||
|
state_id = await get_valid_state_id(request, self.oidc_provider)
|
||||||
|
if not state_id:
|
||||||
|
return await error_response("Missing state cookie, please restart login.")
|
||||||
|
|
||||||
|
# Get the OIDC query parameters
|
||||||
params = request.rel_url.query
|
params = request.rel_url.query
|
||||||
code = params.get("code")
|
code = params.get("code")
|
||||||
state = params.get("state")
|
state = params.get("state")
|
||||||
|
|
||||||
if not (code and state):
|
if not (code and state):
|
||||||
view_html = await get_view(
|
return await error_response("Missing code or state parameter.")
|
||||||
"error",
|
|
||||||
{
|
|
||||||
"error": "Missing code or state parameter.",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return web.Response(text=view_html, content_type="text/html")
|
|
||||||
|
|
||||||
redirect_uri = get_url("/auth/oidc/callback")
|
# Check if the states match
|
||||||
|
if state != state_id:
|
||||||
|
return await error_response(
|
||||||
|
"State parameter does not match, possible CSRF attack."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Complete the OIDC flow to get user details
|
||||||
|
redirect_uri = get_url("/auth/oidc/callback", self.force_https)
|
||||||
user_details = await self.oidc_client.async_complete_token_flow(
|
user_details = await self.oidc_client.async_complete_token_flow(
|
||||||
redirect_uri, code, state
|
redirect_uri, code, state
|
||||||
)
|
)
|
||||||
if user_details is None:
|
if user_details is None:
|
||||||
view_html = await get_view(
|
return await error_response(
|
||||||
"error",
|
"Failed to get user details, see Home Assistant logs for more information.",
|
||||||
{
|
status=500,
|
||||||
"error": "Failed to get user details, "
|
|
||||||
+ "see Home Assistant logs for more information.",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return web.Response(text=view_html, content_type="text/html")
|
|
||||||
|
|
||||||
if user_details.get("role") == "invalid":
|
if user_details.get("role") == "invalid":
|
||||||
view_html = await get_view(
|
return await error_response(
|
||||||
"error",
|
"User is not in the correct group to access Home Assistant, "
|
||||||
{
|
|
||||||
"error": "User is not in the correct group to access Home Assistant, "
|
|
||||||
+ "contact your administrator!",
|
+ "contact your administrator!",
|
||||||
},
|
status=403,
|
||||||
)
|
)
|
||||||
return web.Response(text=view_html, content_type="text/html")
|
|
||||||
|
|
||||||
code = await self.oidc_provider.async_save_user_info(user_details)
|
# Finalize on the state
|
||||||
return web.HTTPFound(get_url("/auth/oidc/finish?code=" + code))
|
success = await self.oidc_provider.async_save_user_info(state_id, user_details)
|
||||||
|
if not success:
|
||||||
|
return await error_response(
|
||||||
|
"Failed to save user information, session probably expired. Please sign in again.",
|
||||||
|
status=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise web.HTTPFound(get_url("/auth/oidc/finish", self.force_https))
|
||||||
|
|||||||
70
custom_components/auth_oidc/endpoints/device_sse.py
Normal file
70
custom_components/auth_oidc/endpoints/device_sse.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
"""SSE handler for OIDC device authentication."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from aiohttp import web
|
||||||
|
from homeassistant.components.http import HomeAssistantView
|
||||||
|
from ..provider import OpenIDAuthProvider
|
||||||
|
from ..tools.helpers import get_valid_state_id
|
||||||
|
|
||||||
|
PATH = "/auth/oidc/device-sse"
|
||||||
|
|
||||||
|
|
||||||
|
class OIDCDeviceSSE(HomeAssistantView):
|
||||||
|
"""OIDC Plugin SSE Handler."""
|
||||||
|
|
||||||
|
requires_auth = False
|
||||||
|
url = PATH
|
||||||
|
name = "auth:oidc:device-sse"
|
||||||
|
|
||||||
|
def __init__(self, oidc_provider: OpenIDAuthProvider) -> None:
|
||||||
|
self.oidc_provider = oidc_provider
|
||||||
|
|
||||||
|
async def get(self, req: web.Request) -> web.Response:
|
||||||
|
"""Check for mobile sign-in completion with short server-side polling."""
|
||||||
|
state_id = await get_valid_state_id(req, self.oidc_provider)
|
||||||
|
if not state_id:
|
||||||
|
raise web.HTTPBadRequest(text="Missing session cookie")
|
||||||
|
|
||||||
|
timeout_seconds = 300
|
||||||
|
started_at = asyncio.get_running_loop().time()
|
||||||
|
|
||||||
|
response = web.StreamResponse(
|
||||||
|
status=200,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "text/event-stream",
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await response.prepare(req)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if (
|
||||||
|
await self.oidc_provider.async_get_redirect_uri_for_state(state_id)
|
||||||
|
is None
|
||||||
|
):
|
||||||
|
await response.write(b"event: expired\ndata: false\n\n")
|
||||||
|
break
|
||||||
|
|
||||||
|
ready = await self.oidc_provider.async_is_state_ready(state_id)
|
||||||
|
if ready:
|
||||||
|
await response.write(b"event: ready\ndata: true\n\n")
|
||||||
|
break
|
||||||
|
|
||||||
|
if asyncio.get_running_loop().time() - started_at >= timeout_seconds:
|
||||||
|
await response.write(b"event: timeout\ndata: false\n\n")
|
||||||
|
break
|
||||||
|
|
||||||
|
await response.write(b"event: waiting\ndata: false\n\n")
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
except (ConnectionResetError, RuntimeError):
|
||||||
|
# Client disconnected while listening for state changes.
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
try:
|
||||||
|
await response.write_eof()
|
||||||
|
except ConnectionResetError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return response
|
||||||
@@ -2,7 +2,12 @@
|
|||||||
|
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from ..helpers import get_view
|
from ..provider import OpenIDAuthProvider
|
||||||
|
from ..tools.helpers import (
|
||||||
|
error_response,
|
||||||
|
get_valid_state_id,
|
||||||
|
template_response,
|
||||||
|
)
|
||||||
|
|
||||||
PATH = "/auth/oidc/finish"
|
PATH = "/auth/oidc/finish"
|
||||||
|
|
||||||
@@ -14,41 +19,60 @@ class OIDCFinishView(HomeAssistantView):
|
|||||||
url = PATH
|
url = PATH
|
||||||
name = "auth:oidc:finish"
|
name = "auth:oidc:finish"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
oidc_provider: OpenIDAuthProvider,
|
||||||
|
) -> None:
|
||||||
|
self.oidc_provider = oidc_provider
|
||||||
|
|
||||||
async def get(self, request: web.Request) -> web.Response:
|
async def get(self, request: web.Request) -> web.Response:
|
||||||
"""Show the finish screen to allow the user to view their code."""
|
"""Show the finish screen to pick between login & device code."""
|
||||||
|
# Get cookie to get the state_id
|
||||||
|
state_id = await get_valid_state_id(request, self.oidc_provider)
|
||||||
|
if not state_id:
|
||||||
|
return await error_response("Missing state cookie, please restart login.")
|
||||||
|
|
||||||
code = request.query.get("code")
|
return await template_response("finish", {})
|
||||||
|
|
||||||
if not code:
|
|
||||||
view_html = await get_view(
|
|
||||||
"error",
|
|
||||||
{"error": "Missing code to show the finish screen."},
|
|
||||||
)
|
|
||||||
return web.Response(text=view_html, content_type="text/html")
|
|
||||||
|
|
||||||
view_html = await get_view("finish", {"code": code})
|
|
||||||
return web.Response(text=view_html, content_type="text/html")
|
|
||||||
|
|
||||||
async def post(self, request: web.Request) -> web.Response:
|
async def post(self, request: web.Request) -> web.Response:
|
||||||
"""Receive response."""
|
"""Receive response."""
|
||||||
|
|
||||||
# Get code from the message body
|
# Get cookie to get the state_id
|
||||||
data = await request.post()
|
state_id = await get_valid_state_id(request, self.oidc_provider)
|
||||||
code = data.get("code")
|
if not state_id:
|
||||||
|
return await error_response("Missing state cookie, please restart login.")
|
||||||
|
|
||||||
if not code:
|
# Get redirect_uri from the state
|
||||||
return web.Response(text="No code received", status=500)
|
redirect_uri = await self.oidc_provider.async_get_redirect_uri_for_state(
|
||||||
|
state_id
|
||||||
# Return redirect to the main page for sign in with a cookie
|
|
||||||
return web.HTTPFound(
|
|
||||||
location="/",
|
|
||||||
headers={
|
|
||||||
# Set a cookie to enable autologin on only the specific path used
|
|
||||||
# for the POST request, with all strict parameters set
|
|
||||||
# This cookie should not be read by any Javascript or any other paths.
|
|
||||||
# It can be really short lifetime as we redirect immediately (5 seconds)
|
|
||||||
"set-cookie": "auth_oidc_code="
|
|
||||||
+ code
|
|
||||||
+ "; Path=/auth/login_flow; SameSite=Strict; HttpOnly; Max-Age=5",
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not redirect_uri:
|
||||||
|
return await error_response("Invalid state, please restart login.")
|
||||||
|
|
||||||
|
# Get the message body
|
||||||
|
data = await request.post()
|
||||||
|
device_code = data.get("device_code")
|
||||||
|
|
||||||
|
# We are trying sign-in on this browser
|
||||||
|
if not device_code:
|
||||||
|
# Add to the URL correctly (also handle case where it's just the root)
|
||||||
|
separator = "?"
|
||||||
|
if "?" in redirect_uri:
|
||||||
|
separator = "&"
|
||||||
|
|
||||||
|
# Redirect to this new URL for login, make sure to skip OIDC to prevent loops
|
||||||
|
redirect_uri = f"{redirect_uri}{separator}skip_oidc_redirect=true"
|
||||||
|
raise web.HTTPFound(location=redirect_uri)
|
||||||
|
|
||||||
|
# Check if we can link this device
|
||||||
|
linked = await self.oidc_provider.async_link_state_to_code(
|
||||||
|
state_id, device_code
|
||||||
|
)
|
||||||
|
|
||||||
|
if not linked:
|
||||||
|
return await error_response(
|
||||||
|
"Failed to link state to device code, please restart login."
|
||||||
|
)
|
||||||
|
|
||||||
|
return await template_response("device_success", {})
|
||||||
|
|||||||
145
custom_components/auth_oidc/endpoints/injected_auth_page.py
Normal file
145
custom_components/auth_oidc/endpoints/injected_auth_page.py
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
"""Injected authorization page, replacing the original"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import logging
|
||||||
|
from functools import partial
|
||||||
|
from urllib.parse import quote, unquote
|
||||||
|
from aiohttp import web
|
||||||
|
from aiofiles import open as async_open
|
||||||
|
|
||||||
|
from homeassistant.components.http import HomeAssistantView, StaticPathConfig
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from .welcome import PATH as WELCOME_PATH
|
||||||
|
from ..tools.helpers import get_url
|
||||||
|
|
||||||
|
PATH = "/auth/authorize"
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def read_file(path: str) -> str:
|
||||||
|
"""Read a file from the static path."""
|
||||||
|
async with async_open(path, mode="r") as f:
|
||||||
|
return await f.read()
|
||||||
|
|
||||||
|
|
||||||
|
async def frontend_injection(hass: HomeAssistant, force_https: bool) -> None:
|
||||||
|
"""Inject new frontend code into /auth/authorize."""
|
||||||
|
router = hass.http.app.router
|
||||||
|
frontend_path = None
|
||||||
|
|
||||||
|
for resource in router.resources():
|
||||||
|
if resource.canonical != "/auth/authorize":
|
||||||
|
continue
|
||||||
|
|
||||||
|
# This path doesn't actually work, gives 404, effectively disabling the old matcher
|
||||||
|
resource.add_prefix("/auth/oidc/unused")
|
||||||
|
|
||||||
|
# Now get the original frontend path from this resource to obtain the GET route
|
||||||
|
routes = iter(resource)
|
||||||
|
route = next(
|
||||||
|
(r for r in routes if r.method == "GET"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if route is not None:
|
||||||
|
if not route.handler or not isinstance(route.handler, partial):
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Unexpected route handler type %s for /auth/authorize",
|
||||||
|
type(route.handler),
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# The original frontend path is the first argument of the handler
|
||||||
|
frontend_path = route.handler.args[0]
|
||||||
|
break
|
||||||
|
|
||||||
|
# Get the path to the original frontend resource
|
||||||
|
if frontend_path is None:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Failed to find GET route for /auth/authorize, cannot inject OIDC frontend code"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Inject our new script into the existing frontend code
|
||||||
|
# First fetch the frontend path into memory
|
||||||
|
frontend_code = await read_file(frontend_path)
|
||||||
|
|
||||||
|
# Inject JS and register that route
|
||||||
|
injection_js = "<script src='/auth/oidc/static/injection.js?v=6'></script>"
|
||||||
|
frontend_code = frontend_code.replace("</body>", f"{injection_js}</body>")
|
||||||
|
|
||||||
|
await hass.http.async_register_static_paths(
|
||||||
|
[
|
||||||
|
StaticPathConfig(
|
||||||
|
"/auth/oidc/static/injection.js",
|
||||||
|
hass.config.path("custom_components/auth_oidc/static/injection.js"),
|
||||||
|
cache_headers=False,
|
||||||
|
),
|
||||||
|
StaticPathConfig(
|
||||||
|
"/auth/oidc/static/style.css",
|
||||||
|
hass.config.path("custom_components/auth_oidc/static/style.css"),
|
||||||
|
cache_headers=False,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# If everything is succesful, register a fake view that just returns the modified HTML
|
||||||
|
hass.http.register_view(OIDCInjectedAuthPage(frontend_code, force_https))
|
||||||
|
_LOGGER.info("Performed OIDC frontend injection")
|
||||||
|
|
||||||
|
|
||||||
|
class OIDCInjectedAuthPage(HomeAssistantView):
|
||||||
|
"""OIDC Plugin Injected Auth Page."""
|
||||||
|
|
||||||
|
requires_auth = False
|
||||||
|
url = PATH
|
||||||
|
name = "auth:oidc:authorize_page"
|
||||||
|
|
||||||
|
def __init__(self, html: str, force_https: bool) -> None:
|
||||||
|
"""Initialize the injected auth page."""
|
||||||
|
self.html = html
|
||||||
|
self.force_https = force_https
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def inject(hass: HomeAssistant, force_https: bool) -> None:
|
||||||
|
"""Inject the OIDC auth page into the frontend."""
|
||||||
|
try:
|
||||||
|
await frontend_injection(hass, force_https)
|
||||||
|
except Exception as e: # pylint: disable=broad-except
|
||||||
|
_LOGGER.error("Failed to inject OIDC auth page: %s", e)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _should_do_oidc_redirect(req: web.Request) -> bool:
|
||||||
|
"""Check if we should redirect to the OIDC flow."""
|
||||||
|
# Set when we return from finish
|
||||||
|
if req.query.get("skip_oidc_redirect") == "true":
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Set whenever you directly do /?skip_oidc_redirect=true,
|
||||||
|
# for example when you click the "other" button on the welcome screen
|
||||||
|
redirect_uri = req.query.get("redirect_uri")
|
||||||
|
if not redirect_uri:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Handle both encoded and plain redirect_uri values.
|
||||||
|
decoded_redirect_uri = unquote(redirect_uri)
|
||||||
|
return "skip_oidc_redirect=true" not in decoded_redirect_uri
|
||||||
|
|
||||||
|
def _get_welcome_redirect_location(self, req: web.Request) -> str:
|
||||||
|
"""Build the welcome URL for the injected auth page redirect."""
|
||||||
|
encoded_current_url = quote(
|
||||||
|
base64.b64encode(str(req.url).encode("utf-8")).decode("ascii")
|
||||||
|
)
|
||||||
|
return get_url(
|
||||||
|
f"{WELCOME_PATH}?redirect_uri={encoded_current_url}",
|
||||||
|
self.force_https,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get(self, req: web.Request) -> web.Response:
|
||||||
|
"""Return the original page or redirect into the OIDC flow."""
|
||||||
|
if self._should_do_oidc_redirect(req):
|
||||||
|
raise web.HTTPFound(location=self._get_welcome_redirect_location(req))
|
||||||
|
|
||||||
|
return web.Response(text=self.html, content_type="text/html")
|
||||||
@@ -1,11 +1,13 @@
|
|||||||
"""Redirect route to redirect the user to the external OIDC server,
|
"""Redirect route to redirect the user to the external OIDC server,
|
||||||
can either be linked to directly or accessed through the welcome page."""
|
can either be linked to directly or accessed through the welcome page."""
|
||||||
|
|
||||||
|
from urllib.parse import quote
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
|
|
||||||
from ..oidc_client import OIDCClient
|
from ..provider import OpenIDAuthProvider
|
||||||
from ..helpers import get_url, get_view
|
from ..tools.oidc_client import OIDCClient
|
||||||
|
from ..tools.helpers import error_response, get_url, get_valid_state_id, get_view
|
||||||
|
|
||||||
PATH = "/auth/oidc/redirect"
|
PATH = "/auth/oidc/redirect"
|
||||||
|
|
||||||
@@ -17,24 +19,44 @@ class OIDCRedirectView(HomeAssistantView):
|
|||||||
url = PATH
|
url = PATH
|
||||||
name = "auth:oidc:redirect"
|
name = "auth:oidc:redirect"
|
||||||
|
|
||||||
def __init__(self, oidc_client: OIDCClient) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
oidc_client: OIDCClient,
|
||||||
|
oidc_provider: OpenIDAuthProvider,
|
||||||
|
force_https: bool,
|
||||||
|
) -> None:
|
||||||
self.oidc_client = oidc_client
|
self.oidc_client = oidc_client
|
||||||
|
self.oidc_provider = oidc_provider
|
||||||
|
self.force_https = force_https
|
||||||
|
|
||||||
async def get(self, _: web.Request) -> web.Response:
|
async def get(self, req: web.Request) -> web.Response:
|
||||||
"""Receive response."""
|
"""Receive response."""
|
||||||
|
|
||||||
redirect_uri = get_url("/auth/oidc/callback")
|
# Get cookie to get the state_id
|
||||||
auth_url = await self.oidc_client.async_get_authorization_url(redirect_uri)
|
state_id = await get_valid_state_id(req, self.oidc_provider)
|
||||||
|
|
||||||
|
if not state_id:
|
||||||
|
# Direct access to the redirect endpoint, go to welcome page instead
|
||||||
|
welcome_url = get_url("/auth/oidc/welcome", self.force_https)
|
||||||
|
raise web.HTTPFound(welcome_url)
|
||||||
|
|
||||||
|
try:
|
||||||
|
redirect_uri = get_url("/auth/oidc/callback", self.force_https)
|
||||||
|
auth_url = await self.oidc_client.async_get_authorization_url(
|
||||||
|
redirect_uri, state_id
|
||||||
|
)
|
||||||
|
|
||||||
if auth_url:
|
if auth_url:
|
||||||
return web.HTTPFound(auth_url)
|
view_html = await get_view("redirect", {"url": quote(auth_url)})
|
||||||
|
|
||||||
view_html = await get_view(
|
|
||||||
"error",
|
|
||||||
{"error": "Integration is misconfigured, discovery could not be obtained."},
|
|
||||||
)
|
|
||||||
return web.Response(text=view_html, content_type="text/html")
|
return web.Response(text=view_html, content_type="text/html")
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
async def post(self, request: web.Request) -> web.Response:
|
return await error_response(
|
||||||
|
"Integration is misconfigured, discovery could not be obtained.",
|
||||||
|
status=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def post(self, req: web.Request) -> web.Response:
|
||||||
"""POST"""
|
"""POST"""
|
||||||
return await self.get(request)
|
return await self.get(req)
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
"""Welcome route to show the user the OIDC login button and give instructions."""
|
"""Welcome route to show the user the OIDC login button and give instructions."""
|
||||||
|
|
||||||
|
from ast import List
|
||||||
|
import base64
|
||||||
|
import binascii
|
||||||
|
from urllib.parse import urlparse, parse_qs, unquote, urlencode
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from homeassistant.components.http import HomeAssistantView
|
from homeassistant.components.http import HomeAssistantView
|
||||||
from ..helpers import get_view
|
from ..tools.helpers import error_response, get_url, template_response
|
||||||
|
from ..provider import OpenIDAuthProvider
|
||||||
|
|
||||||
PATH = "/auth/oidc/welcome"
|
PATH = "/auth/oidc/welcome"
|
||||||
|
|
||||||
@@ -14,10 +19,126 @@ class OIDCWelcomeView(HomeAssistantView):
|
|||||||
url = PATH
|
url = PATH
|
||||||
name = "auth:oidc:welcome"
|
name = "auth:oidc:welcome"
|
||||||
|
|
||||||
def __init__(self, name: str) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
oidc_provider: OpenIDAuthProvider,
|
||||||
|
name: str,
|
||||||
|
force_https: bool,
|
||||||
|
has_other_auth_providers: bool,
|
||||||
|
) -> None:
|
||||||
|
self.oidc_provider = oidc_provider
|
||||||
self.name = name
|
self.name = name
|
||||||
|
self.force_https = force_https
|
||||||
|
self.has_other_auth_providers = has_other_auth_providers
|
||||||
|
|
||||||
async def get(self, _: web.Request) -> web.Response:
|
async def _process_url(self, redirect_uri: str) -> List[str, bool]:
|
||||||
|
"""Processes the redirect URI to determine if we need setTokens and if this is mobile."""
|
||||||
|
# decodeURIComponent(btoa(...)) -> unquote first, then base64 decode
|
||||||
|
redirect_uri = base64.b64decode(unquote(redirect_uri), validate=True).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
oauth2_url = urlparse(redirect_uri)
|
||||||
|
oauth2_query = parse_qs(oauth2_url.query)
|
||||||
|
client_id = oauth2_query.get("client_id")[0]
|
||||||
|
original_redirect_uri = oauth2_query.get("redirect_uri")[0]
|
||||||
|
|
||||||
|
# If the client_id starts with https://home-assistant.io/
|
||||||
|
# we assume it's a mobile client
|
||||||
|
# Android = https://home-assistant.io/Android,
|
||||||
|
# iOS = https://home-assistant.io/iOS
|
||||||
|
is_mobile = client_id.startswith("https://home-assistant.io/")
|
||||||
|
|
||||||
|
# Check if we appear to be signing in to the web version,
|
||||||
|
# for which we want to store tokens.
|
||||||
|
# We don't want to set storeTokens on sign-in to Google for instance
|
||||||
|
base_url = get_url("/", self.force_https)
|
||||||
|
is_web_client = original_redirect_uri.startswith(base_url)
|
||||||
|
|
||||||
|
if is_web_client:
|
||||||
|
# Adjust the original_redirect_uri to include the storeTokens parameter
|
||||||
|
separator = "?"
|
||||||
|
if "?" in original_redirect_uri:
|
||||||
|
separator = "&"
|
||||||
|
|
||||||
|
original_redirect_uri = f"{original_redirect_uri}{separator}storeToken=true"
|
||||||
|
oauth2_query.update({"redirect_uri": original_redirect_uri})
|
||||||
|
|
||||||
|
# Create new redirect_uri with the updated query parameters
|
||||||
|
new_oauth2_url = oauth2_url._replace(
|
||||||
|
query=urlencode(oauth2_query, doseq=True)
|
||||||
|
)
|
||||||
|
redirect_uri = new_oauth2_url.geturl()
|
||||||
|
|
||||||
|
return redirect_uri, is_mobile
|
||||||
|
|
||||||
|
async def get(self, req: web.Request) -> web.Response:
|
||||||
"""Receive response."""
|
"""Receive response."""
|
||||||
view_html = await get_view("welcome", {"name": self.name})
|
|
||||||
return web.Response(text=view_html, content_type="text/html")
|
# Get the query parameter with the redirect_uri
|
||||||
|
redirect_uri = req.query.get("redirect_uri")
|
||||||
|
|
||||||
|
# Do some processing on the redirect_uri to correct it
|
||||||
|
# and determine if this is a mobile client.
|
||||||
|
if redirect_uri:
|
||||||
|
try:
|
||||||
|
redirect_uri, is_mobile = await self._process_url(redirect_uri)
|
||||||
|
except (
|
||||||
|
binascii.Error,
|
||||||
|
UnicodeDecodeError,
|
||||||
|
ValueError,
|
||||||
|
KeyError,
|
||||||
|
TypeError,
|
||||||
|
):
|
||||||
|
return await error_response(
|
||||||
|
"Invalid redirect_uri, please restart login."
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Backwards compatibility with older versions that directly go to /auth/oidc/welcome
|
||||||
|
# If not set, redirect back to the main page and assume that this is a web client
|
||||||
|
redirect_uri = get_url("/?storeToken=true", self.force_https)
|
||||||
|
is_mobile = False
|
||||||
|
|
||||||
|
# Create OIDC state with the redirect_uri so we can use it later in the flow
|
||||||
|
state_id = await self.oidc_provider.async_create_state(redirect_uri)
|
||||||
|
cookie_header = self.oidc_provider.get_cookie_header(
|
||||||
|
state_id, secure=self.force_https or req.url.scheme == "https"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If this is the only provider and we are on desktop,
|
||||||
|
# automatically go through the OIDC login
|
||||||
|
if not is_mobile and not self.has_other_auth_providers:
|
||||||
|
raise web.HTTPFound(
|
||||||
|
location=get_url("/auth/oidc/redirect", self.force_https),
|
||||||
|
headers=cookie_header,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Otherwise display the screen with either mobile sign in or the buttons
|
||||||
|
# First generate code if mobile
|
||||||
|
code = None
|
||||||
|
if is_mobile:
|
||||||
|
# Create a code to login
|
||||||
|
code = await self.oidc_provider.async_generate_device_code(state_id)
|
||||||
|
if not code:
|
||||||
|
return await error_response(
|
||||||
|
"Failed to generate device code, please restart login.",
|
||||||
|
status=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
# And add the other link if we have other auth providers
|
||||||
|
other_link = None
|
||||||
|
if self.has_other_auth_providers:
|
||||||
|
other_link = get_url("/?skip_oidc_redirect=true", self.force_https)
|
||||||
|
|
||||||
|
# And display
|
||||||
|
response = await template_response(
|
||||||
|
"welcome",
|
||||||
|
{
|
||||||
|
"name": self.name,
|
||||||
|
"other_link": other_link,
|
||||||
|
"code": code,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.headers.update(cookie_header)
|
||||||
|
return response
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
"""Helper functions for the integration."""
|
|
||||||
|
|
||||||
from homeassistant.components import http
|
|
||||||
from .views.loader import AsyncTemplateRenderer
|
|
||||||
|
|
||||||
|
|
||||||
def get_url(path: str) -> str:
|
|
||||||
"""Returns the requested path appended to the current request base URL."""
|
|
||||||
if (req := http.current_request.get()) is None:
|
|
||||||
raise RuntimeError("No current request in context")
|
|
||||||
|
|
||||||
base_uri = str(req.url).split("/auth", 2)[0]
|
|
||||||
return f"{base_uri}{path}"
|
|
||||||
|
|
||||||
|
|
||||||
async def get_view(template: str, parameters: dict | None = None) -> str:
|
|
||||||
"""Returns the generated HTML of the requested view."""
|
|
||||||
if parameters is None:
|
|
||||||
parameters = {}
|
|
||||||
|
|
||||||
renderer = AsyncTemplateRenderer()
|
|
||||||
return await renderer.render_template(f"{template}.html", **parameters)
|
|
||||||
@@ -1,10 +1,10 @@
|
|||||||
{
|
{
|
||||||
"domain": "auth_oidc",
|
"domain": "auth_oidc",
|
||||||
"name": "OIDC Authentication",
|
"name": "OpenID Connect/SSO Authentication",
|
||||||
"codeowners": [
|
"codeowners": [
|
||||||
"@christiaangoossens"
|
"@christiaangoossens"
|
||||||
],
|
],
|
||||||
"config_flow": false,
|
"config_flow": true,
|
||||||
"dependencies": [
|
"dependencies": [
|
||||||
"auth",
|
"auth",
|
||||||
"http"
|
"http"
|
||||||
@@ -14,10 +14,9 @@
|
|||||||
"iot_class": "calculated",
|
"iot_class": "calculated",
|
||||||
"issue_tracker": "https://github.com/christiaangoossens/hass-oidc-auth/issues",
|
"issue_tracker": "https://github.com/christiaangoossens/hass-oidc-auth/issues",
|
||||||
"requirements": [
|
"requirements": [
|
||||||
"python-jose>=3.3.0",
|
"aiofiles",
|
||||||
"aiofiles>=24.1.0",
|
"jinja2",
|
||||||
"jinja2>=3.1.4",
|
"joserfc"
|
||||||
"bcrypt>=4.2.0"
|
|
||||||
],
|
],
|
||||||
"version": "0.6.2"
|
"version": "1.0.0-rc3"
|
||||||
}
|
}
|
||||||
@@ -6,7 +6,6 @@ import logging
|
|||||||
|
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
import bcrypt
|
|
||||||
from homeassistant.auth import EVENT_USER_ADDED
|
from homeassistant.auth import EVENT_USER_ADDED
|
||||||
from homeassistant.auth.providers import (
|
from homeassistant.auth.providers import (
|
||||||
AUTH_PROVIDERS,
|
AUTH_PROVIDERS,
|
||||||
@@ -22,21 +21,21 @@ from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
|
|||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.components import http, person
|
from homeassistant.components import http, person
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
import voluptuous as vol
|
|
||||||
|
|
||||||
from .config import (
|
from .config.const import (
|
||||||
FEATURES,
|
FEATURES,
|
||||||
FEATURES_AUTOMATIC_USER_LINKING,
|
FEATURES_AUTOMATIC_USER_LINKING,
|
||||||
FEATURES_AUTOMATIC_PERSON_CREATION,
|
FEATURES_AUTOMATIC_PERSON_CREATION,
|
||||||
DEFAULT_TITLE,
|
DEFAULT_TITLE,
|
||||||
)
|
)
|
||||||
from .stores.code_store import CodeStore
|
from .stores.state_store import StateStore
|
||||||
from .types import UserDetails
|
from .tools.types import UserDetails
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
PROVIDER_TYPE = "auth_oidc"
|
PROVIDER_TYPE = "auth_oidc"
|
||||||
HASS_PROVIDER_TYPE = "homeassistant"
|
HASS_PROVIDER_TYPE = "homeassistant"
|
||||||
|
COOKIE_NAME = "auth_oidc_state"
|
||||||
|
|
||||||
|
|
||||||
class InvalidAuthError(HomeAssistantError):
|
class InvalidAuthError(HomeAssistantError):
|
||||||
@@ -68,7 +67,7 @@ class OpenIDAuthProvider(AuthProvider):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._user_meta: dict[UserDetails] = {}
|
self._user_meta: dict[UserDetails] = {}
|
||||||
self._code_store: CodeStore | None = None
|
self._state_store: StateStore | None = None
|
||||||
self._init_lock = asyncio.Lock()
|
self._init_lock = asyncio.Lock()
|
||||||
|
|
||||||
features = config.get(
|
features = config.get(
|
||||||
@@ -89,29 +88,120 @@ class OpenIDAuthProvider(AuthProvider):
|
|||||||
async def async_initialize(self) -> None:
|
async def async_initialize(self) -> None:
|
||||||
"""Initialize the auth provider."""
|
"""Initialize the auth provider."""
|
||||||
|
|
||||||
# Init the code store first
|
# Init the store first
|
||||||
# Use the same technique as the HomeAssistant auth provider for storage
|
# Use the same technique as the HomeAssistant auth provider for storage
|
||||||
# (/auth/providers/homeassistant.py#L392)
|
# (/auth/providers/homeassistant.py#L392)
|
||||||
async with self._init_lock:
|
async with self._init_lock:
|
||||||
if self._code_store is not None:
|
if self._state_store is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
store = CodeStore(self.hass)
|
store = StateStore(self.hass)
|
||||||
await store.async_load()
|
await store.async_load()
|
||||||
self._code_store = store
|
self._state_store = store
|
||||||
self._user_meta = {}
|
self._user_meta = {}
|
||||||
|
|
||||||
# Listen for user creation events
|
# Listen for user creation events
|
||||||
self.hass.bus.async_listen(EVENT_USER_ADDED, self.async_user_created)
|
self.hass.bus.async_listen(EVENT_USER_ADDED, self.async_user_created)
|
||||||
|
|
||||||
async def async_get_subject(self, code: str) -> Optional[str]:
|
def _resolve_ip(self, ip: str | None = None) -> str | None:
|
||||||
"""Retrieve user from the code, return subject and save meta
|
"""Resolve client IP from explicit input or current request context."""
|
||||||
for later use with this provider instance."""
|
if ip:
|
||||||
if self._code_store is None:
|
return ip
|
||||||
await self.async_initialize()
|
|
||||||
assert self._code_store is not None
|
|
||||||
|
|
||||||
user_data = await self._code_store.receive_userinfo_for_code(code)
|
req = http.current_request.get()
|
||||||
|
if req and req.remote:
|
||||||
|
return req.remote
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_create_state(self, redirect_uri: str, ip: str | None = None) -> str:
|
||||||
|
"""Create a new OIDC state and return the state id."""
|
||||||
|
if self._state_store is None:
|
||||||
|
await self.async_initialize()
|
||||||
|
assert self._state_store is not None
|
||||||
|
|
||||||
|
return await self._state_store.async_create_state_from_url(
|
||||||
|
redirect_uri, self._resolve_ip(ip)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_generate_device_code(self, state_id: str) -> Optional[str]:
|
||||||
|
"""Generate a device code for the state, used for device login."""
|
||||||
|
if self._state_store is None:
|
||||||
|
await self.async_initialize()
|
||||||
|
assert self._state_store is not None
|
||||||
|
|
||||||
|
return await self._state_store.async_generate_code_for_state(state_id)
|
||||||
|
|
||||||
|
async def async_save_user_info(
|
||||||
|
self, state_id: str, user_info: dict[str, dict | str]
|
||||||
|
) -> bool:
|
||||||
|
"""Save user info to the given state."""
|
||||||
|
if self._state_store is None:
|
||||||
|
await self.async_initialize()
|
||||||
|
assert self._state_store is not None
|
||||||
|
|
||||||
|
return await self._state_store.async_add_userinfo_to_state(state_id, user_info)
|
||||||
|
|
||||||
|
async def async_get_redirect_uri_for_state(
|
||||||
|
self, state_id: str, ip: str | None = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Get the redirect_uri for the given state."""
|
||||||
|
if self._state_store is None:
|
||||||
|
await self.async_initialize()
|
||||||
|
assert self._state_store is not None
|
||||||
|
|
||||||
|
return await self._state_store.async_get_redirect_uri_for_state(
|
||||||
|
state_id, self._resolve_ip(ip)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_is_state_valid(self, state_id: str, ip: str | None = None) -> bool:
|
||||||
|
"""Check if a state exists, belongs to this IP, and is not expired."""
|
||||||
|
if self._state_store is None:
|
||||||
|
await self.async_initialize()
|
||||||
|
assert self._state_store is not None
|
||||||
|
|
||||||
|
return (
|
||||||
|
await self._state_store.async_get_redirect_uri_for_state(
|
||||||
|
state_id, self._resolve_ip(ip)
|
||||||
|
)
|
||||||
|
is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_is_state_ready(self, state_id: str, ip: str | None = None) -> bool:
|
||||||
|
"""Check if the state has received the user info from the OIDC callback."""
|
||||||
|
if self._state_store is None:
|
||||||
|
await self.async_initialize()
|
||||||
|
assert self._state_store is not None
|
||||||
|
|
||||||
|
return await self._state_store.async_is_state_ready(
|
||||||
|
state_id, self._resolve_ip(ip)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_link_state_to_code(
|
||||||
|
self, state_id: str, code: str, ip: str | None = None
|
||||||
|
) -> bool:
|
||||||
|
"""Link two states together by copying the user info from one to the other."""
|
||||||
|
if self._state_store is None:
|
||||||
|
await self.async_initialize()
|
||||||
|
assert self._state_store is not None
|
||||||
|
|
||||||
|
return await self._state_store.async_link_state_to_code(
|
||||||
|
state_id, code, self._resolve_ip(ip)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_get_subject(
|
||||||
|
self, state_id: str, ip: str | None = None
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Retrieve user from the state_id, return subject and save meta
|
||||||
|
for later use with this provider instance."""
|
||||||
|
if self._state_store is None:
|
||||||
|
await self.async_initialize()
|
||||||
|
assert self._state_store is not None
|
||||||
|
|
||||||
|
# This also deletes the state as we are using it for sign-in
|
||||||
|
user_data = await self._state_store.async_receive_userinfo_for_state(
|
||||||
|
state_id, self._resolve_ip(ip)
|
||||||
|
)
|
||||||
if user_data is None:
|
if user_data is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -119,14 +209,6 @@ class OpenIDAuthProvider(AuthProvider):
|
|||||||
self._user_meta[sub] = user_data
|
self._user_meta[sub] = user_data
|
||||||
return sub
|
return sub
|
||||||
|
|
||||||
async def async_save_user_info(self, user_info: dict[str, dict | str]) -> str:
|
|
||||||
"""Save user info and return a code."""
|
|
||||||
if self._code_store is None:
|
|
||||||
await self.async_initialize()
|
|
||||||
assert self._code_store is not None
|
|
||||||
|
|
||||||
return await self._code_store.async_generate_code_for_userinfo(user_info)
|
|
||||||
|
|
||||||
async def _async_find_user_by_username(self, username: str) -> Optional[User]:
|
async def _async_find_user_by_username(self, username: str) -> Optional[User]:
|
||||||
"""Find a user by username."""
|
"""Find a user by username."""
|
||||||
users = await self.store.async_get_users()
|
users = await self.store.async_get_users()
|
||||||
@@ -145,6 +227,18 @@ class OpenIDAuthProvider(AuthProvider):
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_cookie_header(self, state_id: str, secure: bool = False):
|
||||||
|
"""Get the cookie header to set the state_id cookie."""
|
||||||
|
secure_flag = "; Secure" if secure else ""
|
||||||
|
return {
|
||||||
|
# Set a cookie for the other pages to know the state_id
|
||||||
|
# Keep cookie lifetime aligned with state lifetime in storage (5 minutes).
|
||||||
|
"set-cookie": f"{COOKIE_NAME}="
|
||||||
|
+ state_id
|
||||||
|
+ "; Path=/auth/; SameSite=Lax; HttpOnly; Max-Age=300"
|
||||||
|
+ secure_flag,
|
||||||
|
}
|
||||||
|
|
||||||
# ====
|
# ====
|
||||||
# Handler for user created and related functions (person creation)
|
# Handler for user created and related functions (person creation)
|
||||||
# ====
|
# ====
|
||||||
@@ -177,9 +271,9 @@ class OpenIDAuthProvider(AuthProvider):
|
|||||||
# If person creation is enabled, add a person for this user
|
# If person creation is enabled, add a person for this user
|
||||||
if self.create_persons:
|
if self.create_persons:
|
||||||
user_meta = await self.async_user_meta_for_credentials(credential)
|
user_meta = await self.async_user_meta_for_credentials(credential)
|
||||||
await self.async_create_person(user, user_meta.name)
|
await self._async_create_person(user, user_meta.name)
|
||||||
|
|
||||||
async def async_create_person(self, user: User, name: str) -> None:
|
async def _async_create_person(self, user: User, name: str) -> None:
|
||||||
"""Create a person for the user."""
|
"""Create a person for the user."""
|
||||||
_LOGGER.info("Automatically creating person for new user %s", user.id)
|
_LOGGER.info("Automatically creating person for new user %s", user.id)
|
||||||
|
|
||||||
@@ -194,7 +288,7 @@ class OpenIDAuthProvider(AuthProvider):
|
|||||||
# pylint: disable=broad-exception-caught
|
# pylint: disable=broad-exception-caught
|
||||||
except Exception:
|
except Exception:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Requested automatic person creation, but person creation failed."
|
"Requested automatic person creation, but person creation failed"
|
||||||
)
|
)
|
||||||
# pylint: enable=broad-exception-caught
|
# pylint: enable=broad-exception-caught
|
||||||
|
|
||||||
@@ -271,16 +365,8 @@ class OpenIDAuthProvider(AuthProvider):
|
|||||||
class OpenIdLoginFlow(LoginFlow):
|
class OpenIdLoginFlow(LoginFlow):
|
||||||
"""Handler for the login flow."""
|
"""Handler for the login flow."""
|
||||||
|
|
||||||
async def _finalize_user(self, code: str) -> AuthFlowResult:
|
async def _finalize_user(self, state_id: str) -> AuthFlowResult:
|
||||||
# Verify a dummy hash to make it last a bit longer
|
sub = await self._auth_provider.async_get_subject(state_id)
|
||||||
# as security measure (limits the amount of attempts you have in 5 min)
|
|
||||||
# Similar to what the HomeAssistant auth provider does
|
|
||||||
dummy = b"$2b$12$CiuFGszHx9eNHxPuQcwBWez4CwDTOcLTX5CbOpV6gef2nYuXkY7BO"
|
|
||||||
bcrypt.checkpw(b"foo", dummy)
|
|
||||||
|
|
||||||
# Actually look up the auth provider after,
|
|
||||||
# this doesn't take a lot of time (regardless of it's in there or not)
|
|
||||||
sub = await self._auth_provider.async_get_subject(code)
|
|
||||||
if sub:
|
if sub:
|
||||||
return await self.async_finish(
|
return await self.async_finish(
|
||||||
{
|
{
|
||||||
@@ -290,53 +376,22 @@ class OpenIdLoginFlow(LoginFlow):
|
|||||||
|
|
||||||
raise InvalidAuthError
|
raise InvalidAuthError
|
||||||
|
|
||||||
def _show_login_form(
|
|
||||||
self, errors: Optional[dict[str, str]] = None
|
|
||||||
) -> AuthFlowResult:
|
|
||||||
if errors is None:
|
|
||||||
errors = {}
|
|
||||||
|
|
||||||
# Show the login form
|
|
||||||
# Abuses the MFA form, as it works better for our usecase
|
|
||||||
# UI suggestions are welcome (make a PR!)
|
|
||||||
return self.async_show_form(
|
|
||||||
step_id="mfa",
|
|
||||||
data_schema=vol.Schema(
|
|
||||||
{
|
|
||||||
vol.Required("code"): str,
|
|
||||||
}
|
|
||||||
),
|
|
||||||
errors=errors,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def async_step_init(
|
async def async_step_init(
|
||||||
self, user_input: dict[str, str] | None = None
|
self, user_input: dict[str, str] | None = None
|
||||||
) -> AuthFlowResult:
|
) -> AuthFlowResult:
|
||||||
"""Handle the step of the form."""
|
"""Handle the step of the form."""
|
||||||
|
|
||||||
# Try to use the user input first
|
# Check if the cookie is present to login
|
||||||
if user_input is not None:
|
|
||||||
try:
|
|
||||||
return await self._finalize_user(user_input["code"])
|
|
||||||
except InvalidAuthError:
|
|
||||||
return self._show_login_form({"base": "invalid_auth"})
|
|
||||||
|
|
||||||
# If not available, check the cookie
|
|
||||||
req = http.current_request.get()
|
req = http.current_request.get()
|
||||||
code_cookie = req.cookies.get("auth_oidc_code")
|
if req and req.cookies:
|
||||||
|
state_cookie = req.cookies.get(COOKIE_NAME)
|
||||||
|
|
||||||
if code_cookie:
|
if state_cookie:
|
||||||
_LOGGER.debug("Code cookie found on login: %s", code_cookie)
|
|
||||||
try:
|
try:
|
||||||
return await self._finalize_user(code_cookie)
|
return await self._finalize_user(state_cookie)
|
||||||
except InvalidAuthError:
|
except InvalidAuthError:
|
||||||
pass
|
return self.async_abort(reason="oidc_cookie_invalid")
|
||||||
|
|
||||||
# If none are available, just show the form
|
# If no cookie is found, abort.
|
||||||
return self._show_login_form()
|
# User should either be redirected or start manually on the welcome
|
||||||
|
return self.async_abort(reason="no_oidc_cookie_found")
|
||||||
async def async_step_mfa(
|
|
||||||
self, user_input: dict[str, str] | None = None
|
|
||||||
) -> AuthFlowResult:
|
|
||||||
# This is a dummy step function just to use the nicer MFA UI instead
|
|
||||||
return await self.async_step_init(user_input)
|
|
||||||
|
|||||||
61
custom_components/auth_oidc/static/injection.js
Normal file
61
custom_components/auth_oidc/static/injection.js
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
/**
|
||||||
|
* hass-oidc-auth - UX script to automatically select the Home Assistant auth provider when the "Login aborted" message is shown.
|
||||||
|
*/
|
||||||
|
|
||||||
|
let authFlowElement = null
|
||||||
|
|
||||||
|
function update() {
|
||||||
|
// Find ha-auth-flow
|
||||||
|
authFlowElement = document.querySelector('ha-auth-flow');
|
||||||
|
|
||||||
|
if (!authFlowElement) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the text "Login aborted" is present on the page
|
||||||
|
if (!authFlowElement.innerText.includes('Login aborted')) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the ha-pick-auth-provider element
|
||||||
|
const authProviderElement = document.querySelector('ha-pick-auth-provider');
|
||||||
|
|
||||||
|
if (!authProviderElement) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Click the first ha-list-item element inside the ha-pick-auth-provider
|
||||||
|
const firstListItem = authProviderElement.shadowRoot?.querySelector('ha-list-item');
|
||||||
|
if (!firstListItem) {
|
||||||
|
console.warn("[OIDC] No ha-list-item found inside ha-pick-auth-provider. Not automatically selecting HA provider.");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
firstListItem.click();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hide the content until ready
|
||||||
|
let ready = false
|
||||||
|
document.querySelector(".content").style.display = "none"
|
||||||
|
|
||||||
|
const observer = new MutationObserver((mutationsList, observer) => {
|
||||||
|
update();
|
||||||
|
|
||||||
|
if (!ready) {
|
||||||
|
ready = Boolean(authFlowElement)
|
||||||
|
if (ready) {
|
||||||
|
document.querySelector(".content").style.display = ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
observer.observe(document.body, { childList: true, subtree: true })
|
||||||
|
|
||||||
|
setTimeout(() => {
|
||||||
|
if (!ready) {
|
||||||
|
console.warn("[hass-oidc-auth]: Document was not ready after 300ms seconds, showing content anyway.")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Force display the content
|
||||||
|
document.querySelector(".content").style.display = "";
|
||||||
|
}, 300)
|
||||||
3
custom_components/auth_oidc/static/input.css
Normal file
3
custom_components/auth_oidc/static/input.css
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
@import "tailwindcss";
|
||||||
|
|
||||||
|
@source "../views/templates";
|
||||||
0
custom_components/auth_oidc/stores/__init__.py
Normal file
0
custom_components/auth_oidc/stores/__init__.py
Normal file
@@ -1,78 +0,0 @@
|
|||||||
"""Code Store, stores the codes and their associated authenticated user temporarily."""
|
|
||||||
|
|
||||||
import random
|
|
||||||
import string
|
|
||||||
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import cast, Optional
|
|
||||||
from homeassistant.helpers.storage import Store
|
|
||||||
from homeassistant.core import HomeAssistant
|
|
||||||
|
|
||||||
from ..types import UserDetails
|
|
||||||
|
|
||||||
STORAGE_VERSION = 1
|
|
||||||
STORAGE_KEY = "auth_provider.auth_oidc.codes"
|
|
||||||
|
|
||||||
|
|
||||||
class CodeStore:
|
|
||||||
"""Holds the codes and associated data"""
|
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
|
||||||
"""Initialize the user data store."""
|
|
||||||
self.hass = hass
|
|
||||||
self._store = Store[dict[str, UserDetails]](
|
|
||||||
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
|
||||||
)
|
|
||||||
self._data: dict[str, dict[str, dict | str]] | None = None
|
|
||||||
|
|
||||||
async def async_load(self) -> None:
|
|
||||||
"""Load stored data."""
|
|
||||||
if (data := await self._store.async_load()) is None:
|
|
||||||
data = cast(dict[str, UserDetails], {})
|
|
||||||
self._data = data
|
|
||||||
|
|
||||||
async def async_save(self) -> None:
|
|
||||||
"""Save data."""
|
|
||||||
if self._data is not None:
|
|
||||||
await self._store.async_save(self._data)
|
|
||||||
|
|
||||||
def _generate_code(self) -> str:
|
|
||||||
"""Generate a random six-digit code."""
|
|
||||||
return "".join(random.choices(string.digits, k=6))
|
|
||||||
|
|
||||||
async def async_generate_code_for_userinfo(self, user_info: UserDetails) -> str:
|
|
||||||
"""Generates a one time code and adds it to the database for 5 minutes."""
|
|
||||||
if self._data is None:
|
|
||||||
raise RuntimeError("Data not loaded")
|
|
||||||
|
|
||||||
code = self._generate_code()
|
|
||||||
expiration = datetime.utcnow() + timedelta(minutes=5)
|
|
||||||
|
|
||||||
self._data[code] = {
|
|
||||||
"user_info": user_info,
|
|
||||||
"code": code,
|
|
||||||
"expiration": expiration.isoformat(),
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.async_save()
|
|
||||||
return code
|
|
||||||
|
|
||||||
async def receive_userinfo_for_code(self, code: str) -> Optional[UserDetails]:
|
|
||||||
"""Retrieve user info based on the code."""
|
|
||||||
if self._data is None:
|
|
||||||
raise RuntimeError("Data not loaded")
|
|
||||||
|
|
||||||
user_data = self._data.get(code)
|
|
||||||
|
|
||||||
if user_data:
|
|
||||||
# We should now wipe it from the database, as it's one time use code
|
|
||||||
self._data.pop(code)
|
|
||||||
await self.async_save()
|
|
||||||
|
|
||||||
if (
|
|
||||||
user_data
|
|
||||||
and datetime.fromisoformat(user_data["expiration"]) > datetime.utcnow()
|
|
||||||
):
|
|
||||||
return user_data["user_info"]
|
|
||||||
|
|
||||||
return None
|
|
||||||
191
custom_components/auth_oidc/stores/state_store.py
Normal file
191
custom_components/auth_oidc/stores/state_store.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
"""State Store, store authentication states (redirect_uri)."""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import cast, Optional
|
||||||
|
from homeassistant.helpers.storage import Store
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from ..tools.types import OIDCState, UserDetails
|
||||||
|
|
||||||
|
STORAGE_VERSION = 1
|
||||||
|
STORAGE_KEY = "auth_provider.auth_oidc.states"
|
||||||
|
MAX_DEVICE_CODE_ATTEMPTS = 10
|
||||||
|
|
||||||
|
|
||||||
|
class StateStore:
|
||||||
|
"""Holds the authentication states and associated data"""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant) -> None:
|
||||||
|
"""Initialize the user data store."""
|
||||||
|
self.hass = hass
|
||||||
|
self._store = Store[dict[str, OIDCState]](
|
||||||
|
hass, STORAGE_VERSION, STORAGE_KEY, private=True, atomic_writes=True
|
||||||
|
)
|
||||||
|
self._data: dict[str, OIDCState] | None = None
|
||||||
|
|
||||||
|
async def async_load(self) -> None:
|
||||||
|
"""Load stored data."""
|
||||||
|
if (data := await self._store.async_load()) is None:
|
||||||
|
data = cast(dict[str, OIDCState], {})
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
async def _async_save(self) -> None:
|
||||||
|
"""Save data."""
|
||||||
|
if self._data is not None:
|
||||||
|
await self._store.async_save(self._data)
|
||||||
|
|
||||||
|
def _generate_id(self) -> str:
|
||||||
|
"""Generate a random identifier."""
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
def _generate_code(self) -> str:
|
||||||
|
"""Generate a random six-digit code."""
|
||||||
|
return "".join(random.choices(string.digits, k=6))
|
||||||
|
|
||||||
|
def _is_expired(self, state: OIDCState) -> bool:
|
||||||
|
"""Check if a state is expired."""
|
||||||
|
return datetime.fromisoformat(state["expiration"]) < datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
def _is_valid(self, state: OIDCState, ip: str | None) -> bool:
|
||||||
|
"""Check if a state is valid"""
|
||||||
|
return (
|
||||||
|
not self._is_expired(state)
|
||||||
|
and bool(state["redirect_uri"])
|
||||||
|
and ip is not None
|
||||||
|
and state["ip_address"] == ip
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_create_state_from_url(self, redirect_uri: str, ip: str) -> str:
|
||||||
|
"""Generates a the OIDC state adds it to the database for 5 minutes."""
|
||||||
|
if self._data is None:
|
||||||
|
raise RuntimeError("Data not loaded")
|
||||||
|
|
||||||
|
state_id = self._generate_id()
|
||||||
|
expiration = datetime.now(timezone.utc) + timedelta(minutes=5)
|
||||||
|
|
||||||
|
self._data[state_id] = {
|
||||||
|
"id": state_id,
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
"device_code": None,
|
||||||
|
"device_code_attempts": 0,
|
||||||
|
"user_details": None,
|
||||||
|
"expiration": expiration.isoformat(),
|
||||||
|
"ip_address": ip,
|
||||||
|
}
|
||||||
|
|
||||||
|
await self._async_save()
|
||||||
|
return state_id
|
||||||
|
|
||||||
|
async def async_generate_code_for_state(self, state_id: str) -> Optional[str]:
|
||||||
|
"""Generates a one time code for the state to link device clients."""
|
||||||
|
if self._data is None:
|
||||||
|
raise RuntimeError("Data not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
code = self._generate_code()
|
||||||
|
self._data[state_id]["device_code"] = code
|
||||||
|
await self._async_save()
|
||||||
|
return code
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_add_userinfo_to_state(
|
||||||
|
self, state_id: str, user_info: UserDetails
|
||||||
|
) -> bool:
|
||||||
|
"""Add userinfo to existing state to complete login"""
|
||||||
|
if self._data is None:
|
||||||
|
raise RuntimeError("Data not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
self._data[state_id]["user_details"] = user_info
|
||||||
|
await self._async_save()
|
||||||
|
return True
|
||||||
|
except KeyError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def async_get_redirect_uri_for_state(
|
||||||
|
self, state_id: str, ip: str
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Get the redirect_uri for a given state_id."""
|
||||||
|
if self._data is None:
|
||||||
|
raise RuntimeError("Data not loaded")
|
||||||
|
|
||||||
|
state = self._data.get(state_id)
|
||||||
|
if state and self._is_valid(state, ip):
|
||||||
|
return state["redirect_uri"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def async_is_state_ready(self, state_id: str, ip: str) -> bool:
|
||||||
|
"""Check if the state has received the user info from the OIDC callback."""
|
||||||
|
if self._data is None:
|
||||||
|
raise RuntimeError("Data not loaded")
|
||||||
|
|
||||||
|
state = self._data.get(state_id)
|
||||||
|
return (
|
||||||
|
state is not None
|
||||||
|
and state["user_details"] is not None
|
||||||
|
and self._is_valid(state, ip)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def async_link_state_to_code(
|
||||||
|
self, state_id: str, code: str, ip: str | None
|
||||||
|
) -> bool:
|
||||||
|
"""Link a state to a device code, used for mobile sign-in."""
|
||||||
|
if self._data is None:
|
||||||
|
raise RuntimeError("Data not loaded")
|
||||||
|
|
||||||
|
state_data = self._data.get(state_id)
|
||||||
|
if (
|
||||||
|
state_data
|
||||||
|
and self._is_valid(state_data, ip)
|
||||||
|
and state_data["user_details"] is not None
|
||||||
|
):
|
||||||
|
attempts = state_data.get("device_code_attempts", 0)
|
||||||
|
if attempts >= MAX_DEVICE_CODE_ATTEMPTS:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Find the state with the matching device code and link it
|
||||||
|
for state in self._data.values():
|
||||||
|
if state["device_code"] == code and not self._is_expired(state):
|
||||||
|
# Set user details on the device state to allow it to complete login
|
||||||
|
state["user_details"] = state_data["user_details"]
|
||||||
|
|
||||||
|
# Delete the 'donor' state as it's one time use
|
||||||
|
self._data.pop(state_id)
|
||||||
|
|
||||||
|
# Save and return true
|
||||||
|
await self._async_save()
|
||||||
|
return True
|
||||||
|
|
||||||
|
state_data["device_code_attempts"] = attempts + 1
|
||||||
|
await self._async_save()
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def async_receive_userinfo_for_state(
|
||||||
|
self, state_id: str, ip: str
|
||||||
|
) -> Optional[OIDCState]:
|
||||||
|
"""Retrieve user info based on the state_id."""
|
||||||
|
if self._data is None:
|
||||||
|
raise RuntimeError("Data not loaded")
|
||||||
|
|
||||||
|
user_data = self._data.get(state_id)
|
||||||
|
|
||||||
|
if user_data:
|
||||||
|
# We should now wipe it from the database, as it's one time use
|
||||||
|
self._data.pop(state_id)
|
||||||
|
await self._async_save()
|
||||||
|
|
||||||
|
if user_data and self._is_valid(user_data, ip):
|
||||||
|
return user_data["user_details"]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_data(self):
|
||||||
|
"""Get the internal data for testing purposes."""
|
||||||
|
return self._data
|
||||||
104
custom_components/auth_oidc/strings.json
Normal file
104
custom_components/auth_oidc/strings.json
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"step": {
|
||||||
|
"user": {
|
||||||
|
"title": "Choose OIDC Provider",
|
||||||
|
"description": "Select your OpenID Connect identity provider to get started with the setup.",
|
||||||
|
"data": {
|
||||||
|
"provider": "Provider"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"discovery_url": {
|
||||||
|
"title": "Provider Configuration",
|
||||||
|
"description": "Enter the discovery URL for {provider_name}. This is typically found in your provider's documentation and usually ends with '/.well-known/openid-configuration'.\n\nNeed detailed setup instructions? See the [provider guide]({documentation_url}).",
|
||||||
|
"data": {
|
||||||
|
"discovery_url": "Discovery URL"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"client_config": {
|
||||||
|
"title": "Client Configuration",
|
||||||
|
"description": "Configure your OIDC client. You can find these details in your {provider_name} application settings.\n\n**Discovery URL:** {discovery_url}\n\n**Setup Instructions:**\n1. Register a new application in your OIDC provider\n2. Set the application type to 'Public Client' (recommended for most users)\n3. Add redirect URLs for Home Assistant\n4. Copy the Client ID below\n\n**Note:** If your provider requires a client secret, check 'Use Confidential Client' and provide your client secret below.\n\n**Need detailed setup instructions?** Check the [setup guide]({documentation_url}) for step-by-step instructions.",
|
||||||
|
"data": {
|
||||||
|
"client_id": "Client ID",
|
||||||
|
"client_secret": "Client Secret (optional; required by some providers)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"client_secret": {
|
||||||
|
"title": "Client Secret Configuration",
|
||||||
|
"description": "Since you selected 'Confidential Client', please provide your client secret.\n\n**Provider:** {provider_name}\n**Client ID:** {client_id}\n**Discovery URL:** {discovery_url}\n\n**Security Note:** The client secret will be stored securely in Home Assistant's configuration. Never share your client secret with others.",
|
||||||
|
"data": {
|
||||||
|
"client_secret": "Client Secret"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"validate_connection": {
|
||||||
|
"title": "Connection Validation",
|
||||||
|
"description": "Testing connection to your {provider_name} OIDC provider...\n\n**Discovery URL:** {discovery_url}\n**Client ID:** {client_id}\n\n{discovery_details}\n\n**What to do next:**\n- **Continue Setup:** Proceed with the configuration (when validation succeeds)\n- **Retry Validation:** Test the connection again with current settings\n- **Modify Client Settings:** Go back to change Client ID or secret\n- **Modify Discovery URL:** Go back to change the discovery URL\n- **Change Provider:** Start over with a different provider\n\n**Need Help?** Check the [setup documentation]({documentation_url}) for detailed configuration instructions.",
|
||||||
|
"data": {
|
||||||
|
"action": "Choose an action"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"groups_config": {
|
||||||
|
"title": "Groups & Role Configuration",
|
||||||
|
"description": "Configure how user groups from {provider_name} should be mapped to Home Assistant roles.\n\n**Groups Support:** Groups allow you to automatically assign admin or user roles based on group membership in your identity provider.\n\n**Admin Group:** Users in this group will have administrator access\n**User Group:** Users in this group will have standard user access (leave empty to allow all authenticated users)",
|
||||||
|
"data": {
|
||||||
|
"enable_groups": "Enable group-based role assignment",
|
||||||
|
"admin_group": "Admin group name",
|
||||||
|
"user_group": "User group name (optional)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"user_linking": {
|
||||||
|
"title": "User Linking Options",
|
||||||
|
"description": "Configure how OIDC users are linked to existing Home Assistant users.\n\n**⚠️ Important Security Information:**\n\n**User Linking Disabled (Recommended):** New OIDC accounts are created for each user. This is the most secure option.\n\n**User Linking Enabled:** OIDC users can be linked to existing Home Assistant users by username. **This has security implications:**\n- If someone can guess or obtain a Home Assistant username, they might gain access to that account\n- Only enable this if you're migrating from local Home Assistant accounts to OIDC\n- You can disable this later if needed",
|
||||||
|
"data": {
|
||||||
|
"enable_user_linking": "Enable automatic user linking (⚠️ Security Risk)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"finalize": {
|
||||||
|
"title": "Setup Complete",
|
||||||
|
"description": "Your OIDC authentication is now configured and ready to use.\n\n**Next Steps:**\n1. Save this configuration\n2. Restart Home Assistant if prompted\n3. The OIDC login option will appear on your login screen\n\n**Advanced Configuration:**\nAdvanced options like custom networking settings, specific claim configurations, or custom scopes are only available through YAML configuration. See the documentation for details.",
|
||||||
|
"data": {}
|
||||||
|
},
|
||||||
|
"reconfigure": {
|
||||||
|
"title": "Reconfigure OIDC Authentication",
|
||||||
|
"description": "Update your OIDC client credentials for {provider_name}.\n\n**Discovery URL:** {discovery_url}\n\n**What you can change:**\n- **Client ID**: Update your application's client identifier\n- **Client Type**: Switch between Public and Confidential client types\n- **Client Secret**: Update or add a client secret (for confidential clients)\n\n**Note:** Changes will be validated against your OIDC provider before being saved. Your existing settings will be preserved if validation fails.\n\n**Security:** For confidential clients, leave the client secret field empty to keep your existing secret unchanged.",
|
||||||
|
"data": {
|
||||||
|
"client_id": "Client ID",
|
||||||
|
"client_secret": "Client Secret (leave empty to keep current)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"cannot_connect": "Failed to connect to the OIDC provider. Please check your network connection and discovery URL.",
|
||||||
|
"discovery_invalid": "The discovery document could not be retrieved or is invalid. Please verify the discovery URL is correct.",
|
||||||
|
"jwks_invalid": "Failed to retrieve or validate the JWKS (JSON Web Key Set). Please check your provider configuration.",
|
||||||
|
"invalid_client_credentials": "The client ID or client secret appears to be invalid. Please check your OIDC application settings and ensure the credentials are correct.",
|
||||||
|
"client_secret_required": "Client secret is required when using confidential client mode.",
|
||||||
|
"invalid_url_format": "The discovery URL must be a valid HTTP or HTTPS URL.",
|
||||||
|
"invalid_client_id": "Client ID cannot be empty and must contain valid characters.",
|
||||||
|
"no_url_available": "Unable to determine Home Assistant URL for OAuth redirect. Please check your network configuration.",
|
||||||
|
"auth_url_failed": "Failed to generate authorization URL for OAuth test.",
|
||||||
|
"unknown": "An unexpected error occurred. Please check the logs for more details."
|
||||||
|
},
|
||||||
|
"abort": {
|
||||||
|
"already_configured": "This OIDC provider is already configured.",
|
||||||
|
"cannot_connect": "Unable to connect to the OIDC provider.",
|
||||||
|
"invalid_discovery": "Invalid discovery document received from the provider.",
|
||||||
|
"reconfigure_successful": "OIDC Authentication has been successfully reconfigured with the updated client credentials.",
|
||||||
|
"single_instance_allowed": "OIDC Authentication only supports a single configuration. You already have OIDC configured (either through YAML or the UI). To modify your existing configuration, go to Settings > Devices & Services > OIDC Authentication and click 'Configure'. To replace your configuration, first remove the existing one."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"step": {
|
||||||
|
"init": {
|
||||||
|
"title": "OIDC Authentication Options",
|
||||||
|
"description": "Update configuration options for your {provider_name} OIDC authentication.\n\n**User Linking:** Control how OIDC users are linked to existing Home Assistant accounts (⚠️ security implications).\n\n**Groups Configuration:** Configure role assignment based on group membership from your identity provider.\n\n**Note:** Changes take effect immediately but may require users to log out and back in.",
|
||||||
|
"data": {
|
||||||
|
"enable_user_linking": "Enable automatic user linking (⚠️ Security Risk)",
|
||||||
|
"enable_groups": "Enable group-based role assignment",
|
||||||
|
"admin_group": "Admin group name",
|
||||||
|
"user_group": "User group name (optional - leave empty to allow all authenticated users)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
0
custom_components/auth_oidc/tools/__init__.py
Normal file
0
custom_components/auth_oidc/tools/__init__.py
Normal file
69
custom_components/auth_oidc/tools/helpers.py
Normal file
69
custom_components/auth_oidc/tools/helpers.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""Helper functions for the integration."""
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from homeassistant.components import http
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from ..views.loader import AsyncTemplateRenderer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from ..provider import OpenIDAuthProvider
|
||||||
|
|
||||||
|
STATE_COOKIE_NAME = "auth_oidc_state"
|
||||||
|
|
||||||
|
|
||||||
|
def get_url(path: str, force_https: bool) -> str:
|
||||||
|
"""Returns the requested path appended to the current request base URL."""
|
||||||
|
if (req := http.current_request.get()) is None:
|
||||||
|
raise RuntimeError("No current request in context")
|
||||||
|
|
||||||
|
base_uri = str(req.url).split("/auth", 2)[0]
|
||||||
|
if force_https:
|
||||||
|
base_uri = base_uri.replace("http://", "https://")
|
||||||
|
return f"{base_uri}{path}"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_view(template: str, parameters: dict | None = None) -> str:
|
||||||
|
"""Returns the generated HTML of the requested view."""
|
||||||
|
if parameters is None:
|
||||||
|
parameters = {}
|
||||||
|
|
||||||
|
renderer = AsyncTemplateRenderer()
|
||||||
|
return await renderer.render_template(f"{template}.html", **parameters)
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_id(request: web.Request) -> str | None:
|
||||||
|
"""Return the current OIDC state cookie, if present."""
|
||||||
|
return request.cookies.get(STATE_COOKIE_NAME)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_valid_state_id(
|
||||||
|
request: web.Request, oidc_provider: "OpenIDAuthProvider"
|
||||||
|
) -> str | None:
|
||||||
|
"""Return state id only when cookie exists and state is still valid."""
|
||||||
|
state_id = get_state_id(request)
|
||||||
|
if not state_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not await oidc_provider.async_is_state_valid(state_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return state_id
|
||||||
|
|
||||||
|
|
||||||
|
def html_response(html: str, status: int = 200) -> web.Response:
|
||||||
|
"""Return an HTML response with the standard content type."""
|
||||||
|
return web.Response(text=html, content_type="text/html", status=status)
|
||||||
|
|
||||||
|
|
||||||
|
async def template_response(
|
||||||
|
template: str, parameters: dict | None = None
|
||||||
|
) -> web.Response:
|
||||||
|
"""Render a template and return it as an HTML response."""
|
||||||
|
return html_response(await get_view(template, parameters))
|
||||||
|
|
||||||
|
|
||||||
|
async def error_response(message: str, status: int = 400) -> web.Response:
|
||||||
|
"""Render the shared error view."""
|
||||||
|
return html_response(await get_view("error", {"error": message}), status=status)
|
||||||
@@ -9,11 +9,11 @@ import ssl
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from jose import jwt, jwk
|
from joserfc import jwt, jwk, jws, errors as joserfc_errors
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from .types import UserDetails
|
from .types import UserDetails
|
||||||
from .config import (
|
from ..config.const import (
|
||||||
FEATURES_DISABLE_PKCE,
|
FEATURES_DISABLE_PKCE,
|
||||||
CLAIMS_DISPLAY_NAME,
|
CLAIMS_DISPLAY_NAME,
|
||||||
CLAIMS_USERNAME,
|
CLAIMS_USERNAME,
|
||||||
@@ -22,7 +22,9 @@ from .config import (
|
|||||||
ROLE_USERS,
|
ROLE_USERS,
|
||||||
NETWORK_TLS_VERIFY,
|
NETWORK_TLS_VERIFY,
|
||||||
NETWORK_TLS_CA_PATH,
|
NETWORK_TLS_CA_PATH,
|
||||||
|
DEFAULT_ID_TOKEN_SIGNING_ALGORITHM,
|
||||||
)
|
)
|
||||||
|
from .validation import validate_url
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -34,6 +36,28 @@ class OIDCClientException(Exception):
|
|||||||
class OIDCDiscoveryInvalid(OIDCClientException):
|
class OIDCDiscoveryInvalid(OIDCClientException):
|
||||||
"Raised when the discovery document is not found, invalid or otherwise malformed."
|
"Raised when the discovery document is not found, invalid or otherwise malformed."
|
||||||
|
|
||||||
|
type: Optional[str]
|
||||||
|
details: Optional[dict]
|
||||||
|
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.message = "OIDC Discovery document is invalid"
|
||||||
|
self.type = kwargs.pop("type", None)
|
||||||
|
self.details = kwargs.pop("details", None)
|
||||||
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
def get_detail_string(self) -> str:
|
||||||
|
"""Returns a detailed string for logging purposes."""
|
||||||
|
string = []
|
||||||
|
|
||||||
|
if self.type:
|
||||||
|
string.append(f"type: {self.type}")
|
||||||
|
|
||||||
|
if self.details:
|
||||||
|
for key, value in self.details.items():
|
||||||
|
string.append(f"{key}: {value}")
|
||||||
|
|
||||||
|
return ", ".join(string)
|
||||||
|
|
||||||
|
|
||||||
class OIDCTokenResponseInvalid(OIDCClientException):
|
class OIDCTokenResponseInvalid(OIDCClientException):
|
||||||
"Raised when the token request returns invalid."
|
"Raised when the token request returns invalid."
|
||||||
@@ -68,16 +92,209 @@ class HTTPClientError(aiohttp.ClientResponseError):
|
|||||||
return f"{self.status} ({self.message}) with response body: {self.body}"
|
return f"{self.status} ({self.message}) with response body: {self.body}"
|
||||||
|
|
||||||
|
|
||||||
|
async def http_raise_for_status(response: aiohttp.ClientResponse) -> None:
|
||||||
|
"""Raises an exception if the response is not OK."""
|
||||||
|
if not response.ok:
|
||||||
|
# reason should always be not None for a started response
|
||||||
|
assert response.reason is not None
|
||||||
|
body = await response.text()
|
||||||
|
|
||||||
|
raise HTTPClientError(
|
||||||
|
response.request_info,
|
||||||
|
response.history,
|
||||||
|
status=response.status,
|
||||||
|
message=response.reason,
|
||||||
|
headers=response.headers,
|
||||||
|
body=body,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OIDCDiscoveryClient:
|
||||||
|
"""OIDC Discovery Client implementation for Python"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
discovery_url: str,
|
||||||
|
http_session: aiohttp.ClientSession,
|
||||||
|
verification_context: dict,
|
||||||
|
):
|
||||||
|
self.discovery_url = discovery_url
|
||||||
|
self.http_session = http_session
|
||||||
|
self.verification_context = verification_context
|
||||||
|
|
||||||
|
async def _fetch_discovery_document(self):
|
||||||
|
"""Fetches discovery document from the given URL."""
|
||||||
|
try:
|
||||||
|
async with self.http_session.get(self.discovery_url) as response:
|
||||||
|
await http_raise_for_status(response)
|
||||||
|
return await response.json()
|
||||||
|
except HTTPClientError as e:
|
||||||
|
if e.status == 404:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Error: Discovery document not found at %s", self.discovery_url
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_LOGGER.warning("Error fetching discovery: %s", e)
|
||||||
|
raise OIDCDiscoveryInvalid(type="fetch_error") from e
|
||||||
|
|
||||||
|
async def _fetch_jwks(self, jwks_uri):
|
||||||
|
"""Fetches JWKS from the given URL."""
|
||||||
|
try:
|
||||||
|
async with self.http_session.get(jwks_uri) as response:
|
||||||
|
await http_raise_for_status(response)
|
||||||
|
return await response.json()
|
||||||
|
except HTTPClientError as e:
|
||||||
|
_LOGGER.warning("Error fetching JWKS: %s", e)
|
||||||
|
raise OIDCJWKSInvalid from e
|
||||||
|
|
||||||
|
# pylint: disable=too-many-branches
|
||||||
|
async def _validate_discovery_document(self, document):
|
||||||
|
"""Validates the discovery document."""
|
||||||
|
|
||||||
|
# Verify that required endpoints are present
|
||||||
|
required_endpoints = [
|
||||||
|
"issuer",
|
||||||
|
"authorization_endpoint",
|
||||||
|
"token_endpoint",
|
||||||
|
"jwks_uri",
|
||||||
|
]
|
||||||
|
|
||||||
|
for endpoint in required_endpoints:
|
||||||
|
if endpoint not in document:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Error: Discovery document %s is missing required endpoint: %s",
|
||||||
|
self.discovery_url,
|
||||||
|
endpoint,
|
||||||
|
)
|
||||||
|
raise OIDCDiscoveryInvalid(
|
||||||
|
type="missing_endpoint", details={"endpoint": endpoint}
|
||||||
|
)
|
||||||
|
if validate_url(document[endpoint]) is False:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Error: Discovery document %s has invalid URL in endpoint: %s (%s)",
|
||||||
|
self.discovery_url,
|
||||||
|
endpoint,
|
||||||
|
document[endpoint],
|
||||||
|
)
|
||||||
|
raise OIDCDiscoveryInvalid(
|
||||||
|
type="invalid_endpoint",
|
||||||
|
details={"endpoint": endpoint, "url": document[endpoint]},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify optional response_modes_supported
|
||||||
|
if "response_modes_supported" in document:
|
||||||
|
if "query" not in document["response_modes_supported"]:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Error: Discovery document %s does not support required 'query' "
|
||||||
|
"response mode, only supports: %s",
|
||||||
|
self.discovery_url,
|
||||||
|
document["response_modes_supported"],
|
||||||
|
)
|
||||||
|
raise OIDCDiscoveryInvalid(
|
||||||
|
type="does_not_support_response_mode",
|
||||||
|
details={"modes": document["response_modes_supported"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
# If grant_types_supported is set, should support 'authorization_code'
|
||||||
|
if "grant_types_supported" in document:
|
||||||
|
if "authorization_code" not in document["grant_types_supported"]:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Error: Discovery document %s does not support required "
|
||||||
|
"'authorization_code' grant type, only supports: %s",
|
||||||
|
self.discovery_url,
|
||||||
|
document["grant_types_supported"],
|
||||||
|
)
|
||||||
|
raise OIDCDiscoveryInvalid(
|
||||||
|
type="does_not_support_grant_type",
|
||||||
|
details={
|
||||||
|
"required": "authorization_code",
|
||||||
|
"supported": document["grant_types_supported"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# If response_types_supported is set, should support 'code'
|
||||||
|
if "response_types_supported" in document:
|
||||||
|
if "code" not in document["response_types_supported"]:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Error: Discovery document %s does not support required "
|
||||||
|
"'code' response type, only supports: %s",
|
||||||
|
self.discovery_url,
|
||||||
|
document["response_types_supported"],
|
||||||
|
)
|
||||||
|
raise OIDCDiscoveryInvalid(
|
||||||
|
type="does_not_support_response_type",
|
||||||
|
details={
|
||||||
|
"required": "code",
|
||||||
|
"supported": document["response_types_supported"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# If code_challenge_methods_supported is present, check that it contains S256
|
||||||
|
if "code_challenge_methods_supported" in document:
|
||||||
|
if "S256" not in document["code_challenge_methods_supported"]:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Error: Discovery document %s does not support required "
|
||||||
|
"'S256' code challenge method, only supports: %s",
|
||||||
|
self.discovery_url,
|
||||||
|
document["code_challenge_methods_supported"],
|
||||||
|
)
|
||||||
|
raise OIDCDiscoveryInvalid(
|
||||||
|
type="does_not_support_required_code_challenge_method",
|
||||||
|
details={
|
||||||
|
"required": "S256",
|
||||||
|
"supported": document["code_challenge_methods_supported"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the id_token_signing_alg_values_supported field is present and filled
|
||||||
|
signing_values = document.get("id_token_signing_alg_values_supported", None)
|
||||||
|
if signing_values is None:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Error: Discovery document %s does not have "
|
||||||
|
"'id_token_signing_alg_values_supported' field",
|
||||||
|
self.discovery_url,
|
||||||
|
)
|
||||||
|
raise OIDCDiscoveryInvalid(type="missing_id_token_signing_alg_values")
|
||||||
|
|
||||||
|
# Verify that the requested id_token_signing_alg is supported
|
||||||
|
requested_alg = self.verification_context.get("id_token_signing_alg", None)
|
||||||
|
if requested_alg is not None and requested_alg not in signing_values:
|
||||||
|
_LOGGER.warning(
|
||||||
|
"Error: Discovery document %s does not support requested "
|
||||||
|
"id_token_signing_alg '%s', only supports: %s",
|
||||||
|
self.discovery_url,
|
||||||
|
requested_alg,
|
||||||
|
signing_values,
|
||||||
|
)
|
||||||
|
raise OIDCDiscoveryInvalid(
|
||||||
|
type="does_not_support_id_token_signing_alg",
|
||||||
|
details={"requested": requested_alg, "supported": signing_values},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_discovery_document(self):
|
||||||
|
"""Fetches discovery document."""
|
||||||
|
document = await self._fetch_discovery_document()
|
||||||
|
await self._validate_discovery_document(document)
|
||||||
|
return document
|
||||||
|
|
||||||
|
async def fetch_jwks(self, jwks_uri: str | None = None):
|
||||||
|
"""Fetches JWKS."""
|
||||||
|
if jwks_uri is None:
|
||||||
|
discovery_document = await self._fetch_discovery_document()
|
||||||
|
jwks_uri = discovery_document["jwks_uri"]
|
||||||
|
return await self._fetch_jwks(jwks_uri)
|
||||||
|
|
||||||
|
|
||||||
# pylint: disable=too-many-instance-attributes
|
# pylint: disable=too-many-instance-attributes
|
||||||
class OIDCClient:
|
class OIDCClient:
|
||||||
"""OIDC Client implementation for Python, including PKCE."""
|
"""OIDC Client implementation for Python, including PKCE."""
|
||||||
|
|
||||||
# Flows stores the state, code_verifier and nonce of all current flows.
|
|
||||||
flows = {}
|
|
||||||
|
|
||||||
# HTTP session to be used
|
# HTTP session to be used
|
||||||
http_session: aiohttp.ClientSession = None
|
http_session: aiohttp.ClientSession = None
|
||||||
|
|
||||||
|
# OIDC Discovery tool to be used
|
||||||
|
discovery_class: OIDCDiscoveryClient = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@@ -92,13 +309,16 @@ class OIDCClient:
|
|||||||
self.client_id = client_id
|
self.client_id = client_id
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
|
||||||
|
# Stores code_verifier and nonce for active authorization flows.
|
||||||
|
self.flows: dict[str, dict[str, str]] = {}
|
||||||
|
|
||||||
# Optional parameters
|
# Optional parameters
|
||||||
self.client_secret = kwargs.get("client_secret")
|
self.client_secret = kwargs.get("client_secret")
|
||||||
|
|
||||||
# Default id_token_signing_alg to RS256 if not specified
|
# Default id_token_signing_alg to RS256 if not specified
|
||||||
self.id_token_signing_alg = kwargs.get("id_token_signing_alg")
|
self.id_token_signing_alg = kwargs.get("id_token_signing_alg")
|
||||||
if self.id_token_signing_alg is None:
|
if self.id_token_signing_alg is None:
|
||||||
self.id_token_signing_alg = "RS256"
|
self.id_token_signing_alg = DEFAULT_ID_TOKEN_SIGNING_ALGORITHM
|
||||||
|
|
||||||
features = kwargs.get("features")
|
features = kwargs.get("features")
|
||||||
claims = kwargs.get("claims")
|
claims = kwargs.get("claims")
|
||||||
@@ -122,23 +342,6 @@ class OIDCClient:
|
|||||||
_LOGGER.debug("Closing HTTP session")
|
_LOGGER.debug("Closing HTTP session")
|
||||||
self.http_session.close()
|
self.http_session.close()
|
||||||
|
|
||||||
async def http_raise_for_status(self, response: aiohttp.ClientResponse) -> None:
|
|
||||||
"""Raises an exception if the response is not OK."""
|
|
||||||
if not response.ok:
|
|
||||||
# reason should always be not None for a started response
|
|
||||||
assert response.reason is not None
|
|
||||||
|
|
||||||
body = await response.text()
|
|
||||||
|
|
||||||
raise HTTPClientError(
|
|
||||||
response.request_info,
|
|
||||||
response.history,
|
|
||||||
status=response.status,
|
|
||||||
message=response.reason,
|
|
||||||
headers=response.headers,
|
|
||||||
body=body,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _base64url_encode(self, value: str) -> str:
|
def _base64url_encode(self, value: str) -> str:
|
||||||
"""Uses base64url encoding on a given string"""
|
"""Uses base64url encoding on a given string"""
|
||||||
return base64.urlsafe_b64encode(value).rstrip(b"=").decode("utf-8")
|
return base64.urlsafe_b64encode(value).rstrip(b"=").decode("utf-8")
|
||||||
@@ -173,42 +376,13 @@ class OIDCClient:
|
|||||||
)
|
)
|
||||||
return self.http_session
|
return self.http_session
|
||||||
|
|
||||||
async def _fetch_discovery_document(self):
|
|
||||||
"""Fetches discovery document from the given URL."""
|
|
||||||
try:
|
|
||||||
session = await self._get_http_session()
|
|
||||||
|
|
||||||
async with session.get(self.discovery_url) as response:
|
|
||||||
await self.http_raise_for_status(response)
|
|
||||||
return await response.json()
|
|
||||||
except HTTPClientError as e:
|
|
||||||
if e.status == 404:
|
|
||||||
_LOGGER.warning(
|
|
||||||
"Error: Discovery document not found at %s", self.discovery_url
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_LOGGER.warning("Error fetching discovery: %s", e)
|
|
||||||
raise OIDCDiscoveryInvalid from e
|
|
||||||
|
|
||||||
async def _get_jwks(self, jwks_uri):
|
|
||||||
"""Fetches JWKS from the given URL."""
|
|
||||||
try:
|
|
||||||
session = await self._get_http_session()
|
|
||||||
|
|
||||||
async with session.get(jwks_uri) as response:
|
|
||||||
await self.http_raise_for_status(response)
|
|
||||||
return await response.json()
|
|
||||||
except HTTPClientError as e:
|
|
||||||
_LOGGER.warning("Error fetching JWKS: %s", e)
|
|
||||||
raise OIDCJWKSInvalid from e
|
|
||||||
|
|
||||||
async def _make_token_request(self, token_endpoint, query_params):
|
async def _make_token_request(self, token_endpoint, query_params):
|
||||||
"""Performs the token POST call"""
|
"""Performs the token POST call"""
|
||||||
try:
|
try:
|
||||||
session = await self._get_http_session()
|
session = await self._get_http_session()
|
||||||
|
|
||||||
async with session.post(token_endpoint, data=query_params) as response:
|
async with session.post(token_endpoint, data=query_params) as response:
|
||||||
await self.http_raise_for_status(response)
|
await http_raise_for_status(response)
|
||||||
return await response.json()
|
return await response.json()
|
||||||
except HTTPClientError as e:
|
except HTTPClientError as e:
|
||||||
if e.status == 400:
|
if e.status == 400:
|
||||||
@@ -231,25 +405,46 @@ class OIDCClient:
|
|||||||
headers = {"Authorization": "Bearer " + access_token}
|
headers = {"Authorization": "Bearer " + access_token}
|
||||||
|
|
||||||
async with session.get(userinfo_uri, headers=headers) as response:
|
async with session.get(userinfo_uri, headers=headers) as response:
|
||||||
await self.http_raise_for_status(response)
|
await http_raise_for_status(response)
|
||||||
return await response.json()
|
return await response.json()
|
||||||
except HTTPClientError as e:
|
except HTTPClientError as e:
|
||||||
_LOGGER.warning("Error fetching userinfo: %s", e)
|
_LOGGER.warning("Error fetching userinfo: %s", e)
|
||||||
raise OIDCUserinfoInvalid from e
|
raise OIDCUserinfoInvalid from e
|
||||||
|
|
||||||
async def _parse_id_token(
|
async def _fetch_discovery_document(self):
|
||||||
self, id_token: str, access_token: str | None
|
"""Fetches discovery document."""
|
||||||
) -> Optional[dict]:
|
if self.discovery_document is not None:
|
||||||
|
return self.discovery_document
|
||||||
|
|
||||||
|
if self.discovery_class is None:
|
||||||
|
session = await self._get_http_session()
|
||||||
|
self.discovery_class = OIDCDiscoveryClient(
|
||||||
|
discovery_url=self.discovery_url,
|
||||||
|
http_session=session,
|
||||||
|
verification_context={
|
||||||
|
"id_token_signing_alg": self.id_token_signing_alg,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.discovery_document = await self.discovery_class.fetch_discovery_document()
|
||||||
|
return self.discovery_document
|
||||||
|
|
||||||
|
async def _fetch_jwks(self, jwks_uri: str):
|
||||||
|
"""Fetches JWKS."""
|
||||||
|
return await self.discovery_class.fetch_jwks(jwks_uri)
|
||||||
|
|
||||||
|
async def _parse_id_token(self, id_token: str) -> Optional[dict]:
|
||||||
"""Parses the ID token into a dict containing token contents."""
|
"""Parses the ID token into a dict containing token contents."""
|
||||||
if self.discovery_document is None:
|
if self.discovery_document is None:
|
||||||
self.discovery_document = await self._fetch_discovery_document()
|
self.discovery_document = await self._fetch_discovery_document()
|
||||||
|
|
||||||
jwks_uri = self.discovery_document["jwks_uri"]
|
jwks_uri = self.discovery_document["jwks_uri"]
|
||||||
jwks_data = await self._get_jwks(jwks_uri)
|
jwks_data = await self._fetch_jwks(jwks_uri)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Obtain the id_token header
|
# Obtain the id_token header
|
||||||
unverified_header = jwt.get_unverified_header(id_token)
|
token_obj = jws.extract_compact(id_token.encode())
|
||||||
|
unverified_header = token_obj.protected
|
||||||
if not unverified_header:
|
if not unverified_header:
|
||||||
_LOGGER.warning("Could not get header from received id_token.")
|
_LOGGER.warning("Could not get header from received id_token.")
|
||||||
return None
|
return None
|
||||||
@@ -278,7 +473,7 @@ class OIDCClient:
|
|||||||
)
|
)
|
||||||
raise OIDCIdTokenSigningAlgorithmInvalid()
|
raise OIDCIdTokenSigningAlgorithmInvalid()
|
||||||
|
|
||||||
jwk_obj = jwk.construct(
|
jwk_obj = jwk.import_key(
|
||||||
{
|
{
|
||||||
"kty": "oct",
|
"kty": "oct",
|
||||||
"k": base64.urlsafe_b64encode(
|
"k": base64.urlsafe_b64encode(
|
||||||
@@ -311,9 +506,9 @@ class OIDCClient:
|
|||||||
signing_key["alg"] = alg
|
signing_key["alg"] = alg
|
||||||
|
|
||||||
# Construct the JWK from the RSA key
|
# Construct the JWK from the RSA key
|
||||||
jwk_obj = jwk.construct(signing_key)
|
jwk_obj = jwk.import_key(signing_key)
|
||||||
|
|
||||||
# Verify the token
|
# Decode the token, decode does not verify it
|
||||||
decoded_token = jwt.decode(
|
decoded_token = jwt.decode(
|
||||||
id_token,
|
id_token,
|
||||||
jwk_obj,
|
jwk_obj,
|
||||||
@@ -322,61 +517,43 @@ class OIDCClient:
|
|||||||
# according to JWS [JWS] using the algorithm specified in the JWT
|
# according to JWS [JWS] using the algorithm specified in the JWT
|
||||||
# alg Header Parameter.
|
# alg Header Parameter.
|
||||||
algorithms=[self.id_token_signing_alg],
|
algorithms=[self.id_token_signing_alg],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create Claims Registry for validation
|
||||||
|
id_token_validator = jwt.JWTClaimsRegistry(
|
||||||
|
leeway=5,
|
||||||
# OpenID Connect Core 1.0 Section 3.1.3.7.3
|
# OpenID Connect Core 1.0 Section 3.1.3.7.3
|
||||||
# The Client MUST validate that the aud (audience) Claim contains
|
# The Client MUST validate that the aud (audience) Claim contains
|
||||||
# its client_id value registered at the Issuer identified by the
|
# its client_id value registered at the Issuer identified by the
|
||||||
# iss (issuer) Claim as an audience.
|
# iss (issuer) Claim as an audience.
|
||||||
audience=self.client_id,
|
aud={"essential": True, "value": self.client_id},
|
||||||
# OpenID Connect Core 1.0 Section 3.1.3.7.2
|
# OpenID Connect Core 1.0 Section 3.1.3.7.2
|
||||||
# The Issuer Identifier for the OpenID Provider MUST exactly
|
# The Issuer Identifier for the OpenID Provider MUST exactly
|
||||||
# match the value of the iss (issuer) Claim.
|
# match the value of the iss (issuer) Claim.
|
||||||
issuer=self.discovery_document["issuer"],
|
iss={"essential": True, "value": self.discovery_document["issuer"]},
|
||||||
access_token=access_token,
|
|
||||||
options={
|
|
||||||
# Verify everything if present
|
|
||||||
"verify_signature": True,
|
|
||||||
"verify_aud": True,
|
|
||||||
"verify_iat": True,
|
|
||||||
"verify_exp": True,
|
|
||||||
"verify_nbf": True,
|
|
||||||
"verify_iss": True,
|
|
||||||
"verify_sub": True,
|
|
||||||
"verify_jti": True,
|
|
||||||
"verify_at_hash": True,
|
|
||||||
# OpenID Connect Core 1.0 Section 3.1.3.7.3
|
|
||||||
"require_aud": True,
|
|
||||||
# OpenID Connect Core 1.0 Section 3.1.3.7.10
|
|
||||||
"require_iat": True,
|
|
||||||
# OpenID Connect Core 1.0 Section 3.1.3.7.9
|
# OpenID Connect Core 1.0 Section 3.1.3.7.9
|
||||||
"require_exp": True,
|
# OpenID Connect Core 1.0 Section 3.1.3.7.10
|
||||||
# OpenID Connect Core 1.0 Section 3.1.3.7.2
|
# No need to specify exp, nbf, iat, they are in here by default
|
||||||
"require_iss": True,
|
sub={"essential": True},
|
||||||
# We need the sub as it's used to identify the user
|
|
||||||
"require_sub": True,
|
|
||||||
# Other values, not required.
|
|
||||||
"require_nbf": False,
|
|
||||||
"require_jti": False,
|
|
||||||
"require_at_hash": False,
|
|
||||||
"leeway": 5,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return decoded_token
|
|
||||||
|
|
||||||
except jwt.JWTError as e:
|
id_token_validator.validate(decoded_token.claims)
|
||||||
_LOGGER.warning("JWT Verification failed: %s", e)
|
return decoded_token.claims
|
||||||
|
|
||||||
|
except joserfc_errors.JoseError as e:
|
||||||
|
_LOGGER.warning("JWT verification failed: %s", e)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def async_get_authorization_url(self, redirect_uri: str) -> Optional[str]:
|
async def async_get_authorization_url(
|
||||||
|
self, redirect_uri: str, state: str
|
||||||
|
) -> Optional[str]:
|
||||||
"""Generates the authorization URL for the OIDC flow."""
|
"""Generates the authorization URL for the OIDC flow."""
|
||||||
try:
|
try:
|
||||||
if self.discovery_document is None:
|
discovery_document = await self._fetch_discovery_document()
|
||||||
self.discovery_document = await self._fetch_discovery_document()
|
auth_endpoint = discovery_document["authorization_endpoint"]
|
||||||
|
|
||||||
auth_endpoint = self.discovery_document["authorization_endpoint"]
|
|
||||||
|
|
||||||
# Generate random nonce & state
|
# Generate random nonce & state
|
||||||
nonce = self._generate_random_url_string()
|
nonce = self._generate_random_url_string()
|
||||||
state = self._generate_random_url_string()
|
|
||||||
|
|
||||||
# Generate PKCE (RFC 7636) parameters
|
# Generate PKCE (RFC 7636) parameters
|
||||||
code_verifier = self._generate_random_url_string(32)
|
code_verifier = self._generate_random_url_string(32)
|
||||||
@@ -417,8 +594,9 @@ class OIDCClient:
|
|||||||
|
|
||||||
# Fetch userinfo if there is an userinfo_endpoint available
|
# Fetch userinfo if there is an userinfo_endpoint available
|
||||||
# and use the data to supply the missing values in id_token
|
# and use the data to supply the missing values in id_token
|
||||||
if "userinfo_endpoint" in self.discovery_document:
|
discovery_document = await self._fetch_discovery_document()
|
||||||
userinfo_endpoint = self.discovery_document["userinfo_endpoint"]
|
if "userinfo_endpoint" in discovery_document:
|
||||||
|
userinfo_endpoint = discovery_document["userinfo_endpoint"]
|
||||||
userinfo = await self._get_userinfo(userinfo_endpoint, access_token)
|
userinfo = await self._get_userinfo(userinfo_endpoint, access_token)
|
||||||
|
|
||||||
# Replace missing claims in the id_token with their userinfo version
|
# Replace missing claims in the id_token with their userinfo version
|
||||||
@@ -451,9 +629,7 @@ class OIDCClient:
|
|||||||
# Only unique per issuer, so we combine it with the issuer and hash it.
|
# Only unique per issuer, so we combine it with the issuer and hash it.
|
||||||
# This might allow multiple OIDC providers to be used with this integration.
|
# This might allow multiple OIDC providers to be used with this integration.
|
||||||
"sub": hashlib.sha256(
|
"sub": hashlib.sha256(
|
||||||
f"{self.discovery_document['issuer']}.{id_token.get('sub')}".encode(
|
f"{discovery_document['issuer']}.{id_token.get('sub')}".encode("utf-8")
|
||||||
"utf-8"
|
|
||||||
)
|
|
||||||
).hexdigest(),
|
).hexdigest(),
|
||||||
# Display name, configurable
|
# Display name, configurable
|
||||||
"display_name": id_token.get(self.display_name_claim),
|
"display_name": id_token.get(self.display_name_claim),
|
||||||
@@ -469,15 +645,12 @@ class OIDCClient:
|
|||||||
"""Completes the OIDC token flow to obtain a user's details."""
|
"""Completes the OIDC token flow to obtain a user's details."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if state not in self.flows:
|
flow = self.flows.pop(state, None)
|
||||||
|
if flow is None:
|
||||||
raise OIDCStateInvalid
|
raise OIDCStateInvalid
|
||||||
|
|
||||||
flow = self.flows[state]
|
discovery_document = await self._fetch_discovery_document()
|
||||||
|
token_endpoint = discovery_document["token_endpoint"]
|
||||||
if self.discovery_document is None:
|
|
||||||
self.discovery_document = await self._fetch_discovery_document()
|
|
||||||
|
|
||||||
token_endpoint = self.discovery_document["token_endpoint"]
|
|
||||||
|
|
||||||
# Construct the params
|
# Construct the params
|
||||||
query_params = {
|
query_params = {
|
||||||
@@ -501,11 +674,9 @@ class OIDCClient:
|
|||||||
)
|
)
|
||||||
|
|
||||||
id_token = token_response.get("id_token")
|
id_token = token_response.get("id_token")
|
||||||
access_token = token_response.get("access_token")
|
|
||||||
|
|
||||||
# Parse the id token to obtain the relevant details
|
# Parse the id token to obtain the relevant details
|
||||||
# Access token is supplied to check at_hash if present
|
id_token = await self._parse_id_token(id_token)
|
||||||
id_token = await self._parse_id_token(id_token, access_token)
|
|
||||||
|
|
||||||
if id_token is None:
|
if id_token is None:
|
||||||
_LOGGER.warning("ID token could not be parsed!")
|
_LOGGER.warning("ID token could not be parsed!")
|
||||||
@@ -519,6 +690,7 @@ class OIDCClient:
|
|||||||
_LOGGER.warning("Nonce mismatch!")
|
_LOGGER.warning("Nonce mismatch!")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
access_token = token_response.get("access_token")
|
||||||
data = await self.parse_user_details(id_token, access_token)
|
data = await self.parse_user_details(id_token, access_token)
|
||||||
|
|
||||||
# Log which details were obtained for debugging
|
# Log which details were obtained for debugging
|
||||||
@@ -16,3 +16,26 @@ class UserDetails(dict):
|
|||||||
username: str
|
username: str
|
||||||
# Home Assistant role to assign to this user
|
# Home Assistant role to assign to this user
|
||||||
role: Literal["system-admin", "system-users", "invalid"]
|
role: Literal["system-admin", "system-users", "invalid"]
|
||||||
|
|
||||||
|
|
||||||
|
class OIDCState(dict):
|
||||||
|
"""OIDC State representation"""
|
||||||
|
|
||||||
|
# ID of this state
|
||||||
|
id: str
|
||||||
|
|
||||||
|
# User friendly device code
|
||||||
|
device_code: str | None
|
||||||
|
|
||||||
|
# The redirect_uri associated with this state,
|
||||||
|
# to be able to redirect the user back after authentication
|
||||||
|
redirect_uri: str
|
||||||
|
|
||||||
|
# User details, if available
|
||||||
|
user_details: UserDetails | None
|
||||||
|
|
||||||
|
# Expiration time of this state, in ISO format
|
||||||
|
expiration: str
|
||||||
|
|
||||||
|
# IP address
|
||||||
|
ip_address: str | None
|
||||||
37
custom_components/auth_oidc/tools/validation.py
Normal file
37
custom_components/auth_oidc/tools/validation.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Validation and sanitization helpers for config flow inputs."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
|
||||||
|
def validate_url(url: str) -> bool:
|
||||||
|
"""Validate that a URL is properly formatted."""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url.strip())
|
||||||
|
return bool(parsed.scheme in ("http", "https") and parsed.netloc)
|
||||||
|
except (ValueError, TypeError, AttributeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def validate_discovery_url(url: str) -> bool:
|
||||||
|
"""Validate that a URL is properly formatted for OIDC discovery."""
|
||||||
|
try:
|
||||||
|
parsed = urlparse(url.strip())
|
||||||
|
return bool(
|
||||||
|
parsed.scheme in ("http", "https")
|
||||||
|
and parsed.netloc
|
||||||
|
and parsed.path.endswith("/.well-known/openid-configuration")
|
||||||
|
)
|
||||||
|
except (ValueError, TypeError, AttributeError):
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def sanitize_client_secret(secret: str) -> str:
|
||||||
|
"""Sanitize client secret input."""
|
||||||
|
return secret.strip() if secret else ""
|
||||||
|
|
||||||
|
|
||||||
|
def validate_client_id(client_id: str) -> bool:
|
||||||
|
"""Validate client ID format."""
|
||||||
|
return bool(client_id and client_id.strip())
|
||||||
94
custom_components/auth_oidc/translations/en.json
Normal file
94
custom_components/auth_oidc/translations/en.json
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
{
|
||||||
|
"config": {
|
||||||
|
"step": {
|
||||||
|
"user": {
|
||||||
|
"title": "Choose OIDC Provider",
|
||||||
|
"description": "Select your OpenID Connect identity provider to get started with the setup.\n\nIf you want to use a provider that isn't listed, try the Generic OpenID Connect provider or use the advanced YAML configuration instead.",
|
||||||
|
"data": {
|
||||||
|
"provider": "Provider"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"discovery_url": {
|
||||||
|
"title": "Provider Configuration",
|
||||||
|
"description": "Enter the discovery URL for {provider_name}. This is typically found in your provider's admin interface and ends with '/.well-known/openid-configuration'.\n\nNeed detailed setup instructions? See the [provider guide]({documentation_url}).",
|
||||||
|
"data": {
|
||||||
|
"discovery_url": "Discovery URL"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"validate_connection": {
|
||||||
|
"title": "Connection Validation",
|
||||||
|
"description": "Testing connection to your {provider_name} OIDC provider...\n\n**Discovery URL:** {discovery_url}\n\n{discovery_details}\n\n**What to do next:**\n- **Continue Setup:** Proceed with the configuration (when validation succeeds)\n- **Retry Validation:** Test the connection again with current settings\n- **Modify Discovery URL:** Go back to change the discovery URL\n- **Change Provider:** Start over with a different provider\n\n**Need Help?** Check the [setup documentation]({documentation_url}) for detailed configuration instructions.",
|
||||||
|
"data": {
|
||||||
|
"action": "Choose an action"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"client_config": {
|
||||||
|
"title": "Client Configuration",
|
||||||
|
"description": "Configure your OIDC client. You can find these details in your {provider_name} application settings.\n\n**Discovery URL:** {discovery_url}\n\n**Setup Instructions:**\n1. Register a new application in your OIDC provider\n2. Set the application type to 'Public Client' (recommended for most users)\n3. Add redirect URLs for Home Assistant\n4. Copy the Client ID below\n\n**Note:** If your provider requires a client secret, check 'Use Confidential Client' and provide your client secret below.\n\n**Need detailed setup instructions?** Check the [setup guide]({documentation_url}) for step-by-step instructions.",
|
||||||
|
"data": {
|
||||||
|
"client_id": "Client ID",
|
||||||
|
"client_secret": "Client Secret (optional; required by some providers)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"groups_config": {
|
||||||
|
"title": "Groups & Role Configuration",
|
||||||
|
"description": "Configure how user groups from {provider_name} should be mapped to Home Assistant roles.\n\n**Groups Support:** Groups allow you to automatically assign admin or user roles based on group membership in your identity provider.\n\n**Admin Group:** Users in this group will have administrator access\n**User Group:** Users in this group will have standard user access (leave empty to allow all authenticated users)",
|
||||||
|
"data": {
|
||||||
|
"enable_groups": "Enable group-based role assignment",
|
||||||
|
"admin_group": "Admin group name",
|
||||||
|
"user_group": "User group name (optional)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"user_linking": {
|
||||||
|
"title": "User Linking Options",
|
||||||
|
"description": "Configure how OIDC users are linked to existing Home Assistant users.\n\n**⚠️ Important Security Information:**\n\n**User Linking Disabled (Recommended):** New OIDC accounts are created for each user. This is the most secure option.\n\n**User Linking Enabled:** OIDC users can be linked to existing Home Assistant users by username. **This has security implications:**\n- If someone can guess or obtain a Home Assistant username, they might gain access to that account\n- Only enable this if you're migrating from local Home Assistant accounts to OIDC\n- You can disable this later if needed",
|
||||||
|
"data": {
|
||||||
|
"enable_user_linking": "Enable automatic user linking (⚠️ Security Risk)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"finalize": {
|
||||||
|
"title": "Setup Complete",
|
||||||
|
"description": "Your OIDC authentication is now configured and ready to use.\n\n**Next Steps:**\n1. Save this configuration\n2. Restart Home Assistant if prompted\n3. The OIDC login option will appear on your login screen\n\n**Advanced Configuration:**\nAdvanced options like custom networking settings, specific claim configurations, or custom scopes are only available through YAML configuration. See the documentation for details.",
|
||||||
|
"data": {}
|
||||||
|
},
|
||||||
|
"reconfigure": {
|
||||||
|
"title": "Reconfigure OIDC Authentication",
|
||||||
|
"description": "Update your OIDC client credentials for {provider_name}.\n\n**Discovery URL:** {discovery_url}\n\n**What you can change:**\n- **Client ID**: Update your application's client identifier\n- **Client Type**: Switch between Public and Confidential client types\n- **Client Secret**: Update or add a client secret (for confidential clients)\n\n**Note:** Changes will be validated against your OIDC provider before being saved. Your existing settings will be preserved if validation fails.\n\n**Security:** For confidential clients, leave the client secret field empty to keep your existing secret unchanged.",
|
||||||
|
"data": {
|
||||||
|
"client_id": "Client ID",
|
||||||
|
"client_secret": "Client Secret (leave empty to keep current)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"error": {
|
||||||
|
"cannot_connect": "Failed to connect to the OIDC provider. Please check your network connection and discovery URL.",
|
||||||
|
"discovery_invalid": "The discovery document could not be retrieved or is invalid. Please verify the discovery URL is correct.",
|
||||||
|
"jwks_invalid": "Failed to retrieve or validate the JWKS (JSON Web Key Set). Please check your provider configuration.",
|
||||||
|
"invalid_url_format": "The discovery URL must be a valid HTTP or HTTPS URL and should end with '/.well-known/openid-configuration'",
|
||||||
|
"invalid_client_id": "Client ID cannot be empty and must contain valid characters.",
|
||||||
|
"unknown": "An unexpected error occurred. Please check the logs for more details."
|
||||||
|
},
|
||||||
|
"abort": {
|
||||||
|
"already_configured": "This OIDC provider is already configured.",
|
||||||
|
"cannot_connect": "Unable to connect to the OIDC provider.",
|
||||||
|
"invalid_discovery": "Invalid discovery document received from the provider.",
|
||||||
|
"reconfigure_successful": "OIDC Authentication has been successfully reconfigured with the updated client credentials.",
|
||||||
|
"single_instance_allowed": "OIDC Authentication only supports a single configuration. You already have OIDC configured in the UI. To modify your existing configuration, go to Settings > Devices & Services > OIDC Authentication and click 'Configure'. To replace your configuration, first remove the existing one.",
|
||||||
|
"yaml_configured": "You are currently using YAML configuration for this integration. To switch to UI configuration, please remove the YAML configuration first. Note that some advanced options configured via YAML may not be available in the UI."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"options": {
|
||||||
|
"step": {
|
||||||
|
"init": {
|
||||||
|
"title": "OIDC Authentication Options",
|
||||||
|
"description": "Update configuration options for your {provider_name} OIDC authentication.\n\n**User Linking:** Control how OIDC users are linked to existing Home Assistant accounts (⚠️ security implications).\n\n**Groups Configuration:** Configure role assignment based on group membership from your identity provider.\n\n**Note:** Changes take effect immediately but may require users to log out and back in.",
|
||||||
|
"data": {
|
||||||
|
"enable_user_linking": "Enable automatic user linking (⚠️ Security Risk)",
|
||||||
|
"enable_groups": "Enable group-based role assignment",
|
||||||
|
"admin_group": "Admin group name",
|
||||||
|
"user_group": "User group name (optional - leave empty to allow all authenticated users)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
0
custom_components/auth_oidc/views/__init__.py
Normal file
0
custom_components/auth_oidc/views/__init__.py
Normal file
@@ -40,7 +40,7 @@ class AsyncTemplateRenderer:
|
|||||||
) as f:
|
) as f:
|
||||||
content = await f.read()
|
content = await f.read()
|
||||||
templates[filename] = content
|
templates[filename] = content
|
||||||
except (OSError, IOError) as e:
|
except (OSError, IOError) as e: # pragma: no cover
|
||||||
_LOGGER.warning("Error reading template file %s: %s", filename, e)
|
_LOGGER.warning("Error reading template file %s: %s", filename, e)
|
||||||
|
|
||||||
async def render_template(self, template_name: str, **kwargs: Any) -> str:
|
async def render_template(self, template_name: str, **kwargs: Any) -> str:
|
||||||
@@ -54,7 +54,9 @@ class AsyncTemplateRenderer:
|
|||||||
if template_name not in templates:
|
if template_name not in templates:
|
||||||
raise ValueError(f"Template '{template_name}' not found.")
|
raise ValueError(f"Template '{template_name}' not found.")
|
||||||
|
|
||||||
env = Environment(loader=DictLoader(templates), enable_async=True)
|
env = Environment(
|
||||||
|
loader=DictLoader(templates), enable_async=True, autoescape=True
|
||||||
|
)
|
||||||
template = env.get_template(template_name)
|
template = env.get_template(template_name)
|
||||||
|
|
||||||
# Render template
|
# Render template
|
||||||
|
|||||||
@@ -6,7 +6,7 @@
|
|||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
<title>{% block title %}{% endblock %}</title>
|
<title>{% block title %}{% endblock %}</title>
|
||||||
<script src="https://cdn.tailwindcss.com"></script>
|
<link rel="stylesheet" href="/auth/oidc/static/style.css">
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,14 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
{% block title %}Done!{% endblock %}
|
||||||
|
{% block head %}
|
||||||
|
{{ super() }}
|
||||||
|
{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
<div class="text-center">
|
||||||
|
<p id="mobile-success-message" class="mb-4">You have successfully logged in on your mobile device. It should continue the login soon. <br/><br/>You have been logged out on this device.</p>
|
||||||
|
<div class="my-6">
|
||||||
|
<a id="restart-login-button" href='/auth/oidc/redirect'
|
||||||
|
class="w-full py-2 px-4 bg-blue-500 text-white font-semibold rounded-lg shadow-md hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-400 focus:ring-opacity-75">Restart</a>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
||||||
@@ -4,28 +4,63 @@
|
|||||||
{{ super() }}
|
{{ super() }}
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
{% block content %}
|
{% block content %}
|
||||||
<div class="text-center">
|
<div>
|
||||||
<div class="my-6">
|
<h1 class="text-2xl font-bold mb-4 text-center">Logged in!</h1>
|
||||||
<h2 class="text-xl font-semibold mb-6 text-gray-800">I want to login to this browser</h2>
|
|
||||||
|
<div class="mb-4 rounded-lg border border-gray-300 bg-gray-50 p-6 text-left">
|
||||||
|
<h2 class="mb-2 text-lg font-semibold text-gray-800">Continue on this device</h2>
|
||||||
|
<p class="mb-4 text-sm text-gray-600">Tap Continue to login to Home Assistant on this device.</p>
|
||||||
<form method="post">
|
<form method="post">
|
||||||
<input type="hidden" name="code" value="{{ code }}">
|
<button
|
||||||
<button type="submit"
|
id="continue-on-this-device"
|
||||||
class="w-full py-2 px-4 bg-blue-500 text-white font-semibold rounded-lg shadow-md hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-400 focus:ring-opacity-75">
|
type="submit"
|
||||||
Login to Home Assistant in this browser
|
class="w-full py-2 px-4 bg-blue-500 text-white font-semibold rounded-lg
|
||||||
|
shadow-md hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-400
|
||||||
|
focus:ring-opacity-75 hover:cursor-pointer"
|
||||||
|
>
|
||||||
|
Continue on this device
|
||||||
</button>
|
</button>
|
||||||
</form>
|
</form>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<hr class="my-12">
|
<div class="rounded-lg border border-gray-300 bg-white p-6 text-left">
|
||||||
|
<div class="mb-4 flex items-center justify-between text-gray-700">
|
||||||
<div class="my-6">
|
<span class="text-lg font-semibold">Use a code from another device</span>
|
||||||
<h2 class="text-xl font-semibold mb-4 text-gray-800">I am on a mobile device</h2>
|
</div>
|
||||||
<p class="mb-4">Your one-time code is: <b class="text-blue-600 text-xl">{{ code }}</b></p>
|
<div class="border-t border-gray-200 pt-4">
|
||||||
<p class="mb-4 text-sm">You have 5 minutes to use this code on any device.<br />The code can only
|
<p class="mb-2 text-sm text-gray-600">On your other device, open the Home Assistant app. You will see a
|
||||||
be used once.</p>
|
6-digit code.</p>
|
||||||
<p class="mb-4 text-sm">Please type the code into your app manually. If you don't see a code input, select
|
<p class="mb-4 text-sm text-gray-600">Input that code here and click Approve to login on the other device.
|
||||||
'Login with
|
</p>
|
||||||
OpenID Connect (SSO)' first.</p>
|
<form method="post">
|
||||||
|
<div>
|
||||||
|
<input
|
||||||
|
type="tel"
|
||||||
|
id="device-code-input"
|
||||||
|
name="device_code"
|
||||||
|
required
|
||||||
|
minlength="6"
|
||||||
|
maxlength="6"
|
||||||
|
pattern="[0-9]{6}"
|
||||||
|
inputmode="numeric"
|
||||||
|
placeholder="123456"
|
||||||
|
class="mb-2 w-full rounded-md border border-gray-300 px-5 py-3 text-center text-base
|
||||||
|
tracking-[0.15em] text-gray-800 focus:outline-none focus:ring-2 focus:ring-blue-400
|
||||||
|
focus:ring-opacity-75"
|
||||||
|
>
|
||||||
|
</div>
|
||||||
|
<button
|
||||||
|
id="approve-login-button"
|
||||||
|
type="submit"
|
||||||
|
class="w-full py-2 px-4 bg-white text-blue-600
|
||||||
|
font-semibold rounded-lg border border-blue-500 shadow-md hover:bg-gray-100
|
||||||
|
hover:text-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-400
|
||||||
|
focus:ring-opacity-75 hover:cursor-pointer"
|
||||||
|
>
|
||||||
|
Approve login on the other device
|
||||||
|
</button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
{% endblock %}
|
{% endblock %}
|
||||||
28
custom_components/auth_oidc/views/templates/redirect.html
Normal file
28
custom_components/auth_oidc/views/templates/redirect.html
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
{% extends "base.html" %}
|
||||||
|
{% block title %}OIDC Redirect{% endblock %}
|
||||||
|
{% block head %}
|
||||||
|
{{ super() }}
|
||||||
|
{% endblock %}
|
||||||
|
{% block content %}
|
||||||
|
<div class="text-center">
|
||||||
|
<div role="status" id="loader" class="items-center justify-center flex">
|
||||||
|
<svg aria-hidden="true" class="w-10 h-10 text-gray-200 animate-spin fill-blue-600" viewBox="0 0 100 101"
|
||||||
|
fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||||
|
<path
|
||||||
|
d="M100 50.5908C100 78.2051 77.6142 100.591 50 100.591C22.3858 100.591 0 78.2051 0 50.5908C0 22.9766 22.3858 0.59082 50 0.59082C77.6142 0.59082 100 22.9766 100 50.5908ZM9.08144 50.5908C9.08144 73.1895 27.4013 91.5094 50 91.5094C72.5987 91.5094 90.9186 73.1895 90.9186 50.5908C90.9186 27.9921 72.5987 9.67226 50 9.67226C27.4013 9.67226 9.08144 27.9921 9.08144 50.5908Z"
|
||||||
|
fill="currentColor" />
|
||||||
|
<path
|
||||||
|
d="M93.9676 39.0409C96.393 38.4038 97.8624 35.9116 97.0079 33.5539C95.2932 28.8227 92.871 24.3692 89.8167 20.348C85.8452 15.1192 80.8826 10.7238 75.2124 7.41289C69.5422 4.10194 63.2754 1.94025 56.7698 1.05124C51.7666 0.367541 46.6976 0.446843 41.7345 1.27873C39.2613 1.69328 37.813 4.19778 38.4501 6.62326C39.0873 9.04874 41.5694 10.4717 44.0505 10.1071C47.8511 9.54855 51.7191 9.52689 55.5402 10.0491C60.8642 10.7766 65.9928 12.5457 70.6331 15.2552C75.2735 17.9648 79.3347 21.5619 82.5849 25.841C84.9175 28.9121 86.7997 32.2913 88.1811 35.8758C89.083 38.2158 91.5421 39.6781 93.9676 39.0409Z"
|
||||||
|
fill="currentFill" />
|
||||||
|
</svg>
|
||||||
|
<span class="sr-only">Redirecting...</span>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<script>
|
||||||
|
// Redirect after loading the page to show the redirect visual
|
||||||
|
setTimeout(() => {
|
||||||
|
auth_url = decodeURIComponent("{{ url }}");
|
||||||
|
window.location.href = auth_url;
|
||||||
|
}, 0);
|
||||||
|
</script>
|
||||||
|
{% endblock %}
|
||||||
@@ -12,41 +12,53 @@
|
|||||||
dashboard</a></p>
|
dashboard</a></p>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<h1 class="text-2xl font-bold mb-4">Home Assistant</h1>
|
{% if code %}
|
||||||
<p class="mb-4">You have been invited to login to Home Assistant.<br />Start the login process below.</p>
|
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
<button id="oidc-login-btn"
|
<p id="device-instructions">Please login to Home Assistant on another device and enter this code when asked:</p>
|
||||||
class="w-full py-2 px-4 bg-blue-500 text-white font-semibold rounded-lg shadow-md hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-blue-400 focus:ring-opacity-75">
|
<div class="mt-4 text-3xl tracking-wide font-bold bg-gray-100 border border-gray-300 rounded-lg py-4 px-6 inline-block" id="device-code">
|
||||||
Login with {{ name }}
|
{{ code }}
|
||||||
</button>
|
|
||||||
|
|
||||||
<div role="status" id="loader" class="items-center justify-center flex hidden">
|
|
||||||
<svg aria-hidden="true" class="w-10 h-10 text-gray-200 animate-spin fill-blue-600" viewBox="0 0 100 101"
|
|
||||||
fill="none" xmlns="http://www.w3.org/2000/svg">
|
|
||||||
<path
|
|
||||||
d="M100 50.5908C100 78.2051 77.6142 100.591 50 100.591C22.3858 100.591 0 78.2051 0 50.5908C0 22.9766 22.3858 0.59082 50 0.59082C77.6142 0.59082 100 22.9766 100 50.5908ZM9.08144 50.5908C9.08144 73.1895 27.4013 91.5094 50 91.5094C72.5987 91.5094 90.9186 73.1895 90.9186 50.5908C90.9186 27.9921 72.5987 9.67226 50 9.67226C27.4013 9.67226 9.08144 27.9921 9.08144 50.5908Z"
|
|
||||||
fill="currentColor" />
|
|
||||||
<path
|
|
||||||
d="M93.9676 39.0409C96.393 38.4038 97.8624 35.9116 97.0079 33.5539C95.2932 28.8227 92.871 24.3692 89.8167 20.348C85.8452 15.1192 80.8826 10.7238 75.2124 7.41289C69.5422 4.10194 63.2754 1.94025 56.7698 1.05124C51.7666 0.367541 46.6976 0.446843 41.7345 1.27873C39.2613 1.69328 37.813 4.19778 38.4501 6.62326C39.0873 9.04874 41.5694 10.4717 44.0505 10.1071C47.8511 9.54855 51.7191 9.52689 55.5402 10.0491C60.8642 10.7766 65.9928 12.5457 70.6331 15.2552C75.2735 17.9648 79.3347 21.5619 82.5849 25.841C84.9175 28.9121 86.7997 32.2913 88.1811 35.8758C89.083 38.2158 91.5421 39.6781 93.9676 39.0409Z"
|
|
||||||
fill="currentFill" />
|
|
||||||
</svg>
|
|
||||||
<span class="sr-only">Redirecting...</span>
|
|
||||||
</div>
|
</div>
|
||||||
|
<p class="mt-4 text-sm text-gray-600">
|
||||||
|
The login will continue automatically when you complete the login on your other device. Please keep the app open.
|
||||||
|
</p>
|
||||||
</div>
|
</div>
|
||||||
|
<script>
|
||||||
|
const source = new EventSource('/auth/oidc/device-sse');
|
||||||
|
|
||||||
<p class="mt-6 text-sm">After login, you will be granted a one-time code to login to any device. You may complete
|
source.addEventListener('ready', function () {
|
||||||
this login on your desktop or any mobile browser and then use the token for any desktop or the Home Assistant
|
source.close();
|
||||||
app.</p>
|
|
||||||
</div>
|
// Perform a POST request to the finish endpoint to complete the login.
|
||||||
<script>
|
const form = document.createElement('form');
|
||||||
// Hide the login button and show the loader when clicked
|
form.method = 'POST';
|
||||||
document.getElementById('oidc-login-btn').addEventListener('click', function () {
|
form.action = '/auth/oidc/finish';
|
||||||
this.classList.add('hidden');
|
document.body.appendChild(form);
|
||||||
document.getElementById('loader').classList.remove('hidden');
|
form.submit();
|
||||||
window.location.href = '/auth/oidc/redirect';
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
source.addEventListener('error', function () {
|
||||||
|
source.close();
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
{% else %}
|
||||||
|
<div>
|
||||||
|
<a id="login-button" href="/auth/oidc/redirect" class="
|
||||||
|
w-full py-2 px-4 bg-blue-500 text-white font-semibold rounded-lg shadow-md hover:bg-blue-700
|
||||||
|
focus:outline-none focus:ring-2 focus:ring-blue-400 focus:ring-opacity-75
|
||||||
|
hover:cursor-pointer
|
||||||
|
">
|
||||||
|
Login with {{ name }}
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if other_link %}
|
||||||
|
<p class=" mt-4 text-sm text-center">
|
||||||
|
<a id="alternative-sign-in-link" href="{{ other_link }}" class="text-gray-600 hover:underline">Use alternative sign-in method</a>
|
||||||
|
</p>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
|
<script>
|
||||||
// Show the direct login button if we already have a token
|
// Show the direct login button if we already have a token
|
||||||
if (localStorage.getItem('hassTokens')) {
|
if (localStorage.getItem('hassTokens')) {
|
||||||
document.getElementById('signed-in').classList.remove('hidden');
|
document.getElementById('signed-in').classList.remove('hidden');
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ Here are some documentation links for specific providers that you may want to fo
|
|||||||
* [Pocket ID](./provider-configurations/pocket-id.md)
|
* [Pocket ID](./provider-configurations/pocket-id.md)
|
||||||
* [Kanidm](./provider-configurations/kanidm.md)
|
* [Kanidm](./provider-configurations/kanidm.md)
|
||||||
* [Microsoft Entra ID](./provider-configurations/microsoft-entra.md)
|
* [Microsoft Entra ID](./provider-configurations/microsoft-entra.md)
|
||||||
|
* [Zitadel](./provider-configurations/zitadel.md)
|
||||||
|
|
||||||
_Missing a provider? Submit your guide using a PR._
|
_Missing a provider? Submit your guide using a PR._
|
||||||
|
|
||||||
@@ -74,6 +75,28 @@ auth_oidc:
|
|||||||
|
|
||||||
This will show the provider on the login screen as: "Login with Example".
|
This will show the provider on the login screen as: "Login with Example".
|
||||||
|
|
||||||
|
### Forcing HTTPS
|
||||||
|
First check if you are setting the header `X-Forwarded-Proto` in your proxy and if the [proxy settings for Home Assistant](https://www.home-assistant.io/integrations/http/#use_x_forwarded_for) are configured correctly. You should also check if IP addresses in your logs actually match the origin IP (instead of proxy IP). If you cannot find any mistakes, you may use the following config option to force HTTPS regardless:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
auth_oidc:
|
||||||
|
features:
|
||||||
|
force_https: true
|
||||||
|
```
|
||||||
|
|
||||||
|
### Disabling registration for new users
|
||||||
|
This integration does not allow disabling registration for new users, as there is no way to abort registration that late in the process while providing a good user experience.
|
||||||
|
You can however set both roles to groups that only contain certain users or to a non-existant group.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
auth_oidc:
|
||||||
|
roles:
|
||||||
|
user: "non_existent"
|
||||||
|
admin: "admins"
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that if you put both on non-existent groups, no users will be able to login.
|
||||||
|
|
||||||
### Migrating from HA username/password users to OIDC users
|
### Migrating from HA username/password users to OIDC users
|
||||||
If you already have users created within Home Assistant and would like to re-use the current user profile for your OIDC login, you can (temporarily) enable `features.automatic_user_linking`, with the following config (example):
|
If you already have users created within Home Assistant and would like to re-use the current user profile for your OIDC login, you can (temporarily) enable `features.automatic_user_linking`, with the following config (example):
|
||||||
|
|
||||||
@@ -93,6 +116,8 @@ Upon login, OIDC users will then automatically be linked to the HA user with the
|
|||||||
> [!CAUTION]
|
> [!CAUTION]
|
||||||
> MFA is ignored when using this setting, thus bypassing any MFA configuration the user has originally configured, as long as the username is an exact match. This is dangerous if you are not aware of it!
|
> MFA is ignored when using this setting, thus bypassing any MFA configuration the user has originally configured, as long as the username is an exact match. This is dangerous if you are not aware of it!
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### Using a private certificate authority
|
### Using a private certificate authority
|
||||||
If you use a private certificate authority to secure your OIDC provider, you must configure the root certificates of your private certificate authority. Otherwise you will get an error (`[SSL: CERTIFICATE_VERIFY_FAILED]`) when connecting to the OIDC provider.
|
If you use a private certificate authority to secure your OIDC provider, you must configure the root certificates of your private certificate authority. Otherwise you will get an error (`[SSL: CERTIFICATE_VERIFY_FAILED]`) when connecting to the OIDC provider.
|
||||||
|
|
||||||
@@ -132,6 +157,8 @@ Here's a table of all options that you can set:
|
|||||||
| `features.automatic_person_creation` | `boolean` | No | `true` | Automatically creates a person entry for new user profiles created by this integration. Recommended if you would like to assign presence detection to OIDC users. |
|
| `features.automatic_person_creation` | `boolean` | No | `true` | Automatically creates a person entry for new user profiles created by this integration. Recommended if you would like to assign presence detection to OIDC users. |
|
||||||
| `features.disable_rfc7636` | `boolean`| No | `false` | Disables PKCE (RFC 7636) for OIDC providers that don't support it. You should not need this with most providers. |
|
| `features.disable_rfc7636` | `boolean`| No | `false` | Disables PKCE (RFC 7636) for OIDC providers that don't support it. You should not need this with most providers. |
|
||||||
| `features.include_groups_scope` | `boolean` | No | `true` | Include the 'groups' scope in the OIDC request. Set to `false` to exclude it. |
|
| `features.include_groups_scope` | `boolean` | No | `true` | Include the 'groups' scope in the OIDC request. Set to `false` to exclude it. |
|
||||||
|
| `features.disable_frontend_changes` | `boolean` | No | `false` | Set to `true` to disable all changes made to the HA frontend for better compatbility with future HA versions, or if you are not comfortable with injecting Javascript into the existing frontend code. |
|
||||||
|
| `features.force_https` | `boolean` | No | `false` | Set to `true` to force all URLs generated to use `https` instead of automatically determining based on the request scheme or `X-Forwarded-Proto`. |
|
||||||
| `claims.display_name` | `string` | No | `name` | The claim to use to obtain the display name.
|
| `claims.display_name` | `string` | No | `name` | The claim to use to obtain the display name.
|
||||||
| `claims.username` | `string` | No | `preferred_username` | The claim to use to obtain the username.
|
| `claims.username` | `string` | No | `preferred_username` | The claim to use to obtain the username.
|
||||||
| `claims.groups` | `string` | No | `groups` | The claim to use to obtain the user's group(s). |
|
| `claims.groups` | `string` | No | `groups` | The claim to use to obtain the user's group(s). |
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Microsoft Entra ID
|
# Microsoft Entra ID
|
||||||
> [!WARNING]
|
> [!WARNING]
|
||||||
> Microsoft Entra ID does not support public clients that are not Single Page Applications (SPA's). Therefore, you will have to use a client secret.
|
> Microsoft Entra ID does not support public clients that are not Single Page Applications (SPA's). Therefore, you will have to use a client secret.
|
||||||
|
## Basic configuration
|
||||||
1. Go to app registrations in Entra ID.
|
1. Go to app registrations in Entra ID.
|
||||||
2. Create a new app, use the "Web" type for the redirect URI and fill in your URL: `<ha url>/auth/oidc/callback`. Note that you either have to use localhost, or HTTPS.
|
2. Create a new app, use the "Web" type for the redirect URI and fill in your URL: `<ha url>/auth/oidc/callback`. Note that you either have to use localhost, or HTTPS.
|
||||||
3. Copy the 'Application (client) ID' on the overview page of your app and use it as your `client_id`.
|
3. Copy the 'Application (client) ID' on the overview page of your app and use it as your `client_id`.
|
||||||
@@ -25,3 +25,27 @@ auth_oidc:
|
|||||||
|
|
||||||
> [!CAUTION]
|
> [!CAUTION]
|
||||||
> Be careful! Configuring Entra ID wrong may leave your Home Assistant install open for anyone with a Microsoft account. Please use "Single tenant" account types only. Do not enable "Accounts in any organizational directory (Any Microsoft Entra ID tenant - Multitenant)" or personal account modes without enabling the mode to only allow specific accounts first!
|
> Be careful! Configuring Entra ID wrong may leave your Home Assistant install open for anyone with a Microsoft account. Please use "Single tenant" account types only. Do not enable "Accounts in any organizational directory (Any Microsoft Entra ID tenant - Multitenant)" or personal account modes without enabling the mode to only allow specific accounts first!
|
||||||
|
|
||||||
|
## Configuring user roles
|
||||||
|
If you like to configure the Home Assistant users roles based on your Entra ID settings, you have to create 2 roles within your Entra ID app registration.
|
||||||
|
Go to "App registrations" and select app roles. Create two new roles for admins and users, giving them sensible names and values (the example uses `users` and `admins`), that you will need later in your HA configuration.
|
||||||
|
|
||||||
|
<img width="1205" height="965" alt="Entra-HA-Roles" src="https://github.com/user-attachments/assets/568a1526-0607-4f88-945f-7c4f1fcc0ac2" />
|
||||||
|
|
||||||
|
Then you need to create the users and assign them a role of your choice.
|
||||||
|
Go to "Enterprise apps" chose your app registration again and select "Users and groups" within the manage section. Add users, or groups from your tenant or AD-sync and assign them a role, from the ones you created before.
|
||||||
|
|
||||||
|
<img width="1112" height="570" alt="Entra-HA-Users" src="https://github.com/user-attachments/assets/13a49cee-798b-4b53-8fee-d2792ccd7763" />
|
||||||
|
|
||||||
|
Last thing to do is to include
|
||||||
|
```
|
||||||
|
claims:
|
||||||
|
groups: "roles"
|
||||||
|
roles:
|
||||||
|
admin: "admins"
|
||||||
|
user: "users"
|
||||||
|
```
|
||||||
|
in your auth_oidc config, where the roles values correspond to the ones you chose in your Entra ID roles.
|
||||||
|
Make sure, you keep the "include_groups_scope: False" from the basic configuration, as the claim needed for Entra ID is "roles".
|
||||||
|
|
||||||
|
Newly created users will get the role assigned in Entra ID, but there is no update to user roles. A user created with user role in HA will not get the admin role, if you change the assignment later on in Entra ID.
|
||||||
|
|||||||
27
docs/provider-configurations/zitadel.md
Normal file
27
docs/provider-configurations/zitadel.md
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# Zitadel
|
||||||
|
|
||||||
|
## Zitadel configuration
|
||||||
|
|
||||||
|
1. From the Zitadel home screen, go to `Projects` and click `Create New Project`
|
||||||
|
2. Enter "Home Assistant" or your preferred name
|
||||||
|
3. Click on `New` to create a new Application
|
||||||
|
4. Enter "Home Assistant" or your preferred name
|
||||||
|
5. Select `Web` and `Continue`
|
||||||
|
6. Select `CODE` (not `PKCE`) and `Continue`
|
||||||
|
7. Enter https://hass.example.com/auth/oidc/callback as the Redirect URI, and click `Continue`
|
||||||
|
8. Click `Create`. A pop-up will dispay the `ClientId` and `ClientSecret`
|
||||||
|
|
||||||
|
## Home Assistant configuration
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> For HTTPS configuration make sure to have a public valid SSL certificate (i.e. LetsEncrypt), if not, use HTTP instead (more insecure) or add your Zitadel CA certificate to `network.tls_ca_path`.
|
||||||
|
|
||||||
|
After installing this HACS addon, edit your `configuration.yaml` file and add:
|
||||||
|
```yaml
|
||||||
|
auth_oidc:
|
||||||
|
client_id: <ClientID from above>
|
||||||
|
client_secret: <ClientSecret from above>
|
||||||
|
discovery_url: "https://auth.example.com/.well-known/openid-configuration"
|
||||||
|
```
|
||||||
|
|
||||||
|
Restart Home Assistant and go to https://hass.example.com/auth/oidc/welcome
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
{
|
{
|
||||||
"name": "OpenID Connect",
|
"name": "OpenID Connect/SSO Authentication",
|
||||||
"hide_default_branch": true,
|
"hide_default_branch": true,
|
||||||
"render_readme": true,
|
"render_readme": true,
|
||||||
"homeassistant": "2024.12"
|
"homeassistant": "2025.11",
|
||||||
|
"zip_release": true,
|
||||||
|
"filename": "hass-oidc-auth.zip"
|
||||||
}
|
}
|
||||||
1092
package-lock.json
generated
Normal file
1092
package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load Diff
11
package.json
Normal file
11
package.json
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
{
|
||||||
|
"name": "hass-oidc-auth",
|
||||||
|
"scripts": {
|
||||||
|
"css": "tailwindcss -i ./custom_components/auth_oidc/static/input.css -o ./custom_components/auth_oidc/static/style.css --minify",
|
||||||
|
"css:watch": "tailwindcss -i ./custom_components/auth_oidc/static/input.css -o ./custom_components/auth_oidc/static/style.css --watch --minify"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"@tailwindcss/cli": "^4.1.14",
|
||||||
|
"tailwindcss": "^4.1.14"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,29 +1,42 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "hass-oidc-auth"
|
name = "hass-oidc-auth"
|
||||||
version = "0.6.2"
|
version = "1.0.0"
|
||||||
description = "OIDC component for Home Assistant"
|
description = "OIDC component for Home Assistant"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "Christiaan Goossens", email = "contact@christiaangoossens.nl" }
|
{ name = "Christiaan Goossens", email = "contact@christiaangoossens.nl" }
|
||||||
]
|
]
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"python-jose>=3.3.0",
|
"aiofiles~=25.1",
|
||||||
"aiofiles>=24.1.0",
|
"jinja2~=3.1",
|
||||||
"jinja2>=3.1.4",
|
"joserfc~=1.6.0",
|
||||||
"bcrypt>=4.2.0",
|
|
||||||
]
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">= 3.13"
|
requires-python = "~=3.14.4"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"homeassistant~=2026.4",
|
||||||
|
"pylint~=4.0",
|
||||||
|
"pytest~=9.0.0",
|
||||||
|
"pytest-asyncio~=1.3.0",
|
||||||
|
"pytest-cov~=7.0.0",
|
||||||
|
"pytest-homeassistant-custom-component~=0.13.308",
|
||||||
|
"ruff~=0.12",
|
||||||
|
]
|
||||||
|
|
||||||
[build-system]
|
[build-system]
|
||||||
requires = ["hatchling"]
|
requires = ["hatchling"]
|
||||||
build-backend = "hatchling.build"
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
[tool.rye]
|
[tool.uv]
|
||||||
managed = true
|
managed = true
|
||||||
dev-dependencies = [
|
override-dependencies = [
|
||||||
"homeassistant~=2024.12",
|
"orjson>=3.11.8,<3.12.0",
|
||||||
"pylint~=3.3",
|
"pyjwt>=2.12.0,<2.13.0",
|
||||||
|
"pillow>=12.2.0,<12.3.0",
|
||||||
|
"pytest>=9.0.3,<9.1.0",
|
||||||
|
"uv>=0.11.6,<0.12.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.hatch.metadata]
|
[tool.hatch.metadata]
|
||||||
@@ -32,11 +45,10 @@ allow-direct-references = true
|
|||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["custom_components/auth_oidc"]
|
packages = ["custom_components/auth_oidc"]
|
||||||
|
|
||||||
[tool.rye.scripts]
|
[tool.pytest.ini_options]
|
||||||
check = { chain = ["check-lint", "check-fmt", "check-pylint" ] }
|
asyncio_mode = "auto"
|
||||||
"check-lint" = "rye lint"
|
addopts = "--cov=custom_components --cov-fail-under=0"
|
||||||
"check-fmt" = "rye fmt --check"
|
log_level = "DEBUG"
|
||||||
"check-pylint" = "pylint custom_components"
|
|
||||||
fix = { chain = ["fix-lint", "fix-fmt" ] }
|
[tool.ruff]
|
||||||
"fix-lint" = "rye lint --fix"
|
target-version = "py313"
|
||||||
"fix-fmt" = "rye fmt"
|
|
||||||
@@ -14,11 +14,16 @@
|
|||||||
],
|
],
|
||||||
"prCreation": "immediate"
|
"prCreation": "immediate"
|
||||||
},
|
},
|
||||||
|
"lockFileMaintenance": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
"packageRules": [
|
"packageRules": [
|
||||||
{
|
{
|
||||||
"description": "Group all GitHub Actions updates",
|
"description": "Group all GitHub Actions updates",
|
||||||
"matchDatasources": [
|
"matchDatasources": [
|
||||||
"github-actions"
|
"github-actions",
|
||||||
|
"github-tags",
|
||||||
|
"github-runners"
|
||||||
],
|
],
|
||||||
"groupName": "Github Actions Updates",
|
"groupName": "Github Actions Updates",
|
||||||
"automerge": true
|
"automerge": true
|
||||||
@@ -34,7 +39,7 @@
|
|||||||
"automerge": false
|
"automerge": false
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"description": "Version updates for other pip packages",
|
"description": "Version updates for other Python packages",
|
||||||
"matchDatasources": [
|
"matchDatasources": [
|
||||||
"pypi"
|
"pypi"
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,286 +0,0 @@
|
|||||||
# generated by rye
|
|
||||||
# use `rye lock` or `rye sync` to update this lockfile
|
|
||||||
#
|
|
||||||
# last locked with the following flags:
|
|
||||||
# pre: false
|
|
||||||
# features: []
|
|
||||||
# all-features: false
|
|
||||||
# with-sources: false
|
|
||||||
# generate-hashes: false
|
|
||||||
# universal: false
|
|
||||||
|
|
||||||
-e file:.
|
|
||||||
acme==3.0.1
|
|
||||||
# via hass-nabucasa
|
|
||||||
aiodns==3.2.0
|
|
||||||
# via homeassistant
|
|
||||||
aiofiles==24.1.0
|
|
||||||
# via hass-oidc-auth
|
|
||||||
aiohappyeyeballs==2.4.4
|
|
||||||
# via aiohttp
|
|
||||||
aiohasupervisor==0.2.1
|
|
||||||
# via homeassistant
|
|
||||||
aiohttp==3.11.11
|
|
||||||
# via aiohasupervisor
|
|
||||||
# via aiohttp-cors
|
|
||||||
# via aiohttp-fast-zlib
|
|
||||||
# via hass-nabucasa
|
|
||||||
# via homeassistant
|
|
||||||
# via snitun
|
|
||||||
aiohttp-cors==0.7.0
|
|
||||||
# via homeassistant
|
|
||||||
aiohttp-fast-zlib==0.2.0
|
|
||||||
# via homeassistant
|
|
||||||
aiooui==0.1.7
|
|
||||||
# via bluetooth-adapters
|
|
||||||
aiosignal==1.3.2
|
|
||||||
# via aiohttp
|
|
||||||
aiozoneinfo==0.2.1
|
|
||||||
# via homeassistant
|
|
||||||
anyio==4.7.0
|
|
||||||
# via httpx
|
|
||||||
astral==2.2
|
|
||||||
# via homeassistant
|
|
||||||
astroid==3.3.8
|
|
||||||
# via pylint
|
|
||||||
async-interrupt==1.2.0
|
|
||||||
# via habluetooth
|
|
||||||
# via homeassistant
|
|
||||||
async-timeout==5.0.1
|
|
||||||
# via snitun
|
|
||||||
atomicwrites-homeassistant==1.4.1
|
|
||||||
# via hass-nabucasa
|
|
||||||
# via homeassistant
|
|
||||||
attrs==24.2.0
|
|
||||||
# via aiohttp
|
|
||||||
# via hass-nabucasa
|
|
||||||
# via homeassistant
|
|
||||||
# via snitun
|
|
||||||
audioop-lts==0.2.1
|
|
||||||
# via homeassistant
|
|
||||||
# via standard-aifc
|
|
||||||
awesomeversion==24.6.0
|
|
||||||
# via homeassistant
|
|
||||||
bcrypt==4.2.0
|
|
||||||
# via hass-oidc-auth
|
|
||||||
# via homeassistant
|
|
||||||
bleak==0.22.3
|
|
||||||
# via bleak-retry-connector
|
|
||||||
# via bluetooth-adapters
|
|
||||||
# via habluetooth
|
|
||||||
bleak-retry-connector==3.6.0
|
|
||||||
# via habluetooth
|
|
||||||
bluetooth-adapters==0.20.2
|
|
||||||
# via bleak-retry-connector
|
|
||||||
# via bluetooth-auto-recovery
|
|
||||||
# via habluetooth
|
|
||||||
bluetooth-auto-recovery==1.4.2
|
|
||||||
# via habluetooth
|
|
||||||
bluetooth-data-tools==1.20.0
|
|
||||||
# via habluetooth
|
|
||||||
boto3==1.35.87
|
|
||||||
# via pycognito
|
|
||||||
botocore==1.35.87
|
|
||||||
# via boto3
|
|
||||||
# via s3transfer
|
|
||||||
btsocket==0.3.0
|
|
||||||
# via bluetooth-auto-recovery
|
|
||||||
certifi==2024.12.14
|
|
||||||
# via homeassistant
|
|
||||||
# via httpcore
|
|
||||||
# via httpx
|
|
||||||
# via requests
|
|
||||||
cffi==1.17.1
|
|
||||||
# via cryptography
|
|
||||||
# via pycares
|
|
||||||
charset-normalizer==3.4.0
|
|
||||||
# via requests
|
|
||||||
ciso8601==2.3.1
|
|
||||||
# via hass-nabucasa
|
|
||||||
# via homeassistant
|
|
||||||
cryptography==43.0.1
|
|
||||||
# via acme
|
|
||||||
# via bluetooth-data-tools
|
|
||||||
# via hass-nabucasa
|
|
||||||
# via homeassistant
|
|
||||||
# via josepy
|
|
||||||
# via pyjwt
|
|
||||||
# via pyopenssl
|
|
||||||
# via securetar
|
|
||||||
# via snitun
|
|
||||||
dbus-fast==2.24.4
|
|
||||||
# via bleak
|
|
||||||
# via bleak-retry-connector
|
|
||||||
# via bluetooth-adapters
|
|
||||||
dill==0.3.9
|
|
||||||
# via pylint
|
|
||||||
ecdsa==0.19.0
|
|
||||||
# via python-jose
|
|
||||||
envs==1.4
|
|
||||||
# via pycognito
|
|
||||||
fnv-hash-fast==1.0.2
|
|
||||||
# via homeassistant
|
|
||||||
fnvhash==0.1.0
|
|
||||||
# via fnv-hash-fast
|
|
||||||
frozenlist==1.5.0
|
|
||||||
# via aiohttp
|
|
||||||
# via aiosignal
|
|
||||||
h11==0.14.0
|
|
||||||
# via httpcore
|
|
||||||
habluetooth==3.6.0
|
|
||||||
# via home-assistant-bluetooth
|
|
||||||
hass-nabucasa==0.86.0
|
|
||||||
# via homeassistant
|
|
||||||
home-assistant-bluetooth==1.13.0
|
|
||||||
# via homeassistant
|
|
||||||
homeassistant==2024.12.5
|
|
||||||
httpcore==1.0.7
|
|
||||||
# via httpx
|
|
||||||
httpx==0.27.2
|
|
||||||
# via homeassistant
|
|
||||||
idna==3.10
|
|
||||||
# via anyio
|
|
||||||
# via httpx
|
|
||||||
# via requests
|
|
||||||
# via yarl
|
|
||||||
ifaddr==0.2.0
|
|
||||||
# via homeassistant
|
|
||||||
isort==5.13.2
|
|
||||||
# via pylint
|
|
||||||
jinja2==3.1.4
|
|
||||||
# via hass-oidc-auth
|
|
||||||
# via homeassistant
|
|
||||||
jmespath==1.0.1
|
|
||||||
# via boto3
|
|
||||||
# via botocore
|
|
||||||
josepy==1.14.0
|
|
||||||
# via acme
|
|
||||||
lru-dict==1.3.0
|
|
||||||
# via homeassistant
|
|
||||||
markupsafe==3.0.2
|
|
||||||
# via jinja2
|
|
||||||
mashumaro==3.15
|
|
||||||
# via aiohasupervisor
|
|
||||||
# via webrtc-models
|
|
||||||
mccabe==0.7.0
|
|
||||||
# via pylint
|
|
||||||
multidict==6.1.0
|
|
||||||
# via aiohttp
|
|
||||||
# via yarl
|
|
||||||
orjson==3.10.12
|
|
||||||
# via aiohasupervisor
|
|
||||||
# via homeassistant
|
|
||||||
# via webrtc-models
|
|
||||||
packaging==24.2
|
|
||||||
# via homeassistant
|
|
||||||
pillow==11.0.0
|
|
||||||
# via homeassistant
|
|
||||||
platformdirs==4.3.6
|
|
||||||
# via pylint
|
|
||||||
propcache==0.2.1
|
|
||||||
# via aiohttp
|
|
||||||
# via homeassistant
|
|
||||||
# via yarl
|
|
||||||
psutil==6.1.1
|
|
||||||
# via psutil-home-assistant
|
|
||||||
psutil-home-assistant==0.0.1
|
|
||||||
# via homeassistant
|
|
||||||
pyasn1==0.6.1
|
|
||||||
# via python-jose
|
|
||||||
# via rsa
|
|
||||||
pycares==4.5.0
|
|
||||||
# via aiodns
|
|
||||||
pycognito==2024.5.1
|
|
||||||
# via hass-nabucasa
|
|
||||||
pycparser==2.22
|
|
||||||
# via cffi
|
|
||||||
pyjwt==2.10.1
|
|
||||||
# via hass-nabucasa
|
|
||||||
# via homeassistant
|
|
||||||
# via pycognito
|
|
||||||
pylint==3.3.3
|
|
||||||
pyopenssl==24.2.1
|
|
||||||
# via acme
|
|
||||||
# via homeassistant
|
|
||||||
# via josepy
|
|
||||||
pyrfc3339==2.0.1
|
|
||||||
# via acme
|
|
||||||
pyric==0.1.6.3
|
|
||||||
# via bluetooth-auto-recovery
|
|
||||||
python-dateutil==2.9.0.post0
|
|
||||||
# via botocore
|
|
||||||
python-jose==3.3.0
|
|
||||||
# via hass-oidc-auth
|
|
||||||
python-slugify==8.0.4
|
|
||||||
# via homeassistant
|
|
||||||
pytz==2024.2
|
|
||||||
# via acme
|
|
||||||
# via astral
|
|
||||||
pyyaml==6.0.2
|
|
||||||
# via homeassistant
|
|
||||||
requests==2.32.3
|
|
||||||
# via acme
|
|
||||||
# via homeassistant
|
|
||||||
# via pycognito
|
|
||||||
rsa==4.9
|
|
||||||
# via python-jose
|
|
||||||
s3transfer==0.10.4
|
|
||||||
# via boto3
|
|
||||||
securetar==2024.11.0
|
|
||||||
# via homeassistant
|
|
||||||
setuptools==75.6.0
|
|
||||||
# via acme
|
|
||||||
six==1.17.0
|
|
||||||
# via ecdsa
|
|
||||||
# via python-dateutil
|
|
||||||
sniffio==1.3.1
|
|
||||||
# via anyio
|
|
||||||
# via httpx
|
|
||||||
snitun==0.39.1
|
|
||||||
# via hass-nabucasa
|
|
||||||
sqlalchemy==2.0.36
|
|
||||||
# via homeassistant
|
|
||||||
standard-aifc==3.13.0
|
|
||||||
# via homeassistant
|
|
||||||
standard-chunk==3.13.0
|
|
||||||
# via standard-aifc
|
|
||||||
standard-telnetlib==3.13.0
|
|
||||||
# via homeassistant
|
|
||||||
text-unidecode==1.3
|
|
||||||
# via python-slugify
|
|
||||||
tomlkit==0.13.2
|
|
||||||
# via pylint
|
|
||||||
typing-extensions==4.12.2
|
|
||||||
# via homeassistant
|
|
||||||
# via mashumaro
|
|
||||||
# via sqlalchemy
|
|
||||||
tzdata==2024.2
|
|
||||||
# via aiozoneinfo
|
|
||||||
uart-devices==0.1.0
|
|
||||||
# via bluetooth-adapters
|
|
||||||
ulid-transform==1.0.2
|
|
||||||
# via homeassistant
|
|
||||||
urllib3==1.26.20
|
|
||||||
# via botocore
|
|
||||||
# via homeassistant
|
|
||||||
# via requests
|
|
||||||
usb-devices==0.4.5
|
|
||||||
# via bluetooth-adapters
|
|
||||||
# via bluetooth-auto-recovery
|
|
||||||
uv==0.5.4
|
|
||||||
# via homeassistant
|
|
||||||
voluptuous==0.15.2
|
|
||||||
# via homeassistant
|
|
||||||
# via voluptuous-openapi
|
|
||||||
# via voluptuous-serialize
|
|
||||||
voluptuous-openapi==0.0.5
|
|
||||||
# via homeassistant
|
|
||||||
voluptuous-serialize==2.6.0
|
|
||||||
# via homeassistant
|
|
||||||
webrtc-models==0.3.0
|
|
||||||
# via hass-nabucasa
|
|
||||||
# via homeassistant
|
|
||||||
yarl==1.18.3
|
|
||||||
# via aiohasupervisor
|
|
||||||
# via aiohttp
|
|
||||||
# via homeassistant
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
# generated by rye
|
|
||||||
# use `rye lock` or `rye sync` to update this lockfile
|
|
||||||
#
|
|
||||||
# last locked with the following flags:
|
|
||||||
# pre: false
|
|
||||||
# features: []
|
|
||||||
# all-features: false
|
|
||||||
# with-sources: false
|
|
||||||
# generate-hashes: false
|
|
||||||
# universal: false
|
|
||||||
|
|
||||||
-e file:.
|
|
||||||
aiofiles==24.1.0
|
|
||||||
# via hass-oidc-auth
|
|
||||||
bcrypt==4.2.1
|
|
||||||
# via hass-oidc-auth
|
|
||||||
ecdsa==0.19.0
|
|
||||||
# via python-jose
|
|
||||||
jinja2==3.1.5
|
|
||||||
# via hass-oidc-auth
|
|
||||||
markupsafe==3.0.2
|
|
||||||
# via jinja2
|
|
||||||
pyasn1==0.6.1
|
|
||||||
# via python-jose
|
|
||||||
# via rsa
|
|
||||||
python-jose==3.3.0
|
|
||||||
# via hass-oidc-auth
|
|
||||||
rsa==4.9
|
|
||||||
# via python-jose
|
|
||||||
six==1.17.0
|
|
||||||
# via ecdsa
|
|
||||||
10
scripts/build
Executable file
10
scripts/build
Executable file
@@ -0,0 +1,10 @@
|
|||||||
|
#! /bin/bash
|
||||||
|
|
||||||
|
# Build the plugin CSS
|
||||||
|
npm install --frozen-lockfile
|
||||||
|
npm run css
|
||||||
|
|
||||||
|
# Create zip from the custom_components/auth_oidc directory
|
||||||
|
# HACS wants only the contents of this dir in a zip
|
||||||
|
cd custom_components/auth_oidc/
|
||||||
|
zip -r ../../hass-oidc-auth.zip ./*
|
||||||
4
scripts/check
Executable file
4
scripts/check
Executable file
@@ -0,0 +1,4 @@
|
|||||||
|
#! /bin/bash
|
||||||
|
uv run ruff check
|
||||||
|
uv run ruff format --check
|
||||||
|
uv run pylint custom_components --allow-reexport-from-package true
|
||||||
3
scripts/coverage-report
Executable file
3
scripts/coverage-report
Executable file
@@ -0,0 +1,3 @@
|
|||||||
|
#! /bin/bash
|
||||||
|
uv run pytest --cov-report html tests/
|
||||||
|
uv run python -m http.server 8000 -d htmlcov
|
||||||
3
scripts/fix
Executable file
3
scripts/fix
Executable file
@@ -0,0 +1,3 @@
|
|||||||
|
#! /bin/bash
|
||||||
|
uv run ruff check --fix
|
||||||
|
uv run ruff format
|
||||||
2
scripts/security-check
Executable file
2
scripts/security-check
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
#! /bin/bash
|
||||||
|
uvx pysentry-rs .
|
||||||
2
scripts/sync
Executable file
2
scripts/sync
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
#! /bin/bash
|
||||||
|
uv sync --locked
|
||||||
2
scripts/test
Executable file
2
scripts/test
Executable file
@@ -0,0 +1,2 @@
|
|||||||
|
#! /bin/bash
|
||||||
|
uv run pytest --cov-report term:skip-covered tests/
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
8
tests/conftest.py
Normal file
8
tests/conftest.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""Fixtures for testing."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def auto_enable_custom_integrations(enable_custom_integrations):
|
||||||
|
yield
|
||||||
0
tests/mocks/__init__.py
Normal file
0
tests/mocks/__init__.py
Normal file
14
tests/mocks/auth_page.html
Normal file
14
tests/mocks/auth_page.html
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Test</title>
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
Test page
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
197
tests/mocks/oidc_server.py
Normal file
197
tests/mocks/oidc_server.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""A simple mock OIDC server for testing purposes."""
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
import time
|
||||||
|
import logging
|
||||||
|
import hashlib
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from urllib.parse import urlparse, parse_qs
|
||||||
|
from joserfc import jwt
|
||||||
|
from joserfc.jwk import RSAKey, KeySet
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
BASE_URL = "https://oidc.example.com"
|
||||||
|
SUBJECT = "testuser"
|
||||||
|
|
||||||
|
|
||||||
|
class MockOIDCServer:
|
||||||
|
"""A simple mock OIDC server for testing purposes."""
|
||||||
|
|
||||||
|
_code_storage = {}
|
||||||
|
_scenario = {}
|
||||||
|
|
||||||
|
def __init__(self, scenario: str | None = None):
|
||||||
|
"""Initialize the mock OIDC server."""
|
||||||
|
# Create a JWK private key
|
||||||
|
self._jwk = RSAKey.generate_key(
|
||||||
|
2048, {"alg": "RS256", "use": "sig"}, private=True, auto_kid=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if scenario:
|
||||||
|
# Load scenario JSON file from disk
|
||||||
|
scenario_path = os.path.join(
|
||||||
|
os.path.dirname(__file__), "scenarios", f"{scenario}.json"
|
||||||
|
)
|
||||||
|
with open(scenario_path, "r", encoding="utf-8") as f:
|
||||||
|
self._scenario = json.load(f)
|
||||||
|
|
||||||
|
# Log it
|
||||||
|
_LOGGER.debug("Loaded scenario: %s", self._scenario)
|
||||||
|
|
||||||
|
def get_random_code(self):
|
||||||
|
"""Return a random authorization code."""
|
||||||
|
return "".join(str(random.randint(0, 9)) for _ in range(6))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_discovery_url():
|
||||||
|
"""Return the discovery URL for the given base URL."""
|
||||||
|
return f"{BASE_URL}/.well-known/openid-configuration"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_authorize_url():
|
||||||
|
"""Return the authorization URL for the given base URL."""
|
||||||
|
return f"{BASE_URL}/authorize"
|
||||||
|
|
||||||
|
def process_request(self, url: str, method: str, body: dict) -> tuple[dict, int]:
|
||||||
|
"""Process a request to the mock OIDC server."""
|
||||||
|
_LOGGER.debug("Received %s request to %s in OIDC mock server", method, url)
|
||||||
|
|
||||||
|
if url == self.get_discovery_url() and method == "GET":
|
||||||
|
response = self._get_discovery_document()
|
||||||
|
elif url.startswith(self.get_authorize_url()) and method == "GET":
|
||||||
|
response = self._get_authorize_response(url)
|
||||||
|
elif url == f"{BASE_URL}/token" and method == "POST":
|
||||||
|
response = self._get_token_response(body)
|
||||||
|
elif url == f"{BASE_URL}/jwks" and method == "GET":
|
||||||
|
response = self._get_jwks_response()
|
||||||
|
else:
|
||||||
|
response = {"error": "Unknown endpoint"}, 404
|
||||||
|
|
||||||
|
_LOGGER.debug("Responding with: %s", response)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _get_discovery_document(self) -> tuple[dict, int]:
|
||||||
|
"""Return a mock discovery document."""
|
||||||
|
|
||||||
|
if "discovery" in self._scenario:
|
||||||
|
return self._scenario["discovery"], 200
|
||||||
|
|
||||||
|
return {
|
||||||
|
"issuer": BASE_URL,
|
||||||
|
"authorization_endpoint": self.get_authorize_url(),
|
||||||
|
"token_endpoint": f"{BASE_URL}/token",
|
||||||
|
"userinfo_endpoint": f"{BASE_URL}/userinfo",
|
||||||
|
"jwks_uri": f"{BASE_URL}/jwks",
|
||||||
|
"id_token_signing_alg_values_supported": ["RS256"],
|
||||||
|
}, 200
|
||||||
|
|
||||||
|
def _get_authorize_response(self, url: str) -> tuple[dict, int]:
|
||||||
|
"""Return a mock authorization response."""
|
||||||
|
# Parse the url
|
||||||
|
parsed_url = urlparse(url)
|
||||||
|
query_params = parse_qs(parsed_url.query)
|
||||||
|
|
||||||
|
code = self.get_random_code()
|
||||||
|
self._code_storage[code] = query_params
|
||||||
|
|
||||||
|
return {"code": code, "state": "xyz"}, 200
|
||||||
|
|
||||||
|
def _get_token_response(self, body: dict) -> tuple[dict, int]:
|
||||||
|
"""Return a mock token response."""
|
||||||
|
|
||||||
|
if body.get("code") in self._code_storage:
|
||||||
|
# TODO: Verify PKCE?
|
||||||
|
return {
|
||||||
|
"access_token": "exampleAccessToken",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"expires_in": 3600,
|
||||||
|
"id_token": self._create_id_token(body.get("code")),
|
||||||
|
}, 200
|
||||||
|
else:
|
||||||
|
return {"error": "invalid_request"}, 400
|
||||||
|
|
||||||
|
def _create_id_token(self, code: str) -> str:
|
||||||
|
"""Create a mock ID token."""
|
||||||
|
# Get the query params
|
||||||
|
if code not in self._code_storage:
|
||||||
|
raise ValueError("Invalid code")
|
||||||
|
query_params = self._code_storage[code]
|
||||||
|
_LOGGER.debug("Creating ID token with query params: %s", query_params)
|
||||||
|
|
||||||
|
# Get username
|
||||||
|
if "username" in self._scenario:
|
||||||
|
username = self._scenario["username"]
|
||||||
|
else:
|
||||||
|
username = "testuser"
|
||||||
|
|
||||||
|
# Create a simple signed JWT with our JWK
|
||||||
|
header = {"alg": self._jwk.alg, "kid": self._jwk.kid}
|
||||||
|
claims = {
|
||||||
|
"iss": BASE_URL,
|
||||||
|
"sub": SUBJECT,
|
||||||
|
"aud": query_params.get("client_id", [""])[0],
|
||||||
|
"nonce": query_params.get("nonce", [""])[0],
|
||||||
|
"name": "Test Name",
|
||||||
|
"preferred_username": username,
|
||||||
|
}
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
claims["nbf"] = now
|
||||||
|
claims["iat"] = now
|
||||||
|
claims["exp"] = now + 3600 # 1 hour expiry
|
||||||
|
|
||||||
|
return jwt.encode(header, claims, self._jwk)
|
||||||
|
|
||||||
|
def _get_jwks_response(self) -> tuple[dict, int]:
|
||||||
|
"""Return a mock JWKS response."""
|
||||||
|
private_key = self._jwk
|
||||||
|
public_key_dict = private_key.as_dict(private=False)
|
||||||
|
public_key = RSAKey.import_key(
|
||||||
|
public_key_dict, {"use": "sig", "alg": "RS256", "kid": private_key.kid}
|
||||||
|
)
|
||||||
|
|
||||||
|
key_set = KeySet([public_key])
|
||||||
|
|
||||||
|
return key_set.as_dict(), 200
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_final_subject():
|
||||||
|
"""Return the subject that's returned to HA."""
|
||||||
|
return hashlib.sha256(f"{BASE_URL}.{SUBJECT}".encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def mock_oidc_responses(scenario: str | None = None):
|
||||||
|
"""Mock OIDC responses for testing."""
|
||||||
|
|
||||||
|
mock_oidc_server = MockOIDCServer(scenario)
|
||||||
|
|
||||||
|
def make_mock_response(json_data, status):
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.__aenter__.return_value = mock_response
|
||||||
|
mock_response.__aexit__.return_value = None
|
||||||
|
mock_response.json = AsyncMock(return_value=json_data)
|
||||||
|
mock_response.status = status
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
def default_handler(method, url, *args, **kwargs):
|
||||||
|
_LOGGER.debug("Mocked %s request to %s", method, url)
|
||||||
|
body = kwargs.get("data") or kwargs.get("json") or None
|
||||||
|
response = mock_oidc_server.process_request(url, method, body)
|
||||||
|
return make_mock_response(response[0], response[1])
|
||||||
|
|
||||||
|
def get_side_effect(url, *args, **kwargs):
|
||||||
|
return default_handler("GET", url, *args, **kwargs)
|
||||||
|
|
||||||
|
def post_side_effect(url, *args, **kwargs):
|
||||||
|
return default_handler("POST", url, *args, **kwargs)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("aiohttp.ClientSession.get", side_effect=get_side_effect) as get_patch,
|
||||||
|
patch("aiohttp.ClientSession.post", side_effect=post_side_effect) as post_patch,
|
||||||
|
):
|
||||||
|
yield (get_patch, post_patch, default_handler)
|
||||||
5
tests/mocks/scenarios/empty.json
Normal file
5
tests/mocks/scenarios/empty.json
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
10
tests/mocks/scenarios/invalid_code_challenge_types.json
Normal file
10
tests/mocks/scenarios/invalid_code_challenge_types.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
"issuer": "https://mock-oidc-server.local",
|
||||||
|
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||||
|
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||||
|
"jwks_uri": "https://mock-oidc-server.local/jwks",
|
||||||
|
"response_types_supported": ["code"],
|
||||||
|
"code_challenge_methods_supported": ["plain"]
|
||||||
|
}
|
||||||
|
}
|
||||||
10
tests/mocks/scenarios/invalid_grant_types.json
Normal file
10
tests/mocks/scenarios/invalid_grant_types.json
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
"issuer": "https://mock-oidc-server.local",
|
||||||
|
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||||
|
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||||
|
"jwks_uri": "https://mock-oidc-server.local/jwks",
|
||||||
|
"response_types_supported": ["code"],
|
||||||
|
"grant_types_supported": ["refresh_token"]
|
||||||
|
}
|
||||||
|
}
|
||||||
8
tests/mocks/scenarios/invalid_id_token_signing_alg.json
Normal file
8
tests/mocks/scenarios/invalid_id_token_signing_alg.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
"issuer": "https://mock-oidc-server.local",
|
||||||
|
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||||
|
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||||
|
"jwks_uri": "https://mock-oidc-server.local/jwks"
|
||||||
|
}
|
||||||
|
}
|
||||||
9
tests/mocks/scenarios/invalid_response_modes.json
Normal file
9
tests/mocks/scenarios/invalid_response_modes.json
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
"issuer": "https://mock-oidc-server.local",
|
||||||
|
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||||
|
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||||
|
"jwks_uri": "https://mock-oidc-server.local/jwks",
|
||||||
|
"response_modes_supported": ["post"]
|
||||||
|
}
|
||||||
|
}
|
||||||
9
tests/mocks/scenarios/invalid_response_types.json
Normal file
9
tests/mocks/scenarios/invalid_response_types.json
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
"issuer": "https://mock-oidc-server.local",
|
||||||
|
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||||
|
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||||
|
"jwks_uri": "https://mock-oidc-server.local/jwks",
|
||||||
|
"response_types_supported": ["token"]
|
||||||
|
}
|
||||||
|
}
|
||||||
8
tests/mocks/scenarios/invalid_url.json
Normal file
8
tests/mocks/scenarios/invalid_url.json
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
"issuer": "https://mock-oidc-server.local",
|
||||||
|
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||||
|
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||||
|
"jwks_uri": "/jwks"
|
||||||
|
}
|
||||||
|
}
|
||||||
7
tests/mocks/scenarios/missing_jwks.json
Normal file
7
tests/mocks/scenarios/missing_jwks.json
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
"issuer": "https://mock-oidc-server.local",
|
||||||
|
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||||
|
"token_endpoint": "https://mock-oidc-server.local/token"
|
||||||
|
}
|
||||||
|
}
|
||||||
6
tests/mocks/scenarios/missing_token.json
Normal file
6
tests/mocks/scenarios/missing_token.json
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
"issuer": "https://mock-oidc-server.local",
|
||||||
|
"authorization_endpoint": "https://mock-oidc-server.local/authorize"
|
||||||
|
}
|
||||||
|
}
|
||||||
5
tests/mocks/scenarios/only_issuer.json
Normal file
5
tests/mocks/scenarios/only_issuer.json
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
"issuer": "https://mock-oidc-server.local"
|
||||||
|
}
|
||||||
|
}
|
||||||
3
tests/mocks/scenarios/username.json
Normal file
3
tests/mocks/scenarios/username.json
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
{
|
||||||
|
"username": "foobar"
|
||||||
|
}
|
||||||
9
tests/mocks/scenarios/wrong_id_token_signing_alg.json
Normal file
9
tests/mocks/scenarios/wrong_id_token_signing_alg.json
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
{
|
||||||
|
"discovery": {
|
||||||
|
"issuer": "https://mock-oidc-server.local",
|
||||||
|
"authorization_endpoint": "https://mock-oidc-server.local/authorize",
|
||||||
|
"token_endpoint": "https://mock-oidc-server.local/token",
|
||||||
|
"jwks_uri": "https://mock-oidc-server.local/jwks",
|
||||||
|
"id_token_signing_alg_values_supported": ["HS256"]
|
||||||
|
}
|
||||||
|
}
|
||||||
1
tests/resources/fake_templates/index.html
Normal file
1
tests/resources/fake_templates/index.html
Normal file
@@ -0,0 +1 @@
|
|||||||
|
<p>Example template</p>
|
||||||
377
tests/test_hass_auth_provider.py
Normal file
377
tests/test_hass_auth_provider.py
Normal file
@@ -0,0 +1,377 @@
|
|||||||
|
"""Tests for the Auth Provider registration in HA"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import re
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from urllib.parse import parse_qs, unquote, urlparse
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
|
from homeassistant.components.person import DOMAIN as PERSON_DOMAIN
|
||||||
|
|
||||||
|
from custom_components.auth_oidc import DOMAIN
|
||||||
|
from custom_components.auth_oidc.config.const import (
|
||||||
|
DISCOVERY_URL,
|
||||||
|
CLIENT_ID,
|
||||||
|
FEATURES,
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION,
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING,
|
||||||
|
)
|
||||||
|
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
|
||||||
|
|
||||||
|
FAKE_REDIR_URL = "http://example.com/auth/authorize?response_type=code&redirect_uri=http%3A%2F%2Fexample.com%3A8123%2F%3Fauth_callback%3D1&client_id=http%3A%2F%2Fexample.com%3A8123%2F&state=example"
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool:
|
||||||
|
"""Set up the auth_oidc component."""
|
||||||
|
result = await async_setup_component(hass, DOMAIN, {DOMAIN: config})
|
||||||
|
|
||||||
|
if expect_success:
|
||||||
|
assert result
|
||||||
|
assert DOMAIN in hass.data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_success_auth_provider_registration(hass: HomeAssistant):
|
||||||
|
"""Test successful setup"""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure the auth provider is registered
|
||||||
|
auth_providers = hass.auth.get_auth_providers(DOMAIN)
|
||||||
|
assert len(auth_providers) == 1
|
||||||
|
|
||||||
|
# Public auth-provider contract: OIDC provider does not support HA MFA
|
||||||
|
assert auth_providers[0].support_mfa is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_ip_fallback_fails_closed_without_request_context(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
"""Provider should not invent a shared IP when request context is missing."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.provider.http.current_request"
|
||||||
|
) as current_request:
|
||||||
|
current_request.get.return_value = None
|
||||||
|
assert provider._resolve_ip() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_provider_cookie_header_sets_secure_when_requested(hass: HomeAssistant):
|
||||||
|
"""Cookie header should include Secure when HTTPS is in use."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||||
|
cookie_header = provider.get_cookie_header("state-id", secure=True)["set-cookie"]
|
||||||
|
|
||||||
|
assert "SameSite=Lax" in cookie_header
|
||||||
|
assert "HttpOnly" in cookie_header
|
||||||
|
assert "Secure" in cookie_header
|
||||||
|
|
||||||
|
|
||||||
|
async def login_user(hass: HomeAssistant, state_id: str):
|
||||||
|
"""Helper to login a user from the stored OIDC state."""
|
||||||
|
|
||||||
|
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||||
|
# This helper runs outside an HTTP request, so pass the known local test IP.
|
||||||
|
sub = await provider.async_get_subject(state_id, "127.0.0.1")
|
||||||
|
assert sub == MockOIDCServer.get_final_subject()
|
||||||
|
|
||||||
|
# Get credentials
|
||||||
|
credentials = await provider.async_get_or_create_credentials({"sub": sub})
|
||||||
|
assert credentials is not None
|
||||||
|
assert credentials.data["sub"] == sub
|
||||||
|
|
||||||
|
user = await hass.auth.async_get_or_create_user(credentials)
|
||||||
|
assert user.is_active
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_login_state(hass: HomeAssistant, hass_client):
|
||||||
|
"""Helper to complete the browser login flow and return the OIDC state id."""
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
redirect_uri = FAKE_REDIR_URL
|
||||||
|
encoded_redirect_uri = base64.b64encode(redirect_uri.encode("utf-8")).decode(
|
||||||
|
"utf-8"
|
||||||
|
)
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded_redirect_uri}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
state_id = resp.cookies["auth_oidc_state"].value
|
||||||
|
|
||||||
|
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||||
|
assert resp.status == 200
|
||||||
|
html = await resp.text()
|
||||||
|
match = re.search(r'decodeURIComponent\("([^"]+)"\)', html)
|
||||||
|
assert match is not None
|
||||||
|
auth_url = unquote(match.group(1))
|
||||||
|
|
||||||
|
parsed_url = urlparse(auth_url)
|
||||||
|
query_params = parse_qs(parsed_url.query)
|
||||||
|
assert query_params["state"][0] == state_id
|
||||||
|
|
||||||
|
session = async_get_clientsession(hass)
|
||||||
|
resp = session.get(auth_url, allow_redirects=False)
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
# Mock OIDC returns JSON
|
||||||
|
json_parsed = await resp.json()
|
||||||
|
assert "code" in json_parsed and json_parsed["code"]
|
||||||
|
|
||||||
|
code = json_parsed["code"]
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/callback?code={code}&state={state_id}", allow_redirects=False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == 302
|
||||||
|
assert resp.headers["Location"].endswith("/auth/oidc/finish")
|
||||||
|
|
||||||
|
return state_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_login(hass: HomeAssistant, hass_client):
|
||||||
|
"""Test a full login flow."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock_oidc_responses():
|
||||||
|
# Actually start the login and get a code
|
||||||
|
state_id = await get_login_state(hass, hass_client)
|
||||||
|
|
||||||
|
# Use the stored state to login directly with the registered auth provider
|
||||||
|
# Inspired by tests for the built-in providers
|
||||||
|
user = await login_user(hass, state_id)
|
||||||
|
assert user.name == "Test Name"
|
||||||
|
|
||||||
|
# Login again to see if we trigger the re-use path
|
||||||
|
state_id2 = await get_login_state(hass, hass_client)
|
||||||
|
user2 = await login_user(hass, state_id2)
|
||||||
|
assert user2.id == user.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_with_linking(hass: HomeAssistant, hass_client):
|
||||||
|
"""Test a linking login."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock_oidc_responses("username"):
|
||||||
|
# Create a user first with username 'foobar'
|
||||||
|
user = await hass.auth.async_create_user("Foo Bar")
|
||||||
|
assert user.is_active
|
||||||
|
|
||||||
|
hass_provider = hass.auth.get_auth_providers("homeassistant")[0]
|
||||||
|
credential = await hass_provider.async_get_or_create_credentials(
|
||||||
|
{"username": "foobar"}
|
||||||
|
)
|
||||||
|
await hass.auth.async_link_user(user, credential)
|
||||||
|
|
||||||
|
# Actually start the login and get a code
|
||||||
|
state_id = await get_login_state(hass, hass_client)
|
||||||
|
|
||||||
|
# Use the stored state to login directly with the registered auth provider
|
||||||
|
user2 = await login_user(hass, state_id)
|
||||||
|
assert user2.id == user.id # Assert that the user was linked
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_with_person_create(hass: HomeAssistant, hass_client):
|
||||||
|
"""Test a person create."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await async_setup_component(hass, PERSON_DOMAIN, {})
|
||||||
|
|
||||||
|
with mock_oidc_responses():
|
||||||
|
state_id = await get_login_state(hass, hass_client)
|
||||||
|
user = await login_user(hass, state_id)
|
||||||
|
assert user.is_active
|
||||||
|
|
||||||
|
# Find the person associated to this user using the PersonRegistry API
|
||||||
|
person_store = hass.data[PERSON_DOMAIN][1]
|
||||||
|
persons = person_store.async_items()
|
||||||
|
assert len(persons) == 1
|
||||||
|
|
||||||
|
person = persons[0]
|
||||||
|
assert person["user_id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_without_person_create_does_not_create_person(
|
||||||
|
hass: HomeAssistant, hass_client
|
||||||
|
):
|
||||||
|
"""Test that person creation can be disabled."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await async_setup_component(hass, PERSON_DOMAIN, {})
|
||||||
|
|
||||||
|
with mock_oidc_responses():
|
||||||
|
state_id = await get_login_state(hass, hass_client)
|
||||||
|
user = await login_user(hass, state_id)
|
||||||
|
assert user.is_active
|
||||||
|
|
||||||
|
person_store = hass.data[PERSON_DOMAIN][1]
|
||||||
|
persons = person_store.async_items()
|
||||||
|
assert len(persons) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_shows_form(hass: HomeAssistant):
|
||||||
|
"""Test a login"""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||||
|
flow = await provider.async_login_flow({})
|
||||||
|
|
||||||
|
result = await flow.async_step_init({})
|
||||||
|
assert result["type"] == FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "no_oidc_cookie_found"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_with_invalid_cookie_aborts(hass: HomeAssistant):
|
||||||
|
"""A cookie that does not map to a valid state should fail closed."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||||
|
flow = await provider.async_login_flow({})
|
||||||
|
|
||||||
|
fake_request = SimpleNamespace(
|
||||||
|
cookies={"auth_oidc_state": "missing-state"}, remote="127.0.0.1"
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.provider.http.current_request"
|
||||||
|
) as current_request:
|
||||||
|
current_request.get.return_value = fake_request
|
||||||
|
|
||||||
|
result = await flow.async_step_init({})
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "oidc_cookie_invalid"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_with_no_cookie_aborts(hass: HomeAssistant):
|
||||||
|
"""Missing cookie should fail closed."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: False,
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = hass.auth.get_auth_providers(DOMAIN)[0]
|
||||||
|
flow = await provider.async_login_flow({})
|
||||||
|
|
||||||
|
fake_request = SimpleNamespace(cookies={}, remote="127.0.0.1")
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.provider.http.current_request"
|
||||||
|
) as current_request:
|
||||||
|
current_request.get.return_value = fake_request
|
||||||
|
|
||||||
|
result = await flow.async_step_init({})
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "no_oidc_cookie_found"
|
||||||
690
tests/test_hass_oidc_client_integration.py
Normal file
690
tests/test_hass_oidc_client_integration.py
Normal file
@@ -0,0 +1,690 @@
|
|||||||
|
"""Tests for the OIDC client"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from urllib.parse import parse_qs, unquote, urlparse, urlencode
|
||||||
|
import pytest
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
|
|
||||||
|
from custom_components.auth_oidc import DOMAIN
|
||||||
|
from custom_components.auth_oidc.tools.oidc_client import (
|
||||||
|
OIDCDiscoveryClient,
|
||||||
|
OIDCDiscoveryInvalid,
|
||||||
|
)
|
||||||
|
from custom_components.auth_oidc.config.const import (
|
||||||
|
DISCOVERY_URL,
|
||||||
|
CLIENT_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
|
||||||
|
|
||||||
|
EXAMPLE_CLIENT_ID = "http://example.com/"
|
||||||
|
WEB_CLIENT_ID = "https://example.com"
|
||||||
|
MOBILE_CLIENT_ID = "https://home-assistant.io/Android"
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
|
||||||
|
|
||||||
|
def encode_redirect_uri(redirect_uri: str) -> str:
|
||||||
|
"""Helper to encode redirect URI for welcome page."""
|
||||||
|
return base64.b64encode(redirect_uri.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def create_redirect_uri(client_id: str) -> str:
|
||||||
|
"""Create a redirect URI for Home Assistant Android app."""
|
||||||
|
params = {
|
||||||
|
"response_type": "code",
|
||||||
|
"redirect_uri": client_id,
|
||||||
|
"client_id": client_id,
|
||||||
|
"state": "example",
|
||||||
|
}
|
||||||
|
|
||||||
|
return f"http://example.com/auth/authorize?{urlencode(params)}"
|
||||||
|
|
||||||
|
|
||||||
|
async def get_welcome_for_client(client, redirect_uri: str) -> tuple[str, str, int]:
|
||||||
|
"""Go to welcome page and return state cookie, HTML content, and status.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (state_id, html_content, status_code)
|
||||||
|
"""
|
||||||
|
encoded_uri = encode_redirect_uri(redirect_uri)
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded_uri}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
state = resp.cookies["auth_oidc_state"].value
|
||||||
|
html = await resp.text() if resp.status == 200 else ""
|
||||||
|
return state, html, resp.status
|
||||||
|
|
||||||
|
|
||||||
|
async def get_redirect_auth_url(client) -> str:
|
||||||
|
"""Go to redirect page and extract the authorization URL.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The full authorization URL to send to the OIDC provider
|
||||||
|
"""
|
||||||
|
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||||
|
assert resp.status == 200
|
||||||
|
html = await resp.text()
|
||||||
|
|
||||||
|
match = re.search(r'decodeURIComponent\("([^"]+)"\)', html)
|
||||||
|
assert match is not None, "Authorization URL not found in redirect page"
|
||||||
|
return unquote(match.group(1))
|
||||||
|
|
||||||
|
|
||||||
|
async def complete_callback_and_finish(client, code: str, state: str):
|
||||||
|
"""Complete the callback and finish flow.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The state_id cookie value after completion
|
||||||
|
"""
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/callback?code={code}&state={state}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 302
|
||||||
|
assert resp.headers["Location"].endswith("/auth/oidc/finish")
|
||||||
|
|
||||||
|
resp_finish = await client.get("/auth/oidc/finish", allow_redirects=False)
|
||||||
|
assert resp_finish.status == 200
|
||||||
|
finish_html = await resp_finish.text()
|
||||||
|
assert 'id="continue-on-this-device"' in finish_html
|
||||||
|
assert 'id="device-code-input"' in finish_html
|
||||||
|
assert 'id="approve-login-button"' in finish_html
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_back_redirect(client, expected_redirect_uri: str):
|
||||||
|
"""Verify that POST to finish without body redirects back to the original redirect_uri."""
|
||||||
|
resp_finish_post = await client.post("/auth/oidc/finish", allow_redirects=False)
|
||||||
|
assert resp_finish_post.status == 302
|
||||||
|
|
||||||
|
location = resp_finish_post.headers["Location"]
|
||||||
|
assert location.startswith(unquote(expected_redirect_uri))
|
||||||
|
assert "skip_oidc_redirect=true" in location
|
||||||
|
|
||||||
|
|
||||||
|
async def listen_for_sse_events(
|
||||||
|
resp_sse,
|
||||||
|
expected_event: str,
|
||||||
|
timeout_seconds: int = 5,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Listen for SSE events and return once the expected event is received.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resp_sse: The SSE response stream
|
||||||
|
expected_event: The event type to listen for (e.g., "waiting" or "ready")
|
||||||
|
timeout_seconds: Maximum time to wait for the event
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of received event lines
|
||||||
|
"""
|
||||||
|
|
||||||
|
if resp_sse is None:
|
||||||
|
raise ValueError("resp_sse cannot be None")
|
||||||
|
|
||||||
|
received_events = []
|
||||||
|
|
||||||
|
async def stream_reader():
|
||||||
|
try:
|
||||||
|
async for line in resp_sse.content:
|
||||||
|
decoded_line = line.decode("utf-8").strip()
|
||||||
|
if not decoded_line:
|
||||||
|
continue
|
||||||
|
|
||||||
|
received_events.append(decoded_line)
|
||||||
|
|
||||||
|
# Check if this is an event line
|
||||||
|
if decoded_line.startswith("event:"):
|
||||||
|
event_type = decoded_line.split(":", 1)[1].strip()
|
||||||
|
if event_type == expected_event:
|
||||||
|
# Found the expected event, return successfully.
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Device SSE may emit multiple waiting events before ready.
|
||||||
|
if expected_event == "ready" and event_type == "waiting":
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise AssertionError(
|
||||||
|
f"Unexpected event type '{event_type}'. Expected: {expected_event}"
|
||||||
|
)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(stream_reader(), timeout=timeout_seconds)
|
||||||
|
if result:
|
||||||
|
return received_events
|
||||||
|
except asyncio.TimeoutError as exc:
|
||||||
|
raise AssertionError(
|
||||||
|
f"Timeout after {timeout_seconds}s waiting for '{expected_event}' event"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
raise AssertionError(f"Failed to receive '{expected_event}' event")
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(hass: HomeAssistant):
|
||||||
|
"""Set up the integration within Home Assistant"""
|
||||||
|
mock_config = {
|
||||||
|
DOMAIN: {
|
||||||
|
CLIENT_ID: EXAMPLE_CLIENT_ID,
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await async_setup_component(hass, DOMAIN, mock_config)
|
||||||
|
assert result
|
||||||
|
|
||||||
|
|
||||||
|
# Actual tests
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_oidc_flow(hass: HomeAssistant, hass_client):
|
||||||
|
"""Test that one full OIDC flow works if OIDC is mocked."""
|
||||||
|
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
with mock_oidc_responses():
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
||||||
|
|
||||||
|
# Go to welcome and get state cookie
|
||||||
|
state, _, status = await get_welcome_for_client(client, redirect_uri)
|
||||||
|
assert status == 200
|
||||||
|
assert state is not None
|
||||||
|
|
||||||
|
# Get authorization URL from redirect page
|
||||||
|
authorization_url = await get_redirect_auth_url(client)
|
||||||
|
assert authorization_url.startswith(MockOIDCServer.get_authorize_url())
|
||||||
|
|
||||||
|
# Parse the rendered redirect URL and test the query params for correctness
|
||||||
|
parsed_url = urlparse(authorization_url)
|
||||||
|
query_params = parse_qs(parsed_url.query)
|
||||||
|
|
||||||
|
assert "response_type" in query_params and query_params.get(
|
||||||
|
"response_type"
|
||||||
|
) == ["code"]
|
||||||
|
assert "client_id" in query_params and query_params.get("client_id") == [
|
||||||
|
EXAMPLE_CLIENT_ID
|
||||||
|
]
|
||||||
|
assert "scope" in query_params and query_params.get("scope") == [
|
||||||
|
"openid profile groups"
|
||||||
|
]
|
||||||
|
assert "state" in query_params and query_params["state"]
|
||||||
|
assert query_params["state"][0] == state
|
||||||
|
assert len(query_params["state"][0]) >= 16 # Ensure state is sufficiently long
|
||||||
|
assert (
|
||||||
|
"redirect_uri" in query_params
|
||||||
|
and query_params["redirect_uri"]
|
||||||
|
and query_params["redirect_uri"][0].endswith("/auth/oidc/callback")
|
||||||
|
)
|
||||||
|
assert "nonce" in query_params and query_params["nonce"]
|
||||||
|
assert "code_challenge_method" in query_params and query_params.get(
|
||||||
|
"code_challenge_method"
|
||||||
|
) == ["S256"]
|
||||||
|
assert "code_challenge" in query_params and query_params["code_challenge"]
|
||||||
|
|
||||||
|
session = async_get_clientsession(hass)
|
||||||
|
resp = session.get(authorization_url, allow_redirects=False)
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
# JSON response from mock server, normally would be interactive
|
||||||
|
json_parsed = await resp.json()
|
||||||
|
assert "code" in json_parsed and json_parsed["code"]
|
||||||
|
|
||||||
|
# Now go back to the callback with a sample code
|
||||||
|
code = json_parsed["code"]
|
||||||
|
|
||||||
|
await complete_callback_and_finish(client, code, state)
|
||||||
|
|
||||||
|
# POST to finish without any POST body should result in 302 back to the original redirect_uri
|
||||||
|
await verify_back_redirect(client, redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
|
async def discovery_test_through_redirect(
|
||||||
|
hass_client, caplog, scenario: str, match_log_line: str
|
||||||
|
):
|
||||||
|
"""Test that discovery document retrieval fails gracefully through redirect endpoint."""
|
||||||
|
with mock_oidc_responses(scenario):
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
||||||
|
|
||||||
|
await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encode_redirect_uri(redirect_uri)}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||||
|
|
||||||
|
# Find matching log line
|
||||||
|
assert match_log_line in caplog.text
|
||||||
|
|
||||||
|
# Assert that we get an error response with an error message
|
||||||
|
assert resp.status == 500
|
||||||
|
text = await resp.text()
|
||||||
|
assert "Integration is misconfigured, discovery could not be obtained." in text
|
||||||
|
|
||||||
|
|
||||||
|
async def direct_discovery_test(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
scenario: str,
|
||||||
|
match_type: str,
|
||||||
|
match_log_line: str | None = None,
|
||||||
|
):
|
||||||
|
"""Test that discovery document retrieval fails with nice error directly."""
|
||||||
|
with mock_oidc_responses(scenario):
|
||||||
|
session = async_get_clientsession(hass)
|
||||||
|
client = OIDCDiscoveryClient(
|
||||||
|
MockOIDCServer.get_discovery_url(),
|
||||||
|
session,
|
||||||
|
{
|
||||||
|
"id_token_signing_alg": "RS256",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(OIDCDiscoveryInvalid) as exc_info:
|
||||||
|
await client.fetch_discovery_document()
|
||||||
|
|
||||||
|
assert exc_info.value.type == match_type
|
||||||
|
assert exc_info.value.get_detail_string().startswith("type: " + match_type)
|
||||||
|
|
||||||
|
if match_log_line:
|
||||||
|
assert match_log_line in exc_info.value.get_detail_string()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discovery_failures(hass: HomeAssistant, hass_client, caplog):
|
||||||
|
"""Test that discovery document retrieval fails gracefully."""
|
||||||
|
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
# Empty scenario
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client, caplog, "empty", "is missing required endpoint: issuer"
|
||||||
|
)
|
||||||
|
await direct_discovery_test(hass, "empty", "missing_endpoint", "endpoint: issuer")
|
||||||
|
|
||||||
|
# Missing authorization_endpoint
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client,
|
||||||
|
caplog,
|
||||||
|
"only_issuer",
|
||||||
|
"is missing required endpoint: authorization_endpoint",
|
||||||
|
)
|
||||||
|
await direct_discovery_test(
|
||||||
|
hass, "only_issuer", "missing_endpoint", "endpoint: authorization_endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Missing token_endpoint
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client,
|
||||||
|
caplog,
|
||||||
|
"missing_token",
|
||||||
|
"is missing required endpoint: token_endpoint",
|
||||||
|
)
|
||||||
|
await direct_discovery_test(
|
||||||
|
hass, "missing_token", "missing_endpoint", "endpoint: token_endpoint"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Missing jwks_uri
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client,
|
||||||
|
caplog,
|
||||||
|
"missing_jwks",
|
||||||
|
"is missing required endpoint: jwks_uri",
|
||||||
|
)
|
||||||
|
await direct_discovery_test(
|
||||||
|
hass, "missing_jwks", "missing_endpoint", "endpoint: jwks_uri"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalid response_modes_supported
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client,
|
||||||
|
caplog,
|
||||||
|
"invalid_response_modes",
|
||||||
|
"does not support required 'query' response mode, only supports: ['post']",
|
||||||
|
)
|
||||||
|
await direct_discovery_test(
|
||||||
|
hass, "invalid_response_modes", "does_not_support_response_mode", "post"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalid grant_types supported
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client,
|
||||||
|
caplog,
|
||||||
|
"invalid_grant_types",
|
||||||
|
"does not support required 'authorization_code' grant type, only supports: ['refresh_token']",
|
||||||
|
)
|
||||||
|
await direct_discovery_test(
|
||||||
|
hass, "invalid_grant_types", "does_not_support_grant_type", "refresh_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalid response types
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client,
|
||||||
|
caplog,
|
||||||
|
"invalid_response_types",
|
||||||
|
"does not support required 'code' response type, only supports: ['token']",
|
||||||
|
)
|
||||||
|
await direct_discovery_test(
|
||||||
|
hass, "invalid_response_types", "does_not_support_response_type", "token"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalid code_challenge types
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client,
|
||||||
|
caplog,
|
||||||
|
"invalid_code_challenge_types",
|
||||||
|
"does not support required 'S256' code challenge method, only supports: ['plain']",
|
||||||
|
)
|
||||||
|
await direct_discovery_test(
|
||||||
|
hass,
|
||||||
|
"invalid_code_challenge_types",
|
||||||
|
"does_not_support_required_code_challenge_method",
|
||||||
|
"plain",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalid id_token_signing alg
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client,
|
||||||
|
caplog,
|
||||||
|
"invalid_id_token_signing_alg",
|
||||||
|
"does not have 'id_token_signing_alg_values_supported' field",
|
||||||
|
)
|
||||||
|
await direct_discovery_test(
|
||||||
|
hass, "invalid_id_token_signing_alg", "missing_id_token_signing_alg_values"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Not matching id_token_signing alg
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client,
|
||||||
|
caplog,
|
||||||
|
"wrong_id_token_signing_alg",
|
||||||
|
"does not support requested id_token_signing_alg 'RS256', only supports: ['HS256']",
|
||||||
|
)
|
||||||
|
await direct_discovery_test(
|
||||||
|
hass,
|
||||||
|
"wrong_id_token_signing_alg",
|
||||||
|
"does_not_support_id_token_signing_alg",
|
||||||
|
"requested: RS256, supported: ['HS256']",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalid URL
|
||||||
|
await discovery_test_through_redirect(
|
||||||
|
hass_client,
|
||||||
|
caplog,
|
||||||
|
"invalid_url",
|
||||||
|
"has invalid URL in endpoint: jwks_uri (/jwks)",
|
||||||
|
)
|
||||||
|
await direct_discovery_test(
|
||||||
|
hass,
|
||||||
|
"invalid_url",
|
||||||
|
"invalid_endpoint",
|
||||||
|
"endpoint: jwks_uri, url: /jwks",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_jwks_fetch(hass: HomeAssistant):
|
||||||
|
"""Test direct fetch of JWKS."""
|
||||||
|
with mock_oidc_responses():
|
||||||
|
session = async_get_clientsession(hass)
|
||||||
|
client = OIDCDiscoveryClient(
|
||||||
|
MockOIDCServer.get_discovery_url(),
|
||||||
|
session,
|
||||||
|
{
|
||||||
|
"id_token_signing_alg": "RS256",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await client.fetch_discovery_document()
|
||||||
|
jwks = await client.fetch_jwks()
|
||||||
|
assert "keys" in jwks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_device_login_flow_two_browsers(hass: HomeAssistant, hass_client):
|
||||||
|
"""Test device login flow with two separate browser sessions.
|
||||||
|
|
||||||
|
This simulates:
|
||||||
|
- Mobile device (Device 1) generating a device code and waiting via SSE
|
||||||
|
- Desktop browser (Device 2) completing full OAuth flow and linking the code
|
||||||
|
- Mobile device receiving ready event after code is linked
|
||||||
|
"""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
with mock_oidc_responses():
|
||||||
|
# ==================== DEVICE 1: Mobile ====================
|
||||||
|
# Mobile client starts the login flow
|
||||||
|
mobile_client = await hass_client()
|
||||||
|
mobile_redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
|
||||||
|
|
||||||
|
mobile_state, mobile_html, status = await get_welcome_for_client(
|
||||||
|
mobile_client, mobile_redirect_uri
|
||||||
|
)
|
||||||
|
assert status == 200
|
||||||
|
assert mobile_state is not None
|
||||||
|
assert 'id="device-instructions"' in mobile_html
|
||||||
|
assert 'id="device-code"' in mobile_html
|
||||||
|
|
||||||
|
# Extract device code from the welcome page.
|
||||||
|
# The code is rendered in a div with id="device-code".
|
||||||
|
device_code_match = re.search(
|
||||||
|
r'id=["\']device-code["\'][^>]*>\s*([^<\s]+)\s*<',
|
||||||
|
mobile_html,
|
||||||
|
)
|
||||||
|
assert device_code_match is not None, (
|
||||||
|
"Device code should be generated for mobile client"
|
||||||
|
)
|
||||||
|
mobile_device_code = device_code_match.group(1)
|
||||||
|
assert len(mobile_device_code) > 0
|
||||||
|
|
||||||
|
# ==================== DEVICE 2: Desktop ====================
|
||||||
|
# Desktop client in a separate session
|
||||||
|
desktop_client = await hass_client()
|
||||||
|
desktop_redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
||||||
|
|
||||||
|
desktop_state, _, status = await get_welcome_for_client(
|
||||||
|
desktop_client, desktop_redirect_uri
|
||||||
|
)
|
||||||
|
assert status in [200, 302]
|
||||||
|
assert desktop_state is not None
|
||||||
|
|
||||||
|
# Desktop goes through redirect to get the authorization URL
|
||||||
|
authorization_url = await get_redirect_auth_url(desktop_client)
|
||||||
|
assert authorization_url.startswith(MockOIDCServer.get_authorize_url())
|
||||||
|
|
||||||
|
# Desktop gets the authorization code from OIDC provider
|
||||||
|
session = async_get_clientsession(hass)
|
||||||
|
resp_auth = session.get(authorization_url, allow_redirects=False)
|
||||||
|
assert resp_auth.status == 200
|
||||||
|
json_auth = await resp_auth.json()
|
||||||
|
assert "code" in json_auth
|
||||||
|
desktop_code = json_auth["code"]
|
||||||
|
|
||||||
|
await complete_callback_and_finish(desktop_client, desktop_code, desktop_state)
|
||||||
|
|
||||||
|
# ==================== Mobile Device Finalizes Flow ====================
|
||||||
|
# Mobile device polls SSE and keeps the connection open throughout
|
||||||
|
resp_sse = await mobile_client.get(
|
||||||
|
"/auth/oidc/device-sse", allow_redirects=False
|
||||||
|
)
|
||||||
|
assert resp_sse.status == 200
|
||||||
|
|
||||||
|
# Listen for waiting events for up to 5 seconds
|
||||||
|
await listen_for_sse_events(resp_sse, "waiting", timeout_seconds=5)
|
||||||
|
|
||||||
|
# Actually submit the mobile code using POST
|
||||||
|
resp_code = await desktop_client.post(
|
||||||
|
"/auth/oidc/finish",
|
||||||
|
data={"device_code": mobile_device_code},
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp_code.status == 200
|
||||||
|
assert resp_code.headers.get("Content-Type", "").startswith("text/html")
|
||||||
|
html_code = await resp_code.text()
|
||||||
|
assert 'id="mobile-success-message"' in html_code
|
||||||
|
assert 'id="restart-login-button"' in html_code
|
||||||
|
|
||||||
|
# ==================== Mobile Device Receives Ready Event ====================
|
||||||
|
# After desktop flow is completed, mobile SSE should receive a ready event on same connection
|
||||||
|
await listen_for_sse_events(resp_sse, "ready", timeout_seconds=5)
|
||||||
|
|
||||||
|
# POST to finish without any POST body should result in 302 back to the original redirect_uri
|
||||||
|
await verify_back_redirect(mobile_client, mobile_redirect_uri)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finish_rejects_device_code_when_state_not_ready(
|
||||||
|
hass: HomeAssistant, hass_client
|
||||||
|
):
|
||||||
|
"""Submitting a device code must fail if callback did not complete for this browser."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
with mock_oidc_responses():
|
||||||
|
# Device session that owns the device code.
|
||||||
|
mobile_client = await hass_client()
|
||||||
|
mobile_redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
|
||||||
|
_, mobile_html, status = await get_welcome_for_client(
|
||||||
|
mobile_client, mobile_redirect_uri
|
||||||
|
)
|
||||||
|
assert status == 200
|
||||||
|
|
||||||
|
device_code_match = re.search(
|
||||||
|
r'id=["\']device-code["\'][^>]*>\s*([^<\s]+)\s*<',
|
||||||
|
mobile_html,
|
||||||
|
)
|
||||||
|
assert device_code_match is not None
|
||||||
|
mobile_device_code = device_code_match.group(1)
|
||||||
|
|
||||||
|
# Separate browser starts but does not complete callback flow.
|
||||||
|
desktop_client = await hass_client()
|
||||||
|
desktop_redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
||||||
|
_, _, desktop_status = await get_welcome_for_client(
|
||||||
|
desktop_client, desktop_redirect_uri
|
||||||
|
)
|
||||||
|
assert desktop_status in [200, 302]
|
||||||
|
|
||||||
|
# Negative branch: try to finalize before desktop state has user info.
|
||||||
|
resp = await desktop_client.post(
|
||||||
|
"/auth/oidc/finish",
|
||||||
|
data={"device_code": mobile_device_code},
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
text = await resp.text()
|
||||||
|
assert "Failed to link state to device code" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_shows_error_if_userinfo_save_fails(
|
||||||
|
hass: HomeAssistant, hass_client
|
||||||
|
):
|
||||||
|
"""Callback should return error page when state save fails after successful token flow."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock_oidc_responses(),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_save_user_info",
|
||||||
|
new=AsyncMock(return_value=False),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
||||||
|
state, _, status = await get_welcome_for_client(client, redirect_uri)
|
||||||
|
assert status == 200
|
||||||
|
|
||||||
|
authorization_url = await get_redirect_auth_url(client)
|
||||||
|
session = async_get_clientsession(hass)
|
||||||
|
resp_auth = session.get(authorization_url, allow_redirects=False)
|
||||||
|
json_auth = await resp_auth.json()
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/callback?code={json_auth['code']}&state={state}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 500
|
||||||
|
text = await resp.text()
|
||||||
|
assert "Failed to save user information, session probably expired." in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_rejects_nonce_mismatch(hass: HomeAssistant, hass_client):
|
||||||
|
"""Callback should fail closed when the returned nonce does not match the stored flow nonce."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
with (
|
||||||
|
mock_oidc_responses(),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.OIDCClient._parse_id_token",
|
||||||
|
new=AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"sub": "test-user",
|
||||||
|
"nonce": "mismatched-nonce",
|
||||||
|
"name": "Test Name",
|
||||||
|
"preferred_username": "testuser",
|
||||||
|
"groups": [],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
||||||
|
|
||||||
|
state, _, status = await get_welcome_for_client(client, redirect_uri)
|
||||||
|
assert status == 200
|
||||||
|
|
||||||
|
authorization_url = await get_redirect_auth_url(client)
|
||||||
|
session = async_get_clientsession(hass)
|
||||||
|
resp_auth = session.get(authorization_url, allow_redirects=False)
|
||||||
|
json_auth = await resp_auth.json()
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/callback?code={json_auth['code']}&state={state}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 500
|
||||||
|
text = await resp.text()
|
||||||
|
assert "Failed to get user details" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_replay_is_rejected(hass: HomeAssistant, hass_client):
|
||||||
|
"""A callback replay with the same state should be rejected after first successful use."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
with mock_oidc_responses():
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(WEB_CLIENT_ID)
|
||||||
|
|
||||||
|
state, _, status = await get_welcome_for_client(client, redirect_uri)
|
||||||
|
assert status == 200
|
||||||
|
|
||||||
|
authorization_url = await get_redirect_auth_url(client)
|
||||||
|
session = async_get_clientsession(hass)
|
||||||
|
resp_auth = session.get(authorization_url, allow_redirects=False)
|
||||||
|
json_auth = await resp_auth.json()
|
||||||
|
code = json_auth["code"]
|
||||||
|
|
||||||
|
# First callback should succeed.
|
||||||
|
first = await client.get(
|
||||||
|
f"/auth/oidc/callback?code={code}&state={state}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert first.status == 302
|
||||||
|
|
||||||
|
# Replay should fail because the state flow has already been consumed.
|
||||||
|
replay = await client.get(
|
||||||
|
f"/auth/oidc/callback?code={code}&state={state}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert replay.status == 500
|
||||||
|
replay_text = await replay.text()
|
||||||
|
assert "Failed to get user details" in replay_text
|
||||||
818
tests/test_hass_oidc_client_unit.py
Normal file
818
tests/test_hass_oidc_client_unit.py
Normal file
@@ -0,0 +1,818 @@
|
|||||||
|
"""Unit tests for OIDC client token and security behavior."""
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import base64
|
||||||
|
import time
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from joserfc import errors as joserfc_errors, jwt, jwk
|
||||||
|
|
||||||
|
from custom_components.auth_oidc.tools.oidc_client import (
|
||||||
|
HTTPClientError,
|
||||||
|
OIDCClient,
|
||||||
|
OIDCDiscoveryInvalid,
|
||||||
|
OIDCIdTokenSigningAlgorithmInvalid,
|
||||||
|
OIDCTokenResponseInvalid,
|
||||||
|
OIDCUserinfoInvalid,
|
||||||
|
http_raise_for_status,
|
||||||
|
)
|
||||||
|
|
||||||
|
# List from https://jose.authlib.org/en/guide/algorithms/#json-web-signature
|
||||||
|
ALL_ID_TOKEN_SIGNING_ALGORITHMS = (
|
||||||
|
"HS256",
|
||||||
|
"HS384",
|
||||||
|
"HS512",
|
||||||
|
"RS256",
|
||||||
|
"RS384",
|
||||||
|
"RS512",
|
||||||
|
"ES256",
|
||||||
|
"ES384",
|
||||||
|
"ES512",
|
||||||
|
"PS256",
|
||||||
|
"PS384",
|
||||||
|
"PS512",
|
||||||
|
"ES256K",
|
||||||
|
"Ed25519",
|
||||||
|
"Ed448",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_client(hass: HomeAssistant, **kwargs) -> OIDCClient:
|
||||||
|
"""Build an OIDC client with explicit defaults for unit testing."""
|
||||||
|
return OIDCClient(
|
||||||
|
hass=hass,
|
||||||
|
discovery_url="https://issuer/.well-known/openid-configuration",
|
||||||
|
client_id="test-client",
|
||||||
|
scope="openid profile",
|
||||||
|
features=kwargs.pop("features", {}),
|
||||||
|
claims=kwargs.pop("claims", {}),
|
||||||
|
roles=kwargs.pop("roles", {}),
|
||||||
|
network=kwargs.pop("network", {}),
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_jwt(
|
||||||
|
header: dict | None,
|
||||||
|
payload: dict | None = None,
|
||||||
|
signature: str = "sig",
|
||||||
|
) -> str:
|
||||||
|
"""Build a compact JWT string for parser-focused tests."""
|
||||||
|
|
||||||
|
def _b64url_json(data: dict) -> str:
|
||||||
|
encoded = json.dumps(data, separators=(",", ":")).encode("utf-8")
|
||||||
|
return base64.urlsafe_b64encode(encoded).rstrip(b"=").decode("utf-8")
|
||||||
|
|
||||||
|
protected = _b64url_json(header) if header is not None else ""
|
||||||
|
claims = _b64url_json(payload or {"sub": "subject"})
|
||||||
|
return f"{protected}.{claims}.{signature}"
|
||||||
|
|
||||||
|
|
||||||
|
def make_signed_hs256_jwt(secret: str, claims: dict) -> str:
|
||||||
|
"""Build a real HS256 signed JWT for parser validation tests."""
|
||||||
|
jwk_obj = jwk.import_key(
|
||||||
|
{
|
||||||
|
"kty": "oct",
|
||||||
|
"k": base64.urlsafe_b64encode(secret.encode()).decode().rstrip("="),
|
||||||
|
"alg": "HS256",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return jwt.encode({"alg": "HS256"}, claims, jwk_obj)
|
||||||
|
|
||||||
|
|
||||||
|
def build_real_signed_token(
|
||||||
|
algorithm: str, claims: dict, secret: str
|
||||||
|
) -> tuple[str, dict]:
|
||||||
|
"""Build a real signed token and matching JWKS payload for a given algorithm."""
|
||||||
|
if algorithm.startswith("HS"):
|
||||||
|
signing_key = jwk.import_key(
|
||||||
|
{
|
||||||
|
"kty": "oct",
|
||||||
|
"k": base64.urlsafe_b64encode(secret.encode()).decode().rstrip("="),
|
||||||
|
"alg": algorithm,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
token = jwt.encode(
|
||||||
|
{"alg": algorithm}, claims, signing_key, algorithms=[algorithm]
|
||||||
|
)
|
||||||
|
return token, {"keys": []}
|
||||||
|
|
||||||
|
if algorithm in ("RS256", "RS384", "RS512", "PS256", "PS384", "PS512"):
|
||||||
|
key = jwk.generate_key(
|
||||||
|
"RSA", 2048, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True
|
||||||
|
)
|
||||||
|
elif algorithm in ("ES256", "ES384", "ES512", "ES256K"):
|
||||||
|
curve = {
|
||||||
|
"ES256": "P-256",
|
||||||
|
"ES384": "P-384",
|
||||||
|
"ES512": "P-521",
|
||||||
|
"ES256K": "secp256k1",
|
||||||
|
}[algorithm]
|
||||||
|
key = jwk.generate_key(
|
||||||
|
"EC", curve, {"alg": algorithm, "use": "sig"}, private=True, auto_kid=True
|
||||||
|
)
|
||||||
|
elif algorithm in ("Ed25519", "Ed448"):
|
||||||
|
key = jwk.generate_key(
|
||||||
|
"OKP",
|
||||||
|
algorithm,
|
||||||
|
{"alg": algorithm, "use": "sig"},
|
||||||
|
private=True,
|
||||||
|
auto_kid=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported test algorithm: {algorithm}")
|
||||||
|
|
||||||
|
kid = key.kid
|
||||||
|
token = jwt.encode(
|
||||||
|
{"alg": algorithm, "kid": kid},
|
||||||
|
claims,
|
||||||
|
key,
|
||||||
|
algorithms=[algorithm],
|
||||||
|
)
|
||||||
|
public_key = key.as_dict(private=False)
|
||||||
|
return token, {"keys": [public_key]}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_token_flow_rejects_missing_state(hass: HomeAssistant):
|
||||||
|
"""Flow state must exist; missing state should fail closed."""
|
||||||
|
client = make_client(hass)
|
||||||
|
|
||||||
|
result = await client.async_complete_token_flow(
|
||||||
|
"https://example.com/callback", "code", "missing-state"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_token_flow_rejects_nonce_mismatch(hass: HomeAssistant):
|
||||||
|
"""Nonce mismatch should reject the token flow."""
|
||||||
|
client = make_client(hass)
|
||||||
|
client.flows["state-1"] = {"code_verifier": "verifier", "nonce": "expected"}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_discovery_document",
|
||||||
|
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_make_token_request",
|
||||||
|
new=AsyncMock(return_value={"id_token": "id", "access_token": "access"}),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_parse_id_token",
|
||||||
|
new=AsyncMock(return_value={"sub": "abc", "nonce": "wrong"}),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await client.async_complete_token_flow(
|
||||||
|
"https://example.com/callback", "code", "state-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert "state-1" not in client.flows
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_token_flow_handles_token_request_failure(hass: HomeAssistant):
|
||||||
|
"""Token endpoint failures should return None to caller."""
|
||||||
|
client = make_client(hass)
|
||||||
|
client.flows["state-2"] = {"code_verifier": "verifier", "nonce": "nonce"}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_discovery_document",
|
||||||
|
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_make_token_request",
|
||||||
|
new=AsyncMock(side_effect=OIDCTokenResponseInvalid()),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await client.async_complete_token_flow(
|
||||||
|
"https://example.com/callback", "code", "state-2"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_user_details_handles_non_list_groups(hass: HomeAssistant):
|
||||||
|
"""Non-list groups should not accidentally grant roles."""
|
||||||
|
client = make_client(hass, roles={"user": "users", "admin": "admins"})
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_discovery_document",
|
||||||
|
new=AsyncMock(return_value={"issuer": "https://issuer"}),
|
||||||
|
):
|
||||||
|
details = await client.parse_user_details(
|
||||||
|
{
|
||||||
|
"sub": "subject",
|
||||||
|
"name": "Display Name",
|
||||||
|
"preferred_username": "username",
|
||||||
|
"groups": "admins",
|
||||||
|
},
|
||||||
|
"access-token",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert details["role"] == "invalid"
|
||||||
|
assert details["display_name"] == "Display Name"
|
||||||
|
assert details["username"] == "username"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_user_details_uses_userinfo_for_missing_claims(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
"""Missing claims in id_token should be filled from userinfo when available."""
|
||||||
|
client = make_client(hass)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_discovery_document",
|
||||||
|
new=AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"issuer": "https://issuer",
|
||||||
|
"userinfo_endpoint": "https://issuer/userinfo",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_get_userinfo",
|
||||||
|
new=AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"name": "From UserInfo",
|
||||||
|
"preferred_username": "userinfo-user",
|
||||||
|
"groups": ["admins"],
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
details = await client.parse_user_details({"sub": "subject"}, "access-token")
|
||||||
|
|
||||||
|
expected_sub = hashlib.sha256("https://issuer.subject".encode("utf-8")).hexdigest()
|
||||||
|
assert details["sub"] == expected_sub
|
||||||
|
assert details["display_name"] == "From UserInfo"
|
||||||
|
assert details["username"] == "userinfo-user"
|
||||||
|
assert details["role"] == "system-admin"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_user_details_assigns_system_users_role(hass: HomeAssistant):
|
||||||
|
"""Configured user role should map to system-users when group is present."""
|
||||||
|
client = make_client(hass, roles={"user": "users", "admin": "admins"})
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_discovery_document",
|
||||||
|
new=AsyncMock(return_value={"issuer": "https://issuer"}),
|
||||||
|
):
|
||||||
|
details = await client.parse_user_details(
|
||||||
|
{
|
||||||
|
"sub": "subject",
|
||||||
|
"name": "Display Name",
|
||||||
|
"preferred_username": "username",
|
||||||
|
"groups": ["users"],
|
||||||
|
},
|
||||||
|
"access-token",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert details["role"] == "system-users"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_user_details_admin_role_overrides_user_role(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
"""Admin group should take precedence when both user and admin groups are present."""
|
||||||
|
client = make_client(hass, roles={"user": "users", "admin": "admins"})
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_discovery_document",
|
||||||
|
new=AsyncMock(return_value={"issuer": "https://issuer"}),
|
||||||
|
):
|
||||||
|
details = await client.parse_user_details(
|
||||||
|
{
|
||||||
|
"sub": "subject",
|
||||||
|
"name": "Display Name",
|
||||||
|
"preferred_username": "username",
|
||||||
|
"groups": ["users", "admins"],
|
||||||
|
},
|
||||||
|
"access-token",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert details["role"] == "system-admin"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_authorization_url_omits_pkce_when_disabled(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
"""Authorization URL should omit PKCE params when compatibility mode disables PKCE."""
|
||||||
|
client = make_client(hass, features={"disable_rfc7636": True})
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_discovery_document",
|
||||||
|
new=AsyncMock(
|
||||||
|
return_value={"authorization_endpoint": "https://issuer/authorize"}
|
||||||
|
),
|
||||||
|
):
|
||||||
|
url = await client.async_get_authorization_url(
|
||||||
|
"https://example.com/callback", "state-xyz"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert url is not None
|
||||||
|
parsed = urlparse(url)
|
||||||
|
query = parse_qs(parsed.query)
|
||||||
|
|
||||||
|
assert query["state"] == ["state-xyz"]
|
||||||
|
assert "nonce" in query
|
||||||
|
assert "code_challenge" not in query
|
||||||
|
assert "code_challenge_method" not in query
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_id_token_returns_none_when_kid_missing(hass: HomeAssistant):
|
||||||
|
"""ID token without kid should be rejected."""
|
||||||
|
client = make_client(hass)
|
||||||
|
client.discovery_document = {
|
||||||
|
"issuer": "https://issuer",
|
||||||
|
"jwks_uri": "https://issuer/jwks",
|
||||||
|
}
|
||||||
|
|
||||||
|
token = make_jwt({"alg": "RS256"})
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_jwks",
|
||||||
|
new=AsyncMock(return_value={"keys": []}),
|
||||||
|
):
|
||||||
|
parsed = await client._parse_id_token(token)
|
||||||
|
|
||||||
|
assert parsed is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_id_token_returns_none_when_kid_not_found(hass: HomeAssistant):
|
||||||
|
"""ID token with unknown kid should be rejected."""
|
||||||
|
client = make_client(hass)
|
||||||
|
client.discovery_document = {
|
||||||
|
"issuer": "https://issuer",
|
||||||
|
"jwks_uri": "https://issuer/jwks",
|
||||||
|
}
|
||||||
|
|
||||||
|
token = make_jwt({"alg": "RS256", "kid": "missing"})
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_jwks",
|
||||||
|
new=AsyncMock(return_value={"keys": [{"kid": "other"}]}),
|
||||||
|
):
|
||||||
|
parsed = await client._parse_id_token(token)
|
||||||
|
|
||||||
|
assert parsed is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_id_token_rejects_hs_without_client_secret(hass: HomeAssistant):
|
||||||
|
"""HMAC-signed id_token requires client_secret and must fail otherwise."""
|
||||||
|
client = make_client(hass, id_token_signing_alg="HS256")
|
||||||
|
client.discovery_document = {
|
||||||
|
"issuer": "https://issuer",
|
||||||
|
"jwks_uri": "https://issuer/jwks",
|
||||||
|
}
|
||||||
|
|
||||||
|
token = make_jwt({"alg": "HS256"})
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_jwks",
|
||||||
|
new=AsyncMock(return_value={"keys": []}),
|
||||||
|
):
|
||||||
|
with pytest.raises(OIDCIdTokenSigningAlgorithmInvalid):
|
||||||
|
await client._parse_id_token(token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_id_token_returns_none_when_decode_fails_jose(hass: HomeAssistant):
|
||||||
|
"""Jose decode/verification failures should be handled without raising to callers."""
|
||||||
|
client = make_client(hass)
|
||||||
|
client.discovery_document = {
|
||||||
|
"issuer": "https://issuer",
|
||||||
|
"jwks_uri": "https://issuer/jwks",
|
||||||
|
}
|
||||||
|
|
||||||
|
token = make_jwt({"alg": "RS256", "kid": "kid1"})
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_jwks",
|
||||||
|
new=AsyncMock(return_value={"keys": [{"kid": "kid1", "kty": "RSA"}]}),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.jwk.import_key",
|
||||||
|
return_value=object(),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.jwt.decode",
|
||||||
|
side_effect=joserfc_errors.JoseError("bad token"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
parsed = await client._parse_id_token(token)
|
||||||
|
|
||||||
|
assert parsed is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_id_token_rejects_wrong_signing_algorithm(hass: HomeAssistant):
|
||||||
|
"""ID token signed with unexpected alg should be rejected."""
|
||||||
|
client = make_client(hass, id_token_signing_alg="RS256")
|
||||||
|
client.discovery_document = {
|
||||||
|
"issuer": "https://issuer",
|
||||||
|
"jwks_uri": "https://issuer/jwks",
|
||||||
|
}
|
||||||
|
|
||||||
|
token = make_jwt({"alg": "HS256"})
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_jwks",
|
||||||
|
new=AsyncMock(return_value={"keys": []}),
|
||||||
|
):
|
||||||
|
with pytest.raises(OIDCIdTokenSigningAlgorithmInvalid):
|
||||||
|
await client._parse_id_token(token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_id_token_rejects_missing_header(hass: HomeAssistant):
|
||||||
|
"""ID token without protected header should be rejected."""
|
||||||
|
client = make_client(hass)
|
||||||
|
client.discovery_document = {
|
||||||
|
"issuer": "https://issuer",
|
||||||
|
"jwks_uri": "https://issuer/jwks",
|
||||||
|
}
|
||||||
|
|
||||||
|
token = make_jwt(None)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_jwks",
|
||||||
|
new=AsyncMock(return_value={"keys": []}),
|
||||||
|
):
|
||||||
|
parsed = await client._parse_id_token(token)
|
||||||
|
|
||||||
|
assert parsed is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_parse_id_token_rejects_invalid_registered_claims(hass: HomeAssistant):
|
||||||
|
"""Invalid aud/iss/sub style claim validation should fail closed."""
|
||||||
|
hs_secret = "top-secret-value"
|
||||||
|
|
||||||
|
client = make_client(
|
||||||
|
hass,
|
||||||
|
id_token_signing_alg="HS256",
|
||||||
|
client_secret=hs_secret,
|
||||||
|
)
|
||||||
|
client.discovery_document = {
|
||||||
|
"issuer": "https://issuer",
|
||||||
|
"jwks_uri": "https://issuer/jwks",
|
||||||
|
}
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
token = make_signed_hs256_jwt(
|
||||||
|
hs_secret,
|
||||||
|
{
|
||||||
|
"sub": "abc",
|
||||||
|
"aud": "wrong-audience",
|
||||||
|
"iss": "https://wrong-issuer",
|
||||||
|
"nbf": now,
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + 3600,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_jwks",
|
||||||
|
new=AsyncMock(return_value={"keys": []}),
|
||||||
|
):
|
||||||
|
parsed = await client._parse_id_token(token)
|
||||||
|
|
||||||
|
assert parsed is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("algorithm", ALL_ID_TOKEN_SIGNING_ALGORITHMS)
|
||||||
|
async def test_parse_id_token_validates_real_signed_tokens_and_decode_inputs(
|
||||||
|
hass: HomeAssistant, algorithm: str
|
||||||
|
):
|
||||||
|
"""Use real signatures and verify token/key/algorithm passed into joserfc."""
|
||||||
|
secret = "top-secret-value"
|
||||||
|
client_kwargs = {"id_token_signing_alg": algorithm}
|
||||||
|
if algorithm.startswith("HS"):
|
||||||
|
client_kwargs["client_secret"] = secret
|
||||||
|
|
||||||
|
client = make_client(hass, **client_kwargs)
|
||||||
|
client.discovery_document = {
|
||||||
|
"issuer": "https://issuer",
|
||||||
|
"jwks_uri": "https://issuer/jwks",
|
||||||
|
}
|
||||||
|
|
||||||
|
now = int(time.time())
|
||||||
|
claims = {
|
||||||
|
"sub": "subject-1",
|
||||||
|
"aud": "test-client",
|
||||||
|
"iss": "https://issuer",
|
||||||
|
"nbf": now,
|
||||||
|
"iat": now,
|
||||||
|
"exp": now + 3600,
|
||||||
|
}
|
||||||
|
|
||||||
|
token, jwks_payload = build_real_signed_token(algorithm, claims, secret)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(client, "_fetch_jwks", new=AsyncMock(return_value=jwks_payload)),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.jwt.decode",
|
||||||
|
wraps=jwt.decode,
|
||||||
|
) as decode_spy,
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.jwk.import_key",
|
||||||
|
wraps=jwk.import_key,
|
||||||
|
) as import_key_spy,
|
||||||
|
):
|
||||||
|
parsed = await client._parse_id_token(token)
|
||||||
|
|
||||||
|
assert parsed == claims
|
||||||
|
decode_spy.assert_called_once()
|
||||||
|
assert decode_spy.call_args.args[0] == token
|
||||||
|
assert decode_spy.call_args.kwargs["algorithms"] == [algorithm]
|
||||||
|
|
||||||
|
import_key_spy.assert_called()
|
||||||
|
imported_key_payload = import_key_spy.call_args.args[0]
|
||||||
|
assert imported_key_payload["alg"] == algorithm
|
||||||
|
if algorithm.startswith("HS"):
|
||||||
|
assert imported_key_payload["kty"] == "oct"
|
||||||
|
else:
|
||||||
|
assert imported_key_payload["kid"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_authorization_url_returns_none_when_discovery_fails(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
"""Discovery failures should return None from authorization URL generation."""
|
||||||
|
client = make_client(hass)
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_discovery_document",
|
||||||
|
new=AsyncMock(side_effect=OIDCDiscoveryInvalid()),
|
||||||
|
):
|
||||||
|
url = await client.async_get_authorization_url(
|
||||||
|
"https://example.com/callback", "state-1"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert url is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_token_flow_omits_code_verifier_when_pkce_disabled(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
"""When PKCE is disabled, token request should omit code_verifier."""
|
||||||
|
client = make_client(hass, features={"disable_rfc7636": True})
|
||||||
|
client.flows["state-3"] = {"code_verifier": "verifier", "nonce": "nonce"}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_fetch_discovery_document",
|
||||||
|
new=AsyncMock(return_value={"token_endpoint": "https://issuer/token"}),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_make_token_request",
|
||||||
|
new=AsyncMock(return_value={"id_token": "id", "access_token": "access"}),
|
||||||
|
) as make_token_request,
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"_parse_id_token",
|
||||||
|
new=AsyncMock(return_value={"sub": "abc", "nonce": "nonce"}),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
client,
|
||||||
|
"parse_user_details",
|
||||||
|
new=AsyncMock(
|
||||||
|
return_value={
|
||||||
|
"sub": "abc",
|
||||||
|
"display_name": "n",
|
||||||
|
"username": "u",
|
||||||
|
"role": "system-users",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
result = await client.async_complete_token_flow(
|
||||||
|
"https://example.com/callback", "code", "state-3"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
token_params = make_token_request.await_args.args[1]
|
||||||
|
assert "code_verifier" not in token_params
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_raise_for_status_noop_on_ok_response():
|
||||||
|
"""Status helper should not raise for successful responses."""
|
||||||
|
response = MagicMock()
|
||||||
|
response.ok = True
|
||||||
|
|
||||||
|
await http_raise_for_status(response)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_http_raise_for_status_raises_http_client_error_with_body():
|
||||||
|
"""Status helper should include response body in raised exception."""
|
||||||
|
response = MagicMock()
|
||||||
|
response.ok = False
|
||||||
|
response.reason = "Bad Request"
|
||||||
|
response.status = 400
|
||||||
|
response.request_info = MagicMock()
|
||||||
|
response.history = ()
|
||||||
|
response.headers = {}
|
||||||
|
response.text = AsyncMock(return_value="problem details")
|
||||||
|
|
||||||
|
with pytest.raises(HTTPClientError) as exc_info:
|
||||||
|
await http_raise_for_status(response)
|
||||||
|
|
||||||
|
assert "400 (Bad Request)" in str(exc_info.value)
|
||||||
|
assert "problem details" in str(exc_info.value)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_http_session_reuses_existing_session(hass: HomeAssistant):
|
||||||
|
"""Session helper should return existing session when already created."""
|
||||||
|
client = make_client(hass)
|
||||||
|
existing_session = MagicMock()
|
||||||
|
client.http_session = existing_session
|
||||||
|
|
||||||
|
session = await client._get_http_session()
|
||||||
|
|
||||||
|
assert session is existing_session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_http_session_applies_tls_verify_flag(hass: HomeAssistant):
|
||||||
|
"""Session helper should pass tls_verify setting into TCP connector."""
|
||||||
|
client = make_client(hass, network={"tls_verify": False})
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.aiohttp.TCPConnector",
|
||||||
|
return_value=MagicMock(),
|
||||||
|
) as tcp_connector,
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.aiohttp.ClientSession",
|
||||||
|
return_value=MagicMock(),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await client._get_http_session()
|
||||||
|
|
||||||
|
tcp_connector.assert_called_once_with(verify_ssl=False)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_http_session_uses_custom_ca_path(hass: HomeAssistant):
|
||||||
|
"""Session helper should create SSL context when custom CA path is configured."""
|
||||||
|
client = make_client(
|
||||||
|
hass,
|
||||||
|
network={"tls_verify": True, "tls_ca_path": "/tmp/test-ca.pem"},
|
||||||
|
)
|
||||||
|
fake_ssl_context = object()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
hass.loop,
|
||||||
|
"run_in_executor",
|
||||||
|
new=AsyncMock(return_value=fake_ssl_context),
|
||||||
|
) as run_in_executor,
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.aiohttp.TCPConnector",
|
||||||
|
return_value=MagicMock(),
|
||||||
|
) as tcp_connector,
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.aiohttp.ClientSession",
|
||||||
|
return_value=MagicMock(),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
await client._get_http_session()
|
||||||
|
|
||||||
|
run_in_executor.assert_awaited_once()
|
||||||
|
tcp_connector.assert_called_once_with(verify_ssl=True, ssl=fake_ssl_context)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_make_token_request_returns_json_on_success(hass: HomeAssistant):
|
||||||
|
"""Token request helper should return JSON payload for successful responses."""
|
||||||
|
client = make_client(hass)
|
||||||
|
response = MagicMock()
|
||||||
|
response.ok = True
|
||||||
|
response.json = AsyncMock(return_value={"access_token": "token"})
|
||||||
|
|
||||||
|
context_manager = AsyncMock()
|
||||||
|
context_manager.__aenter__.return_value = response
|
||||||
|
session = MagicMock()
|
||||||
|
session.post.return_value = context_manager
|
||||||
|
|
||||||
|
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
|
||||||
|
payload = await client._make_token_request(
|
||||||
|
"https://issuer/token", {"code": "abc"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert payload == {"access_token": "token"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_make_token_request_raises_invalid_on_non_400_http_error(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
"""Token request helper should map upstream HTTP errors to OIDCTokenResponseInvalid."""
|
||||||
|
client = make_client(hass)
|
||||||
|
response = MagicMock()
|
||||||
|
response.ok = False
|
||||||
|
response.reason = "Server Error"
|
||||||
|
response.status = 500
|
||||||
|
response.request_info = MagicMock()
|
||||||
|
response.history = ()
|
||||||
|
response.headers = {}
|
||||||
|
response.text = AsyncMock(return_value="boom")
|
||||||
|
|
||||||
|
context_manager = AsyncMock()
|
||||||
|
context_manager.__aenter__.return_value = response
|
||||||
|
session = MagicMock()
|
||||||
|
session.post.return_value = context_manager
|
||||||
|
|
||||||
|
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
|
||||||
|
with pytest.raises(OIDCTokenResponseInvalid):
|
||||||
|
await client._make_token_request("https://issuer/token", {"code": "abc"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_userinfo_returns_json_on_success(hass: HomeAssistant):
|
||||||
|
"""Userinfo helper should return JSON payload for successful responses."""
|
||||||
|
client = make_client(hass)
|
||||||
|
response = MagicMock()
|
||||||
|
response.ok = True
|
||||||
|
response.json = AsyncMock(return_value={"sub": "abc"})
|
||||||
|
|
||||||
|
context_manager = AsyncMock()
|
||||||
|
context_manager.__aenter__.return_value = response
|
||||||
|
session = MagicMock()
|
||||||
|
session.get.return_value = context_manager
|
||||||
|
|
||||||
|
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
|
||||||
|
payload = await client._get_userinfo("https://issuer/userinfo", "access")
|
||||||
|
|
||||||
|
assert payload == {"sub": "abc"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_userinfo_raises_invalid_on_http_error(hass: HomeAssistant):
|
||||||
|
"""Userinfo helper should map upstream HTTP errors to OIDCUserinfoInvalid."""
|
||||||
|
client = make_client(hass)
|
||||||
|
response = MagicMock()
|
||||||
|
response.ok = False
|
||||||
|
response.reason = "Unavailable"
|
||||||
|
response.status = 503
|
||||||
|
response.request_info = MagicMock()
|
||||||
|
response.history = ()
|
||||||
|
response.headers = {}
|
||||||
|
response.text = AsyncMock(return_value="oops")
|
||||||
|
|
||||||
|
context_manager = AsyncMock()
|
||||||
|
context_manager.__aenter__.return_value = response
|
||||||
|
session = MagicMock()
|
||||||
|
session.get.return_value = context_manager
|
||||||
|
|
||||||
|
with patch.object(client, "_get_http_session", new=AsyncMock(return_value=session)):
|
||||||
|
with pytest.raises(OIDCUserinfoInvalid):
|
||||||
|
await client._get_userinfo("https://issuer/userinfo", "access")
|
||||||
649
tests/test_hass_ui_config_flow.py
Normal file
649
tests/test_hass_ui_config_flow.py
Normal file
@@ -0,0 +1,649 @@
|
|||||||
|
"""Tests for the UI config flow"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant import config_entries
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.data_entry_flow import FlowResultType
|
||||||
|
|
||||||
|
|
||||||
|
from custom_components.auth_oidc import DOMAIN
|
||||||
|
from custom_components.auth_oidc.config.const import (
|
||||||
|
OIDC_PROVIDERS,
|
||||||
|
CLIENT_ID,
|
||||||
|
CLIENT_SECRET,
|
||||||
|
DISCOVERY_URL,
|
||||||
|
DISPLAY_NAME,
|
||||||
|
FEATURES,
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING,
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION,
|
||||||
|
FEATURES_INCLUDE_GROUPS_SCOPE,
|
||||||
|
CLAIMS,
|
||||||
|
CLAIMS_DISPLAY_NAME,
|
||||||
|
CLAIMS_GROUPS,
|
||||||
|
CLAIMS_USERNAME,
|
||||||
|
ROLES,
|
||||||
|
ROLE_ADMINS,
|
||||||
|
ROLE_USERS,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .mocks.oidc_server import MockOIDCServer, mock_oidc_responses
|
||||||
|
|
||||||
|
DEMO_CLIENT_ID = "testing_example_client_id"
|
||||||
|
DEMO_CLIENT_SECRET = "faz"
|
||||||
|
DEMO_ADMIN_ROLE = "boo"
|
||||||
|
DEMO_USER_ROLE = "far"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_full_config_flow_success(hass: HomeAssistant):
|
||||||
|
"""Test a successful full config flow."""
|
||||||
|
|
||||||
|
with mock_oidc_responses():
|
||||||
|
# 1. Start the user step
|
||||||
|
# This simulates clicking "Add Integration" in the UI.
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that it's a form and expects user input for the 'user' step
|
||||||
|
# 'user' is always the first step if it is user triggered
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "user"
|
||||||
|
assert result["data_schema"] is not None
|
||||||
|
schema = result["data_schema"]
|
||||||
|
# Extract the schema dict from voluptuous Schema
|
||||||
|
schema_dict = schema.schema
|
||||||
|
# Assert 'provider' is a key in the schema
|
||||||
|
assert "provider" in schema_dict
|
||||||
|
# Assert 'authentik' is one of the allowed values for 'provider'
|
||||||
|
provider_field = schema_dict["provider"]
|
||||||
|
# If provider_field is a voluptuous In validator, get its container
|
||||||
|
allowed_providers = getattr(provider_field, "container", None)
|
||||||
|
assert "authentik" in OIDC_PROVIDERS
|
||||||
|
assert allowed_providers is not None and "authentik" in allowed_providers
|
||||||
|
|
||||||
|
assert result["errors"] == {}
|
||||||
|
|
||||||
|
# 2. Submit user input for the 'user' step
|
||||||
|
# This simulates the user filling out host/port
|
||||||
|
user_input_step_user = {"provider": "authentik"}
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], user_input_step_user
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that it proceeds to the 'auth' step
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "discovery_url"
|
||||||
|
assert result["data_schema"] is not None
|
||||||
|
assert result["errors"] == {}
|
||||||
|
|
||||||
|
# Fill in the discovery URL
|
||||||
|
user_input_step_discovery = {
|
||||||
|
"discovery_url": MockOIDCServer.get_discovery_url()
|
||||||
|
}
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], user_input_step_discovery
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that it proceeds to the 'credentials' step
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "validate_connection"
|
||||||
|
|
||||||
|
# Assert that it validates correctly with our mock
|
||||||
|
assert result["errors"] == {}
|
||||||
|
|
||||||
|
# Send in continue
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"action": "continue"}
|
||||||
|
)
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "client_config"
|
||||||
|
assert result["data_schema"] is not None
|
||||||
|
assert result["errors"] == {}
|
||||||
|
|
||||||
|
# Fill in the client config
|
||||||
|
user_input_step_client_config = {
|
||||||
|
"client_id": DEMO_CLIENT_ID,
|
||||||
|
"client_secret": DEMO_CLIENT_SECRET,
|
||||||
|
}
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], user_input_step_client_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that we are at groups_config
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "groups_config"
|
||||||
|
assert result["data_schema"] is not None
|
||||||
|
assert result["errors"] == {}
|
||||||
|
|
||||||
|
# Fill in the groups config
|
||||||
|
user_input_step_groups_config = {
|
||||||
|
"admin_group": DEMO_ADMIN_ROLE,
|
||||||
|
"user_group": DEMO_USER_ROLE,
|
||||||
|
}
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], user_input_step_groups_config
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert that were are at user_linking config
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "user_linking"
|
||||||
|
assert result["data_schema"] is not None
|
||||||
|
assert result["errors"] == {}
|
||||||
|
|
||||||
|
# Fill in the user linking config
|
||||||
|
user_input_step_user_linking = {"enable_user_linking": False}
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], user_input_step_user_linking
|
||||||
|
)
|
||||||
|
|
||||||
|
# Finally, assert that the flow is complete and a config entry is created
|
||||||
|
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||||
|
assert result["title"] == OIDC_PROVIDERS["authentik"]["name"]
|
||||||
|
|
||||||
|
expected_data = {
|
||||||
|
"provider": "authentik",
|
||||||
|
CLIENT_ID: DEMO_CLIENT_ID,
|
||||||
|
CLIENT_SECRET: DEMO_CLIENT_SECRET,
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"],
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||||
|
FEATURES_INCLUDE_GROUPS_SCOPE: True,
|
||||||
|
},
|
||||||
|
CLAIMS: {
|
||||||
|
CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"][
|
||||||
|
"display_name"
|
||||||
|
],
|
||||||
|
CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"],
|
||||||
|
CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"],
|
||||||
|
},
|
||||||
|
ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE},
|
||||||
|
}
|
||||||
|
|
||||||
|
assert result["data"] == expected_data
|
||||||
|
|
||||||
|
# Verify that the config entry was loaded into Home Assistant
|
||||||
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0].data == expected_data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_options_flow_success(hass: HomeAssistant):
|
||||||
|
"""Test a successful options flow."""
|
||||||
|
|
||||||
|
# First, set up an initial config entry as in the full config flow
|
||||||
|
initial_data = {
|
||||||
|
"provider": "authentik",
|
||||||
|
CLIENT_ID: DEMO_CLIENT_ID,
|
||||||
|
CLIENT_SECRET: DEMO_CLIENT_SECRET,
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"],
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||||
|
FEATURES_INCLUDE_GROUPS_SCOPE: True,
|
||||||
|
},
|
||||||
|
CLAIMS: {
|
||||||
|
CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"]["display_name"],
|
||||||
|
CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"],
|
||||||
|
CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"],
|
||||||
|
},
|
||||||
|
ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE},
|
||||||
|
}
|
||||||
|
|
||||||
|
entry = config_entries.ConfigEntry(
|
||||||
|
version=1,
|
||||||
|
minor_version=0,
|
||||||
|
domain=DOMAIN,
|
||||||
|
title=OIDC_PROVIDERS["authentik"]["name"],
|
||||||
|
data=initial_data,
|
||||||
|
source=config_entries.SOURCE_USER,
|
||||||
|
entry_id="1",
|
||||||
|
unique_id="test_unique_id",
|
||||||
|
options={},
|
||||||
|
pref_disable_new_entities=False,
|
||||||
|
pref_disable_polling=False,
|
||||||
|
discovery_keys=[],
|
||||||
|
subentries_data=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.config_entries.async_add(entry)
|
||||||
|
|
||||||
|
# Start the reconfigure flow
|
||||||
|
result = await hass.config_entries.options.async_init(entry.entry_id)
|
||||||
|
|
||||||
|
# Should start the options flow
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "init"
|
||||||
|
assert result["data_schema"] is not None
|
||||||
|
|
||||||
|
# Assert that the schema is as expected
|
||||||
|
# Schema contains enable_user_linking, enable_groups, admin_group & user_groups and no other keys
|
||||||
|
schema = result["data_schema"]
|
||||||
|
schema_dict = schema.schema
|
||||||
|
# Assert that the schema contains the expected keys
|
||||||
|
expected_keys = {
|
||||||
|
"admin_group",
|
||||||
|
"enable_user_linking",
|
||||||
|
"enable_groups",
|
||||||
|
"user_group",
|
||||||
|
}
|
||||||
|
assert set(schema_dict.keys()) == expected_keys
|
||||||
|
|
||||||
|
# Change the client_id and client_secret
|
||||||
|
new_enable_linking = True
|
||||||
|
new_enable_groups = True
|
||||||
|
new_admin_group = "bazzbbb"
|
||||||
|
new_user_group = "foobar"
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{
|
||||||
|
"enable_user_linking": new_enable_linking,
|
||||||
|
"enable_groups": new_enable_groups,
|
||||||
|
"admin_group": new_admin_group,
|
||||||
|
"user_group": new_user_group,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should finish and update the entry options
|
||||||
|
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||||
|
|
||||||
|
# Optionally, check that the entry options are updated
|
||||||
|
updated_entry = hass.config_entries.async_get_entry(entry.entry_id)
|
||||||
|
assert updated_entry is not None
|
||||||
|
|
||||||
|
# Verify that the config entry was loaded into Home Assistant
|
||||||
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
|
assert len(entries) == 1
|
||||||
|
|
||||||
|
assert (
|
||||||
|
entries[0].data[FEATURES][FEATURES_AUTOMATIC_USER_LINKING] == new_enable_linking
|
||||||
|
)
|
||||||
|
assert entries[0].data[FEATURES][FEATURES_INCLUDE_GROUPS_SCOPE] == new_enable_groups
|
||||||
|
assert entries[0].data[ROLES][ROLE_ADMINS] == new_admin_group
|
||||||
|
assert entries[0].data[ROLES][ROLE_USERS] == new_user_group
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconfigure_flow_success(hass: HomeAssistant):
|
||||||
|
"""Test a successful reconfigure flow."""
|
||||||
|
|
||||||
|
# First, set up an initial config entry as in the full config flow
|
||||||
|
initial_data = {
|
||||||
|
"provider": "authentik",
|
||||||
|
CLIENT_ID: DEMO_CLIENT_ID,
|
||||||
|
CLIENT_SECRET: DEMO_CLIENT_SECRET,
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"],
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||||
|
FEATURES_INCLUDE_GROUPS_SCOPE: True,
|
||||||
|
},
|
||||||
|
CLAIMS: {
|
||||||
|
CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"]["display_name"],
|
||||||
|
CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"],
|
||||||
|
CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"],
|
||||||
|
},
|
||||||
|
ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE},
|
||||||
|
}
|
||||||
|
|
||||||
|
entry = config_entries.ConfigEntry(
|
||||||
|
version=1,
|
||||||
|
minor_version=0,
|
||||||
|
domain=DOMAIN,
|
||||||
|
title=OIDC_PROVIDERS["authentik"]["name"],
|
||||||
|
data=initial_data,
|
||||||
|
source=config_entries.SOURCE_USER,
|
||||||
|
entry_id="1",
|
||||||
|
unique_id="test_unique_id",
|
||||||
|
options={},
|
||||||
|
pref_disable_new_entities=False,
|
||||||
|
pref_disable_polling=False,
|
||||||
|
discovery_keys=[],
|
||||||
|
subentries_data=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.config_entries.async_add(entry)
|
||||||
|
|
||||||
|
# Start async_step_reconfigure to reconfigure the entry
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
context={
|
||||||
|
"source": config_entries.SOURCE_RECONFIGURE,
|
||||||
|
"entry_id": entry.entry_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should start the reconfigure flow
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "reconfigure"
|
||||||
|
assert result["data_schema"] is not None
|
||||||
|
|
||||||
|
# Assert that the schema is client_id & client_secret
|
||||||
|
schema = result["data_schema"]
|
||||||
|
schema_dict = schema.schema
|
||||||
|
# Assert that the schema contains the expected keys
|
||||||
|
expected_keys = {
|
||||||
|
"client_id",
|
||||||
|
"client_secret",
|
||||||
|
}
|
||||||
|
assert set(schema_dict.keys()) == expected_keys
|
||||||
|
|
||||||
|
# Change the client_id and client_secret
|
||||||
|
new_client_id = "newclientid"
|
||||||
|
new_client_secret = "newclientsecret"
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{
|
||||||
|
"client_id": new_client_id,
|
||||||
|
"client_secret": new_client_secret,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should finish and update the entry data
|
||||||
|
assert result["type"] == FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "reconfigure_successful"
|
||||||
|
|
||||||
|
# Verify that the config entry was loaded into Home Assistant
|
||||||
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
|
assert len(entries) == 1
|
||||||
|
assert entries[0].data[CLIENT_ID] == new_client_id
|
||||||
|
assert entries[0].data[CLIENT_SECRET] == new_client_secret
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconfigure_flow_rejects_invalid_client_id(hass: HomeAssistant):
|
||||||
|
"""Reconfigure should keep the form open when the client ID is invalid."""
|
||||||
|
initial_data = {
|
||||||
|
"provider": "authentik",
|
||||||
|
CLIENT_ID: DEMO_CLIENT_ID,
|
||||||
|
CLIENT_SECRET: DEMO_CLIENT_SECRET,
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"],
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||||
|
FEATURES_INCLUDE_GROUPS_SCOPE: True,
|
||||||
|
},
|
||||||
|
CLAIMS: {
|
||||||
|
CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"]["display_name"],
|
||||||
|
CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"],
|
||||||
|
CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"],
|
||||||
|
},
|
||||||
|
ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE},
|
||||||
|
}
|
||||||
|
|
||||||
|
entry = config_entries.ConfigEntry(
|
||||||
|
version=1,
|
||||||
|
minor_version=0,
|
||||||
|
domain=DOMAIN,
|
||||||
|
title=OIDC_PROVIDERS["authentik"]["name"],
|
||||||
|
data=initial_data,
|
||||||
|
source=config_entries.SOURCE_USER,
|
||||||
|
entry_id="1",
|
||||||
|
unique_id="test_unique_id",
|
||||||
|
options={},
|
||||||
|
pref_disable_new_entities=False,
|
||||||
|
pref_disable_polling=False,
|
||||||
|
discovery_keys=[],
|
||||||
|
subentries_data=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.config_entries.async_add(entry)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
context={
|
||||||
|
"source": config_entries.SOURCE_RECONFIGURE,
|
||||||
|
"entry_id": entry.entry_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{"client_id": " ", "client_secret": DEMO_CLIENT_SECRET},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "reconfigure"
|
||||||
|
assert result["errors"]["client_id"] == "invalid_client_id"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reconfigure_flow_keeps_client_secret_when_blank(hass: HomeAssistant):
|
||||||
|
"""Submitting a blank secret should keep the existing client secret."""
|
||||||
|
initial_data = {
|
||||||
|
"provider": "authentik",
|
||||||
|
CLIENT_ID: DEMO_CLIENT_ID,
|
||||||
|
CLIENT_SECRET: DEMO_CLIENT_SECRET,
|
||||||
|
DISCOVERY_URL: MockOIDCServer.get_discovery_url(),
|
||||||
|
DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["name"],
|
||||||
|
FEATURES: {
|
||||||
|
FEATURES_AUTOMATIC_USER_LINKING: False,
|
||||||
|
FEATURES_AUTOMATIC_PERSON_CREATION: True,
|
||||||
|
FEATURES_INCLUDE_GROUPS_SCOPE: True,
|
||||||
|
},
|
||||||
|
CLAIMS: {
|
||||||
|
CLAIMS_DISPLAY_NAME: OIDC_PROVIDERS["authentik"]["claims"]["display_name"],
|
||||||
|
CLAIMS_USERNAME: OIDC_PROVIDERS["authentik"]["claims"]["username"],
|
||||||
|
CLAIMS_GROUPS: OIDC_PROVIDERS["authentik"]["claims"]["groups"],
|
||||||
|
},
|
||||||
|
ROLES: {ROLE_ADMINS: DEMO_ADMIN_ROLE, ROLE_USERS: DEMO_USER_ROLE},
|
||||||
|
}
|
||||||
|
|
||||||
|
entry = config_entries.ConfigEntry(
|
||||||
|
version=1,
|
||||||
|
minor_version=0,
|
||||||
|
domain=DOMAIN,
|
||||||
|
title=OIDC_PROVIDERS["authentik"]["name"],
|
||||||
|
data=initial_data,
|
||||||
|
source=config_entries.SOURCE_USER,
|
||||||
|
entry_id="1",
|
||||||
|
unique_id="test_unique_id",
|
||||||
|
options={},
|
||||||
|
pref_disable_new_entities=False,
|
||||||
|
pref_disable_polling=False,
|
||||||
|
discovery_keys=[],
|
||||||
|
subentries_data=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.config_entries.async_add(entry)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN,
|
||||||
|
context={
|
||||||
|
"source": config_entries.SOURCE_RECONFIGURE,
|
||||||
|
"entry_id": entry.entry_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{"client_id": DEMO_CLIENT_ID, "client_secret": ""},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "reconfigure_successful"
|
||||||
|
|
||||||
|
updated_entry = hass.config_entries.async_get_entry(entry.entry_id)
|
||||||
|
assert updated_entry is not None
|
||||||
|
assert updated_entry.data[CLIENT_SECRET] == DEMO_CLIENT_SECRET
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validation_actions_route_to_other_steps(hass: HomeAssistant):
|
||||||
|
"""Validation actions should route to the requested flow step."""
|
||||||
|
with mock_oidc_responses():
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"provider": "authentik"}
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{"discovery_url": MockOIDCServer.get_discovery_url()},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "validate_connection"
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"action": "fix_discovery"}
|
||||||
|
)
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "discovery_url"
|
||||||
|
|
||||||
|
with mock_oidc_responses():
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"provider": "authentik"}
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{"discovery_url": MockOIDCServer.get_discovery_url()},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "validate_connection"
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"action": "change_provider"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "user"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_flow_aborts_when_yaml_configured(hass: HomeAssistant):
|
||||||
|
"""The user flow should abort when YAML config already owns the provider."""
|
||||||
|
hass.data[DOMAIN] = {"yaml_config": {"client_id": DEMO_CLIENT_ID}}
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "yaml_configured"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_flow_aborts_when_entry_already_exists(hass: HomeAssistant):
|
||||||
|
"""The flow should not create a second OIDC config entry."""
|
||||||
|
entry = config_entries.ConfigEntry(
|
||||||
|
version=1,
|
||||||
|
minor_version=0,
|
||||||
|
domain=DOMAIN,
|
||||||
|
title=OIDC_PROVIDERS["authentik"]["name"],
|
||||||
|
data={"provider": "authentik"},
|
||||||
|
source=config_entries.SOURCE_USER,
|
||||||
|
entry_id="1",
|
||||||
|
unique_id="test_unique_id",
|
||||||
|
options={},
|
||||||
|
pref_disable_new_entities=False,
|
||||||
|
pref_disable_polling=False,
|
||||||
|
discovery_keys=[],
|
||||||
|
subentries_data=None,
|
||||||
|
)
|
||||||
|
await hass.config_entries.async_add(entry)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.ABORT
|
||||||
|
assert result["reason"] == "single_instance_allowed"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_discovery_url_validation_rejects_invalid_url(hass: HomeAssistant):
|
||||||
|
"""Discovery URL validation should reject malformed inputs."""
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"provider": "authentik"}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"discovery_url": "not-a-valid-oidc-url"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "discovery_url"
|
||||||
|
assert result["errors"]["discovery_url"] == "invalid_url_format"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_generic_provider_skips_groups_config(hass: HomeAssistant):
|
||||||
|
"""Providers without group support should go straight to user linking."""
|
||||||
|
with mock_oidc_responses():
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"provider": "generic"}
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{"discovery_url": MockOIDCServer.get_discovery_url()},
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"action": "continue"}
|
||||||
|
)
|
||||||
|
assert result["step_id"] == "client_config"
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{"client_id": DEMO_CLIENT_ID, "client_secret": DEMO_CLIENT_SECRET},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.FORM
|
||||||
|
assert result["step_id"] == "user_linking"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_groups_disabled_skips_roles_and_creates_entry(hass: HomeAssistant):
|
||||||
|
"""Disabling groups should skip role configuration and omit roles from entry data."""
|
||||||
|
with mock_oidc_responses():
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"provider": "authentik"}
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{"discovery_url": MockOIDCServer.get_discovery_url()},
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"action": "continue"}
|
||||||
|
)
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
{"client_id": DEMO_CLIENT_ID, "client_secret": DEMO_CLIENT_SECRET},
|
||||||
|
)
|
||||||
|
assert result["step_id"] == "groups_config"
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"enable_groups": False}
|
||||||
|
)
|
||||||
|
assert result["step_id"] == "user_linking"
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_configure(
|
||||||
|
result["flow_id"], {"enable_user_linking": True}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == FlowResultType.CREATE_ENTRY
|
||||||
|
assert "roles" not in result["data"]
|
||||||
|
assert result["data"][FEATURES][FEATURES_INCLUDE_GROUPS_SCOPE] is False
|
||||||
813
tests/test_hass_webserver.py
Normal file
813
tests/test_hass_webserver.py
Normal file
@@ -0,0 +1,813 @@
|
|||||||
|
"""Tests for the registered webpages"""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
from urllib.parse import parse_qs, quote, unquote, urlparse, urlencode
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from auth_oidc.config.const import DISCOVERY_URL, CLIENT_ID
|
||||||
|
|
||||||
|
from pytest_homeassistant_custom_component.typing import ClientSessionGenerator
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
from homeassistant.components.http import StaticPathConfig, DOMAIN as HTTP_DOMAIN
|
||||||
|
|
||||||
|
from custom_components.auth_oidc import DOMAIN
|
||||||
|
from custom_components.auth_oidc.endpoints.injected_auth_page import (
|
||||||
|
OIDCInjectedAuthPage,
|
||||||
|
frontend_injection,
|
||||||
|
)
|
||||||
|
|
||||||
|
MOBILE_CLIENT_ID = "https://home-assistant.io/Android"
|
||||||
|
|
||||||
|
|
||||||
|
def create_redirect_uri(client_id: str) -> str:
|
||||||
|
"""Build a redirect URI that includes a client_id query parameter."""
|
||||||
|
params = {
|
||||||
|
"response_type": "code",
|
||||||
|
"redirect_uri": client_id,
|
||||||
|
"client_id": client_id,
|
||||||
|
"state": "example",
|
||||||
|
}
|
||||||
|
|
||||||
|
return f"http://example.com/auth/authorize?{urlencode(params)}"
|
||||||
|
|
||||||
|
|
||||||
|
def encode_redirect_uri(redirect_uri: str) -> str:
|
||||||
|
"""Encode redirect_uri in the same way as frontend btoa()."""
|
||||||
|
return base64.b64encode(redirect_uri.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
):
|
||||||
|
mock_config = {
|
||||||
|
DOMAIN: {
|
||||||
|
CLIENT_ID: "dummy",
|
||||||
|
DISCOVERY_URL: "https://example.com/.well-known/openid-configuration",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await async_setup_component(hass, DOMAIN, mock_config)
|
||||||
|
assert result
|
||||||
|
|
||||||
|
|
||||||
|
async def setup_mock_authorize_route(hass: HomeAssistant) -> None:
|
||||||
|
"""Register a mock /auth/authorize page so frontend injection can hook into it."""
|
||||||
|
await async_setup_component(hass, HTTP_DOMAIN, {})
|
||||||
|
|
||||||
|
mock_html_path = os.path.join(os.path.dirname(__file__), "mocks", "auth_page.html")
|
||||||
|
await hass.http.async_register_static_paths(
|
||||||
|
[
|
||||||
|
StaticPathConfig(
|
||||||
|
"/auth/authorize",
|
||||||
|
mock_html_path,
|
||||||
|
cache_headers=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_welcome_page_registration(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Test that welcome page is present."""
|
||||||
|
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.get("/auth/oidc/welcome", allow_redirects=False)
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redirect_page_registration(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Test that redirect page can be reached."""
|
||||||
|
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||||
|
assert resp.status == 302
|
||||||
|
|
||||||
|
resp2 = await client.post("/auth/oidc/redirect", allow_redirects=False)
|
||||||
|
assert resp2.status == 302
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_welcome_rejects_invalid_encoded_redirect_uri(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Welcome should reject malformed base64 redirect_uri values."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.get(
|
||||||
|
"/auth/oidc/welcome?redirect_uri=%25%25%25",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
assert "Invalid redirect_uri, please restart login." in await resp.text()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"redirect_uri",
|
||||||
|
[
|
||||||
|
"http://example.com/auth/authorize?client_id=https://example.com",
|
||||||
|
"http://example.com/auth/authorize?redirect_uri=https://example.com",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_welcome_rejects_redirect_uris_missing_required_query_params(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator, redirect_uri: str
|
||||||
|
):
|
||||||
|
"""Welcome should reject redirect URIs that decode but are incomplete."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == 400
|
||||||
|
assert "Invalid redirect_uri, please restart login." in await resp.text()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("client_id", "should_store_token", "is_mobile"),
|
||||||
|
[
|
||||||
|
("", True, False),
|
||||||
|
(MOBILE_CLIENT_ID, False, True),
|
||||||
|
("https://random.example", False, False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_welcome_only_adds_store_token_for_web_clients(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
client_id: str,
|
||||||
|
should_store_token: bool,
|
||||||
|
is_mobile: bool,
|
||||||
|
):
|
||||||
|
"""Welcome should only append storeToken for clients aligned with the base URL."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
captured_redirect_uri = {}
|
||||||
|
|
||||||
|
async def fake_create_state(state_redirect_uri: str, *_args):
|
||||||
|
captured_redirect_uri["value"] = state_redirect_uri
|
||||||
|
return "state-id"
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_create_state",
|
||||||
|
new=AsyncMock(side_effect=fake_create_state),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_generate_device_code",
|
||||||
|
new=AsyncMock(return_value="123456"),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
client = await hass_client()
|
||||||
|
|
||||||
|
if client_id == "":
|
||||||
|
# If not present, set it to the root URL to
|
||||||
|
# emulate the normal website/Lovelace/dashboard
|
||||||
|
client_id = str(client.make_url("/?test=true"))
|
||||||
|
|
||||||
|
redirect_uri = create_redirect_uri(client_id)
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status in (200, 302)
|
||||||
|
assert "value" in captured_redirect_uri
|
||||||
|
|
||||||
|
parsed_state_redirect = urlparse(captured_redirect_uri["value"])
|
||||||
|
state_redirect_query = parse_qs(parsed_state_redirect.query)
|
||||||
|
nested_redirect_uri = unquote(state_redirect_query["redirect_uri"][0])
|
||||||
|
|
||||||
|
if should_store_token:
|
||||||
|
assert "storeToken=true" in nested_redirect_uri
|
||||||
|
else:
|
||||||
|
assert "storeToken=true" not in nested_redirect_uri
|
||||||
|
|
||||||
|
if is_mobile:
|
||||||
|
assert "https://home-assistant.io/" in nested_redirect_uri
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_welcome_sets_secure_state_cookie_flags(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Welcome should set secure cookie flags for the OIDC state cookie."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status in (200, 302)
|
||||||
|
assert "auth_oidc_state" in resp.cookies
|
||||||
|
|
||||||
|
set_cookie = resp.headers.get("Set-Cookie", "")
|
||||||
|
assert "Path=/auth/" in set_cookie
|
||||||
|
assert "SameSite=Lax" in set_cookie
|
||||||
|
assert "HttpOnly" in set_cookie
|
||||||
|
assert "Max-Age=300" in set_cookie
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_welcome_mobile_device_code_generation_failure(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Welcome should error if device code generation fails for mobile clients."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_generate_device_code",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
):
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 500
|
||||||
|
assert (
|
||||||
|
"Failed to generate device code, please restart login." in await resp.text()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_welcome_shows_alternative_sign_in_link_when_other_providers_exist(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Welcome should render fallback auth link when other providers are present."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
text = await resp.text()
|
||||||
|
assert 'id="login-button"' in text
|
||||||
|
assert 'id="alternative-sign-in-link"' in text
|
||||||
|
assert "skip_oidc_redirect=true" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_welcome_desktop_auto_redirects_without_other_providers(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Welcome should auto-redirect desktop clients when no other providers exist."""
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
hass.auth._providers = [] # Clear initial providers out
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 302
|
||||||
|
assert "/auth/oidc/redirect" in resp.headers["Location"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redirect_without_cookie_goes_to_welcome(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Redirect endpoint should bounce to welcome when no state cookie exists."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||||
|
assert resp.status == 302
|
||||||
|
assert "/auth/oidc/welcome" in resp.headers["Location"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redirect_shows_error_on_oidc_runtime_error(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Redirect should show a configuration error when OIDC URL generation raises."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp_welcome.status in (200, 302)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.OIDCClient.async_get_authorization_url",
|
||||||
|
new=AsyncMock(side_effect=RuntimeError("broken discovery")),
|
||||||
|
):
|
||||||
|
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||||
|
assert resp.status == 500
|
||||||
|
assert (
|
||||||
|
"Integration is misconfigured, discovery could not be obtained."
|
||||||
|
in await resp.text()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_redirect_shows_error_when_auth_url_empty(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Redirect should show error page if OIDC returns no authorization URL."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp_welcome.status in (200, 302)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.OIDCClient.async_get_authorization_url",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
):
|
||||||
|
resp = await client.get("/auth/oidc/redirect", allow_redirects=False)
|
||||||
|
assert resp.status == 500
|
||||||
|
assert (
|
||||||
|
"Integration is misconfigured, discovery could not be obtained."
|
||||||
|
in await resp.text()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_registration(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Test that callback page is reachable."""
|
||||||
|
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.get("/auth/oidc/callback", allow_redirects=False)
|
||||||
|
assert resp.status == 400
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_rejects_missing_code_or_state(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Callback must reject requests missing either code or state."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
state = resp_welcome.cookies["auth_oidc_state"].value
|
||||||
|
|
||||||
|
resp_missing_code = await client.get(
|
||||||
|
f"/auth/oidc/callback?state={state}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp_missing_code.status == 400
|
||||||
|
assert "Missing code or state parameter." in await resp_missing_code.text()
|
||||||
|
|
||||||
|
resp_missing_state = await client.get(
|
||||||
|
"/auth/oidc/callback?code=testcode",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp_missing_state.status == 400
|
||||||
|
assert "Missing code or state parameter." in await resp_missing_state.text()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_rejects_state_mismatch(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Callback must reject state mismatch to protect against CSRF."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
state = resp_welcome.cookies["auth_oidc_state"].value
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/callback?code=testcode&state={state}-other",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
assert "State parameter does not match, possible CSRF attack." in await resp.text()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_rejects_when_user_details_fetch_fails(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Callback should error when token exchange/userinfo retrieval fails."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
state = resp_welcome.cookies["auth_oidc_state"].value
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.OIDCClient.async_complete_token_flow",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
):
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/callback?code=testcode&state={state}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 500
|
||||||
|
assert (
|
||||||
|
"Failed to get user details, see Home Assistant logs for more information."
|
||||||
|
in await resp.text()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_callback_rejects_invalid_role(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Callback should reject users marked with invalid role."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
state = resp_welcome.cookies["auth_oidc_state"].value
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.tools.oidc_client.OIDCClient.async_complete_token_flow",
|
||||||
|
new=AsyncMock(return_value={"sub": "abc", "role": "invalid"}),
|
||||||
|
):
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/oidc/callback?code=testcode&state={state}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 403
|
||||||
|
assert (
|
||||||
|
"User is not in the correct group to access Home Assistant"
|
||||||
|
in await resp.text()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("method", "data"),
|
||||||
|
[
|
||||||
|
("get", None),
|
||||||
|
("post", {}),
|
||||||
|
("post", {"device_code": "456888"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_finish_requires_state_cookie(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
method: str,
|
||||||
|
data: dict | None,
|
||||||
|
):
|
||||||
|
"""Finish endpoint should require the OIDC state cookie for both GET and POST."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
request = getattr(client, method)
|
||||||
|
if data is None:
|
||||||
|
resp = await request("/auth/oidc/finish", allow_redirects=False)
|
||||||
|
else:
|
||||||
|
resp = await request("/auth/oidc/finish", data=data, allow_redirects=False)
|
||||||
|
|
||||||
|
assert resp.status == 400
|
||||||
|
assert "Missing state cookie" in await resp.text()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finish_post_rejects_invalid_state(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Finish POST should error when the state cookie does not resolve to redirect_uri."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(client.make_url("/"))
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp_welcome.status in (200, 302)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_get_redirect_uri_for_state",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
):
|
||||||
|
resp = await client.post("/auth/oidc/finish", allow_redirects=False)
|
||||||
|
assert resp.status == 400
|
||||||
|
assert "Invalid state, please restart login." in await resp.text()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_device_sse_requires_state_cookie(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""SSE endpoint should reject requests without state cookie."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.get("/auth/oidc/device-sse", allow_redirects=False)
|
||||||
|
assert resp.status == 400
|
||||||
|
assert "Missing session cookie" in await resp.text()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_device_sse_emits_expired_for_unknown_state(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""SSE should emit expired when the state can no longer be resolved."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_get_redirect_uri_for_state",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
):
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp_welcome.status == 200
|
||||||
|
|
||||||
|
resp = await client.get("/auth/oidc/device-sse", allow_redirects=False)
|
||||||
|
assert resp.status == 200
|
||||||
|
payload = await resp.text()
|
||||||
|
assert "event: expired" in payload
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_device_sse_emits_timeout(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""SSE should emit timeout if the polling window is exceeded."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp_welcome.status == 200
|
||||||
|
|
||||||
|
fake_loop = MagicMock()
|
||||||
|
fake_loop.time.side_effect = [0, 301]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_get_redirect_uri_for_state",
|
||||||
|
new=AsyncMock(return_value=redirect_uri),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_is_state_ready",
|
||||||
|
new=AsyncMock(return_value=False),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.endpoints.device_sse.asyncio.get_running_loop",
|
||||||
|
return_value=fake_loop,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
resp = await client.get("/auth/oidc/device-sse", allow_redirects=False)
|
||||||
|
assert resp.status == 200
|
||||||
|
payload = await resp.text()
|
||||||
|
assert "event: timeout" in payload
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_device_sse_handles_runtime_error_and_returns_cleanly(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""SSE should swallow runtime errors from stream loop and finish response."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp_welcome.status == 200
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_get_redirect_uri_for_state",
|
||||||
|
new=AsyncMock(return_value=redirect_uri),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_is_state_ready",
|
||||||
|
new=AsyncMock(side_effect=RuntimeError("disconnect")),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
resp = await client.get("/auth/oidc/device-sse", allow_redirects=False)
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_device_sse_ignores_write_eof_connection_reset(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""SSE should ignore ConnectionResetError while closing the stream."""
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
redirect_uri = create_redirect_uri(MOBILE_CLIENT_ID)
|
||||||
|
encoded = encode_redirect_uri(redirect_uri)
|
||||||
|
resp_welcome = await client.get(
|
||||||
|
f"/auth/oidc/welcome?redirect_uri={encoded}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp_welcome.status == 200
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.provider.OpenIDAuthProvider.async_get_redirect_uri_for_state",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"custom_components.auth_oidc.endpoints.device_sse.web.StreamResponse.write_eof",
|
||||||
|
new=AsyncMock(side_effect=ConnectionResetError),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
resp = await client.get("/auth/oidc/device-sse", allow_redirects=False)
|
||||||
|
assert resp.status == 200
|
||||||
|
|
||||||
|
|
||||||
|
# Test the frontend injection
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_frontend_injection(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Test that frontend injection works."""
|
||||||
|
|
||||||
|
# Because there is no frontend in the test setup,
|
||||||
|
# we'll have to fake /auth/authorize for the changes to register.
|
||||||
|
await setup_mock_authorize_route(hass)
|
||||||
|
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
resp = await client.get("/auth/authorize", allow_redirects=False)
|
||||||
|
assert resp.status == 200 # 200 because there is no redirect_uri
|
||||||
|
text = await resp.text()
|
||||||
|
|
||||||
|
assert "<script src='/auth/oidc/static/injection.js" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_frontend_injection_logs_and_returns_when_route_handler_is_unexpected(
|
||||||
|
hass: HomeAssistant, caplog
|
||||||
|
):
|
||||||
|
"""frontend_injection should log and return if the GET handler shape is unexpected."""
|
||||||
|
|
||||||
|
await async_setup_component(hass, HTTP_DOMAIN, {})
|
||||||
|
|
||||||
|
class FakeRoute:
|
||||||
|
method = "GET"
|
||||||
|
handler = object()
|
||||||
|
|
||||||
|
class FakeResource:
|
||||||
|
canonical = "/auth/authorize"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.prefix = None
|
||||||
|
|
||||||
|
def add_prefix(self, prefix):
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter([FakeRoute()])
|
||||||
|
|
||||||
|
with patch.object(hass.http.app.router, "resources", return_value=[FakeResource()]):
|
||||||
|
await frontend_injection(hass, force_https=False)
|
||||||
|
|
||||||
|
assert "Unexpected route handler type" in caplog.text
|
||||||
|
assert (
|
||||||
|
"Failed to find GET route for /auth/authorize, cannot inject OIDC frontend code"
|
||||||
|
in caplog.text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_injected_auth_page_inject_logs_errors(hass: HomeAssistant, caplog):
|
||||||
|
"""OIDCInjectedAuthPage.inject should swallow unexpected injection errors."""
|
||||||
|
|
||||||
|
await async_setup_component(hass, HTTP_DOMAIN, {})
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.endpoints.injected_auth_page.frontend_injection",
|
||||||
|
side_effect=RuntimeError("boom"),
|
||||||
|
):
|
||||||
|
await OIDCInjectedAuthPage.inject(hass, force_https=False)
|
||||||
|
|
||||||
|
assert "Failed to inject OIDC auth page: boom" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_injected_auth_page_redirects_to_welcome_when_not_skipped(
|
||||||
|
hass: HomeAssistant, hass_client: ClientSessionGenerator
|
||||||
|
):
|
||||||
|
"""Injected auth page should redirect into OIDC when skip flags are absent."""
|
||||||
|
|
||||||
|
await setup_mock_authorize_route(hass)
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
encoded_redirect_uri = quote(create_redirect_uri(client.make_url("/")), safe="")
|
||||||
|
|
||||||
|
resp = await client.get(
|
||||||
|
f"/auth/authorize?redirect_uri={encoded_redirect_uri}",
|
||||||
|
allow_redirects=False,
|
||||||
|
)
|
||||||
|
assert resp.status == 302
|
||||||
|
|
||||||
|
location = resp.headers["Location"]
|
||||||
|
parsed_location = urlparse(location)
|
||||||
|
assert parsed_location.path == "/auth/oidc/welcome"
|
||||||
|
|
||||||
|
query = parse_qs(parsed_location.query)
|
||||||
|
assert "redirect_uri" in query
|
||||||
|
|
||||||
|
original_url = base64.b64decode(unquote(query["redirect_uri"][0]), validate=True)
|
||||||
|
original_url = original_url.decode("utf-8")
|
||||||
|
assert "/auth/authorize?redirect_uri=" in original_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"request_target",
|
||||||
|
[
|
||||||
|
"/auth/authorize?skip_oidc_redirect=true",
|
||||||
|
"/auth/authorize?redirect_uri=http%3A%2F%2Fexample.com%2Fauth%2Fauthorize%3Fskip_oidc_redirect%3Dtrue",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_injected_auth_page_returns_original_html_when_skipped(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client,
|
||||||
|
request_target: str,
|
||||||
|
):
|
||||||
|
"""Injected auth page should render HTML when redirect suppression is requested."""
|
||||||
|
|
||||||
|
await setup_mock_authorize_route(hass)
|
||||||
|
await setup(hass)
|
||||||
|
|
||||||
|
client = await hass_client()
|
||||||
|
response = await client.get(request_target, allow_redirects=False)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
assert "<script src='/auth/oidc/static/injection.js" in await response.text()
|
||||||
90
tests/test_hass_yaml_init.py
Normal file
90
tests/test_hass_yaml_init.py
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
"""Tests for the YAML config setup of OIDC"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from custom_components.auth_oidc import DOMAIN
|
||||||
|
from custom_components.auth_oidc.config.const import ADDITIONAL_SCOPES
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(hass: HomeAssistant, config: dict, expect_success: bool) -> bool:
|
||||||
|
"""Set up the auth_oidc component."""
|
||||||
|
result = await async_setup_component(hass, DOMAIN, {DOMAIN: config})
|
||||||
|
|
||||||
|
if expect_success:
|
||||||
|
assert result
|
||||||
|
assert DOMAIN in hass.data
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"config",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"client_id": "dummy",
|
||||||
|
"discovery_url": "https://example.com/.well-known/openid-configuration",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"client_id": "dummy",
|
||||||
|
"discovery_url": "https://example.com/.well-known/openid-configuration",
|
||||||
|
ADDITIONAL_SCOPES: "email phone",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_setup_success_yaml(hass: HomeAssistant, config: dict):
|
||||||
|
"""YAML setup should succeed for minimal and optional-scope configurations."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
config,
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_failure_empty_yaml(hass: HomeAssistant, caplog):
|
||||||
|
"""Test failure setup of an empty YAML configuration."""
|
||||||
|
await setup(hass, {}, False)
|
||||||
|
|
||||||
|
assert "required key 'client_id' not provided" in caplog.text
|
||||||
|
assert "required key 'discovery_url' not provided" in caplog.text
|
||||||
|
assert (
|
||||||
|
"Setup failed for custom integration 'auth_oidc': Invalid config."
|
||||||
|
in caplog.text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_failure_partial_empty_yaml_discovery(hass: HomeAssistant, caplog):
|
||||||
|
"""Test failure setup of an partial YAML configuration."""
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{"discovery_url": "https://example.com/.well-known/openid-configuration"},
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "required key 'client_id' not provided" in caplog.text
|
||||||
|
assert "required key 'discovery_url' not provided" not in caplog.text
|
||||||
|
assert (
|
||||||
|
"Setup failed for custom integration 'auth_oidc': Invalid config."
|
||||||
|
in caplog.text
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_setup_failure_partial_empty_yaml_client(hass: HomeAssistant, caplog):
|
||||||
|
"""Test failure setup of an partial YAML configuration."""
|
||||||
|
|
||||||
|
await setup(
|
||||||
|
hass,
|
||||||
|
{"client_id": "test"},
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "required key 'client_id' not provided" not in caplog.text
|
||||||
|
assert "required key 'discovery_url' not provided" in caplog.text
|
||||||
|
assert (
|
||||||
|
"Setup failed for custom integration 'auth_oidc': Invalid config."
|
||||||
|
in caplog.text
|
||||||
|
)
|
||||||
174
tests/test_helpers.py
Normal file
174
tests/test_helpers.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""Tests for the helpers and validation tools"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from aiohttp.test_utils import make_mocked_request
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from custom_components.auth_oidc.tools.helpers import (
|
||||||
|
STATE_COOKIE_NAME,
|
||||||
|
error_response,
|
||||||
|
get_state_id,
|
||||||
|
get_url,
|
||||||
|
get_valid_state_id,
|
||||||
|
get_view,
|
||||||
|
html_response,
|
||||||
|
template_response,
|
||||||
|
)
|
||||||
|
from custom_components.auth_oidc.tools.validation import (
|
||||||
|
validate_client_id,
|
||||||
|
sanitize_client_secret,
|
||||||
|
validate_discovery_url,
|
||||||
|
validate_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_url():
|
||||||
|
"""Test the get_url helper."""
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
get_url("https://example.com", "/test")
|
||||||
|
assert str(excinfo.value) == "No current request in context"
|
||||||
|
|
||||||
|
# Mock homeassistant.components.http.current_request.get() to test the force HTTP flag
|
||||||
|
with patch("homeassistant.components.http.current_request") as mock_current_request:
|
||||||
|
fake_request = make_mocked_request("GET", "http://example.com")
|
||||||
|
mock_current_request.get.return_value = fake_request
|
||||||
|
result = get_url("/test", True)
|
||||||
|
assert result == "https://example.com/test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_view():
|
||||||
|
"""Test the get_view helper."""
|
||||||
|
|
||||||
|
data = await get_view("welcome")
|
||||||
|
assert data.startswith("<!DOCTYPE html>")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_state_id():
|
||||||
|
"""State cookie helper should return cookie value when present."""
|
||||||
|
request = make_mocked_request(
|
||||||
|
"GET", "/", headers={"Cookie": f"{STATE_COOKIE_NAME}=abc"}
|
||||||
|
)
|
||||||
|
assert get_state_id(request) == "abc"
|
||||||
|
|
||||||
|
request_without_cookie = make_mocked_request("GET", "/")
|
||||||
|
assert get_state_id(request_without_cookie) is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_valid_state_id():
|
||||||
|
"""Valid-state helper should return only existing and valid cookie states."""
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.async_is_state_valid = AsyncMock(return_value=True)
|
||||||
|
|
||||||
|
request = make_mocked_request(
|
||||||
|
"GET", "/", headers={"Cookie": f"{STATE_COOKIE_NAME}=state-1"}
|
||||||
|
)
|
||||||
|
state_id = await get_valid_state_id(request, provider)
|
||||||
|
|
||||||
|
assert state_id == "state-1"
|
||||||
|
provider.async_is_state_valid.assert_awaited_once_with("state-1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_valid_state_id_invalid_or_missing_cookie():
|
||||||
|
"""Valid-state helper should reject missing and invalid states."""
|
||||||
|
provider = MagicMock()
|
||||||
|
provider.async_is_state_valid = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
request = make_mocked_request(
|
||||||
|
"GET", "/", headers={"Cookie": f"{STATE_COOKIE_NAME}=state-2"}
|
||||||
|
)
|
||||||
|
assert await get_valid_state_id(request, provider) is None
|
||||||
|
provider.async_is_state_valid.assert_awaited_once_with("state-2")
|
||||||
|
|
||||||
|
request_without_cookie = make_mocked_request("GET", "/")
|
||||||
|
provider.async_is_state_valid.reset_mock()
|
||||||
|
assert await get_valid_state_id(request_without_cookie, provider) is None
|
||||||
|
provider.async_is_state_valid.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_html_response_and_template_helpers():
|
||||||
|
"""Response helpers should preserve status and render HTML views."""
|
||||||
|
response = html_response("<p>ok</p>", status=418)
|
||||||
|
assert isinstance(response, web.Response)
|
||||||
|
assert response.status == 418
|
||||||
|
assert response.content_type == "text/html"
|
||||||
|
assert response.text == "<p>ok</p>"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.tools.helpers.get_view",
|
||||||
|
new=AsyncMock(return_value="<p>rendered</p>"),
|
||||||
|
) as mocked_get_view:
|
||||||
|
rendered = await template_response("welcome", {"name": "OIDC"})
|
||||||
|
|
||||||
|
assert rendered.status == 200
|
||||||
|
assert rendered.text == "<p>rendered</p>"
|
||||||
|
mocked_get_view.assert_awaited_once_with("welcome", {"name": "OIDC"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_error_response():
|
||||||
|
"""Error response helper should render the shared error template with status."""
|
||||||
|
with patch(
|
||||||
|
"custom_components.auth_oidc.tools.helpers.get_view",
|
||||||
|
new=AsyncMock(return_value="<p>error</p>"),
|
||||||
|
) as mocked_get_view:
|
||||||
|
rendered = await error_response("boom", status=500)
|
||||||
|
|
||||||
|
assert rendered.status == 500
|
||||||
|
assert rendered.text == "<p>error</p>"
|
||||||
|
mocked_get_view.assert_awaited_once_with("error", {"error": "boom"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_url():
|
||||||
|
"""Test the validate_url helper."""
|
||||||
|
|
||||||
|
assert not validate_url("ftp://example.com")
|
||||||
|
assert validate_url("http://example.com")
|
||||||
|
assert validate_url("https://example.com")
|
||||||
|
assert not validate_url("example.com")
|
||||||
|
assert not validate_url(42)
|
||||||
|
assert not validate_url([])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validate_discovery_url():
|
||||||
|
"""Test the validate_discovery_url helper."""
|
||||||
|
|
||||||
|
assert not validate_discovery_url("ftp://example.com")
|
||||||
|
assert not validate_discovery_url("http://example.com")
|
||||||
|
assert not validate_discovery_url("https://example.com")
|
||||||
|
assert not validate_discovery_url("example.com")
|
||||||
|
assert not validate_discovery_url(
|
||||||
|
"https://example.com/.well-known/openid_configuration"
|
||||||
|
)
|
||||||
|
assert validate_discovery_url(
|
||||||
|
"https://example.com/.well-known/openid-configuration"
|
||||||
|
)
|
||||||
|
assert not validate_discovery_url(2)
|
||||||
|
assert not validate_discovery_url([])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_secret():
|
||||||
|
"""Test the sanitize_client_secret helper."""
|
||||||
|
|
||||||
|
assert sanitize_client_secret("test ") == "test"
|
||||||
|
assert sanitize_client_secret("test2") == "test2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_client_id():
|
||||||
|
"""Test the validate_client_id helper."""
|
||||||
|
|
||||||
|
assert not validate_client_id(" ")
|
||||||
|
assert validate_client_id("test4")
|
||||||
|
assert validate_client_id("test4 ")
|
||||||
53
tests/test_provider_catalog.py
Normal file
53
tests/test_provider_catalog.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Tests for the provider catalog helpers."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from custom_components.auth_oidc.config.const import OIDC_PROVIDERS, REPO_ROOT_URL
|
||||||
|
from custom_components.auth_oidc.config.provider_catalog import (
|
||||||
|
get_provider_config,
|
||||||
|
get_provider_docs_url,
|
||||||
|
get_provider_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("provider_key", "expected_name", "expected_supports_groups"),
|
||||||
|
[
|
||||||
|
("authentik", "Authentik", True),
|
||||||
|
("generic", "OpenID Connect (SSO)", False),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_get_provider_config_and_name(
|
||||||
|
provider_key, expected_name, expected_supports_groups
|
||||||
|
):
|
||||||
|
"""Known providers should resolve to their configured metadata."""
|
||||||
|
config = get_provider_config(provider_key)
|
||||||
|
|
||||||
|
assert config == OIDC_PROVIDERS[provider_key]
|
||||||
|
assert get_provider_name(provider_key) == expected_name
|
||||||
|
assert config["supports_groups"] is expected_supports_groups
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("provider_key", [None, "unknown", ""])
|
||||||
|
def test_provider_fallbacks(provider_key):
|
||||||
|
"""Unknown providers should fall back to neutral defaults."""
|
||||||
|
assert get_provider_config(provider_key or "unknown") == {}
|
||||||
|
assert get_provider_name(provider_key) == "Unknown Provider"
|
||||||
|
assert (
|
||||||
|
get_provider_docs_url(provider_key) == f"{REPO_ROOT_URL}/docs/configuration.md"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("provider_key", "expected_suffix"),
|
||||||
|
[
|
||||||
|
("authentik", "/docs/provider-configurations/authentik.md"),
|
||||||
|
("authelia", "/docs/provider-configurations/authelia.md"),
|
||||||
|
("pocketid", "/docs/provider-configurations/pocket-id.md"),
|
||||||
|
("kanidm", "/docs/provider-configurations/kanidm.md"),
|
||||||
|
("microsoft", "/docs/provider-configurations/microsoft-entra.md"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_provider_docs_urls(provider_key, expected_suffix):
|
||||||
|
"""Known providers should point to provider-specific docs."""
|
||||||
|
assert get_provider_docs_url(provider_key) == f"{REPO_ROOT_URL}{expected_suffix}"
|
||||||
260
tests/test_state_store.py
Normal file
260
tests/test_state_store.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
"""Tests for the state store."""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
|
from auth_oidc.stores.state_store import MAX_DEVICE_CODE_ATTEMPTS, StateStore
|
||||||
|
|
||||||
|
TEST_IP = "127.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_store_generate_and_receive_state(hass: HomeAssistant):
|
||||||
|
"""Test creating a state, storing user info, and receiving it once."""
|
||||||
|
store_mock = AsyncMock()
|
||||||
|
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||||
|
state_store = StateStore(hass)
|
||||||
|
|
||||||
|
store_mock.async_load.return_value = {}
|
||||||
|
await state_store.async_load()
|
||||||
|
assert state_store.get_data() == {}
|
||||||
|
|
||||||
|
redirect_uri = "https://example.com/callback"
|
||||||
|
state_id = await state_store.async_create_state_from_url(redirect_uri, TEST_IP)
|
||||||
|
assert state_id in state_store.get_data()
|
||||||
|
assert (
|
||||||
|
await state_store.async_get_redirect_uri_for_state(state_id, TEST_IP)
|
||||||
|
== redirect_uri
|
||||||
|
)
|
||||||
|
|
||||||
|
user_info = {
|
||||||
|
"sub": "user1",
|
||||||
|
"display_name": "Test User",
|
||||||
|
"username": "testuser",
|
||||||
|
"role": "system-users",
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
await state_store.async_add_userinfo_to_state(state_id, user_info) is True
|
||||||
|
)
|
||||||
|
assert state_id in state_store.get_data()
|
||||||
|
assert await state_store.async_is_state_ready(state_id, TEST_IP) is True
|
||||||
|
assert state_id in state_store.get_data()
|
||||||
|
|
||||||
|
result = await state_store.async_receive_userinfo_for_state(state_id, TEST_IP)
|
||||||
|
assert result == user_info
|
||||||
|
assert state_id not in state_store.get_data()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_store_generate_code_and_link_state(hass: HomeAssistant):
|
||||||
|
"""Test generating a device code and linking another state to it."""
|
||||||
|
store_mock = AsyncMock()
|
||||||
|
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||||
|
state_store = StateStore(hass)
|
||||||
|
|
||||||
|
store_mock.async_load.return_value = {}
|
||||||
|
await state_store.async_load()
|
||||||
|
|
||||||
|
donor_state = await state_store.async_create_state_from_url(
|
||||||
|
"https://example.com/donor", TEST_IP
|
||||||
|
)
|
||||||
|
target_state = await state_store.async_create_state_from_url(
|
||||||
|
"https://example.com/target", TEST_IP
|
||||||
|
)
|
||||||
|
|
||||||
|
code = await state_store.async_generate_code_for_state(target_state)
|
||||||
|
assert code is not None
|
||||||
|
assert len(code) == 6
|
||||||
|
assert code.isdigit()
|
||||||
|
|
||||||
|
user_info = {
|
||||||
|
"sub": "user2",
|
||||||
|
"display_name": "Device User",
|
||||||
|
"username": "deviceuser",
|
||||||
|
"role": "system-admin",
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
await state_store.async_add_userinfo_to_state(donor_state, user_info)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
assert donor_state in state_store.get_data()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await state_store.async_link_state_to_code(donor_state, code, TEST_IP)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
assert donor_state not in state_store.get_data()
|
||||||
|
assert await state_store.async_is_state_ready(target_state, TEST_IP) is True
|
||||||
|
assert target_state in state_store.get_data()
|
||||||
|
assert (
|
||||||
|
await state_store.async_receive_userinfo_for_state(target_state, TEST_IP)
|
||||||
|
== user_info
|
||||||
|
)
|
||||||
|
assert target_state not in state_store.get_data()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_store_link_state_returns_false_for_wrong_code(hass: HomeAssistant):
|
||||||
|
"""Test linking fails when the device code does not match any state."""
|
||||||
|
store_mock = AsyncMock()
|
||||||
|
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||||
|
state_store = StateStore(hass)
|
||||||
|
|
||||||
|
store_mock.async_load.return_value = {}
|
||||||
|
await state_store.async_load()
|
||||||
|
|
||||||
|
donor_state = await state_store.async_create_state_from_url(
|
||||||
|
"https://example.com/donor", TEST_IP
|
||||||
|
)
|
||||||
|
target_state = await state_store.async_create_state_from_url(
|
||||||
|
"https://example.com/target", TEST_IP
|
||||||
|
)
|
||||||
|
await state_store.async_generate_code_for_state(target_state)
|
||||||
|
|
||||||
|
user_info = {
|
||||||
|
"sub": "user3",
|
||||||
|
"display_name": "Wrong Code User",
|
||||||
|
"username": "wrongcode",
|
||||||
|
"role": "system-users",
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
await state_store.async_add_userinfo_to_state(donor_state, user_info)
|
||||||
|
is True
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await state_store.async_link_state_to_code(donor_state, "000000", TEST_IP)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
assert donor_state in state_store.get_data()
|
||||||
|
assert await state_store.async_is_state_ready(target_state, TEST_IP) is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_store_throttles_device_code_link_attempts(hass: HomeAssistant):
|
||||||
|
"""Test that repeated wrong device codes are throttled per state."""
|
||||||
|
store_mock = AsyncMock()
|
||||||
|
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||||
|
state_store = StateStore(hass)
|
||||||
|
|
||||||
|
store_mock.async_load.return_value = {}
|
||||||
|
await state_store.async_load()
|
||||||
|
|
||||||
|
donor_state = await state_store.async_create_state_from_url(
|
||||||
|
"https://example.com/donor", TEST_IP
|
||||||
|
)
|
||||||
|
target_state = await state_store.async_create_state_from_url(
|
||||||
|
"https://example.com/target", TEST_IP
|
||||||
|
)
|
||||||
|
code = await state_store.async_generate_code_for_state(target_state)
|
||||||
|
assert code is not None
|
||||||
|
|
||||||
|
user_info = {
|
||||||
|
"sub": "user-throttle",
|
||||||
|
"display_name": "Throttle User",
|
||||||
|
"username": "throttle",
|
||||||
|
"role": "system-users",
|
||||||
|
}
|
||||||
|
assert await state_store.async_add_userinfo_to_state(donor_state, user_info)
|
||||||
|
|
||||||
|
for _ in range(MAX_DEVICE_CODE_ATTEMPTS):
|
||||||
|
assert (
|
||||||
|
await state_store.async_link_state_to_code(
|
||||||
|
donor_state, "000000", TEST_IP
|
||||||
|
)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await state_store.async_link_state_to_code(donor_state, code, TEST_IP)
|
||||||
|
is False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_store_expired_state(hass: HomeAssistant):
|
||||||
|
"""Test that expired states are treated as invalid."""
|
||||||
|
store_mock = AsyncMock()
|
||||||
|
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||||
|
state_store = StateStore(hass)
|
||||||
|
|
||||||
|
store_mock.async_load.return_value = {}
|
||||||
|
await state_store.async_load()
|
||||||
|
|
||||||
|
state_id = await state_store.async_create_state_from_url(
|
||||||
|
"https://example.com/expired", TEST_IP
|
||||||
|
)
|
||||||
|
state_store.get_data()[state_id]["expiration"] = (
|
||||||
|
datetime.now(timezone.utc) - timedelta(minutes=10)
|
||||||
|
).isoformat()
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await state_store.async_get_redirect_uri_for_state(state_id, TEST_IP)
|
||||||
|
is None
|
||||||
|
)
|
||||||
|
assert await state_store.async_is_state_ready(state_id, TEST_IP) is False
|
||||||
|
assert (
|
||||||
|
await state_store.async_receive_userinfo_for_state(state_id, TEST_IP)
|
||||||
|
is None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_store_data_not_loaded(hass: HomeAssistant):
|
||||||
|
"""Test that using the store before loading raises RuntimeError."""
|
||||||
|
store_mock = AsyncMock()
|
||||||
|
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||||
|
state_store = StateStore(hass)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await state_store.async_create_state_from_url(
|
||||||
|
"https://example.com", TEST_IP
|
||||||
|
)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await state_store.async_generate_code_for_state("state")
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await state_store.async_add_userinfo_to_state(
|
||||||
|
"state",
|
||||||
|
{
|
||||||
|
"sub": "user4",
|
||||||
|
"display_name": "Not Loaded",
|
||||||
|
"username": "notloaded",
|
||||||
|
"role": "system-users",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await state_store.async_get_redirect_uri_for_state("state", TEST_IP)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await state_store.async_is_state_ready("state", TEST_IP)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await state_store.async_link_state_to_code("state", "123456", TEST_IP)
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
await state_store.async_receive_userinfo_for_state("state", TEST_IP)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_state_store_missing_keys(hass: HomeAssistant):
|
||||||
|
"""Test that missing keys raise correct responses."""
|
||||||
|
store_mock = AsyncMock()
|
||||||
|
with patch("homeassistant.helpers.storage.Store", return_value=store_mock):
|
||||||
|
state_store = StateStore(hass)
|
||||||
|
|
||||||
|
# async_generate_code_for_state returns None if state_id is not found
|
||||||
|
store_mock.async_load.return_value = {}
|
||||||
|
await state_store.async_load()
|
||||||
|
assert await state_store.async_generate_code_for_state("nonexistent") is None
|
||||||
|
|
||||||
|
# async_add_userinfo_to_state returns False if state_id is not found
|
||||||
|
user_info = {
|
||||||
|
"sub": "user5",
|
||||||
|
"display_name": "Missing Keys",
|
||||||
|
"username": "missingkeys",
|
||||||
|
"role": "system-users",
|
||||||
|
}
|
||||||
|
assert (
|
||||||
|
await state_store.async_add_userinfo_to_state("nonexistent", user_info)
|
||||||
|
is False
|
||||||
|
)
|
||||||
54
tests/test_view_template.py
Normal file
54
tests/test_view_template.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Tests for the view templates"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from os import path
|
||||||
|
|
||||||
|
from custom_components.auth_oidc.views.loader import AsyncTemplateRenderer
|
||||||
|
|
||||||
|
FAKE_TEMPLATE_PATH = path.join(
|
||||||
|
path.dirname(path.abspath(__file__)), "resources", "fake_templates"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_real_template_render():
|
||||||
|
"""Test that view template can render an real existing template."""
|
||||||
|
|
||||||
|
renderer = AsyncTemplateRenderer()
|
||||||
|
await renderer.fetch_templates()
|
||||||
|
rendered = await renderer.render_template(
|
||||||
|
"welcome.html", name="<script>alert(1)</script>"
|
||||||
|
)
|
||||||
|
assert "<!DOCTYPE html>" in rendered
|
||||||
|
assert "<script>alert(1)</script>" in rendered
|
||||||
|
assert "<script>alert(1)</script>" not in rendered
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fake_template_render():
|
||||||
|
"""Test that view template can render an fake existing template."""
|
||||||
|
|
||||||
|
renderer = AsyncTemplateRenderer(template_dir=FAKE_TEMPLATE_PATH)
|
||||||
|
await renderer.fetch_templates()
|
||||||
|
rendered = await renderer.render_template("index.html")
|
||||||
|
assert "<p>Example template</p>" in rendered
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_dir_render_error():
|
||||||
|
"""Test that view template sends correct error if you try to render directory."""
|
||||||
|
|
||||||
|
renderer = AsyncTemplateRenderer(template_dir=FAKE_TEMPLATE_PATH)
|
||||||
|
await renderer.fetch_templates()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await renderer.render_template("folder.html")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_random_render_error():
|
||||||
|
"""Test that view template sends correct error if you try to render non-existing."""
|
||||||
|
|
||||||
|
renderer = AsyncTemplateRenderer(template_dir=FAKE_TEMPLATE_PATH)
|
||||||
|
await renderer.fetch_templates()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await renderer.render_template("non_existing.html")
|
||||||
Reference in New Issue
Block a user