84 lines
3.6 KiB
Python
84 lines
3.6 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Dict
|
|
|
|
import yaml
|
|
from filelock import FileLock
|
|
|
|
from ad_user_creator.exceptions import StatePersistenceError
|
|
|
|
|
|
class StateStore:
|
|
def __init__(self, uid_state_file: str, group_gid_map_file: str, initial_uid_number: int = 2106) -> None:
|
|
self.uid_state_path = Path(uid_state_file)
|
|
self.group_gid_map_path = Path(group_gid_map_file)
|
|
self.initial_uid_number = initial_uid_number
|
|
self._uid_lock = FileLock(str(self.uid_state_path) + ".lock")
|
|
|
|
def ensure_state_files(self) -> None:
|
|
self.uid_state_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self.group_gid_map_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
if not self.uid_state_path.exists():
|
|
self._write_uid_state({"next_uid_number": self.initial_uid_number, "updated_at": self._now_iso()})
|
|
if not self.group_gid_map_path.exists():
|
|
self._write_group_gid_map({"staff": 3000})
|
|
|
|
def get_next_uid_number(self) -> int:
|
|
with self._uid_lock:
|
|
state = self._read_uid_state()
|
|
return int(state["next_uid_number"])
|
|
|
|
def commit_next_uid_number(self) -> int:
|
|
with self._uid_lock:
|
|
state = self._read_uid_state()
|
|
current = int(state["next_uid_number"])
|
|
next_value = current + 1
|
|
self._write_uid_state({"next_uid_number": next_value, "updated_at": self._now_iso()})
|
|
return current
|
|
|
|
def load_group_gid_map(self) -> Dict[str, int]:
|
|
if not self.group_gid_map_path.exists():
|
|
raise StatePersistenceError(f"gid 映射文件不存在: {self.group_gid_map_path}")
|
|
try:
|
|
with self.group_gid_map_path.open("r", encoding="utf-8") as handle:
|
|
data = yaml.safe_load(handle) or {}
|
|
if not isinstance(data, dict):
|
|
raise StatePersistenceError("gid 映射文件内容必须是字典")
|
|
return {str(k): int(v) for k, v in data.items()}
|
|
except (yaml.YAMLError, ValueError, TypeError) as exc:
|
|
raise StatePersistenceError(f"读取 gid 映射失败: {exc}") from exc
|
|
|
|
def _read_uid_state(self) -> Dict[str, object]:
|
|
if not self.uid_state_path.exists():
|
|
raise StatePersistenceError(f"uid 状态文件不存在: {self.uid_state_path}")
|
|
try:
|
|
with self.uid_state_path.open("r", encoding="utf-8") as handle:
|
|
data = json.load(handle)
|
|
if "next_uid_number" not in data:
|
|
raise StatePersistenceError("uid 状态缺少 next_uid_number")
|
|
return data
|
|
except json.JSONDecodeError as exc:
|
|
raise StatePersistenceError(f"uid 状态解析失败: {exc}") from exc
|
|
|
|
def _write_uid_state(self, payload: Dict[str, object]) -> None:
|
|
try:
|
|
with self.uid_state_path.open("w", encoding="utf-8") as handle:
|
|
json.dump(payload, handle, ensure_ascii=True, indent=2)
|
|
except OSError as exc:
|
|
raise StatePersistenceError(f"写入 uid 状态失败: {exc}") from exc
|
|
|
|
def _write_group_gid_map(self, payload: Dict[str, int]) -> None:
|
|
try:
|
|
with self.group_gid_map_path.open("w", encoding="utf-8") as handle:
|
|
yaml.safe_dump(payload, handle, allow_unicode=False, sort_keys=True)
|
|
except OSError as exc:
|
|
raise StatePersistenceError(f"写入 gid 映射失败: {exc}") from exc
|
|
|
|
@staticmethod
|
|
def _now_iso() -> str:
|
|
return datetime.now(tz=timezone.utc).replace(microsecond=0).isoformat()
|