from __future__ import annotations import os from pathlib import Path from typing import Any, Dict, Optional import yaml from filelock import FileLock from ad_user_creator.exceptions import ConfigError from ad_user_creator.models import AppConfig, BehaviorConfig, DefaultsConfig, LdapConfig, PathsConfig def _parse_bool(value: Any, default: bool) -> bool: if value is None: return default if isinstance(value, bool): return value text = str(value).strip().lower() if text in {"1", "true", "yes", "y", "on"}: return True if text in {"0", "false", "no", "n", "off"}: return False return default def _read_yaml(path: Path) -> Dict[str, Any]: if not path.exists(): raise ConfigError(f"yaml 配置文件不存在: {path}") try: with path.open("r", encoding="utf-8") as handle: data = yaml.safe_load(handle) or {} if not isinstance(data, dict): raise ConfigError("yaml 顶层结构必须是对象") return data except yaml.YAMLError as exc: raise ConfigError(f"yaml 解析失败: {exc}") from exc def _merge_ldap_config(yaml_data: Dict[str, Any]) -> LdapConfig: ldap_yaml = yaml_data.get("ldap", {}) or {} host = ldap_yaml.get("host") port = int(ldap_yaml.get("port", 636)) use_ssl = _parse_bool(ldap_yaml.get("use_ssl", True), True) bind_dn = ldap_yaml.get("bind_dn") bind_password = ldap_yaml.get("bind_password", "") base_dn = ldap_yaml.get("base_dn") people_base_dn = ldap_yaml.get("people_base_dn") groups_base_dn = ldap_yaml.get("groups_base_dn") missing = [ key for key, value in { "ldap.host": host, "ldap.bind_dn": bind_dn, "ldap.base_dn": base_dn, "ldap.people_base_dn": people_base_dn, "ldap.groups_base_dn": groups_base_dn, }.items() if not value ] if missing: raise ConfigError(f"缺少必要 LDAP 配置: {', '.join(missing)}") return LdapConfig( host=str(host), port=port, use_ssl=use_ssl, bind_dn=str(bind_dn), bind_password=str(bind_password), base_dn=str(base_dn), people_base_dn=str(people_base_dn), groups_base_dn=str(groups_base_dn), upn_suffix=str(ldap_yaml.get("upn_suffix", "")), user_object_classes=list( ldap_yaml.get( "user_object_classes", ["top", "person", "organizationalPerson", "user", "posixAccount"], ) ), user_rdn_attr=str(ldap_yaml.get("user_rdn_attr", "CN")), ) def _merge_defaults_config(yaml_data: Dict[str, Any]) -> DefaultsConfig: defaults_yaml = yaml_data.get("defaults", {}) or {} return DefaultsConfig( base_group=str(defaults_yaml.get("base_group", "staff")), initial_uid_number=int(defaults_yaml.get("initial_uid_number", 2106)), initial_password=str(defaults_yaml.get("initial_password", "1234.com")), ) def _merge_paths_config(yaml_data: Dict[str, Any]) -> PathsConfig: paths_yaml = yaml_data.get("paths", {}) or {} return PathsConfig( uid_state_file=str(paths_yaml.get("uid_state_file", "state/uid_state.json")), group_gid_map_file=str(paths_yaml.get("group_gid_map_file", "state/group_gid_map.yaml")), batch_result_file=str(paths_yaml.get("batch_result_file", "state/last_batch_result.csv")), log_file=str(paths_yaml.get("log_file", "state/run.log")), ) def _merge_behavior_config(yaml_data: Dict[str, Any], cli_dry_run: Optional[bool]) -> BehaviorConfig: behavior_yaml = yaml_data.get("behavior", {}) or {} base_dry_run = _parse_bool(behavior_yaml.get("dry_run"), False) dry_run = cli_dry_run if cli_dry_run is not None else base_dry_run return BehaviorConfig( skip_if_user_exists=_parse_bool(behavior_yaml.get("skip_if_user_exists"), True), skip_missing_optional_groups=_parse_bool(behavior_yaml.get("skip_missing_optional_groups"), True), dry_run=dry_run, require_ldaps_for_password=_parse_bool(behavior_yaml.get("require_ldaps_for_password"), True), ) def _merge_ui_options_config(yaml_data: Dict[str, Any]) -> "UIOptionsConfig": from ad_user_creator.models import UIOptionsConfig ui_yaml = yaml_data.get("ui_options", {}) or {} return UIOptionsConfig( ou_list=list(ui_yaml.get("ou_list", [])), base_group_list=list(ui_yaml.get("base_group_list", [])), project_group_list=list(ui_yaml.get("project_group_list", [])), resource_group_list=list(ui_yaml.get("resource_group_list", [])), ) def _merge_groups_gid_map(yaml_data: Dict[str, Any]) -> Dict[str, int]: raw_map = yaml_data.get("groups_gid_map", {}) or {} if not isinstance(raw_map, dict): raise ConfigError("groups_gid_map 必须是键值字典") merged: Dict[str, int] = {} for group, gid in raw_map.items(): try: merged[str(group)] = int(gid) except (TypeError, ValueError) as exc: raise ConfigError(f"groups_gid_map 非法值: {group}={gid}") from exc return merged def _resolve_paths(config: AppConfig, workspace_root: Path) -> AppConfig: def make_abs(path_text: str) -> str: path = Path(path_text) if path.is_absolute(): return str(path) return str((workspace_root / path).resolve()) config.paths.uid_state_file = make_abs(config.paths.uid_state_file) config.paths.group_gid_map_file = make_abs(config.paths.group_gid_map_file) config.paths.batch_result_file = make_abs(config.paths.batch_result_file) config.paths.log_file = make_abs(config.paths.log_file) return config def load_config( config_path: str = "config/config.yaml", cli_dry_run: Optional[bool] = None, workspace_root: Optional[str] = None, ) -> AppConfig: root = Path(workspace_root or os.getcwd()).resolve() yaml_full_path = Path(config_path) if not yaml_full_path.is_absolute(): yaml_full_path = (root / yaml_full_path).resolve() yaml_data = _read_yaml(yaml_full_path) app_config = AppConfig( ldap=_merge_ldap_config(yaml_data), defaults=_merge_defaults_config(yaml_data), paths=_merge_paths_config(yaml_data), behavior=_merge_behavior_config(yaml_data, cli_dry_run=cli_dry_run), groups_gid_map=_merge_groups_gid_map(yaml_data), ui_options=_merge_ui_options_config(yaml_data), ) app_config = _resolve_paths(app_config, root) if app_config.behavior.require_ldaps_for_password and not app_config.ldap.use_ssl: raise ConfigError("启用密码设置时必须使用 LDAPS,请将 ldap.use_ssl 设置为 true") return app_config def update_group_gid_map(config_path: str, discovered_map: Dict[str, int]) -> None: if not discovered_map: return path = Path(config_path).resolve() lock = FileLock(str(path) + ".lock") with lock: data = _read_yaml(path) current = data.get("groups_gid_map", {}) or {} if not isinstance(current, dict): current = {} merged = {str(k): int(v) for k, v in current.items()} for group, gid in discovered_map.items(): merged[str(group)] = int(gid) data["groups_gid_map"] = dict(sorted(merged.items())) temp_path = path.with_suffix(path.suffix + ".tmp") with temp_path.open("w", encoding="utf-8") as handle: yaml.safe_dump(data, handle, allow_unicode=False, sort_keys=False) temp_path.replace(path) def update_ui_options(config_path: str, ui_options_data: dict) -> None: path = Path(config_path).resolve() lock = FileLock(str(path) + ".lock") with lock: data = _read_yaml(path) data["ui_options"] = ui_options_data temp_path = path.with_suffix(path.suffix + ".tmp") with temp_path.open("w", encoding="utf-8") as handle: yaml.safe_dump(data, handle, allow_unicode=True, sort_keys=False) temp_path.replace(path)