from __future__ import annotations import ssl from typing import Any from ldap3 import ALL, Connection, Server, Tls from app.security.ldap_config import RuntimeLdapConfig, get_runtime_ldap_config class LdapClient: def __init__(self, config: RuntimeLdapConfig | None = None) -> None: self.config = config if config is not None else get_runtime_ldap_config() self._server: Server | None = None if not self.config: return tls = None if self.config.uri.startswith("ldaps://") or self.config.use_starttls: tls = Tls(validate=ssl.CERT_REQUIRED if self.config.verify_tls else ssl.CERT_NONE) self._server = Server(self.config.uri, use_ssl=self.config.uri.startswith("ldaps://"), get_info=ALL, tls=tls) def is_enabled(self) -> bool: return bool(self.config and self.config.enabled and self.config.uri and self.config.base_dn) def _format_filter(self, template: str, **values: str) -> str: # 兼容 AD 配置中常见的 sAMAccountName 占位写法。 if "username" in values: values.setdefault("sAMAccountName", values["username"]) return template.format(**values) def _connect(self, *, bind_dn: str | None = None, password: str | None = None) -> Connection: if not self._server: raise RuntimeError("LDAP is not configured") conn = Connection(self._server, user=bind_dn, password=password, auto_bind=False) conn.open() if self.config and self.config.use_starttls and self._server.ssl is False: if not conn.start_tls(): raise RuntimeError(f"LDAP StartTLS failed: {conn.result}") if not conn.bind(): raise RuntimeError(f"LDAP bind failed: {conn.result}") return conn def _service_conn(self) -> Connection: if not self.config: raise RuntimeError("LDAP is not configured") return self._connect(bind_dn=self.config.bind_dn, password=self.config.bind_password) def find_user_dn(self, username: str) -> str | None: if not self.is_enabled() or not self.config: return None conn = self._service_conn() try: search_filter = self._format_filter(self.config.user_filter, username=username) conn.search(self.config.base_dn, search_filter, attributes=["dn"]) if not conn.entries: return None return conn.entries[0].entry_dn finally: conn.unbind() def authenticate(self, username: str, password: str) -> dict[str, Any] | None: user_dn = self.find_user_dn(username) if not user_dn: return None try: conn = self._connect(bind_dn=user_dn, password=password) conn.unbind() return {"user_dn": user_dn} except Exception: return None def get_user_groups(self, *, user_dn: str, username: str) -> list[dict[str, str]]: if not self.is_enabled() or not self.config: return [] conn = self._service_conn() try: search_filter = self._format_filter(self.config.group_filter, user_dn=user_dn, username=username) conn.search(self.config.base_dn, search_filter, attributes=["cn"]) results: list[dict[str, str]] = [] for entry in conn.entries: dn = entry.entry_dn name = "" try: name = str(entry.cn.value) if hasattr(entry, "cn") else "" except Exception: name = "" results.append({"dn": dn, "name": name or dn}) return results finally: conn.unbind() def test_connection(self) -> dict[str, Any]: if not self.is_enabled() or not self.config: return {"ok": False, "message": "LDAP 未启用或配置不完整"} conn = None try: conn = self._service_conn() conn.search(self.config.base_dn, "(objectClass=*)", attributes=["dn"], size_limit=1) return { "ok": True, "message": "LDAP 服务账号连接成功", "entries_found": len(conn.entries), } except Exception as exc: return {"ok": False, "message": str(exc)} finally: if conn is not None: conn.unbind()