from __future__ import annotations import logging from datetime import date, timedelta from typing import Any from urllib.parse import quote_plus from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine from app.integrations.ehr import EhrClient logger = logging.getLogger("connecthub.extensions.sync_ehr_leaves_to_oa") def _to_int_safe(v: Any) -> int: try: return int(str(v).strip()) except Exception: return 0 def _extract_staff_code(staff_profile: dict[str, Any]) -> str: if not isinstance(staff_profile, dict): return "" for key in ( "staffCode", "StaffCode", "code", "Code", "jobNumber", "JobNumber", "employeeNo", "EmployeeNo", ): val = str(staff_profile.get(key) or "").strip() if val: return val lower_map = {str(k).lower(): v for k, v in staff_profile.items()} for key in ("staffcode", "code", "jobnumber", "employeeno"): val = str(lower_map.get(key) or "").strip() if val: return val return "" def _extract_staff_name(staff_profile: dict[str, Any]) -> str: if not isinstance(staff_profile, dict): return "" for key in ("name", "Name", "staffName", "StaffName", "employeeName", "EmployeeName"): val = str(staff_profile.get(key) or "").strip() if val: return val lower_map = {str(k).lower(): v for k, v in staff_profile.items()} for key in ("name", "staffname", "employeename"): val = str(lower_map.get(key) or "").strip() if val: return val return "" class SyncEhrLeavesToOaApi: def __init__( self, *, secret_params: dict[str, str], sqlserver_params: dict[str, Any], base_url: str = "https://openapi.italent.cn", timeout_s: float = 10.0, retries: int = 2, retry_backoff_s: float = 0.5, ) -> None: self._client = EhrClient( base_url=base_url, secret_params=secret_params, timeout_s=timeout_s, retries=retries, retry_backoff_s=retry_backoff_s, ) self._sql_engine = self._create_sqlserver_engine(sqlserver_params) def close(self) -> None: self._client.close() self._sql_engine.dispose() @staticmethod def _create_sqlserver_engine(sqlserver_params: dict[str, Any]) -> Engine: host = str(sqlserver_params.get("host") or "").strip() username = str(sqlserver_params.get("username") or "").strip() password = str(sqlserver_params.get("password") or "").strip() database = str(sqlserver_params.get("database") or "").strip() if not host or not username or not password or not database: raise ValueError("sqlserver_params.host/username/password/database are required") port = int(sqlserver_params.get("port") or 1433) timeout_s = int(sqlserver_params.get("connect_timeout_s") or 10) driver = str(sqlserver_params.get("driver") or "ODBC Driver 18 for SQL Server").strip() trust_server_certificate = str(sqlserver_params.get("trust_server_certificate") or "yes").strip().lower() encrypt = str(sqlserver_params.get("encrypt") or "yes").strip().lower() username_q = quote_plus(username) password_q = quote_plus(password) driver_q = quote_plus(driver) conn_url = ( f"mssql+pyodbc://{username_q}:{password_q}@{host}:{port}/{database}" f"?driver={driver_q}&TrustServerCertificate={trust_server_certificate}&Encrypt={encrypt}" ) return create_engine(conn_url, pool_pre_ping=True, pool_recycle=1800, connect_args={"timeout": timeout_s}) def ping_sqlserver(self) -> bool: with self._sql_engine.connect() as conn: conn.execute(text("SELECT 1")) return True def get_vacation_list_by_day( self, *, day: date, page_size: int = 100, max_pages: int = 500, ) -> list[dict[str, Any]]: if page_size <= 0 or page_size > 100: raise ValueError("page_size must be in range [1, 100]") if max_pages <= 0: raise ValueError("max_pages must be > 0") out: list[dict[str, Any]] = [] cursor: str | None = None page = 0 while True: page += 1 if page > max_pages: raise RuntimeError(f"EHR Vacation.GetListByDate exceeds max_pages={max_pages} day={day.isoformat()}") body: dict[str, Any] = { "day": day.isoformat(), "queryCursor": cursor, "pageSize": page_size, } resp = self._client.request( "POST", "/AttendanceOpen/api/v1/Vacation/GetListByDate", json=body, headers={"Content-Type": "application/json"}, ) payload = resp.json() if resp.content else {} code = str(payload.get("code") or "") if code not in ("200", "206"): raise RuntimeError(f"EHR Vacation.GetListByDate failed code={code!r} message={payload.get('message')!r}") data = payload.get("data") or {} if not isinstance(data, dict): raise RuntimeError("EHR Vacation.GetListByDate invalid: data is not an object") vacation_list = data.get("vacationList") or [] if not isinstance(vacation_list, list): raise RuntimeError("EHR Vacation.GetListByDate invalid: data.vacationList is not a list") out.extend([x for x in vacation_list if isinstance(x, dict)]) is_last_page = bool(data.get("isLastPage", False)) next_cursor = str(data.get("sortCursor") or "").strip() or None logger.info( "EHR Vacation.GetListByDate day=%s page=%s batch=%s is_last_page=%s", day.isoformat(), page, len(vacation_list), is_last_page, ) if is_last_page: break cursor = next_cursor if not cursor: break return out def get_vacations_in_date_range( self, *, start_date: date, end_date: date, page_size: int = 100, max_pages_per_day: int = 500, ) -> list[dict[str, Any]]: if end_date < start_date: raise ValueError("end_date must be greater than or equal to start_date") out: list[dict[str, Any]] = [] cur = start_date while cur <= end_date: out.extend( self.get_vacation_list_by_day( day=cur, page_size=page_size, max_pages=max_pages_per_day, ) ) cur = cur + timedelta(days=1) logger.info( "EHR 请假拉取完成:start_date=%s end_date=%s total_records=%s", start_date.isoformat(), end_date.isoformat(), len(out), ) return out def get_staff_briefs_by_user_ids(self, *, user_ids: list[int], chunk_size: int = 100) -> dict[int, dict[str, str]]: if chunk_size <= 0: chunk_size = 100 clean_ids: list[int] = [] seen: set[int] = set() for u in user_ids: uid = _to_int_safe(u) if uid <= 0 or uid in seen: continue seen.add(uid) clean_ids.append(uid) if not clean_ids: return {} out: dict[int, dict[str, str]] = {} for i in range(0, len(clean_ids), chunk_size): chunk = clean_ids[i : i + chunk_size] for uid in chunk: profile: dict[str, Any] | None = None for params in ({"userId": str(uid)}, {"userid": str(uid)}): try: resp = self._client.request( "GET", "/UserFrameworkApiV3/api/v1/staffs/Get", params=params, ) except Exception: continue payload = resp.json() if resp.content else {} profile = self._find_staff_profile_by_uid(payload, uid) if profile is not None: break if profile is None: continue code = _extract_staff_code(profile) name = _extract_staff_name(profile) if code or name: out[uid] = {"job_no": code, "name": name} logger.info( "EHR 员工信息反查完成:input_user_ids=%s matched_staff_profiles=%s", len(clean_ids), len(out), ) return out def get_staff_codes_by_user_ids(self, *, user_ids: list[int], chunk_size: int = 100) -> dict[int, str]: briefs = self.get_staff_briefs_by_user_ids(user_ids=user_ids, chunk_size=chunk_size) out: dict[int, str] = {} for uid, brief in briefs.items(): code = str((brief or {}).get("job_no") or "").strip() if code: out[uid] = code return out @staticmethod def _find_staff_profile_by_uid(payload: Any, uid: int) -> dict[str, Any] | None: def _iter_dicts(node: Any): if isinstance(node, dict): yield node for v in node.values(): yield from _iter_dicts(v) elif isinstance(node, list): for it in node: yield from _iter_dicts(it) def _uid_from_dict(d: dict[str, Any]) -> int: for k in ("userId", "UserId", "userid", "UserID", "id", "Id", "ID"): if k in d: return _to_int_safe(d.get(k)) return 0 best: dict[str, Any] | None = None for d in _iter_dicts(payload): if not isinstance(d, dict): continue if _uid_from_dict(d) != uid: continue if any(k in d for k in ("staffCode", "StaffCode", "code", "Code", "jobNumber", "JobNumber", "employeeNo", "EmployeeNo")): return d if best is None: best = d return best def get_oa_row_id_map_by_job_and_date( self, *, table_name: str, schema: str, job_no_column: str, date_column: str, start_date: date, end_date: date, ) -> dict[tuple[str, str], int]: rows = self.get_oa_rows_by_job_and_date( table_name=table_name, schema=schema, job_no_column=job_no_column, date_column=date_column, leave_days_column=None, name_column=None, start_date=start_date, end_date=end_date, ) out: dict[tuple[str, str], int] = {} for k, v in rows.items(): rid = _to_int_safe((v or {}).get("id")) if rid > 0: out[k] = rid return out def get_oa_rows_by_job_and_date( self, *, table_name: str, schema: str, job_no_column: str, date_column: str, leave_days_column: str | None, name_column: str | None, start_date: date, end_date: date, ) -> dict[tuple[str, str], dict[str, str]]: t = str(table_name or "").strip() jc = str(job_no_column or "").strip() dc = str(date_column or "").strip() lc = str(leave_days_column or "").strip() nc = str(name_column or "").strip() s = str(schema or "").strip() or "dbo" if not t or not jc or not dc: raise ValueError("table_name/job_no_column/date_column are required") select_parts = [ "[id] AS row_id", f"[{jc}] AS job_no", f"[{dc}] AS leave_date", ] if lc: select_parts.append(f"[{lc}] AS leave_days") if nc: select_parts.append(f"[{nc}] AS staff_name") sql = text( f"SELECT {', '.join(select_parts)} " f"FROM [{s}].[{t}] WITH (NOLOCK) " f"WHERE TRY_CONVERT(date, [{dc}]) >= :start_date " f"AND TRY_CONVERT(date, [{dc}]) <= :end_date " f"AND ISNULL(LTRIM(RTRIM([{jc}])), '') <> ''" ) params = { "start_date": start_date.isoformat(), "end_date": end_date.isoformat(), } out: dict[tuple[str, str], dict[str, str]] = {} duplicate_keys = 0 with self._sql_engine.connect() as conn: rows = conn.execute(sql, params).fetchall() for r in rows: try: row_id = int(r.row_id) except Exception: continue job_no = str(r.job_no or "").strip() leave_date = str(r.leave_date or "").strip() leave_date = leave_date.split(" ", 1)[0] if not job_no or not leave_date: continue key = (job_no, leave_date) leave_days = "" if lc: leave_days = str(getattr(r, "leave_days", "") or "").strip() staff_name = "" if nc: staff_name = str(getattr(r, "staff_name", "") or "").strip() if key in out and _to_int_safe((out[key] or {}).get("id")) != row_id: duplicate_keys += 1 out[key] = { "id": str(row_id), "job_no": job_no, "leave_date": leave_date, "leave_days": leave_days, "name": staff_name, } logger.info( "OA 现有记录索引完成:table=%s month_range=%s~%s indexed=%s duplicate_keys=%s", t, start_date.isoformat(), end_date.isoformat(), len(out), duplicate_keys, ) return out