feat: 4th extension

This commit is contained in:
Marsway 2026-04-29 10:39:12 +08:00
parent ed76f85f12
commit ea5e6cca0e
3 changed files with 690 additions and 0 deletions

View File

@ -0,0 +1,6 @@
"""EHR 到本地 AD 的同步扩展。"""
from extensions.sync_ehr_to_ad.api import ActiveDirectoryClient, SyncEhrToAdApi
from extensions.sync_ehr_to_ad.job import SyncEhrToAdUserJob
__all__ = ["ActiveDirectoryClient", "SyncEhrToAdApi", "SyncEhrToAdUserJob"]

View File

@ -0,0 +1,160 @@
from __future__ import annotations
import logging
import ssl
from typing import Any
from ldap3 import ALL, MODIFY_REPLACE, SUBTREE, Connection, Server, Tls
from ldap3.utils.conv import escape_filter_chars
from extensions.sync_ehr_to_oa.api import SyncEhrToOaApi
logger = logging.getLogger("connecthub.extensions.sync_ehr_to_ad")
class SyncEhrToAdApi(SyncEhrToOaApi):
"""北森 EHR -> AD 同步用 EHR API 封装。"""
def __init__(
self,
*,
secret_params: dict[str, str],
base_url: str = "https://openapi.italent.cn",
timeout_s: float = 10.0,
retries: int = 2,
retry_backoff_s: float = 0.5,
) -> None:
super().__init__(
secret_params=secret_params,
sqlserver_params=None,
base_url=base_url,
timeout_s=timeout_s,
retries=retries,
retry_backoff_s=retry_backoff_s,
)
class ActiveDirectoryClient:
"""使用 ldap3 直接更新本地 AD 用户属性。"""
def __init__(
self,
*,
ldap_uri: str,
bind_dn: str,
bind_password: str,
base_dn: str,
user_filter: str = "(sAMAccountName={sAMAccountName})",
use_starttls: bool = False,
verify_tls: bool = True,
connect_timeout_s: int = 10,
) -> None:
self.ldap_uri = str(ldap_uri or "").strip()
self.bind_dn = str(bind_dn or "").strip()
self.bind_password = str(bind_password or "")
self.base_dn = str(base_dn or "").strip()
self.user_filter = str(user_filter or "(sAMAccountName={sAMAccountName})").strip()
self.use_starttls = bool(use_starttls)
self.verify_tls = bool(verify_tls)
self.connect_timeout_s = int(connect_timeout_s or 10)
if not self.ldap_uri:
raise ValueError("ldap_uri is required")
if not self.bind_dn or not self.bind_password:
raise ValueError("bind_dn and bind_password are required")
if not self.base_dn:
raise ValueError("base_dn is required")
tls = None
if self.ldap_uri.startswith("ldaps://") or self.use_starttls:
tls = Tls(validate=ssl.CERT_REQUIRED if self.verify_tls else ssl.CERT_NONE)
self._server = Server(
self.ldap_uri,
use_ssl=self.ldap_uri.startswith("ldaps://"),
get_info=ALL,
tls=tls,
connect_timeout=self.connect_timeout_s,
)
def _connect(self) -> Connection:
conn = Connection(self._server, user=self.bind_dn, password=self.bind_password, auto_bind=False)
conn.open()
if self.use_starttls and self._server.ssl is False:
conn.start_tls()
if not conn.bind():
raise RuntimeError(f"AD bind failed: {conn.result}")
return conn
def ping(self) -> bool:
conn = self._connect()
try:
return bool(conn.bound)
finally:
conn.unbind()
def _format_user_filter(self, sam_account_name: str) -> str:
escaped = escape_filter_chars(str(sam_account_name or "").strip())
return self.user_filter.format(
sAMAccountName=escaped,
samAccountName=escaped,
username=escaped,
sam=escaped,
)
def find_user(self, sam_account_name: str, *, attributes: list[str] | None = None) -> dict[str, Any] | None:
sam = str(sam_account_name or "").strip()
if not sam:
return None
conn = self._connect()
try:
conn.search(
self.base_dn,
self._format_user_filter(sam),
search_scope=SUBTREE,
attributes=attributes or ["distinguishedName", "sAMAccountName"],
size_limit=2,
)
if not conn.entries:
return None
entry = conn.entries[0]
attrs = dict(entry.entry_attributes_as_dict)
return {"dn": entry.entry_dn, "attributes": attrs}
finally:
conn.unbind()
@staticmethod
def _normalize_change_value(value: Any) -> list[Any]:
if isinstance(value, (list, tuple, set)):
return [x for x in value if x is not None and str(x).strip() != ""]
if value is None or str(value).strip() == "":
return []
return [value]
def modify_user(self, dn: str, attributes: dict[str, Any], *, dry_run: bool = False) -> bool:
clean_dn = str(dn or "").strip()
if not clean_dn:
raise ValueError("dn is required")
changes: dict[str, Any] = {}
for attr, value in attributes.items():
attr_name = str(attr or "").strip()
values = self._normalize_change_value(value)
if not attr_name or not values:
continue
changes[attr_name] = [(MODIFY_REPLACE, values)]
if not changes:
return False
if dry_run:
logger.info("AD dry_run: dn=%s changes=%s", clean_dn, sorted(changes.keys()))
return True
conn = self._connect()
try:
ok = bool(conn.modify(clean_dn, changes))
if not ok:
raise RuntimeError(f"AD modify failed dn={clean_dn!r} result={conn.result!r}")
return True
finally:
conn.unbind()

