Vastai-ConnectHub/app/integrations/didi.py

412 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import hashlib
import json as jsonlib
import logging
import time
from typing import Any
import httpx
from app.integrations.base import BaseClient
logger = logging.getLogger("connecthub.integrations.didi")
def _contains_unsupported_sign_chars(s: str) -> bool:
# 文档提示:签名中不支持 \0 \t \n \x0B \r 以及空格进行加密处理。
# 这里仅做检测与告警,不自动清洗,避免服务端/客户端不一致。
return any(ch in s for ch in ("\0", "\t", "\n", "\x0b", "\r", " "))
class DidiClient(BaseClient):
"""
滴滴管理 API Client2024 版):
- POST /river/Auth/authorize 获取 access_token建议缓存半小时401 刷新后重试一次)
- 按文档规则生成 sign默认 MD5
参考:
- https://opendocs.xiaojukeji.com/version2024/10951
- https://opendocs.xiaojukeji.com/version2024/10945
"""
def __init__(
self,
*,
base_url: str,
client_id: str,
client_secret: str,
sign_key: str,
grant_type: str = "client_credentials",
token_skew_s: int = 30,
timeout_s: float = 10.0,
retries: int = 2,
retry_backoff_s: float = 0.5,
headers: dict[str, str] | None = None,
) -> None:
super().__init__(
base_url=base_url,
timeout_s=timeout_s,
retries=retries,
retry_backoff_s=retry_backoff_s,
headers=headers,
)
self.client_id = client_id
self.client_secret = client_secret
self.sign_key = sign_key
self.grant_type = grant_type
self.token_skew_s = token_skew_s
self._access_token: str | None = None
self._token_expires_at: float | None = None
self._token_type: str | None = None
def gen_sign(self, params: dict[str, Any], *, sign_method: str = "md5") -> str:
"""
签名算法(默认 MD5
1) 将 sign_key 加入参与签名参数(不参与传递,仅参与计算)
2) 参数名升序排序
3) 以 & 连接成 a=xxx&b=yyy...
4) md5/sha256 得到 sign小写 hex
文档https://opendocs.xiaojukeji.com/version2024/10945
"""
if sign_method.lower() != "md5":
raise ValueError("Only md5 sign_method is supported in this client (default)")
p = dict(params or {})
p["sign_key"] = self.sign_key
# 排序并拼接
items: list[tuple[str, str]] = []
for k in sorted(p.keys()):
v = p.get(k)
sv = "" if v is None else str(v).strip()
items.append((str(k), sv))
sign_str = "&".join([f"{k}={v}" for k, v in items])
if _contains_unsupported_sign_chars(sign_str):
logger.warning("Didi sign_str contains unsupported chars per docs (signing anyway)")
return hashlib.md5(sign_str.encode("utf-8")).hexdigest()
def authorize(self) -> str:
"""
授权获取 access_token
POST /river/Auth/authorize
文档https://opendocs.xiaojukeji.com/version2024/10951
"""
ts = int(time.time())
body: dict[str, Any] = {
"client_id": self.client_id,
"client_secret": self.client_secret,
"grant_type": self.grant_type,
"timestamp": ts,
}
body["sign"] = self.gen_sign(body)
resp = super().request(
"POST",
"/river/Auth/authorize",
json=body,
headers={"Content-Type": "application/json"},
)
data = resp.json() if resp.content else {}
access_token = str(data.get("access_token", "") or "")
expires_in = int(data.get("expires_in", 0) or 0)
token_type = str(data.get("token_type", "") or "Bearer")
if not access_token:
raise RuntimeError("Didi authorize failed (access_token missing)")
now = time.time()
skew = max(0, int(self.token_skew_s or 0))
# expires_in 单位秒;按文档通常为 1800
self._access_token = access_token
self._token_type = token_type
self._token_expires_at = now + max(0, expires_in - skew)
logger.info("Didi access_token acquired (cached) expires_in=%s token_type=%s", expires_in, token_type)
return access_token
def _get_access_token(self) -> str:
now = time.time()
if self._access_token and self._token_expires_at and now < self._token_expires_at:
return self._access_token
return self.authorize()
def _build_signed_query(self, *, company_id: str, extra_params: dict[str, Any]) -> dict[str, Any]:
"""
构造“参与签名且实际传递”的 query 参数(不包含 sign_key
- client_id/access_token/company_id/timestamp + extra_params + sign
"""
token = self._get_access_token()
ts = int(time.time())
params: dict[str, Any] = {
"client_id": self.client_id,
"access_token": token,
"company_id": company_id,
"timestamp": ts,
}
if extra_params:
params.update(extra_params)
params["sign"] = self.gen_sign({k: v for k, v in params.items() if k != "sign"})
return params
@staticmethod
def _raise_if_errno(api_name: str, payload: Any) -> None:
try:
errno = payload.get("errno")
errmsg = payload.get("errmsg")
except Exception as e: # noqa: BLE001
raise RuntimeError(f"{api_name} invalid response (not a dict)") from e
if errno is None:
raise RuntimeError(f"{api_name} invalid response (errno missing)")
try:
errno_i = int(errno)
except Exception:
errno_i = -1
if errno_i != 0:
raise RuntimeError(f"{api_name} failed errno={errno} errmsg={errmsg!r}")
def get_legal_entities(
self,
*,
company_id: str,
offset: int,
length: int,
keyword: str | None = None,
legal_entity_id: str | None = None,
out_legal_entity_id: str | None = None,
) -> dict[str, Any]:
"""
公司主体查询:
GET /river/LegalEntity/get
"""
extra: dict[str, Any] = {"offset": offset, "length": length}
if keyword:
extra["keyword"] = keyword
if legal_entity_id:
extra["legal_entity_id"] = legal_entity_id
if out_legal_entity_id:
extra["out_legal_entity_id"] = out_legal_entity_id
params = self._build_signed_query(company_id=company_id, extra_params=extra)
resp = super().request(
"GET",
"/river/LegalEntity/get",
params=params,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
payload = resp.json() if resp.content else {}
self._raise_if_errno("LegalEntity.get", payload)
data = payload.get("data") or {}
if not isinstance(data, dict):
raise RuntimeError("LegalEntity.get invalid response (data not a dict)")
return data # {total, records}
def get_member_detail(
self,
*,
company_id: str,
employee_number: str | None = None,
member_id: str | None = None,
phone: str | None = None,
) -> dict[str, Any]:
"""
员工明细:
GET /river/Member/detail
"""
extra: dict[str, Any] = {}
if member_id:
extra["member_id"] = member_id
elif employee_number:
extra["employee_number"] = employee_number
elif phone:
extra["phone"] = phone
else:
raise ValueError("member_id/employee_number/phone cannot all be empty")
params = self._build_signed_query(company_id=company_id, extra_params=extra)
resp = super().request(
"GET",
"/river/Member/detail",
params=params,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
payload = resp.json() if resp.content else {}
self._raise_if_errno("Member.detail", payload)
data = payload.get("data") or {}
if not isinstance(data, dict):
raise RuntimeError("Member.detail invalid response (data not a dict)")
return data
def edit_member_legal_entity(
self,
*,
company_id: str,
member_id: str | None,
employee_number: str | None,
legal_entity_id: str,
) -> None:
"""
员工修改更新员工所在公司主体legal_entity_id
POST /river/Member/edit
"""
if not member_id and not employee_number:
raise ValueError("member_id or employee_number is required")
if not legal_entity_id:
raise ValueError("legal_entity_id is required")
token = self._get_access_token()
ts = int(time.time())
data_str = self.dumps_data_for_sign({"legal_entity_id": legal_entity_id})
body: dict[str, Any] = {
"client_id": self.client_id,
"access_token": token,
"company_id": company_id,
"timestamp": ts,
"data": data_str,
}
if member_id:
body["member_id"] = member_id
if employee_number:
body["employee_number"] = employee_number
body["sign"] = self.gen_sign({k: v for k, v in body.items() if k != "sign"})
resp = super().request(
"POST",
"/river/Member/edit",
json=body,
headers={"Content-Type": "application/json"},
)
payload = resp.json() if resp.content else {}
self._raise_if_errno("Member.edit", payload)
return None
def request_authed(
self,
method: str,
path: str,
*,
params: dict[str, Any] | None = None,
json: Any = None,
data: Any = None,
headers: dict[str, str] | None = None,
signed_params: dict[str, Any] | None = None,
**kwargs: Any,
) -> httpx.Response:
"""
统一带 token +(可选)签名的请求:
- Authorization: Bearer <access_token>
- 若 signed_params 提供:自动补 timestamp 与 sign并注入到 json/params/data优先注入到 dict 类型的 json其次 params再次 data否则默认注入 json dict
- 遇到 401清空 token重新 authorize 后重试一次
"""
token = self._get_access_token()
token_type = self._token_type or "Bearer"
extra_headers = dict(headers or {})
extra_headers["Authorization"] = f"{token_type} {token}"
sp: dict[str, Any] | None = None
if signed_params is not None:
sp = dict(signed_params)
if "timestamp" not in sp:
sp["timestamp"] = int(time.time())
# 如该接口签名参数包含 access_token则需参与签名
if "access_token" in sp and not sp.get("access_token"):
sp["access_token"] = token
sp["sign"] = self.gen_sign({k: v for k, v in sp.items() if k != "sign"})
def _inject(target_json: Any, target_params: dict[str, Any] | None, target_data: Any) -> tuple[Any, dict[str, Any] | None, Any]:
if sp is None:
return target_json, target_params, target_data
if isinstance(target_json, dict):
merged = dict(target_json)
merged.update(sp)
return merged, target_params, target_data
if isinstance(target_params, dict):
merged_p = dict(target_params)
merged_p.update(sp)
return target_json, merged_p, target_data
if isinstance(target_data, dict):
merged_d = dict(target_data)
merged_d.update(sp)
return target_json, target_params, merged_d
# 默认注入到 json dict
return dict(sp), target_params, target_data
json2, params2, data2 = _inject(json, params, data)
try:
return super().request(method, path, params=params2, json=json2, data=data2, headers=extra_headers, **kwargs)
except httpx.HTTPStatusError as e:
resp = e.response
if resp.status_code != 401:
raise
# 401token 无效或过期,刷新后仅重试一次
logger.info("Didi access_token invalid (401), refreshing and retrying once")
self._access_token = None
self._token_expires_at = None
self._token_type = None
token2 = self._get_access_token()
token_type2 = self._token_type or "Bearer"
extra_headers2 = dict(headers or {})
extra_headers2["Authorization"] = f"{token_type2} {token2}"
# 若签名参数中包含 access_token需要更新并重新计算 sign
if signed_params is not None:
sp2 = dict(signed_params)
if "timestamp" not in sp2:
sp2["timestamp"] = int(time.time())
if "access_token" in sp2:
sp2["access_token"] = token2
sp2["sign"] = self.gen_sign({k: v for k, v in sp2.items() if k != "sign"})
json2_retry, params2_retry, data2_retry = _inject(json, params, data)
# _inject 使用闭包 sp这里临时覆盖行为以避免额外结构改动
if isinstance(json, dict):
json2_retry = dict(json)
json2_retry.update(sp2)
elif isinstance(params, dict):
params2_retry = dict(params)
params2_retry.update(sp2)
elif isinstance(data, dict):
data2_retry = dict(data)
data2_retry.update(sp2)
else:
json2_retry = dict(sp2)
return super().request(
method,
path,
params=params2_retry,
json=json2_retry,
data=data2_retry,
headers=extra_headers2,
**kwargs,
)
return super().request(method, path, params=params2, json=json2, data=data2, headers=extra_headers2, **kwargs)
def post_signed_json(self, path: str, *, body: dict[str, Any]) -> httpx.Response:
"""
便捷方法JSON POST + 自动补 timestamp/sign + 自动带 Authorization。
注意:如 body 内包含复杂字段(例如 data 为对象),建议调用方先 json.dumps(...) 成字符串再参与签名。
"""
if not isinstance(body, dict):
raise ValueError("body must be a dict")
return self.request_authed(
"POST",
path,
json=body,
signed_params=body,
headers={"Content-Type": "application/json"},
)
@staticmethod
def dumps_data_for_sign(data_obj: Any) -> str:
"""
将复杂 data 对象序列化为“参与签名的字符串”(紧凑 JSON以贴近文档示例。
"""
return jsonlib.dumps(data_obj, ensure_ascii=False, separators=(",", ":"))