from __future__ import annotations import hashlib import json as jsonlib import logging import time from typing import Any import httpx from app.integrations.base import BaseClient logger = logging.getLogger("connecthub.integrations.didi") def _contains_unsupported_sign_chars(s: str) -> bool: # 文档提示:签名中不支持 \0 \t \n \x0B \r 以及空格进行加密处理。 # 这里仅做检测与告警,不自动清洗,避免服务端/客户端不一致。 return any(ch in s for ch in ("\0", "\t", "\n", "\x0b", "\r", " ")) class DidiClient(BaseClient): """ 滴滴管理 API Client(2024 版): - POST /river/Auth/authorize 获取 access_token(建议缓存半小时;401 刷新后重试一次) - 按文档规则生成 sign(默认 MD5) 参考: - https://opendocs.xiaojukeji.com/version2024/10951 - https://opendocs.xiaojukeji.com/version2024/10945 """ def __init__( self, *, base_url: str, client_id: str, client_secret: str, sign_key: str, grant_type: str = "client_credentials", token_skew_s: int = 30, timeout_s: float = 10.0, retries: int = 2, retry_backoff_s: float = 0.5, headers: dict[str, str] | None = None, ) -> None: super().__init__( base_url=base_url, timeout_s=timeout_s, retries=retries, retry_backoff_s=retry_backoff_s, headers=headers, ) self.client_id = client_id self.client_secret = client_secret self.sign_key = sign_key self.grant_type = grant_type self.token_skew_s = token_skew_s self._access_token: str | None = None self._token_expires_at: float | None = None self._token_type: str | None = None def gen_sign(self, params: dict[str, Any], *, sign_method: str = "md5") -> str: """ 签名算法(默认 MD5): 1) 将 sign_key 加入参与签名参数(不参与传递,仅参与计算) 2) 参数名升序排序 3) 以 & 连接成 a=xxx&b=yyy... 4) md5/sha256 得到 sign(小写 hex) 文档:https://opendocs.xiaojukeji.com/version2024/10945 """ if sign_method.lower() != "md5": raise ValueError("Only md5 sign_method is supported in this client (default)") p = dict(params or {}) p["sign_key"] = self.sign_key # 排序并拼接 items: list[tuple[str, str]] = [] for k in sorted(p.keys()): v = p.get(k) sv = "" if v is None else str(v).strip() items.append((str(k), sv)) sign_str = "&".join([f"{k}={v}" for k, v in items]) if _contains_unsupported_sign_chars(sign_str): logger.warning("Didi sign_str contains unsupported chars per docs (signing anyway)") return hashlib.md5(sign_str.encode("utf-8")).hexdigest() def authorize(self) -> str: """ 授权获取 access_token: POST /river/Auth/authorize 文档:https://opendocs.xiaojukeji.com/version2024/10951 """ ts = int(time.time()) body: dict[str, Any] = { "client_id": self.client_id, "client_secret": self.client_secret, "grant_type": self.grant_type, "timestamp": ts, } body["sign"] = self.gen_sign(body) resp = super().request( "POST", "/river/Auth/authorize", json=body, headers={"Content-Type": "application/json"}, ) data = resp.json() if resp.content else {} access_token = str(data.get("access_token", "") or "") expires_in = int(data.get("expires_in", 0) or 0) token_type = str(data.get("token_type", "") or "Bearer") if not access_token: raise RuntimeError("Didi authorize failed (access_token missing)") now = time.time() skew = max(0, int(self.token_skew_s or 0)) # expires_in 单位秒;按文档通常为 1800 self._access_token = access_token self._token_type = token_type self._token_expires_at = now + max(0, expires_in - skew) logger.info("Didi access_token acquired (cached) expires_in=%s token_type=%s", expires_in, token_type) return access_token def _get_access_token(self) -> str: now = time.time() if self._access_token and self._token_expires_at and now < self._token_expires_at: return self._access_token return self.authorize() def _build_signed_query(self, *, company_id: str, extra_params: dict[str, Any]) -> dict[str, Any]: """ 构造“参与签名且实际传递”的 query 参数(不包含 sign_key): - client_id/access_token/company_id/timestamp + extra_params + sign """ token = self._get_access_token() ts = int(time.time()) params: dict[str, Any] = { "client_id": self.client_id, "access_token": token, "company_id": company_id, "timestamp": ts, } if extra_params: params.update(extra_params) params["sign"] = self.gen_sign({k: v for k, v in params.items() if k != "sign"}) return params @staticmethod def _raise_if_errno(api_name: str, payload: Any) -> None: try: errno = payload.get("errno") errmsg = payload.get("errmsg") except Exception as e: # noqa: BLE001 raise RuntimeError(f"{api_name} invalid response (not a dict)") from e if errno is None: raise RuntimeError(f"{api_name} invalid response (errno missing)") try: errno_i = int(errno) except Exception: errno_i = -1 if errno_i != 0: raise RuntimeError(f"{api_name} failed errno={errno} errmsg={errmsg!r}") def get_legal_entities( self, *, company_id: str, offset: int, length: int, keyword: str | None = None, legal_entity_id: str | None = None, out_legal_entity_id: str | None = None, ) -> dict[str, Any]: """ 公司主体查询: GET /river/LegalEntity/get """ extra: dict[str, Any] = {"offset": offset, "length": length} if keyword: extra["keyword"] = keyword if legal_entity_id: extra["legal_entity_id"] = legal_entity_id if out_legal_entity_id: extra["out_legal_entity_id"] = out_legal_entity_id params = self._build_signed_query(company_id=company_id, extra_params=extra) resp = super().request( "GET", "/river/LegalEntity/get", params=params, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) payload = resp.json() if resp.content else {} self._raise_if_errno("LegalEntity.get", payload) data = payload.get("data") or {} if not isinstance(data, dict): raise RuntimeError("LegalEntity.get invalid response (data not a dict)") return data # {total, records} def get_member_detail( self, *, company_id: str, employee_number: str | None = None, member_id: str | None = None, phone: str | None = None, ) -> dict[str, Any]: """ 员工明细: GET /river/Member/detail """ extra: dict[str, Any] = {} if member_id: extra["member_id"] = member_id elif employee_number: extra["employee_number"] = employee_number elif phone: extra["phone"] = phone else: raise ValueError("member_id/employee_number/phone cannot all be empty") params = self._build_signed_query(company_id=company_id, extra_params=extra) resp = super().request( "GET", "/river/Member/detail", params=params, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) payload = resp.json() if resp.content else {} self._raise_if_errno("Member.detail", payload) data = payload.get("data") or {} if not isinstance(data, dict): raise RuntimeError("Member.detail invalid response (data not a dict)") return data def edit_member_legal_entity( self, *, company_id: str, member_id: str | None, employee_number: str | None, legal_entity_id: str, ) -> None: """ 员工修改:更新员工所在公司主体(legal_entity_id) POST /river/Member/edit """ if not member_id and not employee_number: raise ValueError("member_id or employee_number is required") if not legal_entity_id: raise ValueError("legal_entity_id is required") token = self._get_access_token() ts = int(time.time()) data_str = self.dumps_data_for_sign({"legal_entity_id": legal_entity_id}) body: dict[str, Any] = { "client_id": self.client_id, "access_token": token, "company_id": company_id, "timestamp": ts, "data": data_str, } if member_id: body["member_id"] = member_id if employee_number: body["employee_number"] = employee_number body["sign"] = self.gen_sign({k: v for k, v in body.items() if k != "sign"}) resp = super().request( "POST", "/river/Member/edit", json=body, headers={"Content-Type": "application/json"}, ) payload = resp.json() if resp.content else {} self._raise_if_errno("Member.edit", payload) return None def request_authed( self, method: str, path: str, *, params: dict[str, Any] | None = None, json: Any = None, data: Any = None, headers: dict[str, str] | None = None, signed_params: dict[str, Any] | None = None, **kwargs: Any, ) -> httpx.Response: """ 统一带 token +(可选)签名的请求: - Authorization: Bearer - 若 signed_params 提供:自动补 timestamp 与 sign,并注入到 json/params/data(优先注入到 dict 类型的 json,其次 params,再次 data,否则默认注入 json dict) - 遇到 401:清空 token,重新 authorize 后重试一次 """ token = self._get_access_token() token_type = self._token_type or "Bearer" extra_headers = dict(headers or {}) extra_headers["Authorization"] = f"{token_type} {token}" sp: dict[str, Any] | None = None if signed_params is not None: sp = dict(signed_params) if "timestamp" not in sp: sp["timestamp"] = int(time.time()) # 如该接口签名参数包含 access_token,则需参与签名 if "access_token" in sp and not sp.get("access_token"): sp["access_token"] = token sp["sign"] = self.gen_sign({k: v for k, v in sp.items() if k != "sign"}) def _inject(target_json: Any, target_params: dict[str, Any] | None, target_data: Any) -> tuple[Any, dict[str, Any] | None, Any]: if sp is None: return target_json, target_params, target_data if isinstance(target_json, dict): merged = dict(target_json) merged.update(sp) return merged, target_params, target_data if isinstance(target_params, dict): merged_p = dict(target_params) merged_p.update(sp) return target_json, merged_p, target_data if isinstance(target_data, dict): merged_d = dict(target_data) merged_d.update(sp) return target_json, target_params, merged_d # 默认注入到 json dict return dict(sp), target_params, target_data json2, params2, data2 = _inject(json, params, data) try: return super().request(method, path, params=params2, json=json2, data=data2, headers=extra_headers, **kwargs) except httpx.HTTPStatusError as e: resp = e.response if resp.status_code != 401: raise # 401:token 无效或过期,刷新后仅重试一次 logger.info("Didi access_token invalid (401), refreshing and retrying once") self._access_token = None self._token_expires_at = None self._token_type = None token2 = self._get_access_token() token_type2 = self._token_type or "Bearer" extra_headers2 = dict(headers or {}) extra_headers2["Authorization"] = f"{token_type2} {token2}" # 若签名参数中包含 access_token,需要更新并重新计算 sign if signed_params is not None: sp2 = dict(signed_params) if "timestamp" not in sp2: sp2["timestamp"] = int(time.time()) if "access_token" in sp2: sp2["access_token"] = token2 sp2["sign"] = self.gen_sign({k: v for k, v in sp2.items() if k != "sign"}) json2_retry, params2_retry, data2_retry = _inject(json, params, data) # _inject 使用闭包 sp;这里临时覆盖行为以避免额外结构改动 if isinstance(json, dict): json2_retry = dict(json) json2_retry.update(sp2) elif isinstance(params, dict): params2_retry = dict(params) params2_retry.update(sp2) elif isinstance(data, dict): data2_retry = dict(data) data2_retry.update(sp2) else: json2_retry = dict(sp2) return super().request( method, path, params=params2_retry, json=json2_retry, data=data2_retry, headers=extra_headers2, **kwargs, ) return super().request(method, path, params=params2, json=json2, data=data2, headers=extra_headers2, **kwargs) def post_signed_json(self, path: str, *, body: dict[str, Any]) -> httpx.Response: """ 便捷方法:JSON POST + 自动补 timestamp/sign + 自动带 Authorization。 注意:如 body 内包含复杂字段(例如 data 为对象),建议调用方先 json.dumps(...) 成字符串再参与签名。 """ if not isinstance(body, dict): raise ValueError("body must be a dict") return self.request_authed( "POST", path, json=body, signed_params=body, headers={"Content-Type": "application/json"}, ) @staticmethod def dumps_data_for_sign(data_obj: Any) -> str: """ 将复杂 data 对象序列化为“参与签名的字符串”(紧凑 JSON),以贴近文档示例。 """ return jsonlib.dumps(data_obj, ensure_ascii=False, separators=(",", ":"))