Vastai-ConnectHub/app/security/fernet.py

78 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import json
import os
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from app.core.config import settings
def _ensure_parent_dir(path: str) -> None:
parent = os.path.dirname(path)
if parent:
os.makedirs(parent, exist_ok=True)
def get_or_create_fernet_key(path: str | None = None) -> bytes:
key_path = path or settings.fernet_key_path
_ensure_parent_dir(key_path)
if os.path.exists(key_path):
with open(key_path, "rb") as f:
return f.read().strip()
key = Fernet.generate_key()
# best-effort set 0o600 (not always supported on some FS)
try:
flags = os.O_WRONLY | os.O_CREAT | os.O_EXCL
fd = os.open(key_path, flags, 0o600)
with os.fdopen(fd, "wb") as f:
f.write(key)
f.write(b"\n")
except FileExistsError:
# race: another process wrote it
with open(key_path, "rb") as f:
return f.read().strip()
except OSError:
with open(key_path, "wb") as f:
f.write(key)
f.write(b"\n")
return key
def _fernet() -> Fernet:
return Fernet(get_or_create_fernet_key())
def encrypt_json(obj: dict[str, Any]) -> str:
data = json.dumps(obj, ensure_ascii=False, separators=(",", ":"), sort_keys=True).encode("utf-8")
return _fernet().encrypt(data).decode("utf-8")
def decrypt_json(token: str) -> dict[str, Any]:
if not token:
return {}
token = token.strip()
# 兼容:历史/手工输入导致误存明文 JSON
if token.startswith("{"):
try:
obj = json.loads(token)
if isinstance(obj, dict):
return obj
except Exception:
pass
# 兼容:末尾 padding '=' 被裁剪导致 base64 解码失败len % 4 != 0
if token and (len(token) % 4) != 0:
token = token + ("=" * (-len(token) % 4))
try:
raw = _fernet().decrypt(token.encode("utf-8"))
except InvalidToken as e:
raise ValueError("Invalid secret_cfg token (Fernet)") from e
return json.loads(raw.decode("utf-8"))