"""JWT helpers for bincio-auth. HS256 tokens: used for session cookies (existing behaviour, shared secret). RS256 tokens: used for OIDC id_tokens (asymmetric, public key via JWKS). """ from __future__ import annotations import base64 import time import jwt from cryptography.hazmat.primitives.serialization import load_pem_private_key _KID = "bincio-oidc-1" # ── HS256 (session cookies) ─────────────────────────────────────────────────── def create_token(payload: dict, secret: str, expires_in: int) -> str: claims = {**payload, "exp": int(time.time()) + expires_in} return jwt.encode(claims, secret, algorithm="HS256") def decode_token(token: str, secret: str) -> dict: """Decode and verify an HS256 JWT. Raises jwt.PyJWTError on any failure.""" return jwt.decode(token, secret, algorithms=["HS256"]) # ── RS256 (OIDC id_tokens) ──────────────────────────────────────────────────── def create_id_token(payload: dict, private_key_pem: str, expires_in: int) -> str: """Sign an OIDC id_token with RS256. payload should include iss, sub, aud.""" now = int(time.time()) claims = {**payload, "iat": now, "exp": now + expires_in} private_key = load_pem_private_key(private_key_pem.encode(), password=None) return jwt.encode(claims, private_key, algorithm="RS256", headers={"kid": _KID}) def get_jwks(private_key_pem: str) -> dict: """Return the JWKS document for the given RSA private key.""" private_key = load_pem_private_key(private_key_pem.encode(), password=None) pub = private_key.public_key().public_numbers() def b64url(n: int) -> str: length = (n.bit_length() + 7) // 8 return base64.urlsafe_b64encode(n.to_bytes(length, "big")).rstrip(b"=").decode() return { "keys": [{ "kty": "RSA", "use": "sig", "alg": "RS256", "kid": _KID, "n": b64url(pub.n), "e": b64url(pub.e), }] } def decode_id_token(token: str, private_key_pem: str, audience: str) -> dict: """Decode and verify an RS256 id_token (used by userinfo endpoint).""" private_key = load_pem_private_key(private_key_pem.encode(), password=None) public_key = private_key.public_key() return jwt.decode(token, public_key, algorithms=["RS256"], audience=audience)