from __future__ import annotations import logging from datetime import datetime, timedelta from typing import Any from urllib.parse import quote_plus from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine from app.integrations.ehr import EhrClient logger = logging.getLogger("connecthub.extensions.sync_ehr_to_oa") class SyncEhrToOaApi: """ 北森 EHR -> OA 同步 API 封装。 已封装 API: - 员工与单条任职时间窗滚动查询 - 组织单元时间窗滚动查询 - 职务时间窗滚动查询 """ def __init__( self, *, secret_params: dict[str, str], sqlserver_params: dict[str, Any] | None = None, base_url: str = "https://openapi.italent.cn", timeout_s: float = 10.0, retries: int = 2, retry_backoff_s: float = 0.5, ) -> None: self._client = EhrClient( base_url=base_url, secret_params=secret_params, timeout_s=timeout_s, retries=retries, retry_backoff_s=retry_backoff_s, ) self._sqlserver_params = dict(sqlserver_params or {}) self._sql_engine: Engine | None = None if self._sqlserver_params: self._sql_engine = self._create_sqlserver_engine(self._sqlserver_params) def close(self) -> None: self._client.close() if self._sql_engine is not None: self._sql_engine.dispose() self._sql_engine = None @staticmethod def _create_sqlserver_engine(sqlserver_params: dict[str, Any]) -> Engine: host = str(sqlserver_params.get("host") or "").strip() username = str(sqlserver_params.get("username") or "").strip() password = str(sqlserver_params.get("password") or "").strip() database = str(sqlserver_params.get("database") or "").strip() if not host or not username or not password or not database: raise ValueError("sqlserver_params.host/username/password/database are required") port = int(sqlserver_params.get("port") or 1433) timeout_s = int(sqlserver_params.get("connect_timeout_s") or 10) driver = str(sqlserver_params.get("driver") or "ODBC Driver 18 for SQL Server").strip() trust_server_certificate = str(sqlserver_params.get("trust_server_certificate") or "yes").strip().lower() encrypt = str(sqlserver_params.get("encrypt") or "yes").strip().lower() # 依赖 mssql+pyodbc;运行环境需安装 pyodbc 与对应 ODBC driver。 # 例:mssql+pyodbc://user:pass@host:1433/db?driver=ODBC+Driver+18+for+SQL+Server # 使用 query 参数以减少 URL 编码问题。 username_q = quote_plus(username) password_q = quote_plus(password) driver_q = quote_plus(driver) conn_url = ( f"mssql+pyodbc://{username_q}:{password_q}@{host}:{port}/{database}" f"?driver={driver_q}&TrustServerCertificate={trust_server_certificate}&Encrypt={encrypt}" ) return create_engine(conn_url, pool_pre_ping=True, pool_recycle=1800, connect_args={"timeout": timeout_s}) def ping_sqlserver(self) -> bool: if self._sql_engine is None: raise RuntimeError("SQL Server engine is not initialized; pass sqlserver_params when creating SyncEhrToOaApi") with self._sql_engine.connect() as conn: conn.execute(text("SELECT 1")) return True def get_oa_record_id_map_from_sqlserver( self, *, table_name: str, job_numbers: list[str], job_no_column: str = "field0001", id_column: str = "id", schema: str | None = "dbo", ) -> dict[str, int]: """ 从 OA SQLServer 表按工号查询记录ID映射。 返回:{job_no: id} """ if self._sql_engine is None: raise RuntimeError("SQL Server engine is not initialized; pass sqlserver_params when creating SyncEhrToOaApi") t = str(table_name or "").strip() jc = str(job_no_column or "").strip() ic = str(id_column or "").strip() if not t or not jc or not ic: raise ValueError("table_name/job_no_column/id_column are required") if not job_numbers: return {} clean_job_numbers = [str(x or "").strip() for x in job_numbers if str(x or "").strip()] if not clean_job_numbers: return {} quoted_table = f"[{t}]" if schema: quoted_table = f"[{str(schema).strip()}].{quoted_table}" out: dict[str, int] = {} # SQL Server IN 参数分批,避免参数过多;每批 500。 chunk_size = 500 with self._sql_engine.connect() as conn: for i in range(0, len(clean_job_numbers), chunk_size): chunk = clean_job_numbers[i : i + chunk_size] binds = ", ".join([f":v{idx}" for idx in range(len(chunk))]) sql = text( f"SELECT [{jc}] AS job_no, [{ic}] AS row_id " f"FROM {quoted_table} WITH (NOLOCK) " f"WHERE [{jc}] IN ({binds})" ) params = {f"v{idx}": chunk[idx] for idx in range(len(chunk))} rows = conn.execute(sql, params).fetchall() for r in rows: job_no = str(r.job_no or "").strip() try: row_id = int(r.row_id) except Exception: continue if job_no: out[job_no] = row_id logger.info( "SQLServer 工号映射查询完成:table=%s schema=%s input=%s matched=%s", t, schema or "", len(clean_job_numbers), len(out), ) return out @staticmethod def _to_datetime(value: datetime | str | None) -> datetime: if value is None: return datetime.now() if isinstance(value, datetime): return value s = str(value).strip() if not s: raise ValueError("datetime string cannot be empty") if "T" in s: return datetime.fromisoformat(s) if " " in s: return datetime.fromisoformat(s.replace(" ", "T")) return datetime.strptime(s, "%Y-%m-%d") @staticmethod def _to_api_datetime(value: datetime | str | None) -> str: return SyncEhrToOaApi._to_datetime(value).strftime("%Y-%m-%dT%H:%M:%S") @staticmethod def _iter_windows(start_dt: datetime, stop_dt: datetime, max_days: int = 90) -> list[tuple[datetime, datetime]]: if stop_dt < start_dt: raise ValueError("stop_time must be greater than or equal to start_time") windows: list[tuple[datetime, datetime]] = [] cur = start_dt max_delta = timedelta(days=max_days) while cur < stop_dt: nxt = min(cur + max_delta, stop_dt) windows.append((cur, nxt)) cur = nxt if not windows: windows.append((start_dt, stop_dt)) return windows def _get_all_by_time_window( self, *, api_path: str, api_name: str, stop_time: datetime | str | None, capacity: int, time_window_query_type: int, with_disabled: bool, is_with_deleted: bool, max_pages: int, ) -> dict[str, Any]: if capacity <= 0 or capacity > 300: raise ValueError("capacity must be in range [1, 300]") if max_pages <= 0: raise ValueError("max_pages must be > 0") start_dt = datetime(2015, 1, 1, 0, 0, 0) stop_dt = self._to_datetime(stop_time) windows = self._iter_windows(start_dt=start_dt, stop_dt=stop_dt, max_days=90) all_data: list[dict[str, Any]] = [] total_pages = 0 last_scroll_id = "" for idx, (w_start, w_stop) in enumerate(windows, start=1): start_time = self._to_api_datetime(w_start) stop_time_s = self._to_api_datetime(w_stop) scroll_id = "" page = 0 window_total = 0 while True: page += 1 total_pages += 1 if page > max_pages: raise RuntimeError(f"scroll pages exceed max_pages={max_pages} in window index={idx}") body: dict[str, Any] = { "startTime": start_time, "stopTime": stop_time_s, "timeWindowQueryType": time_window_query_type, "scrollId": scroll_id, "capacity": capacity, "withDisabled": with_disabled, "isWithDeleted": is_with_deleted, } resp = self._client.request( "POST", api_path, json=body, headers={"Content-Type": "application/json"}, ) payload = resp.json() if resp.content else {} code = str(payload.get("code", "") or "") if code != "200": message = payload.get("message") raise RuntimeError(f"EHR {api_name} failed code={code!r} message={message!r}") batch = payload.get("data") or [] if not isinstance(batch, list): raise RuntimeError(f"EHR {api_name} invalid response: data is not a list") all_data.extend([x for x in batch if isinstance(x, dict)]) total_val = payload.get("total") if total_val is not None: try: window_total = int(total_val) except (TypeError, ValueError): pass is_last_data = bool(payload.get("isLastData", False)) scroll_id = str(payload.get("scrollId", "") or "") last_scroll_id = scroll_id logger.info( "EHR %s window=%s/%s page=%s batch=%s window_total=%s isLastData=%s", api_name, idx, len(windows), page, len(batch), window_total, is_last_data, ) if is_last_data: break return { "startTime": self._to_api_datetime(start_dt), "stopTime": self._to_api_datetime(stop_dt), "total": len(all_data), "pages": total_pages, "count": len(all_data), "data": all_data, "lastScrollId": last_scroll_id, "windowCount": len(windows), } def get_all_employees_with_record_by_time_window( self, *, stop_time: datetime | str | None = None, capacity: int = 300, time_window_query_type: int = 1, with_disabled: bool = True, is_with_deleted: bool = True, max_pages: int = 100000, ) -> dict[str, Any]: """ 滚动查询“员工 + 单条任职”全量结果。 固定起始时间: - 2015-01-01T00:00:00 """ return self._get_all_by_time_window( api_path="/TenantBaseExternal/api/v5/Employee/GetByTimeWindow", api_name="Employee.GetByTimeWindow", stop_time=stop_time, capacity=capacity, time_window_query_type=time_window_query_type, with_disabled=with_disabled, is_with_deleted=is_with_deleted, max_pages=max_pages, ) def get_all_organizations_by_time_window( self, *, stop_time: datetime | str | None = None, capacity: int = 300, time_window_query_type: int = 1, with_disabled: bool = True, is_with_deleted: bool = True, max_pages: int = 100000, ) -> dict[str, Any]: """ 滚动查询“组织单元”全量结果。 固定起始时间: - 2015-01-01T00:00:00 """ return self._get_all_by_time_window( api_path="/TenantBaseExternal/api/v5/Organization/GetByTimeWindow", api_name="Organization.GetByTimeWindow", stop_time=stop_time, capacity=capacity, time_window_query_type=time_window_query_type, with_disabled=with_disabled, is_with_deleted=is_with_deleted, max_pages=max_pages, ) def get_all_job_posts_by_time_window( self, *, stop_time: datetime | str | None = None, capacity: int = 300, time_window_query_type: int = 1, with_disabled: bool = True, is_with_deleted: bool = True, max_pages: int = 100000, ) -> dict[str, Any]: """ 滚动查询“职务”全量结果。 固定起始时间: - 2015-01-01T00:00:00 """ return self._get_all_by_time_window( api_path="/TenantBaseExternal/api/v5/JobPost/GetByTimeWindow", api_name="JobPost.GetByTimeWindow", stop_time=stop_time, capacity=capacity, time_window_query_type=time_window_query_type, with_disabled=with_disabled, is_with_deleted=is_with_deleted, max_pages=max_pages, ) @staticmethod def _pick_company_from_contracts(contracts: list[dict[str, Any]]) -> str: if not contracts: return "" def _sort_key(item: dict[str, Any]) -> str: return str(item.get("effectiveDate") or item.get("createdTime") or item.get("modifiedTime") or "") sorted_items = sorted([x for x in contracts if isinstance(x, dict)], key=_sort_key, reverse=True) for c in sorted_items: first_party = str(c.get("firstParty") or "").strip() if first_party: return first_party return "" def get_contract_first_party_by_user_ids( self, *, user_ids: list[int], is_current_effective: bool = True, status: int | None = 1, contract_type: int | None = None, is_with_deleted: bool = False, columns: list[str] | None = None, enable_translate: bool = False, chunk_size: int = 300, ) -> dict[int, str]: """ 调用合同接口按员工 UserID 集合获取所属公司(firstParty)。 接口:POST /TenantBaseExternal/api/v5/Contract/GetByUserIds """ if chunk_size <= 0 or chunk_size > 300: raise ValueError("chunk_size must be in range [1, 300]") if not user_ids: return {} clean_ids: list[int] = [] seen: set[int] = set() for u in user_ids: try: uid = int(u) except Exception: continue if uid <= 0: continue if uid in seen: continue seen.add(uid) clean_ids.append(uid) if not clean_ids: return {} out: dict[int, str] = {} for i in range(0, len(clean_ids), chunk_size): chunk = clean_ids[i : i + chunk_size] body: dict[str, Any] = { "oIds": chunk, "isCurrentEffective": is_current_effective, "isWithDeleted": is_with_deleted, "enableTranslate": enable_translate, } if status is not None: body["status"] = status if contract_type is not None: body["contractType"] = contract_type if columns is not None: body["columns"] = columns resp = self._client.request( "POST", "/TenantBaseExternal/api/v5/Contract/GetByUserIds", json=body, headers={"Content-Type": "application/json"}, ) payload = resp.json() if resp.content else {} code = str(payload.get("code", "") or "") if code != "200": raise RuntimeError(f"EHR Contract.GetByUserIds failed code={code!r} message={payload.get('message')!r}") data = payload.get("data") or {} if not isinstance(data, dict): raise RuntimeError("EHR Contract.GetByUserIds invalid response: data is not an object") for k, v in data.items(): try: uid = int(str(k)) except Exception: continue contracts = v if isinstance(v, list) else [] company = self._pick_company_from_contracts(contracts) if company: out[uid] = company logger.info( "EHR 合同公司查询完成:input_user_ids=%s matched_first_party=%s", len(clean_ids), len(out), ) return out