Vastai-ConnectHub/extensions/sync_ehr_leaves_to_oa/api.py

343 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import logging
from datetime import date, timedelta
from typing import Any
from urllib.parse import quote_plus
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from app.integrations.ehr import EhrClient
logger = logging.getLogger("connecthub.extensions.sync_ehr_leaves_to_oa")
def _to_int_safe(v: Any) -> int:
try:
return int(str(v).strip())
except Exception:
return 0
def _extract_staff_code(staff_profile: dict[str, Any]) -> str:
if not isinstance(staff_profile, dict):
return ""
for key in (
"staffCode",
"StaffCode",
"code",
"Code",
"jobNumber",
"JobNumber",
"employeeNo",
"EmployeeNo",
):
val = str(staff_profile.get(key) or "").strip()
if val:
return val
lower_map = {str(k).lower(): v for k, v in staff_profile.items()}
for key in ("staffcode", "code", "jobnumber", "employeeno"):
val = str(lower_map.get(key) or "").strip()
if val:
return val
return ""
def _extract_staff_name(staff_profile: dict[str, Any]) -> str:
if not isinstance(staff_profile, dict):
return ""
for key in ("name", "Name", "staffName", "StaffName", "employeeName", "EmployeeName"):
val = str(staff_profile.get(key) or "").strip()
if val:
return val
lower_map = {str(k).lower(): v for k, v in staff_profile.items()}
for key in ("name", "staffname", "employeename"):
val = str(lower_map.get(key) or "").strip()
if val:
return val
return ""
class SyncEhrLeavesToOaApi:
def __init__(
self,
*,
secret_params: dict[str, str],
sqlserver_params: dict[str, Any],
base_url: str = "https://openapi.italent.cn",
timeout_s: float = 10.0,
retries: int = 2,
retry_backoff_s: float = 0.5,
) -> None:
self._client = EhrClient(
base_url=base_url,
secret_params=secret_params,
timeout_s=timeout_s,
retries=retries,
retry_backoff_s=retry_backoff_s,
)
self._sql_engine = self._create_sqlserver_engine(sqlserver_params)
def close(self) -> None:
self._client.close()
self._sql_engine.dispose()
@staticmethod
def _create_sqlserver_engine(sqlserver_params: dict[str, Any]) -> Engine:
host = str(sqlserver_params.get("host") or "").strip()
username = str(sqlserver_params.get("username") or "").strip()
password = str(sqlserver_params.get("password") or "").strip()
database = str(sqlserver_params.get("database") or "").strip()
if not host or not username or not password or not database:
raise ValueError("sqlserver_params.host/username/password/database are required")
port = int(sqlserver_params.get("port") or 1433)
timeout_s = int(sqlserver_params.get("connect_timeout_s") or 10)
driver = str(sqlserver_params.get("driver") or "ODBC Driver 18 for SQL Server").strip()
trust_server_certificate = str(sqlserver_params.get("trust_server_certificate") or "yes").strip().lower()
encrypt = str(sqlserver_params.get("encrypt") or "yes").strip().lower()
username_q = quote_plus(username)
password_q = quote_plus(password)
driver_q = quote_plus(driver)
conn_url = (
f"mssql+pyodbc://{username_q}:{password_q}@{host}:{port}/{database}"
f"?driver={driver_q}&TrustServerCertificate={trust_server_certificate}&Encrypt={encrypt}"
)
return create_engine(conn_url, pool_pre_ping=True, pool_recycle=1800, connect_args={"timeout": timeout_s})
def ping_sqlserver(self) -> bool:
with self._sql_engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
def get_vacation_list_by_day(
self,
*,
day: date,
page_size: int = 100,
max_pages: int = 500,
) -> list[dict[str, Any]]:
if page_size <= 0 or page_size > 100:
raise ValueError("page_size must be in range [1, 100]")
if max_pages <= 0:
raise ValueError("max_pages must be > 0")
out: list[dict[str, Any]] = []
cursor: str | None = None
page = 0
while True:
page += 1
if page > max_pages:
raise RuntimeError(f"EHR Vacation.GetListByDate exceeds max_pages={max_pages} day={day.isoformat()}")
body: dict[str, Any] = {
"day": day.isoformat(),
"queryCursor": cursor,
"pageSize": page_size,
}
resp = self._client.request(
"POST",
"/AttendanceOpen/api/v1/Vacation/GetListByDate",
json=body,
headers={"Content-Type": "application/json"},
)
payload = resp.json() if resp.content else {}
code = str(payload.get("code") or "")
if code not in ("200", "206"):
raise RuntimeError(f"EHR Vacation.GetListByDate failed code={code!r} message={payload.get('message')!r}")
data = payload.get("data") or {}
if not isinstance(data, dict):
raise RuntimeError("EHR Vacation.GetListByDate invalid: data is not an object")
vacation_list = data.get("vacationList") or []
if not isinstance(vacation_list, list):
raise RuntimeError("EHR Vacation.GetListByDate invalid: data.vacationList is not a list")
out.extend([x for x in vacation_list if isinstance(x, dict)])
is_last_page = bool(data.get("isLastPage", False))
next_cursor = str(data.get("sortCursor") or "").strip() or None
logger.info(
"EHR Vacation.GetListByDate day=%s page=%s batch=%s is_last_page=%s",
day.isoformat(),
page,
len(vacation_list),
is_last_page,
)
if is_last_page:
break
cursor = next_cursor
if not cursor:
break
return out
def get_vacations_in_date_range(
self,
*,
start_date: date,
end_date: date,
page_size: int = 100,
max_pages_per_day: int = 500,
) -> list[dict[str, Any]]:
if end_date < start_date:
raise ValueError("end_date must be greater than or equal to start_date")
out: list[dict[str, Any]] = []
cur = start_date
while cur <= end_date:
out.extend(
self.get_vacation_list_by_day(
day=cur,
page_size=page_size,
max_pages=max_pages_per_day,
)
)
cur = cur + timedelta(days=1)
logger.info(
"EHR 请假拉取完成start_date=%s end_date=%s total_records=%s",
start_date.isoformat(),
end_date.isoformat(),
len(out),
)
return out
def get_staff_briefs_by_user_ids(self, *, user_ids: list[int], chunk_size: int = 100) -> dict[int, dict[str, str]]:
if chunk_size <= 0:
chunk_size = 100
clean_ids: list[int] = []
seen: set[int] = set()
for u in user_ids:
uid = _to_int_safe(u)
if uid <= 0 or uid in seen:
continue
seen.add(uid)
clean_ids.append(uid)
if not clean_ids:
return {}
out: dict[int, dict[str, str]] = {}
for i in range(0, len(clean_ids), chunk_size):
chunk = clean_ids[i : i + chunk_size]
for uid in chunk:
profile: dict[str, Any] | None = None
for params in ({"userId": str(uid)}, {"userid": str(uid)}):
try:
resp = self._client.request(
"GET",
"/UserFrameworkApiV3/api/v1/staffs/Get",
params=params,
)
except Exception:
continue
payload = resp.json() if resp.content else {}
profile = self._find_staff_profile_by_uid(payload, uid)
if profile is not None:
break
if profile is None:
continue
code = _extract_staff_code(profile)
name = _extract_staff_name(profile)
if code or name:
out[uid] = {"job_no": code, "name": name}
logger.info(
"EHR 员工信息反查完成input_user_ids=%s matched_staff_profiles=%s",
len(clean_ids),
len(out),
)
return out
def get_staff_codes_by_user_ids(self, *, user_ids: list[int], chunk_size: int = 100) -> dict[int, str]:
briefs = self.get_staff_briefs_by_user_ids(user_ids=user_ids, chunk_size=chunk_size)
out: dict[int, str] = {}
for uid, brief in briefs.items():
code = str((brief or {}).get("job_no") or "").strip()
if code:
out[uid] = code
return out
@staticmethod
def _find_staff_profile_by_uid(payload: Any, uid: int) -> dict[str, Any] | None:
def _iter_dicts(node: Any):
if isinstance(node, dict):
yield node
for v in node.values():
yield from _iter_dicts(v)
elif isinstance(node, list):
for it in node:
yield from _iter_dicts(it)
def _uid_from_dict(d: dict[str, Any]) -> int:
for k in ("userId", "UserId", "userid", "UserID", "id", "Id", "ID"):
if k in d:
return _to_int_safe(d.get(k))
return 0
best: dict[str, Any] | None = None
for d in _iter_dicts(payload):
if not isinstance(d, dict):
continue
if _uid_from_dict(d) != uid:
continue
if any(k in d for k in ("staffCode", "StaffCode", "code", "Code", "jobNumber", "JobNumber", "employeeNo", "EmployeeNo")):
return d
if best is None:
best = d
return best
def get_oa_row_id_map_by_job_and_date(
self,
*,
table_name: str,
schema: str,
job_no_column: str,
date_column: str,
start_date: date,
end_date: date,
) -> dict[tuple[str, str], int]:
t = str(table_name or "").strip()
jc = str(job_no_column or "").strip()
dc = str(date_column or "").strip()
s = str(schema or "").strip() or "dbo"
if not t or not jc or not dc:
raise ValueError("table_name/job_no_column/date_column are required")
sql = text(
f"SELECT [id] AS row_id, [{jc}] AS job_no, [{dc}] AS leave_date "
f"FROM [{s}].[{t}] WITH (NOLOCK) "
f"WHERE TRY_CONVERT(date, [{dc}]) >= :start_date "
f"AND TRY_CONVERT(date, [{dc}]) <= :end_date "
f"AND ISNULL(LTRIM(RTRIM([{jc}])), '') <> ''"
)
params = {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
}
out: dict[tuple[str, str], int] = {}
duplicate_keys = 0
with self._sql_engine.connect() as conn:
rows = conn.execute(sql, params).fetchall()
for r in rows:
try:
row_id = int(r.row_id)
except Exception:
continue
job_no = str(r.job_no or "").strip()
leave_date = str(r.leave_date or "").strip()
leave_date = leave_date.split(" ", 1)[0]
if not job_no or not leave_date:
continue
key = (job_no, leave_date)
if key in out and out[key] != row_id:
duplicate_keys += 1
out[key] = row_id
logger.info(
"OA 现有记录索引完成table=%s month_range=%s~%s indexed=%s duplicate_keys=%s",
t,
start_date.isoformat(),
end_date.isoformat(),
len(out),
duplicate_keys,
)
return out