Vastai-ConnectHub/app/security/ldap_config.py

105 lines
2.8 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.core.config import settings
from app.db.engine import get_session
from app.db.models import LdapConfig
from app.security.fernet import decrypt_json, encrypt_json
LDAP_CONFIG_NAME = "default"
@dataclass(frozen=True)
class RuntimeLdapConfig:
enabled: bool
uri: str
bind_dn: str
bind_password: str
base_dn: str
user_filter: str
group_filter: str
use_starttls: bool
verify_tls: bool
def encrypt_bind_password(password: str) -> str:
if not password:
return ""
return encrypt_json({"password": password})
def decrypt_bind_password(token: str) -> str:
if not token:
return ""
try:
data = decrypt_json(token)
except Exception:
return ""
value = data.get("password")
return str(value) if value is not None else ""
def get_default_ldap_config(session: Session) -> LdapConfig | None:
return session.scalar(select(LdapConfig).where(LdapConfig.name == LDAP_CONFIG_NAME))
def ensure_default_ldap_config() -> None:
db = get_session()
try:
if get_default_ldap_config(db):
return
config = LdapConfig(
name=LDAP_CONFIG_NAME,
enabled=bool(settings.ldap_base_dn),
uri=settings.ldap_uri,
bind_dn=settings.ldap_bind_dn,
bind_password_encrypted=encrypt_bind_password(settings.ldap_bind_password),
base_dn=settings.ldap_base_dn,
user_filter=settings.ldap_user_filter,
group_filter=settings.ldap_group_filter,
use_starttls=settings.ldap_use_starttls,
verify_tls=settings.ldap_verify_tls,
last_test_result={},
)
db.add(config)
db.commit()
finally:
db.close()
def to_runtime_config(config: LdapConfig | None) -> RuntimeLdapConfig | None:
if not config or not config.enabled:
return None
return RuntimeLdapConfig(
enabled=bool(config.enabled),
uri=config.uri,
bind_dn=config.bind_dn,
bind_password=decrypt_bind_password(config.bind_password_encrypted),
base_dn=config.base_dn,
user_filter=config.user_filter,
group_filter=config.group_filter,
use_starttls=bool(config.use_starttls),
verify_tls=bool(config.verify_tls),
)
def get_runtime_ldap_config() -> RuntimeLdapConfig | None:
db = get_session()
try:
return to_runtime_config(get_default_ldap_config(db))
finally:
db.close()
def mask_sensitive_result(result: dict[str, Any]) -> dict[str, Any]:
cleaned = dict(result)
cleaned.pop("password", None)
cleaned.pop("bind_password", None)
return cleaned