546 lines
20 KiB
Python
546 lines
20 KiB
Python
from __future__ import annotations
|
||
|
||
import logging
|
||
from datetime import datetime, timedelta
|
||
from typing import Any
|
||
from urllib.parse import quote_plus
|
||
|
||
from sqlalchemy import create_engine, text
|
||
from sqlalchemy.engine import Engine
|
||
|
||
from app.integrations.ehr import EhrClient
|
||
|
||
|
||
logger = logging.getLogger("connecthub.extensions.sync_ehr_to_oa")
|
||
|
||
|
||
class SyncEhrToOaApi:
|
||
"""
|
||
北森 EHR -> OA 同步 API 封装。
|
||
|
||
已封装 API:
|
||
- 员工与单条任职时间窗滚动查询
|
||
- 组织单元时间窗滚动查询
|
||
- 职务时间窗滚动查询
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
secret_params: dict[str, str],
|
||
sqlserver_params: dict[str, Any] | None = None,
|
||
base_url: str = "https://openapi.italent.cn",
|
||
timeout_s: float = 10.0,
|
||
retries: int = 2,
|
||
retry_backoff_s: float = 0.5,
|
||
) -> None:
|
||
self._client = EhrClient(
|
||
base_url=base_url,
|
||
secret_params=secret_params,
|
||
timeout_s=timeout_s,
|
||
retries=retries,
|
||
retry_backoff_s=retry_backoff_s,
|
||
)
|
||
self._sqlserver_params = dict(sqlserver_params or {})
|
||
self._sql_engine: Engine | None = None
|
||
if self._sqlserver_params:
|
||
self._sql_engine = self._create_sqlserver_engine(self._sqlserver_params)
|
||
|
||
def close(self) -> None:
|
||
self._client.close()
|
||
if self._sql_engine is not None:
|
||
self._sql_engine.dispose()
|
||
self._sql_engine = None
|
||
|
||
@staticmethod
|
||
def _create_sqlserver_engine(sqlserver_params: dict[str, Any]) -> Engine:
|
||
host = str(sqlserver_params.get("host") or "").strip()
|
||
username = str(sqlserver_params.get("username") or "").strip()
|
||
password = str(sqlserver_params.get("password") or "").strip()
|
||
database = str(sqlserver_params.get("database") or "").strip()
|
||
if not host or not username or not password or not database:
|
||
raise ValueError("sqlserver_params.host/username/password/database are required")
|
||
|
||
port = int(sqlserver_params.get("port") or 1433)
|
||
timeout_s = int(sqlserver_params.get("connect_timeout_s") or 10)
|
||
driver = str(sqlserver_params.get("driver") or "ODBC Driver 18 for SQL Server").strip()
|
||
trust_server_certificate = str(sqlserver_params.get("trust_server_certificate") or "yes").strip().lower()
|
||
encrypt = str(sqlserver_params.get("encrypt") or "yes").strip().lower()
|
||
|
||
# 依赖 mssql+pyodbc;运行环境需安装 pyodbc 与对应 ODBC driver。
|
||
# 例:mssql+pyodbc://user:pass@host:1433/db?driver=ODBC+Driver+18+for+SQL+Server
|
||
# 使用 query 参数以减少 URL 编码问题。
|
||
username_q = quote_plus(username)
|
||
password_q = quote_plus(password)
|
||
driver_q = quote_plus(driver)
|
||
conn_url = (
|
||
f"mssql+pyodbc://{username_q}:{password_q}@{host}:{port}/{database}"
|
||
f"?driver={driver_q}&TrustServerCertificate={trust_server_certificate}&Encrypt={encrypt}"
|
||
)
|
||
return create_engine(conn_url, pool_pre_ping=True, pool_recycle=1800, connect_args={"timeout": timeout_s})
|
||
|
||
def ping_sqlserver(self) -> bool:
|
||
if self._sql_engine is None:
|
||
raise RuntimeError("SQL Server engine is not initialized; pass sqlserver_params when creating SyncEhrToOaApi")
|
||
with self._sql_engine.connect() as conn:
|
||
conn.execute(text("SELECT 1"))
|
||
return True
|
||
|
||
def get_oa_record_id_map_from_sqlserver(
|
||
self,
|
||
*,
|
||
table_name: str,
|
||
job_numbers: list[str],
|
||
job_no_column: str = "field0001",
|
||
id_column: str = "id",
|
||
schema: str | None = "dbo",
|
||
) -> dict[str, int]:
|
||
"""
|
||
从 OA SQLServer 表按工号查询记录ID映射。
|
||
返回:{job_no: id}
|
||
"""
|
||
if self._sql_engine is None:
|
||
raise RuntimeError("SQL Server engine is not initialized; pass sqlserver_params when creating SyncEhrToOaApi")
|
||
t = str(table_name or "").strip()
|
||
jc = str(job_no_column or "").strip()
|
||
ic = str(id_column or "").strip()
|
||
if not t or not jc or not ic:
|
||
raise ValueError("table_name/job_no_column/id_column are required")
|
||
if not job_numbers:
|
||
return {}
|
||
|
||
clean_job_numbers = [str(x or "").strip() for x in job_numbers if str(x or "").strip()]
|
||
if not clean_job_numbers:
|
||
return {}
|
||
|
||
quoted_table = f"[{t}]"
|
||
if schema:
|
||
quoted_table = f"[{str(schema).strip()}].{quoted_table}"
|
||
|
||
out: dict[str, int] = {}
|
||
# SQL Server IN 参数分批,避免参数过多;每批 500。
|
||
chunk_size = 500
|
||
with self._sql_engine.connect() as conn:
|
||
for i in range(0, len(clean_job_numbers), chunk_size):
|
||
chunk = clean_job_numbers[i : i + chunk_size]
|
||
binds = ", ".join([f":v{idx}" for idx in range(len(chunk))])
|
||
sql = text(
|
||
f"SELECT [{jc}] AS job_no, [{ic}] AS row_id "
|
||
f"FROM {quoted_table} WITH (NOLOCK) "
|
||
f"WHERE [{jc}] IN ({binds})"
|
||
)
|
||
params = {f"v{idx}": chunk[idx] for idx in range(len(chunk))}
|
||
rows = conn.execute(sql, params).fetchall()
|
||
for r in rows:
|
||
job_no = str(r.job_no or "").strip()
|
||
try:
|
||
row_id = int(r.row_id)
|
||
except Exception:
|
||
continue
|
||
if job_no:
|
||
out[job_no] = row_id
|
||
logger.info(
|
||
"SQLServer 工号映射查询完成:table=%s schema=%s input=%s matched=%s",
|
||
t,
|
||
schema or "",
|
||
len(clean_job_numbers),
|
||
len(out),
|
||
)
|
||
return out
|
||
|
||
@staticmethod
|
||
def _to_datetime(value: datetime | str | None) -> datetime:
|
||
if value is None:
|
||
return datetime.now()
|
||
if isinstance(value, datetime):
|
||
return value
|
||
s = str(value).strip()
|
||
if not s:
|
||
raise ValueError("datetime string cannot be empty")
|
||
if "T" in s:
|
||
return datetime.fromisoformat(s)
|
||
if " " in s:
|
||
return datetime.fromisoformat(s.replace(" ", "T"))
|
||
return datetime.strptime(s, "%Y-%m-%d")
|
||
|
||
@staticmethod
|
||
def _to_api_datetime(value: datetime | str | None) -> str:
|
||
return SyncEhrToOaApi._to_datetime(value).strftime("%Y-%m-%dT%H:%M:%S")
|
||
|
||
@staticmethod
|
||
def _iter_windows(start_dt: datetime, stop_dt: datetime, max_days: int = 90) -> list[tuple[datetime, datetime]]:
|
||
if stop_dt < start_dt:
|
||
raise ValueError("stop_time must be greater than or equal to start_time")
|
||
windows: list[tuple[datetime, datetime]] = []
|
||
cur = start_dt
|
||
max_delta = timedelta(days=max_days)
|
||
while cur < stop_dt:
|
||
nxt = min(cur + max_delta, stop_dt)
|
||
windows.append((cur, nxt))
|
||
cur = nxt
|
||
if not windows:
|
||
windows.append((start_dt, stop_dt))
|
||
return windows
|
||
|
||
def _get_all_by_time_window(
|
||
self,
|
||
*,
|
||
api_path: str,
|
||
api_name: str,
|
||
stop_time: datetime | str | None,
|
||
capacity: int,
|
||
time_window_query_type: int,
|
||
with_disabled: bool,
|
||
is_with_deleted: bool,
|
||
max_pages: int,
|
||
) -> dict[str, Any]:
|
||
if capacity <= 0 or capacity > 300:
|
||
raise ValueError("capacity must be in range [1, 300]")
|
||
if max_pages <= 0:
|
||
raise ValueError("max_pages must be > 0")
|
||
|
||
start_dt = datetime(2015, 1, 1, 0, 0, 0)
|
||
stop_dt = self._to_datetime(stop_time)
|
||
windows = self._iter_windows(start_dt=start_dt, stop_dt=stop_dt, max_days=90)
|
||
|
||
all_data: list[dict[str, Any]] = []
|
||
total_pages = 0
|
||
last_scroll_id = ""
|
||
|
||
for idx, (w_start, w_stop) in enumerate(windows, start=1):
|
||
start_time = self._to_api_datetime(w_start)
|
||
stop_time_s = self._to_api_datetime(w_stop)
|
||
scroll_id = ""
|
||
page = 0
|
||
window_total = 0
|
||
|
||
while True:
|
||
page += 1
|
||
total_pages += 1
|
||
if page > max_pages:
|
||
raise RuntimeError(f"scroll pages exceed max_pages={max_pages} in window index={idx}")
|
||
|
||
body: dict[str, Any] = {
|
||
"startTime": start_time,
|
||
"stopTime": stop_time_s,
|
||
"timeWindowQueryType": time_window_query_type,
|
||
"scrollId": scroll_id,
|
||
"capacity": capacity,
|
||
"withDisabled": with_disabled,
|
||
"isWithDeleted": is_with_deleted,
|
||
}
|
||
|
||
resp = self._client.request(
|
||
"POST",
|
||
api_path,
|
||
json=body,
|
||
headers={"Content-Type": "application/json"},
|
||
)
|
||
payload = resp.json() if resp.content else {}
|
||
|
||
code = str(payload.get("code", "") or "")
|
||
if code != "200":
|
||
message = payload.get("message")
|
||
raise RuntimeError(f"EHR {api_name} failed code={code!r} message={message!r}")
|
||
|
||
batch = payload.get("data") or []
|
||
if not isinstance(batch, list):
|
||
raise RuntimeError(f"EHR {api_name} invalid response: data is not a list")
|
||
all_data.extend([x for x in batch if isinstance(x, dict)])
|
||
|
||
total_val = payload.get("total")
|
||
if total_val is not None:
|
||
try:
|
||
window_total = int(total_val)
|
||
except (TypeError, ValueError):
|
||
pass
|
||
|
||
is_last_data = bool(payload.get("isLastData", False))
|
||
scroll_id = str(payload.get("scrollId", "") or "")
|
||
last_scroll_id = scroll_id
|
||
|
||
logger.info(
|
||
"EHR %s window=%s/%s page=%s batch=%s window_total=%s isLastData=%s",
|
||
api_name,
|
||
idx,
|
||
len(windows),
|
||
page,
|
||
len(batch),
|
||
window_total,
|
||
is_last_data,
|
||
)
|
||
|
||
if is_last_data:
|
||
break
|
||
|
||
return {
|
||
"startTime": self._to_api_datetime(start_dt),
|
||
"stopTime": self._to_api_datetime(stop_dt),
|
||
"total": len(all_data),
|
||
"pages": total_pages,
|
||
"count": len(all_data),
|
||
"data": all_data,
|
||
"lastScrollId": last_scroll_id,
|
||
"windowCount": len(windows),
|
||
}
|
||
|
||
def get_all_employees_with_record_by_time_window(
|
||
self,
|
||
*,
|
||
stop_time: datetime | str | None = None,
|
||
capacity: int = 300,
|
||
time_window_query_type: int = 1,
|
||
with_disabled: bool = True,
|
||
is_with_deleted: bool = True,
|
||
max_pages: int = 100000,
|
||
) -> dict[str, Any]:
|
||
"""
|
||
滚动查询“员工 + 单条任职”全量结果。
|
||
|
||
固定起始时间:
|
||
- 2015-01-01T00:00:00
|
||
"""
|
||
return self._get_all_by_time_window(
|
||
api_path="/TenantBaseExternal/api/v5/Employee/GetByTimeWindow",
|
||
api_name="Employee.GetByTimeWindow",
|
||
stop_time=stop_time,
|
||
capacity=capacity,
|
||
time_window_query_type=time_window_query_type,
|
||
with_disabled=with_disabled,
|
||
is_with_deleted=is_with_deleted,
|
||
max_pages=max_pages,
|
||
)
|
||
|
||
def get_all_organizations_by_time_window(
|
||
self,
|
||
*,
|
||
stop_time: datetime | str | None = None,
|
||
capacity: int = 300,
|
||
time_window_query_type: int = 1,
|
||
with_disabled: bool = True,
|
||
is_with_deleted: bool = True,
|
||
max_pages: int = 100000,
|
||
) -> dict[str, Any]:
|
||
"""
|
||
滚动查询“组织单元”全量结果。
|
||
|
||
固定起始时间:
|
||
- 2015-01-01T00:00:00
|
||
"""
|
||
return self._get_all_by_time_window(
|
||
api_path="/TenantBaseExternal/api/v5/Organization/GetByTimeWindow",
|
||
api_name="Organization.GetByTimeWindow",
|
||
stop_time=stop_time,
|
||
capacity=capacity,
|
||
time_window_query_type=time_window_query_type,
|
||
with_disabled=with_disabled,
|
||
is_with_deleted=is_with_deleted,
|
||
max_pages=max_pages,
|
||
)
|
||
|
||
def get_all_job_posts_by_time_window(
|
||
self,
|
||
*,
|
||
stop_time: datetime | str | None = None,
|
||
capacity: int = 300,
|
||
time_window_query_type: int = 1,
|
||
with_disabled: bool = True,
|
||
is_with_deleted: bool = True,
|
||
max_pages: int = 100000,
|
||
) -> dict[str, Any]:
|
||
"""
|
||
滚动查询“职务”全量结果。
|
||
|
||
固定起始时间:
|
||
- 2015-01-01T00:00:00
|
||
"""
|
||
return self._get_all_by_time_window(
|
||
api_path="/TenantBaseExternal/api/v5/JobPost/GetByTimeWindow",
|
||
api_name="JobPost.GetByTimeWindow",
|
||
stop_time=stop_time,
|
||
capacity=capacity,
|
||
time_window_query_type=time_window_query_type,
|
||
with_disabled=with_disabled,
|
||
is_with_deleted=is_with_deleted,
|
||
max_pages=max_pages,
|
||
)
|
||
|
||
@staticmethod
|
||
def _pick_company_from_contracts(contracts: list[dict[str, Any]]) -> str:
|
||
if not contracts:
|
||
return ""
|
||
|
||
def _sort_key(item: dict[str, Any]) -> str:
|
||
return str(item.get("effectiveDate") or item.get("createdTime") or item.get("modifiedTime") or "")
|
||
|
||
sorted_items = sorted([x for x in contracts if isinstance(x, dict)], key=_sort_key, reverse=True)
|
||
for c in sorted_items:
|
||
first_party = str(c.get("firstParty") or "").strip()
|
||
if first_party:
|
||
return first_party
|
||
return ""
|
||
|
||
def get_contract_first_party_by_user_ids(
|
||
self,
|
||
*,
|
||
user_ids: list[int],
|
||
is_current_effective: bool = True,
|
||
status: int | None = 1,
|
||
contract_type: int | None = None,
|
||
is_with_deleted: bool = False,
|
||
columns: list[str] | None = None,
|
||
enable_translate: bool = False,
|
||
chunk_size: int = 300,
|
||
) -> dict[int, str]:
|
||
"""
|
||
调用合同接口按员工 UserID 集合获取所属公司(firstParty)。
|
||
接口:POST /TenantBaseExternal/api/v5/Contract/GetByUserIds
|
||
"""
|
||
if chunk_size <= 0 or chunk_size > 300:
|
||
raise ValueError("chunk_size must be in range [1, 300]")
|
||
if not user_ids:
|
||
return {}
|
||
|
||
clean_ids: list[int] = []
|
||
seen: set[int] = set()
|
||
for u in user_ids:
|
||
try:
|
||
uid = int(u)
|
||
except Exception:
|
||
continue
|
||
if uid <= 0:
|
||
continue
|
||
if uid in seen:
|
||
continue
|
||
seen.add(uid)
|
||
clean_ids.append(uid)
|
||
if not clean_ids:
|
||
return {}
|
||
|
||
out: dict[int, str] = {}
|
||
for i in range(0, len(clean_ids), chunk_size):
|
||
chunk = clean_ids[i : i + chunk_size]
|
||
body: dict[str, Any] = {
|
||
"oIds": chunk,
|
||
"isCurrentEffective": is_current_effective,
|
||
"isWithDeleted": is_with_deleted,
|
||
"enableTranslate": enable_translate,
|
||
}
|
||
if status is not None:
|
||
body["status"] = status
|
||
if contract_type is not None:
|
||
body["contractType"] = contract_type
|
||
if columns is not None:
|
||
body["columns"] = columns
|
||
|
||
resp = self._client.request(
|
||
"POST",
|
||
"/TenantBaseExternal/api/v5/Contract/GetByUserIds",
|
||
json=body,
|
||
headers={"Content-Type": "application/json"},
|
||
)
|
||
payload = resp.json() if resp.content else {}
|
||
code = str(payload.get("code", "") or "")
|
||
if code != "200":
|
||
raise RuntimeError(f"EHR Contract.GetByUserIds failed code={code!r} message={payload.get('message')!r}")
|
||
|
||
data = payload.get("data") or {}
|
||
if not isinstance(data, dict):
|
||
raise RuntimeError("EHR Contract.GetByUserIds invalid response: data is not an object")
|
||
|
||
for k, v in data.items():
|
||
try:
|
||
uid = int(str(k))
|
||
except Exception:
|
||
continue
|
||
contracts = v if isinstance(v, list) else []
|
||
company = self._pick_company_from_contracts(contracts)
|
||
if company:
|
||
out[uid] = company
|
||
|
||
logger.info(
|
||
"EHR 合同公司查询完成:input_user_ids=%s matched_first_party=%s",
|
||
len(clean_ids),
|
||
len(out),
|
||
)
|
||
return out
|
||
|
||
def get_staff_profiles_by_user_ids(self, *, user_ids: list[int], chunk_size: int = 100) -> dict[int, dict[str, Any]]:
|
||
"""
|
||
调用 UserFrameworkApiV3 获取员工信息(按 userId)。
|
||
接口:GET /UserFrameworkApiV3/api/v1/staffs/Get
|
||
返回:{userId: staff_profile}
|
||
"""
|
||
if chunk_size <= 0:
|
||
chunk_size = 100
|
||
clean_ids: list[int] = []
|
||
seen: set[int] = set()
|
||
for u in user_ids:
|
||
try:
|
||
uid = int(u)
|
||
except Exception:
|
||
continue
|
||
if uid <= 0 or uid in seen:
|
||
continue
|
||
seen.add(uid)
|
||
clean_ids.append(uid)
|
||
if not clean_ids:
|
||
return {}
|
||
|
||
out: dict[int, dict[str, Any]] = {}
|
||
for i in range(0, len(clean_ids), chunk_size):
|
||
chunk = clean_ids[i : i + chunk_size]
|
||
for uid in chunk:
|
||
profile: dict[str, Any] | None = None
|
||
for params in ({"userId": str(uid)}, {"userid": str(uid)}):
|
||
try:
|
||
resp = self._client.request(
|
||
"GET",
|
||
"/UserFrameworkApiV3/api/v1/staffs/Get",
|
||
params=params,
|
||
)
|
||
except Exception:
|
||
continue
|
||
payload = resp.json() if resp.content else {}
|
||
profile = self._find_staff_profile_by_uid(payload, uid)
|
||
if profile is not None:
|
||
break
|
||
if profile is not None:
|
||
out[uid] = profile
|
||
logger.info("EHR 员工详情查询完成:input_user_ids=%s matched_profiles=%s", len(clean_ids), len(out))
|
||
return out
|
||
|
||
@staticmethod
|
||
def _find_staff_profile_by_uid(payload: Any, uid: int) -> dict[str, Any] | None:
|
||
def _iter_dicts(node: Any):
|
||
if isinstance(node, dict):
|
||
yield node
|
||
for v in node.values():
|
||
yield from _iter_dicts(v)
|
||
elif isinstance(node, list):
|
||
for it in node:
|
||
yield from _iter_dicts(it)
|
||
|
||
def _uid_from_dict(d: dict[str, Any]) -> int:
|
||
for k in ("userId", "UserId", "userid", "UserID", "id", "Id", "ID"):
|
||
if k in d:
|
||
try:
|
||
return int(str(d.get(k)).strip())
|
||
except Exception:
|
||
continue
|
||
return 0
|
||
|
||
best: dict[str, Any] | None = None
|
||
for d in _iter_dicts(payload):
|
||
if not isinstance(d, dict):
|
||
continue
|
||
duid = _uid_from_dict(d)
|
||
if duid != uid:
|
||
continue
|
||
# 优先返回带工号字段的对象
|
||
if any(k in d for k in ("staffCode", "StaffCode", "code", "Code", "jobNumber", "JobNumber", "employeeNo", "EmployeeNo")):
|
||
return d
|
||
if best is None:
|
||
best = d
|
||
return best
|