343 lines
12 KiB
Python
343 lines
12 KiB
Python
from __future__ import annotations
|
||
|
||
import logging
|
||
from datetime import date, timedelta
|
||
from typing import Any
|
||
from urllib.parse import quote_plus
|
||
|
||
from sqlalchemy import create_engine, text
|
||
from sqlalchemy.engine import Engine
|
||
|
||
from app.integrations.ehr import EhrClient
|
||
|
||
|
||
logger = logging.getLogger("connecthub.extensions.sync_ehr_leaves_to_oa")
|
||
|
||
|
||
def _to_int_safe(v: Any) -> int:
|
||
try:
|
||
return int(str(v).strip())
|
||
except Exception:
|
||
return 0
|
||
|
||
|
||
def _extract_staff_code(staff_profile: dict[str, Any]) -> str:
|
||
if not isinstance(staff_profile, dict):
|
||
return ""
|
||
for key in (
|
||
"staffCode",
|
||
"StaffCode",
|
||
"code",
|
||
"Code",
|
||
"jobNumber",
|
||
"JobNumber",
|
||
"employeeNo",
|
||
"EmployeeNo",
|
||
):
|
||
val = str(staff_profile.get(key) or "").strip()
|
||
if val:
|
||
return val
|
||
lower_map = {str(k).lower(): v for k, v in staff_profile.items()}
|
||
for key in ("staffcode", "code", "jobnumber", "employeeno"):
|
||
val = str(lower_map.get(key) or "").strip()
|
||
if val:
|
||
return val
|
||
return ""
|
||
|
||
|
||
def _extract_staff_name(staff_profile: dict[str, Any]) -> str:
|
||
if not isinstance(staff_profile, dict):
|
||
return ""
|
||
for key in ("name", "Name", "staffName", "StaffName", "employeeName", "EmployeeName"):
|
||
val = str(staff_profile.get(key) or "").strip()
|
||
if val:
|
||
return val
|
||
lower_map = {str(k).lower(): v for k, v in staff_profile.items()}
|
||
for key in ("name", "staffname", "employeename"):
|
||
val = str(lower_map.get(key) or "").strip()
|
||
if val:
|
||
return val
|
||
return ""
|
||
|
||
|
||
class SyncEhrLeavesToOaApi:
|
||
def __init__(
|
||
self,
|
||
*,
|
||
secret_params: dict[str, str],
|
||
sqlserver_params: dict[str, Any],
|
||
base_url: str = "https://openapi.italent.cn",
|
||
timeout_s: float = 10.0,
|
||
retries: int = 2,
|
||
retry_backoff_s: float = 0.5,
|
||
) -> None:
|
||
self._client = EhrClient(
|
||
base_url=base_url,
|
||
secret_params=secret_params,
|
||
timeout_s=timeout_s,
|
||
retries=retries,
|
||
retry_backoff_s=retry_backoff_s,
|
||
)
|
||
self._sql_engine = self._create_sqlserver_engine(sqlserver_params)
|
||
|
||
def close(self) -> None:
|
||
self._client.close()
|
||
self._sql_engine.dispose()
|
||
|
||
@staticmethod
|
||
def _create_sqlserver_engine(sqlserver_params: dict[str, Any]) -> Engine:
|
||
host = str(sqlserver_params.get("host") or "").strip()
|
||
username = str(sqlserver_params.get("username") or "").strip()
|
||
password = str(sqlserver_params.get("password") or "").strip()
|
||
database = str(sqlserver_params.get("database") or "").strip()
|
||
if not host or not username or not password or not database:
|
||
raise ValueError("sqlserver_params.host/username/password/database are required")
|
||
|
||
port = int(sqlserver_params.get("port") or 1433)
|
||
timeout_s = int(sqlserver_params.get("connect_timeout_s") or 10)
|
||
driver = str(sqlserver_params.get("driver") or "ODBC Driver 18 for SQL Server").strip()
|
||
trust_server_certificate = str(sqlserver_params.get("trust_server_certificate") or "yes").strip().lower()
|
||
encrypt = str(sqlserver_params.get("encrypt") or "yes").strip().lower()
|
||
|
||
username_q = quote_plus(username)
|
||
password_q = quote_plus(password)
|
||
driver_q = quote_plus(driver)
|
||
conn_url = (
|
||
f"mssql+pyodbc://{username_q}:{password_q}@{host}:{port}/{database}"
|
||
f"?driver={driver_q}&TrustServerCertificate={trust_server_certificate}&Encrypt={encrypt}"
|
||
)
|
||
return create_engine(conn_url, pool_pre_ping=True, pool_recycle=1800, connect_args={"timeout": timeout_s})
|
||
|
||
def ping_sqlserver(self) -> bool:
|
||
with self._sql_engine.connect() as conn:
|
||
conn.execute(text("SELECT 1"))
|
||
return True
|
||
|
||
def get_vacation_list_by_day(
|
||
self,
|
||
*,
|
||
day: date,
|
||
page_size: int = 100,
|
||
max_pages: int = 500,
|
||
) -> list[dict[str, Any]]:
|
||
if page_size <= 0 or page_size > 100:
|
||
raise ValueError("page_size must be in range [1, 100]")
|
||
if max_pages <= 0:
|
||
raise ValueError("max_pages must be > 0")
|
||
|
||
out: list[dict[str, Any]] = []
|
||
cursor: str | None = None
|
||
page = 0
|
||
while True:
|
||
page += 1
|
||
if page > max_pages:
|
||
raise RuntimeError(f"EHR Vacation.GetListByDate exceeds max_pages={max_pages} day={day.isoformat()}")
|
||
|
||
body: dict[str, Any] = {
|
||
"day": day.isoformat(),
|
||
"queryCursor": cursor,
|
||
"pageSize": page_size,
|
||
}
|
||
resp = self._client.request(
|
||
"POST",
|
||
"/AttendanceOpen/api/v1/Vacation/GetListByDate",
|
||
json=body,
|
||
headers={"Content-Type": "application/json"},
|
||
)
|
||
payload = resp.json() if resp.content else {}
|
||
code = str(payload.get("code") or "")
|
||
if code not in ("200", "206"):
|
||
raise RuntimeError(f"EHR Vacation.GetListByDate failed code={code!r} message={payload.get('message')!r}")
|
||
|
||
data = payload.get("data") or {}
|
||
if not isinstance(data, dict):
|
||
raise RuntimeError("EHR Vacation.GetListByDate invalid: data is not an object")
|
||
vacation_list = data.get("vacationList") or []
|
||
if not isinstance(vacation_list, list):
|
||
raise RuntimeError("EHR Vacation.GetListByDate invalid: data.vacationList is not a list")
|
||
out.extend([x for x in vacation_list if isinstance(x, dict)])
|
||
|
||
is_last_page = bool(data.get("isLastPage", False))
|
||
next_cursor = str(data.get("sortCursor") or "").strip() or None
|
||
logger.info(
|
||
"EHR Vacation.GetListByDate day=%s page=%s batch=%s is_last_page=%s",
|
||
day.isoformat(),
|
||
page,
|
||
len(vacation_list),
|
||
is_last_page,
|
||
)
|
||
if is_last_page:
|
||
break
|
||
cursor = next_cursor
|
||
if not cursor:
|
||
break
|
||
return out
|
||
|
||
def get_vacations_in_date_range(
|
||
self,
|
||
*,
|
||
start_date: date,
|
||
end_date: date,
|
||
page_size: int = 100,
|
||
max_pages_per_day: int = 500,
|
||
) -> list[dict[str, Any]]:
|
||
if end_date < start_date:
|
||
raise ValueError("end_date must be greater than or equal to start_date")
|
||
out: list[dict[str, Any]] = []
|
||
cur = start_date
|
||
while cur <= end_date:
|
||
out.extend(
|
||
self.get_vacation_list_by_day(
|
||
day=cur,
|
||
page_size=page_size,
|
||
max_pages=max_pages_per_day,
|
||
)
|
||
)
|
||
cur = cur + timedelta(days=1)
|
||
logger.info(
|
||
"EHR 请假拉取完成:start_date=%s end_date=%s total_records=%s",
|
||
start_date.isoformat(),
|
||
end_date.isoformat(),
|
||
len(out),
|
||
)
|
||
return out
|
||
|
||
def get_staff_briefs_by_user_ids(self, *, user_ids: list[int], chunk_size: int = 100) -> dict[int, dict[str, str]]:
|
||
if chunk_size <= 0:
|
||
chunk_size = 100
|
||
clean_ids: list[int] = []
|
||
seen: set[int] = set()
|
||
for u in user_ids:
|
||
uid = _to_int_safe(u)
|
||
if uid <= 0 or uid in seen:
|
||
continue
|
||
seen.add(uid)
|
||
clean_ids.append(uid)
|
||
if not clean_ids:
|
||
return {}
|
||
|
||
out: dict[int, dict[str, str]] = {}
|
||
for i in range(0, len(clean_ids), chunk_size):
|
||
chunk = clean_ids[i : i + chunk_size]
|
||
for uid in chunk:
|
||
profile: dict[str, Any] | None = None
|
||
for params in ({"userId": str(uid)}, {"userid": str(uid)}):
|
||
try:
|
||
resp = self._client.request(
|
||
"GET",
|
||
"/UserFrameworkApiV3/api/v1/staffs/Get",
|
||
params=params,
|
||
)
|
||
except Exception:
|
||
continue
|
||
payload = resp.json() if resp.content else {}
|
||
profile = self._find_staff_profile_by_uid(payload, uid)
|
||
if profile is not None:
|
||
break
|
||
if profile is None:
|
||
continue
|
||
code = _extract_staff_code(profile)
|
||
name = _extract_staff_name(profile)
|
||
if code or name:
|
||
out[uid] = {"job_no": code, "name": name}
|
||
logger.info(
|
||
"EHR 员工信息反查完成:input_user_ids=%s matched_staff_profiles=%s",
|
||
len(clean_ids),
|
||
len(out),
|
||
)
|
||
return out
|
||
|
||
def get_staff_codes_by_user_ids(self, *, user_ids: list[int], chunk_size: int = 100) -> dict[int, str]:
|
||
briefs = self.get_staff_briefs_by_user_ids(user_ids=user_ids, chunk_size=chunk_size)
|
||
out: dict[int, str] = {}
|
||
for uid, brief in briefs.items():
|
||
code = str((brief or {}).get("job_no") or "").strip()
|
||
if code:
|
||
out[uid] = code
|
||
return out
|
||
|
||
@staticmethod
|
||
def _find_staff_profile_by_uid(payload: Any, uid: int) -> dict[str, Any] | None:
|
||
def _iter_dicts(node: Any):
|
||
if isinstance(node, dict):
|
||
yield node
|
||
for v in node.values():
|
||
yield from _iter_dicts(v)
|
||
elif isinstance(node, list):
|
||
for it in node:
|
||
yield from _iter_dicts(it)
|
||
|
||
def _uid_from_dict(d: dict[str, Any]) -> int:
|
||
for k in ("userId", "UserId", "userid", "UserID", "id", "Id", "ID"):
|
||
if k in d:
|
||
return _to_int_safe(d.get(k))
|
||
return 0
|
||
|
||
best: dict[str, Any] | None = None
|
||
for d in _iter_dicts(payload):
|
||
if not isinstance(d, dict):
|
||
continue
|
||
if _uid_from_dict(d) != uid:
|
||
continue
|
||
if any(k in d for k in ("staffCode", "StaffCode", "code", "Code", "jobNumber", "JobNumber", "employeeNo", "EmployeeNo")):
|
||
return d
|
||
if best is None:
|
||
best = d
|
||
return best
|
||
|
||
def get_oa_row_id_map_by_job_and_date(
|
||
self,
|
||
*,
|
||
table_name: str,
|
||
schema: str,
|
||
job_no_column: str,
|
||
date_column: str,
|
||
start_date: date,
|
||
end_date: date,
|
||
) -> dict[tuple[str, str], int]:
|
||
t = str(table_name or "").strip()
|
||
jc = str(job_no_column or "").strip()
|
||
dc = str(date_column or "").strip()
|
||
s = str(schema or "").strip() or "dbo"
|
||
if not t or not jc or not dc:
|
||
raise ValueError("table_name/job_no_column/date_column are required")
|
||
|
||
sql = text(
|
||
f"SELECT [id] AS row_id, [{jc}] AS job_no, [{dc}] AS leave_date "
|
||
f"FROM [{s}].[{t}] WITH (NOLOCK) "
|
||
f"WHERE TRY_CONVERT(date, [{dc}]) >= :start_date "
|
||
f"AND TRY_CONVERT(date, [{dc}]) <= :end_date "
|
||
f"AND ISNULL(LTRIM(RTRIM([{jc}])), '') <> ''"
|
||
)
|
||
params = {
|
||
"start_date": start_date.isoformat(),
|
||
"end_date": end_date.isoformat(),
|
||
}
|
||
out: dict[tuple[str, str], int] = {}
|
||
duplicate_keys = 0
|
||
with self._sql_engine.connect() as conn:
|
||
rows = conn.execute(sql, params).fetchall()
|
||
for r in rows:
|
||
try:
|
||
row_id = int(r.row_id)
|
||
except Exception:
|
||
continue
|
||
job_no = str(r.job_no or "").strip()
|
||
leave_date = str(r.leave_date or "").strip()
|
||
leave_date = leave_date.split(" ", 1)[0]
|
||
if not job_no or not leave_date:
|
||
continue
|
||
key = (job_no, leave_date)
|
||
if key in out and out[key] != row_id:
|
||
duplicate_keys += 1
|
||
out[key] = row_id
|
||
logger.info(
|
||
"OA 现有记录索引完成:table=%s month_range=%s~%s indexed=%s duplicate_keys=%s",
|
||
t,
|
||
start_date.isoformat(),
|
||
end_date.isoformat(),
|
||
len(out),
|
||
duplicate_keys,
|
||
)
|
||
return out
|