View File

@ -0,0 +1,524 @@
from __future__ import annotations
import json
import logging
import re
from typing import Any
from app.jobs.base import BaseJob
from extensions.sync_ehr_to_ad.api import ActiveDirectoryClient, SyncEhrToAdApi
logger = logging.getLogger("connecthub.extensions.sync_ehr_to_ad")
_EHR_AD_ACCOUNT_KEY = "extADAccountName_606508_511687157"
_EHR_WORK_LOCATION_TEXT_KEY = "extgzddxx1_606508_892394263Text"
_EHR_STREET_ADDRESS_KEY = "extgzddxx_606508_618643707"
def _to_bool_or_none(v: Any) -> bool | None:
if v is None:
return None
if isinstance(v, bool):
return v
s = str(v).strip().lower()
if s in ("1", "true", "yes", "y", "on"):
return True
if s in ("0", "false", "no", "n", "off", ""):
return False
return bool(v)
def _to_int_safe(v: Any) -> int:
try:
return int(str(v).strip())
except Exception:
return 0
def _scalar_value(raw: Any) -> str:
if isinstance(raw, dict):
val = raw.get("value")
if val is None or str(val).strip() == "":
val = raw.get("showValue")
return str(val or "").strip()
return str(raw or "").strip()
def _custom_prop_value(custom_props: Any, key: str | None) -> str:
if not key or not isinstance(custom_props, dict):
return ""
return _scalar_value(custom_props.get(key))
def _translate_value(node: dict[str, Any], key: str | None) -> str:
if not key or not isinstance(node, dict):
return ""
translate = node.get("translateProperties")
if not isinstance(translate, dict):
return ""
candidates = [key, f"{key}Text"]
if key.endswith("Text"):
candidates.insert(0, key)
candidates.append(key[:-4])
for candidate in candidates:
s = str(translate.get(candidate) or "").strip()
if s:
return s
return ""
def _field_value(item: dict[str, Any], key: str) -> str:
emp = item.get("employeeInfo") or {}
rec = item.get("recordInfo") or {}
for node in (emp, rec):
if not isinstance(node, dict):
continue
s = _scalar_value(node.get(key))
if s:
return s
s = _translate_value(node, key)
if s:
return s
s = _custom_prop_value(node.get("customProperties"), key)
if s:
return s
return ""
def _field_translate_or_value(item: dict[str, Any], key: str) -> str:
emp = item.get("employeeInfo") or {}
rec = item.get("recordInfo") or {}
for node in (emp, rec):
if isinstance(node, dict):
s = _translate_value(node, key)
if s:
return s
return _field_value(item, key)
def _extract_mobile_phone(emp_info: dict[str, Any]) -> str:
candidate_keys = (
"mobile",
"mobilePhone",
"MobilePhone",
"phone",
"phoneNumber",
"PhoneNumber",
"tel",
"telephone",
)
for key in candidate_keys:
s = _scalar_value(emp_info.get(key))
if s:
return s
translate = emp_info.get("translateProperties")
if isinstance(translate, dict):
for key in candidate_keys:
s = str(translate.get(key) or "").strip()
if s:
return s
return ""
def _choose_better_record(current: dict[str, Any], candidate: dict[str, Any]) -> dict[str, Any]:
def _score(item: dict[str, Any]) -> str:
record = item.get("recordInfo") or {}
emp = item.get("employeeInfo") or {}
return "|".join(
[
str(record.get("businessModifiedTime") or ""),
str(record.get("modifiedTime") or ""),
str(emp.get("businessModifiedTime") or ""),
str(emp.get("modifiedTime") or ""),
str(record.get("createdTime") or ""),
str(emp.get("createdTime") or ""),
]
)
return candidate if _score(candidate) >= _score(current) else current
def _has_cjk(s: str) -> bool:
return bool(re.search(r"[\u4e00-\u9fff]", str(s or "")))
def _display_name(given_name: str, sn: str, name: str) -> str:
english = " ".join([x for x in (given_name.strip(), sn.strip()) if x]).strip()
raw_name = str(name or "").strip()
if raw_name and _has_cjk(raw_name):
return " ".join([x for x in (english, raw_name) if x]).strip()
return english or raw_name
def _proxy_addresses(email: str, sam: str, alias_domain: str | None) -> list[str]:
clean_email = str(email or "").strip()
clean_sam = str(sam or "").strip()
if not clean_email:
return []
domain = str(alias_domain or "").strip()
if not domain and "@" in clean_email:
domain = clean_email.split("@", 1)[1]
values = [f"SMTP:{clean_email}"]
if clean_sam and domain:
values.append(f"smtp:{clean_sam}@{domain}")
out: list[str] = []
seen: set[str] = set()
for value in values:
key = value.lower()
if key in seen:
continue
seen.add(key)
out.append(value)
return out
def _location_from_workplace(workplace: str, mappings: dict[str, Any] | None = None) -> dict[str, Any]:
text = str(workplace or "").strip()
defaults: dict[str, dict[str, Any]] = {
"上海": {"co": "China", "c": "CN", "countryCode": 156, "st": "Shanghai", "l": "Shanghai"},
"shanghai": {"co": "China", "c": "CN", "countryCode": 156, "st": "Shanghai", "l": "Shanghai"},
"北京": {"co": "China", "c": "CN", "countryCode": 156, "st": "Beijing", "l": "Beijing"},
"beijing": {"co": "China", "c": "CN", "countryCode": 156, "st": "Beijing", "l": "Beijing"},
"深圳": {"co": "China", "c": "CN", "countryCode": 156, "st": "Guangdong", "l": "Shenzhen"},
"shenzhen": {"co": "China", "c": "CN", "countryCode": 156, "st": "Guangdong", "l": "Shenzhen"},
}
merged = dict(defaults)
if isinstance(mappings, dict):
for k, v in mappings.items():
if isinstance(v, dict):
merged[str(k).strip().lower()] = v
lower_text = text.lower()
for needle, value in merged.items():
if needle and needle in lower_text:
return dict(value)
if "中国" in text or "china" in lower_text:
return {"co": "China", "c": "CN", "countryCode": 156}
return {}
def _org_name(org: dict[str, Any]) -> str:
for key in ("name", "Name", "shortName", "ShortName"):
s = str(org.get(key) or "").strip()
if s:
return s
return ""
def _org_code(org: dict[str, Any]) -> str:
for key in ("code", "Code", "orgCode", "OrgCode"):
s = str(org.get(key) or "").strip()
if s:
return s
return ""
def _root_org_name(org: dict[str, Any], org_by_oid: dict[str, dict[str, Any]]) -> str:
cur = org
seen: set[str] = set()
last_name = _org_name(cur)
while isinstance(cur, dict):
oid = str(cur.get("oId") or cur.get("oid") or cur.get("id") or "").strip()
if oid:
if oid in seen:
break
seen.add(oid)
name = _org_name(cur)
if name:
last_name = name
parent_oid = str(
cur.get("pOId")
or cur.get("parentOId")
or cur.get("oIdParent")
or cur.get("parentId")
or cur.get("ParentId")
or ""
).strip()
if not parent_oid or parent_oid not in org_by_oid:
break
cur = org_by_oid[parent_oid]
return last_name
def _is_current_employee(item: dict[str, Any], current_status_values: set[str]) -> bool:
rec = item.get("recordInfo") or {}
emp = item.get("employeeInfo") or {}
if not isinstance(rec, dict) or not isinstance(emp, dict):
return False
if str(rec.get("lastWorkDate") or "").strip():
return False
status = _field_translate_or_value(item, "EmployeeStatus")
if current_status_values and status and status not in current_status_values:
return False
for key in ("isDeleted", "IsDeleted", "deleted", "disabled", "Disabled"):
raw = emp.get(key, rec.get(key))
if _to_bool_or_none(raw) is True:
return False
return True
def _job_post_name(job_post: dict[str, Any]) -> str:
for key in ("name", "Name", "jobPostName", "JobPostName"):
s = str(job_post.get(key) or "").strip()
if s:
return s
return ""
class SyncEhrToAdUserJob(BaseJob):
"""
EHR 当前人员 -> 本地 AD 用户属性同步
- 只更新 AD 中已存在的 sAMAccountName 用户不自动创建用户
- AD 连接信息从 Job params/secrets 注入
"""
job_id = "sync_ehr_to_ad.sync_users"
def run(self, params: dict[str, Any], secrets: dict[str, Any]) -> dict[str, Any]:
app_key = str(secrets.get("app_key") or "").strip()
app_secret = str(secrets.get("app_secret") or "").strip()
if not app_key or not app_secret:
raise ValueError("secret_cfg.app_key and secret_cfg.app_secret are required")
ldap_uri = str(params.get("ldap_uri") or secrets.get("ldap_uri") or "").strip()
ldap_base_dn = str(params.get("ldap_base_dn") or secrets.get("ldap_base_dn") or "").strip()
ldap_bind_dn = str(params.get("ldap_bind_dn") or secrets.get("ldap_bind_dn") or "").strip()
ldap_bind_password = str(params.get("ldap_bind_password") or secrets.get("ldap_bind_password") or "")
if not ldap_uri or not ldap_base_dn or not ldap_bind_dn or not ldap_bind_password:
raise ValueError("ldap_uri/ldap_base_dn/ldap_bind_dn/ldap_bind_password are required")
ldap_user_filter = str(params.get("ldap_user_filter") or "(sAMAccountName={sAMAccountName})").strip()
ldap_use_starttls = _to_bool_or_none(params.get("ldap_use_starttls"))
ldap_verify_tls = _to_bool_or_none(params.get("ldap_verify_tls"))
dry_run = _to_bool_or_none(params.get("dry_run"))
verbose_trace = _to_bool_or_none(params.get("verbose_trace"))
if ldap_use_starttls is None:
ldap_use_starttls = False
if ldap_verify_tls is None:
ldap_verify_tls = True
if dry_run is None:
dry_run = False
if verbose_trace is None:
verbose_trace = True
stop_time = params.get("stop_time")
capacity = int(params.get("capacity") or 300)
if capacity <= 0 or capacity > 300:
capacity = 300
max_users = int(params.get("max_users") or 0)
connect_timeout_s = int(params.get("ldap_connect_timeout_s") or 10)
domain_account_key = str(params.get("domain_account_custom_key") or "").strip() or _EHR_AD_ACCOUNT_KEY
work_location_text_key = str(params.get("work_location_text_key") or "").strip() or _EHR_WORK_LOCATION_TEXT_KEY
street_address_key = str(params.get("street_address_key") or "").strip() or _EHR_STREET_ADDRESS_KEY
proxy_alias_domain = str(params.get("proxy_alias_domain") or "").strip() or None
department_code_attr = str(params.get("department_code_ad_attribute") or "departmentNumber").strip()
postal_code = str(params.get("postal_code") or "").strip()
default_company = str(params.get("default_company") or "").strip()
current_status_values_param = params.get("current_status_values")
if isinstance(current_status_values_param, list):
current_status_values = {str(x).strip() for x in current_status_values_param if str(x).strip()}
elif str(current_status_values_param or "").strip():
current_status_values = {x.strip() for x in str(current_status_values_param).split(",") if x.strip()}
else:
current_status_values = set()
location_mappings = params.get("location_mappings")
location_mappings = location_mappings if isinstance(location_mappings, dict) else None
ehr = SyncEhrToAdApi(secret_params={"app_key": app_key, "app_secret": app_secret})
ad = ActiveDirectoryClient(
ldap_uri=ldap_uri,
bind_dn=ldap_bind_dn,
bind_password=ldap_bind_password,
base_dn=ldap_base_dn,
user_filter=ldap_user_filter,
use_starttls=ldap_use_starttls,
verify_tls=ldap_verify_tls,
connect_timeout_s=connect_timeout_s,
)
try:
ad.ping()
logger.info("AD 连接检查通过uri=%s base_dn=%s dry_run=%s", ldap_uri, ldap_base_dn, dry_run)
emp_res = ehr.get_all_employees_with_record_by_time_window(
stop_time=stop_time,
capacity=capacity,
with_disabled=False,
is_with_deleted=False,
)
org_res = ehr.get_all_organizations_by_time_window(
stop_time=stop_time,
capacity=capacity,
with_disabled=False,
is_with_deleted=False,
)
job_post_res = ehr.get_all_job_posts_by_time_window(
stop_time=stop_time,
capacity=capacity,
with_disabled=False,
is_with_deleted=False,
)
emp_rows = emp_res.get("data") or []
org_rows = org_res.get("data") or []
job_post_rows = job_post_res.get("data") or []
if not isinstance(emp_rows, list) or not isinstance(org_rows, list) or not isinstance(job_post_rows, list):
raise RuntimeError("EHR result invalid: data is not list")
org_by_oid: dict[str, dict[str, Any]] = {}
for org in org_rows:
if not isinstance(org, dict):
continue
oid = str(org.get("oId") or org.get("oid") or org.get("id") or "").strip()
if oid:
org_by_oid[oid] = org
job_post_by_oid: dict[str, dict[str, Any]] = {}
for job_post in job_post_rows:
if not isinstance(job_post, dict):
continue
oid = str(job_post.get("oId") or job_post.get("oid") or job_post.get("id") or "").strip()
if oid:
job_post_by_oid[oid] = job_post
users_by_sam: dict[str, dict[str, Any]] = {}
user_id_to_sam: dict[int, str] = {}
for item in emp_rows:
if not isinstance(item, dict) or not _is_current_employee(item, current_status_values):
continue
sam = _field_value(item, domain_account_key)
if not sam:
continue
existing = users_by_sam.get(sam.lower())
users_by_sam[sam.lower()] = item if existing is None else _choose_better_record(existing, item)
user_id = _to_int_safe((item.get("employeeInfo") or {}).get("userID"))
if user_id > 0:
user_id_to_sam[user_id] = sam
if max_users > 0 and len(users_by_sam) >= max_users:
break
logger.info(
"EHR 当前用户准备完成employee_rows=%s current_with_ad_account=%s org_rows=%s job_post_rows=%s",
len(emp_rows),
len(users_by_sam),
len(org_rows),
len(job_post_rows),
)
processed = 0
updated = 0
skipped_missing_sam = 0
skipped_not_found_ad = 0
failed = 0
manager_dn_cache: dict[str, str] = {}
for sam_key, item in users_by_sam.items():
emp = item.get("employeeInfo") or {}
rec = item.get("recordInfo") or {}
if not isinstance(emp, dict):
emp = {}
if not isinstance(rec, dict):
rec = {}
sam = _field_value(item, domain_account_key)
if not sam:
skipped_missing_sam += 1
continue
processed += 1
try:
ad_user = ad.find_user(sam)
if not ad_user:
skipped_not_found_ad += 1
logger.warning("AD 用户不存在跳过sAMAccountName=%s", sam)
continue
org_oid = str(rec.get("oIdDepartment") or rec.get("OIdDepartment") or "").strip()
org = org_by_oid.get(org_oid, {})
job_post_oid = str(rec.get("oIdJobPost") or rec.get("OIdJobPost") or "").strip()
job_post = job_post_by_oid.get(job_post_oid, {})
given_name = _field_value(item, "PhoneticOfMing")
surname = _field_value(item, "PhoneticOfXing")
name = str(emp.get("name") or emp.get("Name") or "").strip()
title = _field_translate_or_value(item, "OIdJobPost") or _job_post_name(job_post)
department = _field_translate_or_value(item, "OIdDepartment") or _org_name(org)
department_code = _org_code(org)
employee_status = _field_translate_or_value(item, "EmployeeStatus")
email = _field_value(item, "Email") or _field_value(item, "email")
job_number = str(rec.get("jobNumber") or rec.get("JobNumber") or "").strip()
mobile = _field_value(item, "MobilePhone") or _extract_mobile_phone(emp)
office = _field_translate_or_value(item, "Place")
workplace_text = _field_translate_or_value(item, work_location_text_key)
street_address = _field_value(item, street_address_key)
company = default_company or _root_org_name(org, org_by_oid)
location_attrs = _location_from_workplace(workplace_text or office, location_mappings)
manager_dn = ""
manager_uid = _to_int_safe(rec.get("pOIdEmpAdmin") or rec.get("POIdEmpAdmin"))
manager_sam = user_id_to_sam.get(manager_uid, "")
if manager_sam:
manager_dn = manager_dn_cache.get(manager_sam, "")
if not manager_dn:
manager_ad_user = ad.find_user(manager_sam)
if manager_ad_user:
manager_dn = str(manager_ad_user.get("dn") or "")
manager_dn_cache[manager_sam] = manager_dn
attributes: dict[str, Any] = {
"sAMAccountName": sam,
"givenName": given_name,
"sn": surname,
"title": title,
"department": department,
"manager": manager_dn,
"proxyAddresses": _proxy_addresses(email, sam, proxy_alias_domain),
"co": location_attrs.get("co"),
"c": location_attrs.get("c"),
"countryCode": location_attrs.get("countryCode"),
"company": company,
"displayName": _display_name(given_name, surname, name),
"mail": email,
"employeeID": job_number,
"employeeType": employee_status,
"mobile": mobile,
"physicalDeliveryOfficeName": office,
"postalCode": postal_code,
"st": location_attrs.get("st"),
"l": location_attrs.get("l"),
"streetAddress": street_address,
}
if department_code_attr and department_code:
attributes[department_code_attr] = department_code
changed = ad.modify_user(str(ad_user["dn"]), attributes, dry_run=dry_run)
if changed:
updated += 1
if verbose_trace:
logger.info(
"AD 用户同步完成sam=%s dn=%s attrs=%s",
sam,
ad_user["dn"],
json.dumps({k: v for k, v in attributes.items() if v}, ensure_ascii=False, default=str),
)
except Exception as e: # noqa: BLE001
failed += 1
logger.exception("AD 用户同步失败sam_key=%s error=%r", sam_key, e)
result = {
"ok": failed == 0,
"dry_run": dry_run,
"ehr_employee_rows": len(emp_rows),
"ehr_current_users_with_ad_account": len(users_by_sam),
"processed": processed,
"updated": updated,
"skipped_missing_sam": skipped_missing_sam,
"skipped_not_found_ad": skipped_not_found_ad,
"failed": failed,
}
logger.info("EHR -> AD 同步结束:%s", result)
return result
finally:
ehr.close()