from __future__ import annotations import logging import ssl from typing import Any from ldap3 import ALL, BASE, MODIFY_REPLACE, SUBTREE, Connection, Server, Tls from ldap3.utils.conv import escape_filter_chars from extensions.sync_ehr_to_oa.api import SyncEhrToOaApi logger = logging.getLogger("connecthub.extensions.sync_ehr_to_ad") class SyncEhrToAdApi(SyncEhrToOaApi): """北森 EHR -> AD 同步用 EHR API 封装。""" def __init__( self, *, secret_params: dict[str, str], base_url: str = "https://openapi.italent.cn", timeout_s: float = 10.0, retries: int = 2, retry_backoff_s: float = 0.5, ) -> None: super().__init__( secret_params=secret_params, sqlserver_params=None, base_url=base_url, timeout_s=timeout_s, retries=retries, retry_backoff_s=retry_backoff_s, ) class ActiveDirectoryClient: """使用 ldap3 直接更新本地 AD 用户属性。""" def __init__( self, *, ldap_uri: str, bind_dn: str, bind_password: str, base_dn: str, user_filter: str = "(sAMAccountName={sAMAccountName})", use_starttls: bool = False, verify_tls: bool = True, connect_timeout_s: int = 10, ) -> None: self.ldap_uri = str(ldap_uri or "").strip() self.bind_dn = str(bind_dn or "").strip() self.bind_password = str(bind_password or "") self.base_dn = str(base_dn or "").strip() self.user_filter = str(user_filter or "(sAMAccountName={sAMAccountName})").strip() self.use_starttls = bool(use_starttls) self.verify_tls = bool(verify_tls) self.connect_timeout_s = int(connect_timeout_s or 10) if not self.ldap_uri: raise ValueError("ldap_uri is required") if not self.bind_dn or not self.bind_password: raise ValueError("bind_dn and bind_password are required") if not self.base_dn: raise ValueError("base_dn is required") tls = None if self.ldap_uri.startswith("ldaps://") or self.use_starttls: tls = Tls(validate=ssl.CERT_REQUIRED if self.verify_tls else ssl.CERT_NONE) self._server = Server( self.ldap_uri, use_ssl=self.ldap_uri.startswith("ldaps://"), get_info=ALL, tls=tls, connect_timeout=self.connect_timeout_s, ) def _connect(self) -> Connection: conn = Connection(self._server, user=self.bind_dn, password=self.bind_password, auto_bind=False) conn.open() if self.use_starttls and self._server.ssl is False: conn.start_tls() if not conn.bind(): raise RuntimeError(f"AD bind failed: {conn.result}") return conn def ping(self) -> bool: conn = self._connect() try: return bool(conn.bound) finally: conn.unbind() def _format_user_filter(self, sam_account_name: str) -> str: escaped = escape_filter_chars(str(sam_account_name or "").strip()) return self.user_filter.format( sAMAccountName=escaped, samAccountName=escaped, username=escaped, sam=escaped, ) def find_user(self, sam_account_name: str, *, attributes: list[str] | None = None) -> dict[str, Any] | None: sam = str(sam_account_name or "").strip() if not sam: return None conn = self._connect() try: conn.search( self.base_dn, self._format_user_filter(sam), search_scope=SUBTREE, attributes=attributes or ["distinguishedName", "sAMAccountName"], size_limit=2, ) if not conn.entries: return None entry = conn.entries[0] attrs = dict(entry.entry_attributes_as_dict) return {"dn": entry.entry_dn, "attributes": attrs} finally: conn.unbind() def read_user_by_dn(self, dn: str, *, attributes: list[str] | None = None) -> dict[str, Any] | None: clean_dn = str(dn or "").strip() if not clean_dn: return None conn = self._connect() try: conn.search( clean_dn, "(objectClass=*)", search_scope=BASE, attributes=attributes or ["distinguishedName", "sAMAccountName"], ) if not conn.entries: return None entry = conn.entries[0] return {"dn": entry.entry_dn, "attributes": dict(entry.entry_attributes_as_dict)} finally: conn.unbind() @staticmethod def _normalize_change_value(value: Any) -> list[Any]: if isinstance(value, (list, tuple, set)): return [x for x in value if x is not None and str(x).strip() != ""] if value is None or str(value).strip() == "": return [] return [value] def modify_user(self, dn: str, attributes: dict[str, Any], *, dry_run: bool = False) -> bool: clean_dn = str(dn or "").strip() if not clean_dn: raise ValueError("dn is required") changes: dict[str, Any] = {} for attr, value in attributes.items(): attr_name = str(attr or "").strip() values = self._normalize_change_value(value) if not attr_name or not values: continue changes[attr_name] = [(MODIFY_REPLACE, values)] if not changes: return False if dry_run: logger.info("AD dry_run: dn=%s changes=%s", clean_dn, sorted(changes.keys())) return True conn = self._connect() try: ok = bool(conn.modify(clean_dn, changes)) if not ok: raise RuntimeError(f"AD modify failed dn={clean_dn!r} result={conn.result!r}") logger.info("AD modify success: dn=%s result=%s changed_attrs=%s", clean_dn, conn.result, sorted(changes.keys())) return True finally: conn.unbind()