242 lines
9.0 KiB
Python
242 lines
9.0 KiB
Python
from __future__ import annotations
|
||
|
||
import logging
|
||
import os
|
||
import traceback as tb
|
||
from datetime import datetime
|
||
from typing import Any
|
||
from zoneinfo import ZoneInfo
|
||
|
||
from app.core.log_capture import capture_logs
|
||
from app.core.logging import setup_logging
|
||
from app.core.log_context import clear_job_context, set_job_context
|
||
from app.core.config import settings
|
||
from app.db import crud
|
||
from app.db.engine import engine, get_session
|
||
from app.db.models import JobStatus
|
||
from app.db.schema import ensure_schema
|
||
from app.plugins.manager import instantiate
|
||
from app.security.fernet import decrypt_json
|
||
from app.tasks.celery_app import celery_app
|
||
|
||
|
||
logger = logging.getLogger("connecthub.tasks.execute")
|
||
|
||
_MAX_MESSAGE_WARNING_LINES = 200
|
||
_MAX_MESSAGE_CHARS = 50_000
|
||
|
||
|
||
def _extract_warning_lines(run_log_text: str) -> list[str]:
|
||
"""
|
||
从 run_log 文本里提取 WARNING 行(保留原始行文本)。
|
||
capture_logs 的格式为:'%(asctime)s %(levelname)s %(name)s %(message)s'
|
||
"""
|
||
run_log_text = run_log_text or ""
|
||
lines = run_log_text.splitlines()
|
||
return [ln for ln in lines if " WARNING " in f" {ln} "]
|
||
|
||
|
||
def _compose_message(base_message: str, warning_lines: list[str]) -> str:
|
||
"""
|
||
base_message + warnings(具体内容) + summary,并做截断保护。
|
||
"""
|
||
base_message = base_message or ""
|
||
warning_lines = warning_lines or []
|
||
|
||
parts: list[str] = [base_message]
|
||
if warning_lines:
|
||
parts.append(f"WARNINGS ({len(warning_lines)}):")
|
||
if len(warning_lines) <= _MAX_MESSAGE_WARNING_LINES:
|
||
parts.extend(warning_lines)
|
||
else:
|
||
parts.extend(warning_lines[:_MAX_MESSAGE_WARNING_LINES])
|
||
parts.append(f"[TRUNCATED] warnings exceeded {_MAX_MESSAGE_WARNING_LINES} lines")
|
||
parts.append(f"SUMMARY: warnings={len(warning_lines)}")
|
||
|
||
msg = "\n".join([p for p in parts if p is not None])
|
||
if len(msg) > _MAX_MESSAGE_CHARS:
|
||
msg = msg[: _MAX_MESSAGE_CHARS - 64] + "\n[TRUNCATED] message exceeded 50000 chars"
|
||
return msg
|
||
|
||
|
||
def _safe_job_dir_name(job_id: str) -> str:
|
||
"""
|
||
将 job_id 映射为安全的目录名(避免路径分隔符造成目录穿越/嵌套)。
|
||
"""
|
||
s = (job_id or "").strip() or "unknown"
|
||
return s.replace("/", "_").replace("\\", "_")
|
||
|
||
|
||
@celery_app.task(bind=True, name="connecthub.execute_job")
|
||
def execute_job(
|
||
self,
|
||
job_id: str | None = None,
|
||
snapshot_params: dict[str, Any] | None = None,
|
||
log_id: int | None = None,
|
||
) -> dict[str, Any]:
|
||
"""
|
||
通用执行入口:
|
||
- 传 job_id:从 DB 读取 Job 定义
|
||
- 传 snapshot_params:按快照重跑(用于 Admin 一键重试)
|
||
"""
|
||
setup_logging()
|
||
|
||
# 确保 schema 已升级(即使 worker 先启动也不会写库失败)
|
||
try:
|
||
ensure_schema(engine)
|
||
except Exception:
|
||
# schema upgrade 失败不能影响执行(最多导致 run_log 无法写入)
|
||
pass
|
||
|
||
started_at = datetime.utcnow()
|
||
session = get_session()
|
||
status = JobStatus.SUCCESS
|
||
message = ""
|
||
traceback = ""
|
||
result: dict[str, Any] = {}
|
||
run_log_text = ""
|
||
job_log_id: int | None = log_id
|
||
celery_task_id = getattr(self.request, "id", "") or ""
|
||
attempt = int(getattr(self.request, "retries", 0) or 0)
|
||
snapshot: dict[str, Any] = {}
|
||
|
||
try:
|
||
if snapshot_params:
|
||
job_id = snapshot_params["job_id"]
|
||
handler_path = snapshot_params["handler_path"]
|
||
public_cfg = snapshot_params.get("public_cfg", {}) or {}
|
||
secret_token = snapshot_params.get("secret_cfg", "") or ""
|
||
else:
|
||
if not job_id:
|
||
raise ValueError("job_id or snapshot_params is required")
|
||
job = crud.get_job(session, job_id)
|
||
if not job:
|
||
raise ValueError(f"Job not found: {job_id}")
|
||
handler_path = job.handler_path
|
||
public_cfg = job.public_cfg or {}
|
||
secret_token = job.secret_cfg or ""
|
||
|
||
snapshot = snapshot_params or {
|
||
"job_id": job_id,
|
||
"handler_path": handler_path,
|
||
"public_cfg": public_cfg,
|
||
"secret_cfg": secret_token,
|
||
"meta": {
|
||
"trigger": "celery",
|
||
"celery_task_id": celery_task_id,
|
||
"started_at": started_at.isoformat(),
|
||
},
|
||
}
|
||
|
||
# 任务开始即落库一条 RUNNING 记录(若外部已传入 log_id,则只更新该条;若创建失败则降级为旧行为:结束时再 create)
|
||
if job_log_id is None:
|
||
try:
|
||
running = crud.create_job_log(
|
||
session,
|
||
job_id=str(job_id or ""),
|
||
status=JobStatus.RUNNING,
|
||
snapshot_params=snapshot,
|
||
message="运行中",
|
||
traceback="",
|
||
run_log="",
|
||
celery_task_id=celery_task_id,
|
||
attempt=attempt,
|
||
started_at=started_at,
|
||
finished_at=None,
|
||
)
|
||
job_log_id = int(running.id)
|
||
except Exception:
|
||
job_log_id = None
|
||
|
||
# per-run 全量日志落盘(best-effort)。若 job_log_id 缺失则无法保证唯一性,直接跳过。
|
||
per_run_log_path: str | None = None
|
||
if job_log_id is not None and job_id:
|
||
try:
|
||
log_root = settings.log_dir or os.path.join(settings.data_dir, "logs")
|
||
job_dir = os.path.join(log_root, _safe_job_dir_name(str(job_id)))
|
||
os.makedirs(job_dir, exist_ok=True)
|
||
tz = ZoneInfo("Asia/Shanghai")
|
||
ts = datetime.now(tz).strftime("%Y-%m-%d_%H-%M-%S")
|
||
per_run_log_path = os.path.join(job_dir, f"{ts}_log-{int(job_log_id)}.log")
|
||
except Exception:
|
||
per_run_log_path = None
|
||
logger.warning("prepare per-run log file failed job_id=%s log_id=%s", job_id, job_log_id)
|
||
|
||
ctx_tokens = None
|
||
with capture_logs(max_bytes=200_000, job_log_id=job_log_id, file_path=per_run_log_path) as get_run_log:
|
||
try:
|
||
if job_log_id is not None and job_id:
|
||
ctx_tokens = set_job_context(job_id=str(job_id), job_log_id=int(job_log_id))
|
||
|
||
secrets = decrypt_json(secret_token)
|
||
job_instance = instantiate(handler_path)
|
||
out = job_instance.run(params=public_cfg, secrets=secrets)
|
||
if isinstance(out, dict):
|
||
result = out
|
||
message = "OK"
|
||
|
||
except Exception as e: # noqa: BLE001 (framework-wide)
|
||
# 如果是 Celery retry 触发,框架可在此处扩展为自动 retry;此版本先记录失败信息
|
||
status = JobStatus.FAILURE
|
||
message = repr(e)
|
||
traceback = tb.format_exc()
|
||
logger.exception("execute_job failed job_id=%s", job_id)
|
||
finally:
|
||
try:
|
||
clear_job_context(ctx_tokens)
|
||
except Exception:
|
||
# best-effort:不能影响任务执行
|
||
pass
|
||
try:
|
||
run_log_text = get_run_log() or ""
|
||
except Exception:
|
||
run_log_text = ""
|
||
finally:
|
||
finished_at = datetime.utcnow()
|
||
warning_lines = _extract_warning_lines(run_log_text)
|
||
message = _compose_message(message, warning_lines)
|
||
# 结束时:优先更新 RUNNING 那条;若没有则创建最终记录
|
||
if job_log_id is not None:
|
||
crud.update_job_log(
|
||
session,
|
||
job_log_id,
|
||
status=status,
|
||
message=message,
|
||
traceback=traceback,
|
||
run_log=run_log_text,
|
||
celery_task_id=celery_task_id,
|
||
attempt=attempt,
|
||
finished_at=finished_at,
|
||
)
|
||
else:
|
||
if not snapshot:
|
||
snapshot = snapshot_params or {
|
||
"job_id": job_id,
|
||
"handler_path": handler_path if "handler_path" in locals() else "",
|
||
"public_cfg": public_cfg if "public_cfg" in locals() else {},
|
||
"secret_cfg": secret_token if "secret_token" in locals() else "",
|
||
"meta": {
|
||
"trigger": "celery",
|
||
"celery_task_id": celery_task_id,
|
||
"started_at": started_at.isoformat(),
|
||
},
|
||
}
|
||
crud.create_job_log(
|
||
session,
|
||
job_id=str(job_id or ""),
|
||
status=status,
|
||
snapshot_params=snapshot,
|
||
message=message,
|
||
traceback=traceback,
|
||
run_log=run_log_text,
|
||
celery_task_id=celery_task_id,
|
||
attempt=attempt,
|
||
started_at=started_at,
|
||
finished_at=finished_at,
|
||
)
|
||
session.close()
|
||
|
||
return {"status": status.value, "job_id": job_id, "result": result, "message": message}
|
||
|
||
|