From 58b117264a8650774eab1e4fcec85d56058bc06e Mon Sep 17 00:00:00 2001 From: Marsway Date: Wed, 25 Mar 2026 10:08:10 +0800 Subject: [PATCH] add sql server --- docker/Dockerfile | 15 ++++ extensions/sync_ehr_to_oa/api.py | 108 ++++++++++++++++++++++++++++ extensions/sync_ehr_to_oa/job.py | 118 +++++++++++++++++++++---------- pyproject.toml | 1 + 4 files changed, 205 insertions(+), 37 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 5613a99..3a56bfd 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,6 +22,21 @@ RUN set -eux; \ apt-get update; \ apt-get install -y --no-install-recommends \ build-essential \ + ca-certificates \ + curl \ + gnupg \ + unixodbc \ + unixodbc-dev \ + apt-transport-https \ + && rm -rf /var/lib/apt/lists/*; \ + curl -fsSL https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor -o /usr/share/keyrings/microsoft-prod.gpg; \ + . /etc/os-release; \ + curl -fsSL "https://packages.microsoft.com/config/debian/${VERSION_ID}/prod.list" \ + | sed 's#^deb #deb [signed-by=/usr/share/keyrings/microsoft-prod.gpg] #' \ + > /etc/apt/sources.list.d/microsoft-prod.list; \ + apt-get update; \ + ACCEPT_EULA=Y apt-get install -y --no-install-recommends \ + msodbcsql18 \ && rm -rf /var/lib/apt/lists/* COPY pyproject.toml /app/pyproject.toml diff --git a/extensions/sync_ehr_to_oa/api.py b/extensions/sync_ehr_to_oa/api.py index 0f5557c..c305360 100644 --- a/extensions/sync_ehr_to_oa/api.py +++ b/extensions/sync_ehr_to_oa/api.py @@ -3,6 +3,10 @@ from __future__ import annotations import logging from datetime import datetime, timedelta from typing import Any +from urllib.parse import quote_plus + +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Engine from app.integrations.ehr import EhrClient @@ -24,6 +28,7 @@ class SyncEhrToOaApi: self, *, secret_params: dict[str, str], + sqlserver_params: dict[str, Any] | None = None, base_url: str = "https://openapi.italent.cn", timeout_s: float = 10.0, retries: int = 2, @@ -36,9 +41,112 @@ class SyncEhrToOaApi: retries=retries, retry_backoff_s=retry_backoff_s, ) + self._sqlserver_params = dict(sqlserver_params or {}) + self._sql_engine: Engine | None = None + if self._sqlserver_params: + self._sql_engine = self._create_sqlserver_engine(self._sqlserver_params) def close(self) -> None: self._client.close() + if self._sql_engine is not None: + self._sql_engine.dispose() + self._sql_engine = None + + @staticmethod + def _create_sqlserver_engine(sqlserver_params: dict[str, Any]) -> Engine: + host = str(sqlserver_params.get("host") or "").strip() + username = str(sqlserver_params.get("username") or "").strip() + password = str(sqlserver_params.get("password") or "").strip() + database = str(sqlserver_params.get("database") or "").strip() + if not host or not username or not password or not database: + raise ValueError("sqlserver_params.host/username/password/database are required") + + port = int(sqlserver_params.get("port") or 1433) + timeout_s = int(sqlserver_params.get("connect_timeout_s") or 10) + driver = str(sqlserver_params.get("driver") or "ODBC Driver 18 for SQL Server").strip() + trust_server_certificate = str(sqlserver_params.get("trust_server_certificate") or "yes").strip().lower() + encrypt = str(sqlserver_params.get("encrypt") or "yes").strip().lower() + + # 依赖 mssql+pyodbc;运行环境需安装 pyodbc 与对应 ODBC driver。 + # 例:mssql+pyodbc://user:pass@host:1433/db?driver=ODBC+Driver+18+for+SQL+Server + # 使用 query 参数以减少 URL 编码问题。 + username_q = quote_plus(username) + password_q = quote_plus(password) + driver_q = quote_plus(driver) + conn_url = ( + f"mssql+pyodbc://{username_q}:{password_q}@{host}:{port}/{database}" + f"?driver={driver_q}&TrustServerCertificate={trust_server_certificate}&Encrypt={encrypt}" + ) + return create_engine(conn_url, pool_pre_ping=True, pool_recycle=1800, connect_args={"timeout": timeout_s}) + + def ping_sqlserver(self) -> bool: + if self._sql_engine is None: + raise RuntimeError("SQL Server engine is not initialized; pass sqlserver_params when creating SyncEhrToOaApi") + with self._sql_engine.connect() as conn: + conn.execute(text("SELECT 1")) + return True + + def get_oa_record_id_map_from_sqlserver( + self, + *, + table_name: str, + job_numbers: list[str], + job_no_column: str = "field0001", + id_column: str = "id", + schema: str | None = "dbo", + ) -> dict[str, int]: + """ + 从 OA SQLServer 表按工号查询记录ID映射。 + 返回:{job_no: id} + """ + if self._sql_engine is None: + raise RuntimeError("SQL Server engine is not initialized; pass sqlserver_params when creating SyncEhrToOaApi") + t = str(table_name or "").strip() + jc = str(job_no_column or "").strip() + ic = str(id_column or "").strip() + if not t or not jc or not ic: + raise ValueError("table_name/job_no_column/id_column are required") + if not job_numbers: + return {} + + clean_job_numbers = [str(x or "").strip() for x in job_numbers if str(x or "").strip()] + if not clean_job_numbers: + return {} + + quoted_table = f"[{t}]" + if schema: + quoted_table = f"[{str(schema).strip()}].{quoted_table}" + + out: dict[str, int] = {} + # SQL Server IN 参数分批,避免参数过多;每批 500。 + chunk_size = 500 + with self._sql_engine.connect() as conn: + for i in range(0, len(clean_job_numbers), chunk_size): + chunk = clean_job_numbers[i : i + chunk_size] + binds = ", ".join([f":v{idx}" for idx in range(len(chunk))]) + sql = text( + f"SELECT [{jc}] AS job_no, [{ic}] AS row_id " + f"FROM {quoted_table} WITH (NOLOCK) " + f"WHERE [{jc}] IN ({binds})" + ) + params = {f"v{idx}": chunk[idx] for idx in range(len(chunk))} + rows = conn.execute(sql, params).fetchall() + for r in rows: + job_no = str(r.job_no or "").strip() + try: + row_id = int(r.row_id) + except Exception: + continue + if job_no: + out[job_no] = row_id + logger.info( + "SQLServer 工号映射查询完成:table=%s schema=%s input=%s matched=%s", + t, + schema or "", + len(clean_job_numbers), + len(out), + ) + return out @staticmethod def _to_datetime(value: datetime | str | None) -> datetime: diff --git a/extensions/sync_ehr_to_oa/job.py b/extensions/sync_ehr_to_oa/job.py index 716142b..a0d6056 100644 --- a/extensions/sync_ehr_to_oa/job.py +++ b/extensions/sync_ehr_to_oa/job.py @@ -11,6 +11,23 @@ from extensions.sync_ehr_to_oa.api import SyncEhrToOaApi logger = logging.getLogger("connecthub.extensions.sync_ehr_to_oa") +# OA SQLServer(按你的要求硬编码) +_OA_SQLSERVER_PARAMS: dict[str, Any] = { + "host": "192.168.30.108", + "port": 1433, + "database": "seeyon", + "username": "SHOADB91", + "password": "E7nZ8x@12", + "driver": "ODBC Driver 18 for SQL Server", + "encrypt": "no", + "trust_server_certificate": "yes", + "connect_timeout_s": 10, +} +_OA_SQLSERVER_SCHEMA = "dbo" +_OA_SQLSERVER_TABLE = "formmain_20250359" +_OA_SQLSERVER_JOB_NO_COLUMN = "field0001" +_OA_SQLSERVER_ID_COLUMN = "id" + def _cell_value(cell: Any) -> str: if isinstance(cell, dict): @@ -79,6 +96,13 @@ def _normalize_job_no(v: Any) -> str: return s.upper() +def _prefer_non_empty(new_val: Any, old_val: Any) -> str: + s_new = str(new_val or "").strip() + if s_new: + return s_new + return str(old_val or "").strip() + + def _extract_oa_row_id_and_fields(row: dict[str, Any]) -> tuple[int | None, dict[str, Any]]: """ 兼容不同 OA export 返回结构,提取: @@ -216,8 +240,22 @@ class SyncEhrToOaFormJob(BaseJob): preview_limit = 20 seeyon = SeeyonClient(base_url=oa_base_url, rest_user=rest_user, rest_password=rest_password, loginName=login_name) - ehr = SyncEhrToOaApi(secret_params={"app_key": app_key, "app_secret": app_secret}) + ehr = SyncEhrToOaApi( + secret_params={"app_key": app_key, "app_secret": app_secret}, + sqlserver_params=_OA_SQLSERVER_PARAMS, + ) try: + try: + ehr.ping_sqlserver() + logger.info( + "SQLServer 连通性检查通过:host=%s db=%s table=%s", + _OA_SQLSERVER_PARAMS["host"], + _OA_SQLSERVER_PARAMS["database"], + _OA_SQLSERVER_TABLE, + ) + except Exception as e: # noqa: BLE001 + raise RuntimeError(f"SQLServer 连接失败: {e!r}") from e + # 1) EHR 拉取员工任职与组织 emp_res = ehr.get_all_employees_with_record_by_time_window(stop_time=stop_time, capacity=capacity) org_res = ehr.get_all_organizations_by_time_window(stop_time=stop_time, capacity=capacity) @@ -373,7 +411,6 @@ class SyncEhrToOaFormJob(BaseJob): missing = [x for x in needed_displays if x not in display_to_code] if missing: raise RuntimeError(f"OA export invalid: missing form fields by display names: {missing}") - rows = form.get("data") or [] if not isinstance(rows, list): raise RuntimeError("OA export invalid: data is not a list") @@ -384,58 +421,51 @@ class SyncEhrToOaFormJob(BaseJob): if v: oa_master_table_name = v break - if not oa_master_table_name and rows: - r0 = rows[0] if isinstance(rows[0], dict) else {} - master_tbl = r0.get("masterTable") - if isinstance(master_tbl, dict): - oa_master_table_name = str(master_tbl.get("name") or "").strip() + if not oa_master_table_name and fields: + first_field = fields[0] if isinstance(fields[0], dict) else {} + oa_master_table_name = str(first_field.get("tableName") or "").strip() + # 与 SQLServer 查询目标保持一致(优先使用硬编码表) + oa_master_table_name = _OA_SQLSERVER_TABLE if not oa_master_table_name: raise RuntimeError("public_cfg.oa_master_table_name is required (cannot infer from OA export)") logger.info( - "OA 表单解析完成:template=%s master_table=%s form_rows=%s", + "OA 表单解析完成:template=%s master_table=%s", oa_template_code, oa_master_table_name, - len(rows), ) + # 从 export 中提取“工号 -> 字段值字典”,用于值兜底(避免把已有值覆盖为空) + oa_fields_by_job_no_norm: dict[str, dict[str, Any]] = {} + for row in rows: + if not isinstance(row, dict): + continue + _rid, field_map = _extract_oa_row_id_and_fields(row) + job_cell = field_map.get(display_to_code["工号"]) + job_no = _cell_value(job_cell) + norm = _normalize_job_no(job_no) + if norm: + oa_fields_by_job_no_norm[norm] = field_map + logger.info("OA export 字段值索引完成:rows=%s indexed_by_job_no=%s", len(rows), len(oa_fields_by_job_no_norm)) + job_field_code = display_to_code["工号"] oa_id_by_job_no: dict[str, int] = {} oa_id_by_job_no_norm: dict[str, int] = {} - row_parse_miss = 0 - for row in rows: - if not isinstance(row, dict): - continue - row_id, field_map = _extract_oa_row_id_and_fields(row) - job_no = _cell_value(field_map.get(job_field_code)) - if not job_no: - row_parse_miss += 1 - if verbose_trace and row_parse_miss <= 20: - logger.info( - "OA 行解析未取到工号:job_field=%s row_keys=%s field_keys_sample=%s", - job_field_code, - list(row.keys())[:20], - list(field_map.keys())[:20], - ) - continue - - if row_id is None: - row_parse_miss += 1 - if verbose_trace and row_parse_miss <= 20: - logger.info( - "OA 行解析未取到记录ID:job_no=%s row_keys=%s", - job_no, - list(row.keys())[:20], - ) - continue + sql_map = ehr.get_oa_record_id_map_from_sqlserver( + table_name=_OA_SQLSERVER_TABLE, + schema=_OA_SQLSERVER_SCHEMA, + job_numbers=list(ehr_by_job_no.keys()), + job_no_column=_OA_SQLSERVER_JOB_NO_COLUMN, + id_column=_OA_SQLSERVER_ID_COLUMN, + ) + for job_no, row_id in sql_map.items(): oa_id_by_job_no[job_no] = row_id job_no_norm = _normalize_job_no(job_no) if job_no_norm: oa_id_by_job_no_norm[job_no_norm] = row_id logger.info( - "OA 工号索引完成:indexed_job_numbers=%s indexed_job_numbers_norm=%s parse_miss=%s", + "OA 工号索引完成(SQLServer):indexed_job_numbers=%s indexed_job_numbers_norm=%s", len(oa_id_by_job_no), len(oa_id_by_job_no_norm), - row_parse_miss, ) if verbose_trace: for job_no, row_id in list(oa_id_by_job_no.items()): @@ -476,6 +506,7 @@ class SyncEhrToOaFormJob(BaseJob): org_oid = str(rec.get("oIdOrganization") or rec.get("oIdDepartment") or "").strip() org = org_by_oid.get(org_oid, {}) + existing_field_map = oa_fields_by_job_no_norm.get(_normalize_job_no(job_no), {}) company = str((org or {}).get("name") or "") name = str(emp.get("name") or "") @@ -491,6 +522,19 @@ class SyncEhrToOaFormJob(BaseJob): is_leaving = "是" if _date_only(rec.get("lastWorkDate")) else "否" domain_account = _custom_prop_value(emp.get("customProperties"), domain_custom_key) or str(emp.get("_Name") or "") + company = _prefer_non_empty(company, _cell_value(existing_field_map.get(display_to_code["所属公司"]))) + name = _prefer_non_empty(name, _cell_value(existing_field_map.get(display_to_code["姓名"]))) + rd_attr = _prefer_non_empty(rd_attr, _cell_value(existing_field_map.get(display_to_code["研发属性"]))) + place = _prefer_non_empty(place, _cell_value(existing_field_map.get(display_to_code["工作地点"]))) + entry_date = _prefer_non_empty(entry_date, _cell_value(existing_field_map.get(display_to_code["入职日期"]))) + # 离职日期按需求默认 2099-12-31,仅当已有值且北森也空时可被已有值覆盖 + leave_date = _prefer_non_empty(leave_date, _cell_value(existing_field_map.get(display_to_code["离职日期"]))) + id_number = _prefer_non_empty(id_number, _cell_value(existing_field_map.get(display_to_code["身份证号"]))) + hrbp = _prefer_non_empty(hrbp, _cell_value(existing_field_map.get(display_to_code["HRBP"]))) + manager = _prefer_non_empty(manager, _cell_value(existing_field_map.get(display_to_code["汇报人"]))) + is_leaving = _prefer_non_empty(is_leaving, _cell_value(existing_field_map.get(display_to_code["在离职"]))) + domain_account = _prefer_non_empty(domain_account, _cell_value(existing_field_map.get(display_to_code["域账号"]))) + fields_payload = [ {"name": display_to_code["所属公司"], "value": company, "showValue": company}, {"name": display_to_code["姓名"], "value": name, "showValue": name}, diff --git a/pyproject.toml b/pyproject.toml index b9adfcd..f17c00c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "httpx>=0.26", "jinja2>=3.1", "watchfiles>=0.21", + "pyodbc>=5.1", ] [build-system]