Vastai-ConnectHub/app/security/ldap_client.py

72 lines
2.6 KiB
Python

from __future__ import annotations
import ssl
from typing import Any
from ldap3 import ALL, Connection, Server, Tls
from app.core.config import settings
class LdapClient:
def __init__(self) -> None:
tls = None
if settings.ldap_uri.startswith("ldaps://") or settings.ldap_use_starttls:
tls = Tls(validate=ssl.CERT_REQUIRED if settings.ldap_verify_tls else ssl.CERT_NONE)
self._server = Server(settings.ldap_uri, use_ssl=settings.ldap_uri.startswith("ldaps://"), get_info=ALL, tls=tls)
def _connect(self, *, bind_dn: str | None = None, password: str | None = None) -> Connection:
conn = Connection(self._server, user=bind_dn, password=password, auto_bind=False)
conn.open()
if settings.ldap_use_starttls and self._server.ssl is False:
conn.start_tls()
conn.bind()
return conn
def _service_conn(self) -> Connection:
return self._connect(bind_dn=settings.ldap_bind_dn, password=settings.ldap_bind_password)
def find_user_dn(self, username: str) -> str | None:
if not settings.ldap_base_dn:
return None
conn = self._service_conn()
try:
search_filter = settings.ldap_user_filter.format(username=username)
conn.search(settings.ldap_base_dn, search_filter, attributes=["dn"])
if not conn.entries:
return None
return conn.entries[0].entry_dn
finally:
conn.unbind()
def authenticate(self, username: str, password: str) -> dict[str, Any] | None:
user_dn = self.find_user_dn(username)
if not user_dn:
return None
try:
conn = self._connect(bind_dn=user_dn, password=password)
conn.unbind()
return {"user_dn": user_dn}
except Exception:
return None
def get_user_groups(self, *, user_dn: str, username: str) -> list[dict[str, str]]:
if not settings.ldap_base_dn:
return []
conn = self._service_conn()
try:
search_filter = settings.ldap_group_filter.format(user_dn=user_dn, username=username)
conn.search(settings.ldap_base_dn, search_filter, attributes=["cn"])
results: list[dict[str, str]] = []
for entry in conn.entries:
dn = entry.entry_dn
name = ""
try:
name = str(entry.cn.value) if hasattr(entry, "cn") else ""
except Exception:
name = ""
results.append({"dn": dn, "name": name or dn})
return results
finally:
conn.unbind()