Vastai-ConnectHub/app/tasks/execute.py

201 lines
7.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import logging
import traceback as tb
from datetime import datetime
from typing import Any
from app.core.log_capture import capture_logs
from app.core.logging import setup_logging
from app.db import crud
from app.db.engine import engine, get_session
from app.db.models import JobStatus
from app.db.schema import ensure_schema
from app.plugins.manager import instantiate
from app.security.fernet import decrypt_json
from app.tasks.celery_app import celery_app
logger = logging.getLogger("connecthub.tasks.execute")
_MAX_MESSAGE_WARNING_LINES = 200
_MAX_MESSAGE_CHARS = 50_000
def _extract_warning_lines(run_log_text: str) -> list[str]:
"""
从 run_log 文本里提取 WARNING 行(保留原始行文本)。
capture_logs 的格式为:'%(asctime)s %(levelname)s %(name)s %(message)s'
"""
run_log_text = run_log_text or ""
lines = run_log_text.splitlines()
return [ln for ln in lines if " WARNING " in f" {ln} "]
def _compose_message(base_message: str, warning_lines: list[str]) -> str:
"""
base_message + warnings(具体内容) + summary并做截断保护。
"""
base_message = base_message or ""
warning_lines = warning_lines or []
parts: list[str] = [base_message]
if warning_lines:
parts.append(f"WARNINGS ({len(warning_lines)}):")
if len(warning_lines) <= _MAX_MESSAGE_WARNING_LINES:
parts.extend(warning_lines)
else:
parts.extend(warning_lines[:_MAX_MESSAGE_WARNING_LINES])
parts.append(f"[TRUNCATED] warnings exceeded {_MAX_MESSAGE_WARNING_LINES} lines")
parts.append(f"SUMMARY: warnings={len(warning_lines)}")
msg = "\n".join([p for p in parts if p is not None])
if len(msg) > _MAX_MESSAGE_CHARS:
msg = msg[: _MAX_MESSAGE_CHARS - 64] + "\n[TRUNCATED] message exceeded 50000 chars"
return msg
@celery_app.task(bind=True, name="connecthub.execute_job")
def execute_job(self, job_id: str | None = None, snapshot_params: dict[str, Any] | None = None) -> dict[str, Any]:
"""
通用执行入口:
- 传 job_id从 DB 读取 Job 定义
- 传 snapshot_params按快照重跑用于 Admin 一键重试)
"""
setup_logging()
# 确保 schema 已升级(即使 worker 先启动也不会写库失败)
try:
ensure_schema(engine)
except Exception:
# schema upgrade 失败不能影响执行(最多导致 run_log 无法写入)
pass
started_at = datetime.utcnow()
session = get_session()
status = JobStatus.SUCCESS
message = ""
traceback = ""
result: dict[str, Any] = {}
run_log_text = ""
log_id: int | None = None
celery_task_id = getattr(self.request, "id", "") or ""
attempt = int(getattr(self.request, "retries", 0) or 0)
snapshot: dict[str, Any] = {}
try:
with capture_logs(max_bytes=200_000) as get_run_log:
try:
if snapshot_params:
job_id = snapshot_params["job_id"]
handler_path = snapshot_params["handler_path"]
public_cfg = snapshot_params.get("public_cfg", {}) or {}
secret_token = snapshot_params.get("secret_cfg", "") or ""
else:
if not job_id:
raise ValueError("job_id or snapshot_params is required")
job = crud.get_job(session, job_id)
if not job:
raise ValueError(f"Job not found: {job_id}")
handler_path = job.handler_path
public_cfg = job.public_cfg or {}
secret_token = job.secret_cfg or ""
snapshot = snapshot_params or {
"job_id": job_id,
"handler_path": handler_path,
"public_cfg": public_cfg,
"secret_cfg": secret_token,
"meta": {
"trigger": "celery",
"celery_task_id": celery_task_id,
"started_at": started_at.isoformat(),
},
}
# 任务开始即落库一条 RUNNING 记录(若失败则降级为旧行为:结束时再 create
try:
running = crud.create_job_log(
session,
job_id=str(job_id or ""),
status=JobStatus.RUNNING,
snapshot_params=snapshot,
message="运行中",
traceback="",
run_log="",
celery_task_id=celery_task_id,
attempt=attempt,
started_at=started_at,
finished_at=None,
)
log_id = int(running.id)
except Exception:
log_id = None
secrets = decrypt_json(secret_token)
job_instance = instantiate(handler_path)
out = job_instance.run(params=public_cfg, secrets=secrets)
if isinstance(out, dict):
result = out
message = "OK"
except Exception as e: # noqa: BLE001 (framework-wide)
# 如果是 Celery retry 触发,框架可在此处扩展为自动 retry此版本先记录失败信息
status = JobStatus.FAILURE
message = repr(e)
traceback = tb.format_exc()
logger.exception("execute_job failed job_id=%s", job_id)
finally:
try:
run_log_text = get_run_log() or ""
except Exception:
run_log_text = ""
finally:
finished_at = datetime.utcnow()
warning_lines = _extract_warning_lines(run_log_text)
message = _compose_message(message, warning_lines)
# 结束时:优先更新 RUNNING 那条;若没有则创建最终记录(兼容降级)
if log_id is not None:
crud.update_job_log(
session,
log_id,
status=status,
message=message,
traceback=traceback,
run_log=run_log_text,
celery_task_id=celery_task_id,
attempt=attempt,
finished_at=finished_at,
)
else:
if not snapshot:
snapshot = snapshot_params or {
"job_id": job_id,
"handler_path": handler_path if "handler_path" in locals() else "",
"public_cfg": public_cfg if "public_cfg" in locals() else {},
"secret_cfg": secret_token if "secret_token" in locals() else "",
"meta": {
"trigger": "celery",
"celery_task_id": celery_task_id,
"started_at": started_at.isoformat(),
},
}
crud.create_job_log(
session,
job_id=str(job_id or ""),
status=status,
snapshot_params=snapshot,
message=message,
traceback=traceback,
run_log=run_log_text,
celery_task_id=celery_task_id,
attempt=attempt,
started_at=started_at,
finished_at=finished_at,
)
session.close()
return {"status": status.value, "job_id": job_id, "result": result, "message": message}