Files
Davide Scaini 42bc476882 feat: OIDC Identity Provider — Phase 1 endpoints
Add OIDC/OAuth2 endpoints to bincio-auth so it acts as a full IdP:
  GET  /.well-known/openid-configuration
  GET  /.well-known/jwks.json
  GET  /oauth2/authorize  (auth-code flow, redirects to /login/ if no session)
  POST /oauth2/token      (exchanges code for RS256 id_token; PKCE supported)
  GET  /oauth2/userinfo   (Bearer token → profile claims)

Infrastructure:
  - oauth2_clients + oauth2_codes tables in db.py with CRUD helpers
  - RS256 sign/verify helpers in tokens.py (create_id_token, get_jwks)
  - oidc_private_key_pem / oidc_issuer state + _issue_id_token in deps.py
  - serve_cmd reads BINCIO_OIDC_PRIVATE_KEY_FILE / BINCIO_OIDC_ISSUER env vars
  - `bincio-auth client add/list` commands for managing OAuth2 clients
2026-06-03 15:11:43 +02:00

66 lines
2.5 KiB
Python

"""JWT helpers for bincio-auth.
HS256 tokens: used for session cookies (existing behaviour, shared secret).
RS256 tokens: used for OIDC id_tokens (asymmetric, public key via JWKS).
"""
from __future__ import annotations
import base64
import time
import jwt
from cryptography.hazmat.primitives.serialization import load_pem_private_key
_KID = "bincio-oidc-1"
# ── HS256 (session cookies) ───────────────────────────────────────────────────
def create_token(payload: dict, secret: str, expires_in: int) -> str:
claims = {**payload, "exp": int(time.time()) + expires_in}
return jwt.encode(claims, secret, algorithm="HS256")
def decode_token(token: str, secret: str) -> dict:
"""Decode and verify an HS256 JWT. Raises jwt.PyJWTError on any failure."""
return jwt.decode(token, secret, algorithms=["HS256"])
# ── RS256 (OIDC id_tokens) ────────────────────────────────────────────────────
def create_id_token(payload: dict, private_key_pem: str, expires_in: int) -> str:
"""Sign an OIDC id_token with RS256. payload should include iss, sub, aud."""
now = int(time.time())
claims = {**payload, "iat": now, "exp": now + expires_in}
private_key = load_pem_private_key(private_key_pem.encode(), password=None)
return jwt.encode(claims, private_key, algorithm="RS256", headers={"kid": _KID})
def get_jwks(private_key_pem: str) -> dict:
"""Return the JWKS document for the given RSA private key."""
private_key = load_pem_private_key(private_key_pem.encode(), password=None)
pub = private_key.public_key().public_numbers()
def b64url(n: int) -> str:
length = (n.bit_length() + 7) // 8
return base64.urlsafe_b64encode(n.to_bytes(length, "big")).rstrip(b"=").decode()
return {
"keys": [{
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"kid": _KID,
"n": b64url(pub.n),
"e": b64url(pub.e),
}]
}
def decode_id_token(token: str, private_key_pem: str, audience: str) -> dict:
"""Decode and verify an RS256 id_token (used by userinfo endpoint)."""
private_key = load_pem_private_key(private_key_pem.encode(), password=None)
public_key = private_key.public_key()
return jwt.decode(token, public_key, algorithms=["RS256"], audience=audience)