add sql server

This commit is contained in:
Marsway 2026-03-25 10:08:10 +08:00
parent 0928492ae4
commit 58b117264a
4 changed files with 205 additions and 37 deletions

View File

@ -22,6 +22,21 @@ RUN set -eux; \
apt-get update; \ apt-get update; \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
build-essential \ build-essential \
ca-certificates \
curl \
gnupg \
unixodbc \
unixodbc-dev \
apt-transport-https \
&& rm -rf /var/lib/apt/lists/*; \
curl -fsSL https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor -o /usr/share/keyrings/microsoft-prod.gpg; \
. /etc/os-release; \
curl -fsSL "https://packages.microsoft.com/config/debian/${VERSION_ID}/prod.list" \
| sed 's#^deb #deb [signed-by=/usr/share/keyrings/microsoft-prod.gpg] #' \
> /etc/apt/sources.list.d/microsoft-prod.list; \
apt-get update; \
ACCEPT_EULA=Y apt-get install -y --no-install-recommends \
msodbcsql18 \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
COPY pyproject.toml /app/pyproject.toml COPY pyproject.toml /app/pyproject.toml

View File

@ -3,6 +3,10 @@ from __future__ import annotations
import logging import logging
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any from typing import Any
from urllib.parse import quote_plus
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
from app.integrations.ehr import EhrClient from app.integrations.ehr import EhrClient
@ -24,6 +28,7 @@ class SyncEhrToOaApi:
self, self,
*, *,
secret_params: dict[str, str], secret_params: dict[str, str],
sqlserver_params: dict[str, Any] | None = None,
base_url: str = "https://openapi.italent.cn", base_url: str = "https://openapi.italent.cn",
timeout_s: float = 10.0, timeout_s: float = 10.0,
retries: int = 2, retries: int = 2,
@ -36,9 +41,112 @@ class SyncEhrToOaApi:
retries=retries, retries=retries,
retry_backoff_s=retry_backoff_s, retry_backoff_s=retry_backoff_s,
) )
self._sqlserver_params = dict(sqlserver_params or {})
self._sql_engine: Engine | None = None
if self._sqlserver_params:
self._sql_engine = self._create_sqlserver_engine(self._sqlserver_params)
def close(self) -> None: def close(self) -> None:
self._client.close() self._client.close()
if self._sql_engine is not None:
self._sql_engine.dispose()
self._sql_engine = None
@staticmethod
def _create_sqlserver_engine(sqlserver_params: dict[str, Any]) -> Engine:
host = str(sqlserver_params.get("host") or "").strip()
username = str(sqlserver_params.get("username") or "").strip()
password = str(sqlserver_params.get("password") or "").strip()
database = str(sqlserver_params.get("database") or "").strip()
if not host or not username or not password or not database:
raise ValueError("sqlserver_params.host/username/password/database are required")
port = int(sqlserver_params.get("port") or 1433)
timeout_s = int(sqlserver_params.get("connect_timeout_s") or 10)
driver = str(sqlserver_params.get("driver") or "ODBC Driver 18 for SQL Server").strip()
trust_server_certificate = str(sqlserver_params.get("trust_server_certificate") or "yes").strip().lower()
encrypt = str(sqlserver_params.get("encrypt") or "yes").strip().lower()
# 依赖 mssql+pyodbc运行环境需安装 pyodbc 与对应 ODBC driver。
# 例mssql+pyodbc://user:pass@host:1433/db?driver=ODBC+Driver+18+for+SQL+Server
# 使用 query 参数以减少 URL 编码问题。
username_q = quote_plus(username)
password_q = quote_plus(password)
driver_q = quote_plus(driver)
conn_url = (
f"mssql+pyodbc://{username_q}:{password_q}@{host}:{port}/{database}"
f"?driver={driver_q}&TrustServerCertificate={trust_server_certificate}&Encrypt={encrypt}"
)
return create_engine(conn_url, pool_pre_ping=True, pool_recycle=1800, connect_args={"timeout": timeout_s})
def ping_sqlserver(self) -> bool:
if self._sql_engine is None:
raise RuntimeError("SQL Server engine is not initialized; pass sqlserver_params when creating SyncEhrToOaApi")
with self._sql_engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
def get_oa_record_id_map_from_sqlserver(
self,
*,
table_name: str,
job_numbers: list[str],
job_no_column: str = "field0001",
id_column: str = "id",
schema: str | None = "dbo",
) -> dict[str, int]:
"""
OA SQLServer 表按工号查询记录ID映射
返回{job_no: id}
"""
if self._sql_engine is None:
raise RuntimeError("SQL Server engine is not initialized; pass sqlserver_params when creating SyncEhrToOaApi")
t = str(table_name or "").strip()
jc = str(job_no_column or "").strip()
ic = str(id_column or "").strip()
if not t or not jc or not ic:
raise ValueError("table_name/job_no_column/id_column are required")
if not job_numbers:
return {}
clean_job_numbers = [str(x or "").strip() for x in job_numbers if str(x or "").strip()]
if not clean_job_numbers:
return {}
quoted_table = f"[{t}]"
if schema:
quoted_table = f"[{str(schema).strip()}].{quoted_table}"
out: dict[str, int] = {}
# SQL Server IN 参数分批,避免参数过多;每批 500。
chunk_size = 500
with self._sql_engine.connect() as conn:
for i in range(0, len(clean_job_numbers), chunk_size):
chunk = clean_job_numbers[i : i + chunk_size]
binds = ", ".join([f":v{idx}" for idx in range(len(chunk))])
sql = text(
f"SELECT [{jc}] AS job_no, [{ic}] AS row_id "
f"FROM {quoted_table} WITH (NOLOCK) "
f"WHERE [{jc}] IN ({binds})"
)
params = {f"v{idx}": chunk[idx] for idx in range(len(chunk))}
rows = conn.execute(sql, params).fetchall()
for r in rows:
job_no = str(r.job_no or "").strip()
try:
row_id = int(r.row_id)
except Exception:
continue
if job_no:
out[job_no] = row_id
logger.info(
"SQLServer 工号映射查询完成table=%s schema=%s input=%s matched=%s",
t,
schema or "",
len(clean_job_numbers),
len(out),
)
return out
@staticmethod @staticmethod
def _to_datetime(value: datetime | str | None) -> datetime: def _to_datetime(value: datetime | str | None) -> datetime:

View File

@ -11,6 +11,23 @@ from extensions.sync_ehr_to_oa.api import SyncEhrToOaApi
logger = logging.getLogger("connecthub.extensions.sync_ehr_to_oa") logger = logging.getLogger("connecthub.extensions.sync_ehr_to_oa")
# OA SQLServer按你的要求硬编码
_OA_SQLSERVER_PARAMS: dict[str, Any] = {
"host": "192.168.30.108",
"port": 1433,
"database": "seeyon",
"username": "SHOADB91",
"password": "E7nZ8x@12",
"driver": "ODBC Driver 18 for SQL Server",
"encrypt": "no",
"trust_server_certificate": "yes",
"connect_timeout_s": 10,
}
_OA_SQLSERVER_SCHEMA = "dbo"
_OA_SQLSERVER_TABLE = "formmain_20250359"
_OA_SQLSERVER_JOB_NO_COLUMN = "field0001"
_OA_SQLSERVER_ID_COLUMN = "id"
def _cell_value(cell: Any) -> str: def _cell_value(cell: Any) -> str:
if isinstance(cell, dict): if isinstance(cell, dict):
@ -79,6 +96,13 @@ def _normalize_job_no(v: Any) -> str:
return s.upper() return s.upper()
def _prefer_non_empty(new_val: Any, old_val: Any) -> str:
s_new = str(new_val or "").strip()
if s_new:
return s_new
return str(old_val or "").strip()
def _extract_oa_row_id_and_fields(row: dict[str, Any]) -> tuple[int | None, dict[str, Any]]: def _extract_oa_row_id_and_fields(row: dict[str, Any]) -> tuple[int | None, dict[str, Any]]:
""" """
兼容不同 OA export 返回结构提取 兼容不同 OA export 返回结构提取
@ -216,8 +240,22 @@ class SyncEhrToOaFormJob(BaseJob):
preview_limit = 20 preview_limit = 20
seeyon = SeeyonClient(base_url=oa_base_url, rest_user=rest_user, rest_password=rest_password, loginName=login_name) seeyon = SeeyonClient(base_url=oa_base_url, rest_user=rest_user, rest_password=rest_password, loginName=login_name)
ehr = SyncEhrToOaApi(secret_params={"app_key": app_key, "app_secret": app_secret}) ehr = SyncEhrToOaApi(
secret_params={"app_key": app_key, "app_secret": app_secret},
sqlserver_params=_OA_SQLSERVER_PARAMS,
)
try: try:
try:
ehr.ping_sqlserver()
logger.info(
"SQLServer 连通性检查通过host=%s db=%s table=%s",
_OA_SQLSERVER_PARAMS["host"],
_OA_SQLSERVER_PARAMS["database"],
_OA_SQLSERVER_TABLE,
)
except Exception as e: # noqa: BLE001
raise RuntimeError(f"SQLServer 连接失败: {e!r}") from e
# 1) EHR 拉取员工任职与组织 # 1) EHR 拉取员工任职与组织
emp_res = ehr.get_all_employees_with_record_by_time_window(stop_time=stop_time, capacity=capacity) emp_res = ehr.get_all_employees_with_record_by_time_window(stop_time=stop_time, capacity=capacity)
org_res = ehr.get_all_organizations_by_time_window(stop_time=stop_time, capacity=capacity) org_res = ehr.get_all_organizations_by_time_window(stop_time=stop_time, capacity=capacity)
@ -373,7 +411,6 @@ class SyncEhrToOaFormJob(BaseJob):
missing = [x for x in needed_displays if x not in display_to_code] missing = [x for x in needed_displays if x not in display_to_code]
if missing: if missing:
raise RuntimeError(f"OA export invalid: missing form fields by display names: {missing}") raise RuntimeError(f"OA export invalid: missing form fields by display names: {missing}")
rows = form.get("data") or [] rows = form.get("data") or []
if not isinstance(rows, list): if not isinstance(rows, list):
raise RuntimeError("OA export invalid: data is not a list") raise RuntimeError("OA export invalid: data is not a list")
@ -384,58 +421,51 @@ class SyncEhrToOaFormJob(BaseJob):
if v: if v:
oa_master_table_name = v oa_master_table_name = v
break break
if not oa_master_table_name and rows: if not oa_master_table_name and fields:
r0 = rows[0] if isinstance(rows[0], dict) else {} first_field = fields[0] if isinstance(fields[0], dict) else {}
master_tbl = r0.get("masterTable") oa_master_table_name = str(first_field.get("tableName") or "").strip()
if isinstance(master_tbl, dict): # 与 SQLServer 查询目标保持一致(优先使用硬编码表)
oa_master_table_name = str(master_tbl.get("name") or "").strip() oa_master_table_name = _OA_SQLSERVER_TABLE
if not oa_master_table_name: if not oa_master_table_name:
raise RuntimeError("public_cfg.oa_master_table_name is required (cannot infer from OA export)") raise RuntimeError("public_cfg.oa_master_table_name is required (cannot infer from OA export)")
logger.info( logger.info(
"OA 表单解析完成template=%s master_table=%s form_rows=%s", "OA 表单解析完成template=%s master_table=%s",
oa_template_code, oa_template_code,
oa_master_table_name, oa_master_table_name,
len(rows),
) )
# 从 export 中提取“工号 -> 字段值字典”,用于值兜底(避免把已有值覆盖为空)
oa_fields_by_job_no_norm: dict[str, dict[str, Any]] = {}
for row in rows:
if not isinstance(row, dict):
continue
_rid, field_map = _extract_oa_row_id_and_fields(row)
job_cell = field_map.get(display_to_code["工号"])
job_no = _cell_value(job_cell)
norm = _normalize_job_no(job_no)
if norm:
oa_fields_by_job_no_norm[norm] = field_map
logger.info("OA export 字段值索引完成rows=%s indexed_by_job_no=%s", len(rows), len(oa_fields_by_job_no_norm))
job_field_code = display_to_code["工号"] job_field_code = display_to_code["工号"]
oa_id_by_job_no: dict[str, int] = {} oa_id_by_job_no: dict[str, int] = {}
oa_id_by_job_no_norm: dict[str, int] = {} oa_id_by_job_no_norm: dict[str, int] = {}
row_parse_miss = 0 sql_map = ehr.get_oa_record_id_map_from_sqlserver(
for row in rows: table_name=_OA_SQLSERVER_TABLE,
if not isinstance(row, dict): schema=_OA_SQLSERVER_SCHEMA,
continue job_numbers=list(ehr_by_job_no.keys()),
row_id, field_map = _extract_oa_row_id_and_fields(row) job_no_column=_OA_SQLSERVER_JOB_NO_COLUMN,
job_no = _cell_value(field_map.get(job_field_code)) id_column=_OA_SQLSERVER_ID_COLUMN,
if not job_no:
row_parse_miss += 1
if verbose_trace and row_parse_miss <= 20:
logger.info(
"OA 行解析未取到工号job_field=%s row_keys=%s field_keys_sample=%s",
job_field_code,
list(row.keys())[:20],
list(field_map.keys())[:20],
) )
continue for job_no, row_id in sql_map.items():
if row_id is None:
row_parse_miss += 1
if verbose_trace and row_parse_miss <= 20:
logger.info(
"OA 行解析未取到记录IDjob_no=%s row_keys=%s",
job_no,
list(row.keys())[:20],
)
continue
oa_id_by_job_no[job_no] = row_id oa_id_by_job_no[job_no] = row_id
job_no_norm = _normalize_job_no(job_no) job_no_norm = _normalize_job_no(job_no)
if job_no_norm: if job_no_norm:
oa_id_by_job_no_norm[job_no_norm] = row_id oa_id_by_job_no_norm[job_no_norm] = row_id
logger.info( logger.info(
"OA 工号索引完成indexed_job_numbers=%s indexed_job_numbers_norm=%s parse_miss=%s", "OA 工号索引完成SQLServerindexed_job_numbers=%s indexed_job_numbers_norm=%s",
len(oa_id_by_job_no), len(oa_id_by_job_no),
len(oa_id_by_job_no_norm), len(oa_id_by_job_no_norm),
row_parse_miss,
) )
if verbose_trace: if verbose_trace:
for job_no, row_id in list(oa_id_by_job_no.items()): for job_no, row_id in list(oa_id_by_job_no.items()):
@ -476,6 +506,7 @@ class SyncEhrToOaFormJob(BaseJob):
org_oid = str(rec.get("oIdOrganization") or rec.get("oIdDepartment") or "").strip() org_oid = str(rec.get("oIdOrganization") or rec.get("oIdDepartment") or "").strip()
org = org_by_oid.get(org_oid, {}) org = org_by_oid.get(org_oid, {})
existing_field_map = oa_fields_by_job_no_norm.get(_normalize_job_no(job_no), {})
company = str((org or {}).get("name") or "") company = str((org or {}).get("name") or "")
name = str(emp.get("name") or "") name = str(emp.get("name") or "")
@ -491,6 +522,19 @@ class SyncEhrToOaFormJob(BaseJob):
is_leaving = "" if _date_only(rec.get("lastWorkDate")) else "" is_leaving = "" if _date_only(rec.get("lastWorkDate")) else ""
domain_account = _custom_prop_value(emp.get("customProperties"), domain_custom_key) or str(emp.get("_Name") or "") domain_account = _custom_prop_value(emp.get("customProperties"), domain_custom_key) or str(emp.get("_Name") or "")
company = _prefer_non_empty(company, _cell_value(existing_field_map.get(display_to_code["所属公司"])))
name = _prefer_non_empty(name, _cell_value(existing_field_map.get(display_to_code["姓名"])))
rd_attr = _prefer_non_empty(rd_attr, _cell_value(existing_field_map.get(display_to_code["研发属性"])))
place = _prefer_non_empty(place, _cell_value(existing_field_map.get(display_to_code["工作地点"])))
entry_date = _prefer_non_empty(entry_date, _cell_value(existing_field_map.get(display_to_code["入职日期"])))
# 离职日期按需求默认 2099-12-31仅当已有值且北森也空时可被已有值覆盖
leave_date = _prefer_non_empty(leave_date, _cell_value(existing_field_map.get(display_to_code["离职日期"])))
id_number = _prefer_non_empty(id_number, _cell_value(existing_field_map.get(display_to_code["身份证号"])))
hrbp = _prefer_non_empty(hrbp, _cell_value(existing_field_map.get(display_to_code["HRBP"])))
manager = _prefer_non_empty(manager, _cell_value(existing_field_map.get(display_to_code["汇报人"])))
is_leaving = _prefer_non_empty(is_leaving, _cell_value(existing_field_map.get(display_to_code["在离职"])))
domain_account = _prefer_non_empty(domain_account, _cell_value(existing_field_map.get(display_to_code["域账号"])))
fields_payload = [ fields_payload = [
{"name": display_to_code["所属公司"], "value": company, "showValue": company}, {"name": display_to_code["所属公司"], "value": company, "showValue": company},
{"name": display_to_code["姓名"], "value": name, "showValue": name}, {"name": display_to_code["姓名"], "value": name, "showValue": name},

View File

@ -20,6 +20,7 @@ dependencies = [
"httpx>=0.26", "httpx>=0.26",
"jinja2>=3.1", "jinja2>=3.1",
"watchfiles>=0.21", "watchfiles>=0.21",
"pyodbc>=5.1",
] ]
[build-system] [build-system]