from __future__ import annotations import ssl from typing import Any from ldap3 import ALL, Connection, Server, Tls from app.core.config import settings class LdapClient: def __init__(self) -> None: tls = None if settings.ldap_uri.startswith("ldaps://") or settings.ldap_use_starttls: tls = Tls(validate=ssl.CERT_REQUIRED if settings.ldap_verify_tls else ssl.CERT_NONE) self._server = Server(settings.ldap_uri, use_ssl=settings.ldap_uri.startswith("ldaps://"), get_info=ALL, tls=tls) def _connect(self, *, bind_dn: str | None = None, password: str | None = None) -> Connection: conn = Connection(self._server, user=bind_dn, password=password, auto_bind=False) conn.open() if settings.ldap_use_starttls and self._server.ssl is False: conn.start_tls() conn.bind() return conn def _service_conn(self) -> Connection: return self._connect(bind_dn=settings.ldap_bind_dn, password=settings.ldap_bind_password) def find_user_dn(self, username: str) -> str | None: if not settings.ldap_base_dn: return None conn = self._service_conn() try: search_filter = settings.ldap_user_filter.format(username=username) conn.search(settings.ldap_base_dn, search_filter, attributes=["dn"]) if not conn.entries: return None return conn.entries[0].entry_dn finally: conn.unbind() def authenticate(self, username: str, password: str) -> dict[str, Any] | None: user_dn = self.find_user_dn(username) if not user_dn: return None try: conn = self._connect(bind_dn=user_dn, password=password) conn.unbind() return {"user_dn": user_dn} except Exception: return None def get_user_groups(self, *, user_dn: str, username: str) -> list[dict[str, str]]: if not settings.ldap_base_dn: return [] conn = self._service_conn() try: search_filter = settings.ldap_group_filter.format(user_dn=user_dn, username=username) conn.search(settings.ldap_base_dn, search_filter, attributes=["cn"]) results: list[dict[str, str]] = [] for entry in conn.entries: dn = entry.entry_dn name = "" try: name = str(entry.cn.value) if hasattr(entry, "cn") else "" except Exception: name = "" results.append({"dn": dn, "name": name or dn}) return results finally: conn.unbind()