72 lines
2.6 KiB
Python
72 lines
2.6 KiB
Python
from __future__ import annotations
|
|
|
|
import ssl
|
|
from typing import Any
|
|
|
|
from ldap3 import ALL, Connection, Server, Tls
|
|
|
|
from app.core.config import settings
|
|
|
|
|
|
class LdapClient:
|
|
def __init__(self) -> None:
|
|
tls = None
|
|
if settings.ldap_uri.startswith("ldaps://") or settings.ldap_use_starttls:
|
|
tls = Tls(validate=ssl.CERT_REQUIRED if settings.ldap_verify_tls else ssl.CERT_NONE)
|
|
self._server = Server(settings.ldap_uri, use_ssl=settings.ldap_uri.startswith("ldaps://"), get_info=ALL, tls=tls)
|
|
|
|
def _connect(self, *, bind_dn: str | None = None, password: str | None = None) -> Connection:
|
|
conn = Connection(self._server, user=bind_dn, password=password, auto_bind=False)
|
|
conn.open()
|
|
if settings.ldap_use_starttls and self._server.ssl is False:
|
|
conn.start_tls()
|
|
conn.bind()
|
|
return conn
|
|
|
|
def _service_conn(self) -> Connection:
|
|
return self._connect(bind_dn=settings.ldap_bind_dn, password=settings.ldap_bind_password)
|
|
|
|
def find_user_dn(self, username: str) -> str | None:
|
|
if not settings.ldap_base_dn:
|
|
return None
|
|
conn = self._service_conn()
|
|
try:
|
|
search_filter = settings.ldap_user_filter.format(username=username)
|
|
conn.search(settings.ldap_base_dn, search_filter, attributes=["dn"])
|
|
if not conn.entries:
|
|
return None
|
|
return conn.entries[0].entry_dn
|
|
finally:
|
|
conn.unbind()
|
|
|
|
def authenticate(self, username: str, password: str) -> dict[str, Any] | None:
|
|
user_dn = self.find_user_dn(username)
|
|
if not user_dn:
|
|
return None
|
|
try:
|
|
conn = self._connect(bind_dn=user_dn, password=password)
|
|
conn.unbind()
|
|
return {"user_dn": user_dn}
|
|
except Exception:
|
|
return None
|
|
|
|
def get_user_groups(self, *, user_dn: str, username: str) -> list[dict[str, str]]:
|
|
if not settings.ldap_base_dn:
|
|
return []
|
|
conn = self._service_conn()
|
|
try:
|
|
search_filter = settings.ldap_group_filter.format(user_dn=user_dn, username=username)
|
|
conn.search(settings.ldap_base_dn, search_filter, attributes=["cn"])
|
|
results: list[dict[str, str]] = []
|
|
for entry in conn.entries:
|
|
dn = entry.entry_dn
|
|
name = ""
|
|
try:
|
|
name = str(entry.cn.value) if hasattr(entry, "cn") else ""
|
|
except Exception:
|
|
name = ""
|
|
results.append({"dn": dn, "name": name or dn})
|
|
return results
|
|
finally:
|
|
conn.unbind()
|