181 lines
6.1 KiB
Python
181 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import ssl
|
|
from typing import Any
|
|
|
|
from ldap3 import ALL, BASE, MODIFY_REPLACE, SUBTREE, Connection, Server, Tls
|
|
from ldap3.utils.conv import escape_filter_chars
|
|
|
|
from extensions.sync_ehr_to_oa.api import SyncEhrToOaApi
|
|
|
|
|
|
logger = logging.getLogger("connecthub.extensions.sync_ehr_to_ad")
|
|
|
|
|
|
class SyncEhrToAdApi(SyncEhrToOaApi):
|
|
"""北森 EHR -> AD 同步用 EHR API 封装。"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
secret_params: dict[str, str],
|
|
base_url: str = "https://openapi.italent.cn",
|
|
timeout_s: float = 10.0,
|
|
retries: int = 2,
|
|
retry_backoff_s: float = 0.5,
|
|
) -> None:
|
|
super().__init__(
|
|
secret_params=secret_params,
|
|
sqlserver_params=None,
|
|
base_url=base_url,
|
|
timeout_s=timeout_s,
|
|
retries=retries,
|
|
retry_backoff_s=retry_backoff_s,
|
|
)
|
|
|
|
|
|
class ActiveDirectoryClient:
|
|
"""使用 ldap3 直接更新本地 AD 用户属性。"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
ldap_uri: str,
|
|
bind_dn: str,
|
|
bind_password: str,
|
|
base_dn: str,
|
|
user_filter: str = "(sAMAccountName={sAMAccountName})",
|
|
use_starttls: bool = False,
|
|
verify_tls: bool = True,
|
|
connect_timeout_s: int = 10,
|
|
) -> None:
|
|
self.ldap_uri = str(ldap_uri or "").strip()
|
|
self.bind_dn = str(bind_dn or "").strip()
|
|
self.bind_password = str(bind_password or "")
|
|
self.base_dn = str(base_dn or "").strip()
|
|
self.user_filter = str(user_filter or "(sAMAccountName={sAMAccountName})").strip()
|
|
self.use_starttls = bool(use_starttls)
|
|
self.verify_tls = bool(verify_tls)
|
|
self.connect_timeout_s = int(connect_timeout_s or 10)
|
|
|
|
if not self.ldap_uri:
|
|
raise ValueError("ldap_uri is required")
|
|
if not self.bind_dn or not self.bind_password:
|
|
raise ValueError("bind_dn and bind_password are required")
|
|
if not self.base_dn:
|
|
raise ValueError("base_dn is required")
|
|
|
|
tls = None
|
|
if self.ldap_uri.startswith("ldaps://") or self.use_starttls:
|
|
tls = Tls(validate=ssl.CERT_REQUIRED if self.verify_tls else ssl.CERT_NONE)
|
|
self._server = Server(
|
|
self.ldap_uri,
|
|
use_ssl=self.ldap_uri.startswith("ldaps://"),
|
|
get_info=ALL,
|
|
tls=tls,
|
|
connect_timeout=self.connect_timeout_s,
|
|
)
|
|
|
|
def _connect(self) -> Connection:
|
|
conn = Connection(self._server, user=self.bind_dn, password=self.bind_password, auto_bind=False)
|
|
conn.open()
|
|
if self.use_starttls and self._server.ssl is False:
|
|
conn.start_tls()
|
|
if not conn.bind():
|
|
raise RuntimeError(f"AD bind failed: {conn.result}")
|
|
return conn
|
|
|
|
def ping(self) -> bool:
|
|
conn = self._connect()
|
|
try:
|
|
return bool(conn.bound)
|
|
finally:
|
|
conn.unbind()
|
|
|
|
def _format_user_filter(self, sam_account_name: str) -> str:
|
|
escaped = escape_filter_chars(str(sam_account_name or "").strip())
|
|
return self.user_filter.format(
|
|
sAMAccountName=escaped,
|
|
samAccountName=escaped,
|
|
username=escaped,
|
|
sam=escaped,
|
|
)
|
|
|
|
def find_user(self, sam_account_name: str, *, attributes: list[str] | None = None) -> dict[str, Any] | None:
|
|
sam = str(sam_account_name or "").strip()
|
|
if not sam:
|
|
return None
|
|
conn = self._connect()
|
|
try:
|
|
conn.search(
|
|
self.base_dn,
|
|
self._format_user_filter(sam),
|
|
search_scope=SUBTREE,
|
|
attributes=attributes or ["distinguishedName", "sAMAccountName"],
|
|
size_limit=2,
|
|
)
|
|
if not conn.entries:
|
|
return None
|
|
entry = conn.entries[0]
|
|
attrs = dict(entry.entry_attributes_as_dict)
|
|
return {"dn": entry.entry_dn, "attributes": attrs}
|
|
finally:
|
|
conn.unbind()
|
|
|
|
def read_user_by_dn(self, dn: str, *, attributes: list[str] | None = None) -> dict[str, Any] | None:
|
|
clean_dn = str(dn or "").strip()
|
|
if not clean_dn:
|
|
return None
|
|
conn = self._connect()
|
|
try:
|
|
conn.search(
|
|
clean_dn,
|
|
"(objectClass=*)",
|
|
search_scope=BASE,
|
|
attributes=attributes or ["distinguishedName", "sAMAccountName"],
|
|
)
|
|
if not conn.entries:
|
|
return None
|
|
entry = conn.entries[0]
|
|
return {"dn": entry.entry_dn, "attributes": dict(entry.entry_attributes_as_dict)}
|
|
finally:
|
|
conn.unbind()
|
|
|
|
@staticmethod
|
|
def _normalize_change_value(value: Any) -> list[Any]:
|
|
if isinstance(value, (list, tuple, set)):
|
|
return [x for x in value if x is not None and str(x).strip() != ""]
|
|
if value is None or str(value).strip() == "":
|
|
return []
|
|
return [value]
|
|
|
|
def modify_user(self, dn: str, attributes: dict[str, Any], *, dry_run: bool = False) -> bool:
|
|
clean_dn = str(dn or "").strip()
|
|
if not clean_dn:
|
|
raise ValueError("dn is required")
|
|
|
|
changes: dict[str, Any] = {}
|
|
for attr, value in attributes.items():
|
|
attr_name = str(attr or "").strip()
|
|
values = self._normalize_change_value(value)
|
|
if not attr_name or not values:
|
|
continue
|
|
changes[attr_name] = [(MODIFY_REPLACE, values)]
|
|
|
|
if not changes:
|
|
return False
|
|
if dry_run:
|
|
logger.info("AD dry_run: dn=%s changes=%s", clean_dn, sorted(changes.keys()))
|
|
return True
|
|
|
|
conn = self._connect()
|
|
try:
|
|
ok = bool(conn.modify(clean_dn, changes))
|
|
if not ok:
|
|
raise RuntimeError(f"AD modify failed dn={clean_dn!r} result={conn.result!r}")
|
|
logger.info("AD modify success: dn=%s result=%s changed_attrs=%s", clean_dn, conn.result, sorted(changes.keys()))
|
|
return True
|
|
finally:
|
|
conn.unbind()
|