Vastai-ConnectHub/extensions/sync_ehr_to_oa/api.py

519 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import logging
from datetime import datetime, timedelta
from typing import Any
from urllib.parse import quote_plus
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from app.integrations.ehr import EhrClient
logger = logging.getLogger("connecthub.extensions.sync_ehr_to_oa")
class SyncEhrToOaApi:
"""
北森 EHR -> OA 同步 API 封装。
已封装 API
- 员工与单条任职时间窗滚动查询
- 组织单元时间窗滚动查询
- 职务时间窗滚动查询
"""
def __init__(
self,
*,
secret_params: dict[str, str],
sqlserver_params: dict[str, Any] | None = None,
base_url: str = "https://openapi.italent.cn",
timeout_s: float = 10.0,
retries: int = 2,
retry_backoff_s: float = 0.5,
) -> None:
self._client = EhrClient(
base_url=base_url,
secret_params=secret_params,
timeout_s=timeout_s,
retries=retries,
retry_backoff_s=retry_backoff_s,
)
self._sqlserver_params = dict(sqlserver_params or {})
self._sql_engine: Engine | None = None
if self._sqlserver_params:
self._sql_engine = self._create_sqlserver_engine(self._sqlserver_params)
def close(self) -> None:
self._client.close()
if self._sql_engine is not None:
self._sql_engine.dispose()
self._sql_engine = None
@staticmethod
def _create_sqlserver_engine(sqlserver_params: dict[str, Any]) -> Engine:
host = str(sqlserver_params.get("host") or "").strip()
username = str(sqlserver_params.get("username") or "").strip()
password = str(sqlserver_params.get("password") or "").strip()
database = str(sqlserver_params.get("database") or "").strip()
if not host or not username or not password or not database:
raise ValueError("sqlserver_params.host/username/password/database are required")
port = int(sqlserver_params.get("port") or 1433)
timeout_s = int(sqlserver_params.get("connect_timeout_s") or 10)
driver = str(sqlserver_params.get("driver") or "ODBC Driver 18 for SQL Server").strip()
trust_server_certificate = str(sqlserver_params.get("trust_server_certificate") or "yes").strip().lower()
encrypt = str(sqlserver_params.get("encrypt") or "yes").strip().lower()
# 依赖 mssql+pyodbc运行环境需安装 pyodbc 与对应 ODBC driver。
# 例mssql+pyodbc://user:pass@host:1433/db?driver=ODBC+Driver+18+for+SQL+Server
# 使用 query 参数以减少 URL 编码问题。
username_q = quote_plus(username)
password_q = quote_plus(password)
driver_q = quote_plus(driver)
conn_url = (
f"mssql+pyodbc://{username_q}:{password_q}@{host}:{port}/{database}"
f"?driver={driver_q}&TrustServerCertificate={trust_server_certificate}&Encrypt={encrypt}"
)
return create_engine(conn_url, pool_pre_ping=True, pool_recycle=1800, connect_args={"timeout": timeout_s})
def ping_sqlserver(self) -> bool:
if self._sql_engine is None:
raise RuntimeError("SQL Server engine is not initialized; pass sqlserver_params when creating SyncEhrToOaApi")
with self._sql_engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
def get_oa_record_id_map_from_sqlserver(
self,
*,
table_name: str,
job_numbers: list[str],
job_no_column: str = "field0001",
id_column: str = "id",
schema: str | None = "dbo",
) -> dict[str, int]:
"""
从 OA SQLServer 表按工号查询记录ID映射。
返回:{job_no: id}
"""
if self._sql_engine is None:
raise RuntimeError("SQL Server engine is not initialized; pass sqlserver_params when creating SyncEhrToOaApi")
t = str(table_name or "").strip()
jc = str(job_no_column or "").strip()
ic = str(id_column or "").strip()
if not t or not jc or not ic:
raise ValueError("table_name/job_no_column/id_column are required")
if not job_numbers:
return {}
clean_job_numbers = [str(x or "").strip() for x in job_numbers if str(x or "").strip()]
if not clean_job_numbers:
return {}
quoted_table = f"[{t}]"
if schema:
quoted_table = f"[{str(schema).strip()}].{quoted_table}"
out: dict[str, int] = {}
# SQL Server IN 参数分批,避免参数过多;每批 500。
chunk_size = 500
with self._sql_engine.connect() as conn:
for i in range(0, len(clean_job_numbers), chunk_size):
chunk = clean_job_numbers[i : i + chunk_size]
binds = ", ".join([f":v{idx}" for idx in range(len(chunk))])
sql = text(
f"SELECT [{jc}] AS job_no, [{ic}] AS row_id "
f"FROM {quoted_table} WITH (NOLOCK) "
f"WHERE [{jc}] IN ({binds})"
)
params = {f"v{idx}": chunk[idx] for idx in range(len(chunk))}
rows = conn.execute(sql, params).fetchall()
for r in rows:
job_no = str(r.job_no or "").strip()
try:
row_id = int(r.row_id)
except Exception:
continue
if job_no:
out[job_no] = row_id
logger.info(
"SQLServer 工号映射查询完成table=%s schema=%s input=%s matched=%s",
t,
schema or "",
len(clean_job_numbers),
len(out),
)
return out
@staticmethod
def _to_datetime(value: datetime | str | None) -> datetime:
if value is None:
return datetime.now()
if isinstance(value, datetime):
return value
s = str(value).strip()
if not s:
raise ValueError("datetime string cannot be empty")
if "T" in s:
return datetime.fromisoformat(s)
if " " in s:
return datetime.fromisoformat(s.replace(" ", "T"))
return datetime.strptime(s, "%Y-%m-%d")
@staticmethod
def _to_api_datetime(value: datetime | str | None) -> str:
return SyncEhrToOaApi._to_datetime(value).strftime("%Y-%m-%dT%H:%M:%S")
@staticmethod
def _iter_windows(start_dt: datetime, stop_dt: datetime, max_days: int = 90) -> list[tuple[datetime, datetime]]:
if stop_dt < start_dt:
raise ValueError("stop_time must be greater than or equal to start_time")
windows: list[tuple[datetime, datetime]] = []
cur = start_dt
max_delta = timedelta(days=max_days)
while cur < stop_dt:
nxt = min(cur + max_delta, stop_dt)
windows.append((cur, nxt))
cur = nxt
if not windows:
windows.append((start_dt, stop_dt))
return windows
def _get_all_by_time_window(
self,
*,
api_path: str,
api_name: str,
stop_time: datetime | str | None,
capacity: int,
time_window_query_type: int,
with_disabled: bool,
is_with_deleted: bool,
max_pages: int,
) -> dict[str, Any]:
if capacity <= 0 or capacity > 300:
raise ValueError("capacity must be in range [1, 300]")
if max_pages <= 0:
raise ValueError("max_pages must be > 0")
start_dt = datetime(2015, 1, 1, 0, 0, 0)
stop_dt = self._to_datetime(stop_time)
windows = self._iter_windows(start_dt=start_dt, stop_dt=stop_dt, max_days=90)
all_data: list[dict[str, Any]] = []
total_pages = 0
last_scroll_id = ""
for idx, (w_start, w_stop) in enumerate(windows, start=1):
start_time = self._to_api_datetime(w_start)
stop_time_s = self._to_api_datetime(w_stop)
scroll_id = ""
page = 0
window_total = 0
while True:
page += 1
total_pages += 1
if page > max_pages:
raise RuntimeError(f"scroll pages exceed max_pages={max_pages} in window index={idx}")
body: dict[str, Any] = {
"startTime": start_time,
"stopTime": stop_time_s,
"timeWindowQueryType": time_window_query_type,
"scrollId": scroll_id,
"capacity": capacity,
"withDisabled": with_disabled,
"isWithDeleted": is_with_deleted,
}
resp = self._client.request(
"POST",
api_path,
json=body,
headers={"Content-Type": "application/json"},
)
payload = resp.json() if resp.content else {}
code = str(payload.get("code", "") or "")
if code != "200":
message = payload.get("message")
raise RuntimeError(f"EHR {api_name} failed code={code!r} message={message!r}")
batch = payload.get("data") or []
if not isinstance(batch, list):
raise RuntimeError(f"EHR {api_name} invalid response: data is not a list")
all_data.extend([x for x in batch if isinstance(x, dict)])
total_val = payload.get("total")
if total_val is not None:
try:
window_total = int(total_val)
except (TypeError, ValueError):
pass
is_last_data = bool(payload.get("isLastData", False))
scroll_id = str(payload.get("scrollId", "") or "")
last_scroll_id = scroll_id
logger.info(
"EHR %s window=%s/%s page=%s batch=%s window_total=%s isLastData=%s",
api_name,
idx,
len(windows),
page,
len(batch),
window_total,
is_last_data,
)
if is_last_data:
break
return {
"startTime": self._to_api_datetime(start_dt),
"stopTime": self._to_api_datetime(stop_dt),
"total": len(all_data),
"pages": total_pages,
"count": len(all_data),
"data": all_data,
"lastScrollId": last_scroll_id,
"windowCount": len(windows),
}
def get_all_employees_with_record_by_time_window(
self,
*,
stop_time: datetime | str | None = None,
capacity: int = 300,
time_window_query_type: int = 1,
with_disabled: bool = True,
is_with_deleted: bool = True,
max_pages: int = 100000,
) -> dict[str, Any]:
"""
滚动查询“员工 + 单条任职”全量结果。
固定起始时间:
- 2015-01-01T00:00:00
"""
return self._get_all_by_time_window(
api_path="/TenantBaseExternal/api/v5/Employee/GetByTimeWindow",
api_name="Employee.GetByTimeWindow",
stop_time=stop_time,
capacity=capacity,
time_window_query_type=time_window_query_type,
with_disabled=with_disabled,
is_with_deleted=is_with_deleted,
max_pages=max_pages,
)
def get_all_organizations_by_time_window(
self,
*,
stop_time: datetime | str | None = None,
capacity: int = 300,
time_window_query_type: int = 1,
with_disabled: bool = True,
is_with_deleted: bool = True,
max_pages: int = 100000,
) -> dict[str, Any]:
"""
滚动查询“组织单元”全量结果。
固定起始时间:
- 2015-01-01T00:00:00
"""
return self._get_all_by_time_window(
api_path="/TenantBaseExternal/api/v5/Organization/GetByTimeWindow",
api_name="Organization.GetByTimeWindow",
stop_time=stop_time,
capacity=capacity,
time_window_query_type=time_window_query_type,
with_disabled=with_disabled,
is_with_deleted=is_with_deleted,
max_pages=max_pages,
)
def get_all_job_posts_by_time_window(
self,
*,
stop_time: datetime | str | None = None,
capacity: int = 300,
time_window_query_type: int = 1,
with_disabled: bool = True,
is_with_deleted: bool = True,
max_pages: int = 100000,
) -> dict[str, Any]:
"""
滚动查询“职务”全量结果。
固定起始时间:
- 2015-01-01T00:00:00
"""
return self._get_all_by_time_window(
api_path="/TenantBaseExternal/api/v5/JobPost/GetByTimeWindow",
api_name="JobPost.GetByTimeWindow",
stop_time=stop_time,
capacity=capacity,
time_window_query_type=time_window_query_type,
with_disabled=with_disabled,
is_with_deleted=is_with_deleted,
max_pages=max_pages,
)
@staticmethod
def _pick_company_from_contracts(contracts: list[dict[str, Any]]) -> str:
if not contracts:
return ""
def _sort_key(item: dict[str, Any]) -> str:
return str(item.get("effectiveDate") or item.get("createdTime") or item.get("modifiedTime") or "")
sorted_items = sorted([x for x in contracts if isinstance(x, dict)], key=_sort_key, reverse=True)
for c in sorted_items:
first_party = str(c.get("firstParty") or "").strip()
if first_party:
return first_party
return ""
def get_contract_first_party_by_user_ids(
self,
*,
user_ids: list[int],
is_current_effective: bool = True,
status: int | None = 1,
contract_type: int | None = None,
is_with_deleted: bool = False,
columns: list[str] | None = None,
enable_translate: bool = False,
chunk_size: int = 300,
) -> dict[int, str]:
"""
调用合同接口按员工 UserID 集合获取所属公司firstParty
接口POST /TenantBaseExternal/api/v5/Contract/GetByUserIds
"""
if chunk_size <= 0 or chunk_size > 300:
raise ValueError("chunk_size must be in range [1, 300]")
if not user_ids:
return {}
clean_ids: list[int] = []
seen: set[int] = set()
for u in user_ids:
try:
uid = int(u)
except Exception:
continue
if uid <= 0:
continue
if uid in seen:
continue
seen.add(uid)
clean_ids.append(uid)
if not clean_ids:
return {}
out: dict[int, str] = {}
for i in range(0, len(clean_ids), chunk_size):
chunk = clean_ids[i : i + chunk_size]
body: dict[str, Any] = {
"oIds": chunk,
"isCurrentEffective": is_current_effective,
"isWithDeleted": is_with_deleted,
"enableTranslate": enable_translate,
}
if status is not None:
body["status"] = status
if contract_type is not None:
body["contractType"] = contract_type
if columns is not None:
body["columns"] = columns
resp = self._client.request(
"POST",
"/TenantBaseExternal/api/v5/Contract/GetByUserIds",
json=body,
headers={"Content-Type": "application/json"},
)
payload = resp.json() if resp.content else {}
code = str(payload.get("code", "") or "")
if code != "200":
raise RuntimeError(f"EHR Contract.GetByUserIds failed code={code!r} message={payload.get('message')!r}")
data = payload.get("data") or {}
if not isinstance(data, dict):
raise RuntimeError("EHR Contract.GetByUserIds invalid response: data is not an object")
for k, v in data.items():
try:
uid = int(str(k))
except Exception:
continue
contracts = v if isinstance(v, list) else []
company = self._pick_company_from_contracts(contracts)
if company:
out[uid] = company
logger.info(
"EHR 合同公司查询完成input_user_ids=%s matched_first_party=%s",
len(clean_ids),
len(out),
)
return out
def get_staff_profiles_by_user_ids(self, *, user_ids: list[int], chunk_size: int = 100) -> dict[int, dict[str, Any]]:
"""
调用 UserFrameworkApiV3 获取员工信息(按 userId
接口GET /UserFrameworkApiV3/api/v1/staffs/Get
返回:{userId: staff_profile}
"""
if chunk_size <= 0:
chunk_size = 100
clean_ids: list[int] = []
seen: set[int] = set()
for u in user_ids:
try:
uid = int(u)
except Exception:
continue
if uid <= 0 or uid in seen:
continue
seen.add(uid)
clean_ids.append(uid)
if not clean_ids:
return {}
out: dict[int, dict[str, Any]] = {}
for i in range(0, len(clean_ids), chunk_size):
chunk = clean_ids[i : i + chunk_size]
for uid in chunk:
resp = self._client.request(
"GET",
"/UserFrameworkApiV3/api/v1/staffs/Get",
params={"userId": uid},
)
payload = resp.json() if resp.content else {}
# 兼容多种返回结构
data = payload.get("data", payload)
if isinstance(data, list):
items = [x for x in data if isinstance(x, dict)]
if items:
out[uid] = items[0]
continue
if isinstance(data, dict):
# 可能是单条,也可能是 map
if "userId" in data or "UserId" in data:
out[uid] = data
else:
# map 场景key=uid
d2 = data.get(str(uid))
if isinstance(d2, dict):
out[uid] = d2
# 非 200 业务码场景不硬失败,避免单个用户影响全量
logger.info("EHR 员工详情查询完成input_user_ids=%s matched_profiles=%s", len(clean_ids), len(out))
return out