109 lines
4.3 KiB
Python
109 lines
4.3 KiB
Python
from __future__ import annotations
|
|
|
|
import ssl
|
|
from typing import Any
|
|
|
|
from ldap3 import ALL, Connection, Server, Tls
|
|
|
|
from app.security.ldap_config import RuntimeLdapConfig, get_runtime_ldap_config
|
|
|
|
|
|
class LdapClient:
|
|
def __init__(self, config: RuntimeLdapConfig | None = None) -> None:
|
|
self.config = config if config is not None else get_runtime_ldap_config()
|
|
self._server: Server | None = None
|
|
if not self.config:
|
|
return
|
|
tls = None
|
|
if self.config.uri.startswith("ldaps://") or self.config.use_starttls:
|
|
tls = Tls(validate=ssl.CERT_REQUIRED if self.config.verify_tls else ssl.CERT_NONE)
|
|
self._server = Server(self.config.uri, use_ssl=self.config.uri.startswith("ldaps://"), get_info=ALL, tls=tls)
|
|
|
|
def is_enabled(self) -> bool:
|
|
return bool(self.config and self.config.enabled and self.config.uri and self.config.base_dn)
|
|
|
|
def _format_filter(self, template: str, **values: str) -> str:
|
|
# 兼容 AD 配置中常见的 sAMAccountName 占位写法。
|
|
if "username" in values:
|
|
values.setdefault("sAMAccountName", values["username"])
|
|
return template.format(**values)
|
|
|
|
def _connect(self, *, bind_dn: str | None = None, password: str | None = None) -> Connection:
|
|
if not self._server:
|
|
raise RuntimeError("LDAP is not configured")
|
|
conn = Connection(self._server, user=bind_dn, password=password, auto_bind=False)
|
|
conn.open()
|
|
if self.config and self.config.use_starttls and self._server.ssl is False:
|
|
if not conn.start_tls():
|
|
raise RuntimeError(f"LDAP StartTLS failed: {conn.result}")
|
|
if not conn.bind():
|
|
raise RuntimeError(f"LDAP bind failed: {conn.result}")
|
|
return conn
|
|
|
|
def _service_conn(self) -> Connection:
|
|
if not self.config:
|
|
raise RuntimeError("LDAP is not configured")
|
|
return self._connect(bind_dn=self.config.bind_dn, password=self.config.bind_password)
|
|
|
|
def find_user_dn(self, username: str) -> str | None:
|
|
if not self.is_enabled() or not self.config:
|
|
return None
|
|
conn = self._service_conn()
|
|
try:
|
|
search_filter = self._format_filter(self.config.user_filter, username=username)
|
|
conn.search(self.config.base_dn, search_filter, attributes=["dn"])
|
|
if not conn.entries:
|
|
return None
|
|
return conn.entries[0].entry_dn
|
|
finally:
|
|
conn.unbind()
|
|
|
|
def authenticate(self, username: str, password: str) -> dict[str, Any] | None:
|
|
user_dn = self.find_user_dn(username)
|
|
if not user_dn:
|
|
return None
|
|
try:
|
|
conn = self._connect(bind_dn=user_dn, password=password)
|
|
conn.unbind()
|
|
return {"user_dn": user_dn}
|
|
except Exception:
|
|
return None
|
|
|
|
def get_user_groups(self, *, user_dn: str, username: str) -> list[dict[str, str]]:
|
|
if not self.is_enabled() or not self.config:
|
|
return []
|
|
conn = self._service_conn()
|
|
try:
|
|
search_filter = self._format_filter(self.config.group_filter, user_dn=user_dn, username=username)
|
|
conn.search(self.config.base_dn, search_filter, attributes=["cn"])
|
|
results: list[dict[str, str]] = []
|
|
for entry in conn.entries:
|
|
dn = entry.entry_dn
|
|
name = ""
|
|
try:
|
|
name = str(entry.cn.value) if hasattr(entry, "cn") else ""
|
|
except Exception:
|
|
name = ""
|
|
results.append({"dn": dn, "name": name or dn})
|
|
return results
|
|
finally:
|
|
conn.unbind()
|
|
|
|
def test_connection(self) -> dict[str, Any]:
|
|
if not self.is_enabled() or not self.config:
|
|
return {"ok": False, "message": "LDAP 未启用或配置不完整"}
|
|
conn = None
|
|
try:
|
|
conn = self._service_conn()
|
|
conn.search(self.config.base_dn, "(objectClass=*)", attributes=["dn"], size_limit=1)
|
|
return {
|
|
"ok": True,
|
|
"message": "LDAP 服务账号连接成功",
|
|
"entries_found": len(conn.entries),
|
|
}
|
|
except Exception as exc:
|
|
return {"ok": False, "message": str(exc)}
|
|
finally:
|
|
if conn is not None:
|
|
conn.unbind()
|