ad-user-creator/ad_user_creator/config.py

213 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
import os
from pathlib import Path
from typing import Any, Dict, Optional
import yaml
from filelock import FileLock
from ad_user_creator.exceptions import ConfigError
from ad_user_creator.models import AppConfig, BehaviorConfig, DefaultsConfig, LdapConfig, PathsConfig
def _parse_bool(value: Any, default: bool) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
text = str(value).strip().lower()
if text in {"1", "true", "yes", "y", "on"}:
return True
if text in {"0", "false", "no", "n", "off"}:
return False
return default
def _read_yaml(path: Path) -> Dict[str, Any]:
if not path.exists():
raise ConfigError(f"yaml 配置文件不存在: {path}")
try:
with path.open("r", encoding="utf-8") as handle:
data = yaml.safe_load(handle) or {}
if not isinstance(data, dict):
raise ConfigError("yaml 顶层结构必须是对象")
return data
except yaml.YAMLError as exc:
raise ConfigError(f"yaml 解析失败: {exc}") from exc
def _merge_ldap_config(yaml_data: Dict[str, Any]) -> LdapConfig:
ldap_yaml = yaml_data.get("ldap", {}) or {}
host = ldap_yaml.get("host")
port = int(ldap_yaml.get("port", 636))
use_ssl = _parse_bool(ldap_yaml.get("use_ssl", True), True)
bind_dn = ldap_yaml.get("bind_dn")
bind_password = ldap_yaml.get("bind_password", "")
base_dn = ldap_yaml.get("base_dn")
people_base_dn = ldap_yaml.get("people_base_dn")
groups_base_dn = ldap_yaml.get("groups_base_dn")
missing = [
key
for key, value in {
"ldap.host": host,
"ldap.bind_dn": bind_dn,
"ldap.base_dn": base_dn,
"ldap.people_base_dn": people_base_dn,
"ldap.groups_base_dn": groups_base_dn,
}.items()
if not value
]
if missing:
raise ConfigError(f"缺少必要 LDAP 配置: {', '.join(missing)}")
return LdapConfig(
host=str(host),
port=port,
use_ssl=use_ssl,
bind_dn=str(bind_dn),
bind_password=str(bind_password),
base_dn=str(base_dn),
people_base_dn=str(people_base_dn),
groups_base_dn=str(groups_base_dn),
upn_suffix=str(ldap_yaml.get("upn_suffix", "")),
user_object_classes=list(
ldap_yaml.get(
"user_object_classes",
["top", "person", "organizationalPerson", "user", "posixAccount"],
)
),
user_rdn_attr=str(ldap_yaml.get("user_rdn_attr", "CN")),
)
def _merge_defaults_config(yaml_data: Dict[str, Any]) -> DefaultsConfig:
defaults_yaml = yaml_data.get("defaults", {}) or {}
return DefaultsConfig(
base_group=str(defaults_yaml.get("base_group", "staff")),
initial_uid_number=int(defaults_yaml.get("initial_uid_number", 2106)),
initial_password=str(defaults_yaml.get("initial_password", "1234.com")),
)
def _merge_paths_config(yaml_data: Dict[str, Any]) -> PathsConfig:
paths_yaml = yaml_data.get("paths", {}) or {}
return PathsConfig(
uid_state_file=str(paths_yaml.get("uid_state_file", "state/uid_state.json")),
group_gid_map_file=str(paths_yaml.get("group_gid_map_file", "state/group_gid_map.yaml")),
batch_result_file=str(paths_yaml.get("batch_result_file", "state/last_batch_result.csv")),
log_file=str(paths_yaml.get("log_file", "state/run.log")),
)
def _merge_behavior_config(yaml_data: Dict[str, Any], cli_dry_run: Optional[bool]) -> BehaviorConfig:
behavior_yaml = yaml_data.get("behavior", {}) or {}
base_dry_run = _parse_bool(behavior_yaml.get("dry_run"), False)
dry_run = cli_dry_run if cli_dry_run is not None else base_dry_run
return BehaviorConfig(
skip_if_user_exists=_parse_bool(behavior_yaml.get("skip_if_user_exists"), True),
skip_missing_optional_groups=_parse_bool(behavior_yaml.get("skip_missing_optional_groups"), True),
dry_run=dry_run,
require_ldaps_for_password=_parse_bool(behavior_yaml.get("require_ldaps_for_password"), True),
)
def _merge_ui_options_config(yaml_data: Dict[str, Any]) -> "UIOptionsConfig":
from ad_user_creator.models import UIOptionsConfig
ui_yaml = yaml_data.get("ui_options", {}) or {}
return UIOptionsConfig(
ou_list=list(ui_yaml.get("ou_list", [])),
base_group_list=list(ui_yaml.get("base_group_list", [])),
project_group_list=list(ui_yaml.get("project_group_list", [])),
resource_group_list=list(ui_yaml.get("resource_group_list", [])),
)
def _merge_groups_gid_map(yaml_data: Dict[str, Any]) -> Dict[str, int]:
raw_map = yaml_data.get("groups_gid_map", {}) or {}
if not isinstance(raw_map, dict):
raise ConfigError("groups_gid_map 必须是键值字典")
merged: Dict[str, int] = {}
for group, gid in raw_map.items():
try:
merged[str(group)] = int(gid)
except (TypeError, ValueError) as exc:
raise ConfigError(f"groups_gid_map 非法值: {group}={gid}") from exc
return merged
def _resolve_paths(config: AppConfig, workspace_root: Path) -> AppConfig:
def make_abs(path_text: str) -> str:
path = Path(path_text)
if path.is_absolute():
return str(path)
return str((workspace_root / path).resolve())
config.paths.uid_state_file = make_abs(config.paths.uid_state_file)
config.paths.group_gid_map_file = make_abs(config.paths.group_gid_map_file)
config.paths.batch_result_file = make_abs(config.paths.batch_result_file)
config.paths.log_file = make_abs(config.paths.log_file)
return config
def load_config(
config_path: str = "config/config.yaml",
cli_dry_run: Optional[bool] = None,
workspace_root: Optional[str] = None,
) -> AppConfig:
root = Path(workspace_root or os.getcwd()).resolve()
yaml_full_path = Path(config_path)
if not yaml_full_path.is_absolute():
yaml_full_path = (root / yaml_full_path).resolve()
yaml_data = _read_yaml(yaml_full_path)
app_config = AppConfig(
ldap=_merge_ldap_config(yaml_data),
defaults=_merge_defaults_config(yaml_data),
paths=_merge_paths_config(yaml_data),
behavior=_merge_behavior_config(yaml_data, cli_dry_run=cli_dry_run),
groups_gid_map=_merge_groups_gid_map(yaml_data),
ui_options=_merge_ui_options_config(yaml_data),
)
app_config = _resolve_paths(app_config, root)
if app_config.behavior.require_ldaps_for_password and not app_config.ldap.use_ssl:
raise ConfigError("启用密码设置时必须使用 LDAPS请将 ldap.use_ssl 设置为 true")
return app_config
def update_group_gid_map(config_path: str, discovered_map: Dict[str, int]) -> None:
if not discovered_map:
return
path = Path(config_path).resolve()
lock = FileLock(str(path) + ".lock")
with lock:
data = _read_yaml(path)
current = data.get("groups_gid_map", {}) or {}
if not isinstance(current, dict):
current = {}
merged = {str(k): int(v) for k, v in current.items()}
for group, gid in discovered_map.items():
merged[str(group)] = int(gid)
data["groups_gid_map"] = dict(sorted(merged.items()))
temp_path = path.with_suffix(path.suffix + ".tmp")
with temp_path.open("w", encoding="utf-8") as handle:
yaml.safe_dump(data, handle, allow_unicode=False, sort_keys=False)
temp_path.replace(path)
def update_ui_options(config_path: str, ui_options_data: dict) -> None:
path = Path(config_path).resolve()
lock = FileLock(str(path) + ".lock")
with lock:
data = _read_yaml(path)
data["ui_options"] = ui_options_data
temp_path = path.with_suffix(path.suffix + ".tmp")
with temp_path.open("w", encoding="utf-8") as handle:
yaml.safe_dump(data, handle, allow_unicode=True, sort_keys=False)
temp_path.replace(path)