ad-user-creator/ad_user_creator/main.py

218 lines
7.4 KiB
Python

from __future__ import annotations
import sys
from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
from ad_user_creator.cli import build_parser
from ad_user_creator.config import load_config, update_group_gid_map
from ad_user_creator.exceptions import AppError, ConfigError, InputValidationError, LdapConnectionError
from ad_user_creator.input_parser import parse_input_file
from ad_user_creator.interactive import run_interactive_create
from ad_user_creator.ldap_client import LdapClient
from ad_user_creator.logging_setup import setup_logging
from ad_user_creator.models import UserProcessResult
from ad_user_creator.persistence import StateStore
from ad_user_creator.user_service import UserService
def _format_gid_map(gid_map: Optional[Dict[str, str]]) -> str:
if not gid_map:
return ""
return ";".join(f"{group}:{gid}" for group, gid in gid_map.items())
def _write_batch_report_xlsx(path: str, rows: List[Dict[str, str]]) -> None:
output = Path(path)
output.parent.mkdir(parents=True, exist_ok=True)
headers = [
"姓名",
"用户名",
"邮箱",
"部门 OU",
"基础组",
"项目组",
"资源组",
"状态",
"原因",
"用户DN",
"uid",
"linuxuidnumber",
"基础组gid",
"项目组gid",
"资源组gid",
]
dataframe = pd.DataFrame(rows)
dataframe = dataframe.reindex(columns=headers)
dataframe.to_excel(output, index=False, engine="openpyxl", sheet_name="report")
def _to_bool_text(value: str) -> bool:
return value.lower() == "true"
def _result_to_row(raw: Dict[str, str], result: UserProcessResult) -> Dict[str, str]:
return {
"姓名": raw.get("姓名", ""),
"用户名": raw.get("用户名", ""),
"邮箱": raw.get("邮箱", ""),
"部门 OU": raw.get("部门 OU", ""),
"基础组": raw.get("基础组", ""),
"项目组": raw.get("项目组", ""),
"资源组": raw.get("资源组", ""),
"状态": result.status,
"原因": result.reason,
"用户DN": result.user_dn,
"uid": "" if result.uid is None else str(result.uid),
"linuxuidnumber": "" if result.linux_uid_number is None else str(result.linux_uid_number),
"基础组gid": "" if result.base_gid is None else str(result.base_gid),
"项目组gid": _format_gid_map(result.project_group_gid_map),
"资源组gid": _format_gid_map(result.resource_group_gid_map),
}
def execute_command(
command: str,
config_path: str = "config/config.yaml",
dry_run: bool = False,
input_path: Optional[str] = None,
continue_on_error: bool = True,
host: Optional[str] = None,
port: Optional[int] = None,
) -> int:
try:
config = load_config(config_path=config_path, cli_dry_run=dry_run)
except ConfigError as exc:
print(f"[FATAL] 配置错误: {exc}")
return 2
logger = setup_logging(config.paths.log_file)
logger.info("配置加载完成")
state = StateStore(
uid_state_file=config.paths.uid_state_file,
group_gid_map_file=config.paths.group_gid_map_file,
initial_uid_number=config.defaults.initial_uid_number,
)
if command == "init-state":
if dry_run:
print(f"[DRY-RUN] 将初始化: {config.paths.uid_state_file}{config.paths.group_gid_map_file}")
return 0
state.ensure_state_files()
print("状态文件初始化完成。")
return 0
if command == "web":
state.ensure_state_files()
from ad_user_creator.web import create_app
import uvicorn
config._config_path = config_path
app = create_app(config, state)
uvicorn.run(app, host=host or "0.0.0.0", port=port or 8000)
return 0
state.ensure_state_files()
ldap_client = LdapClient(config.ldap)
if not config.behavior.dry_run:
try:
ldap_client.connect()
logger.info("LDAP 连接成功")
except LdapConnectionError as exc:
print(f"[FATAL] LDAP 连接失败: {exc}")
return 2
service = UserService(config=config, state_store=state, ldap_client=ldap_client)
try:
if command == "interactive":
try:
run_interactive_create(
user_service=service,
default_base_group=config.defaults.base_group,
dry_run=config.behavior.dry_run,
)
return 0
except (InputValidationError, AppError) as exc:
print(f"[ERROR] 交互式执行失败: {exc}")
return 1
if command == "batch":
if not input_path:
print("[ERROR] batch 模式需要提供输入文件路径")
return 1
try:
records = parse_input_file(input_path)
except InputValidationError as exc:
print(f"[ERROR] 输入文件校验失败: {exc}")
return 1
created = 0
updated = 0
skipped = 0
failed = 0
result_rows: List[Dict[str, str]] = []
for idx, (record, raw) in enumerate(records, start=1):
logger.info("处理第 %s 条: %s", idx, record.sam_account_name)
try:
result = service.process_user(record, dry_run=config.behavior.dry_run)
except (InputValidationError, AppError) as exc:
result = UserProcessResult(status="FAILED", reason=str(exc), raw=raw, uid=record.sam_account_name)
if result.status == "CREATED":
created += 1
elif result.status == "UPDATED":
updated += 1
elif result.status in {"SKIPPED_EXISTS", "SKIPPED_NO_CHANGE"}:
skipped += 1
else:
failed += 1
result_rows.append(_result_to_row(raw, result))
print(f"[{idx}/{len(records)}] {raw.get('用户名', '')} -> {result.status} {result.reason}")
if result.status == "FAILED" and not continue_on_error:
break
report_path = str((Path.cwd() / "report.xlsx").resolve())
_write_batch_report_xlsx(report_path, result_rows)
try:
update_group_gid_map(config_path, service.get_discovered_group_gid_map())
except Exception as exc:
logger.warning("回写 groups_gid_map 到 config.yaml 失败: %s", exc)
total = len(result_rows)
print(
f"完成: total={total}, created={created}, updated={updated}, skipped={skipped}, failed={failed}, "
f"result={report_path}"
)
return 1 if failed > 0 else 0
finally:
ldap_client.close()
return 0
def run() -> int:
parser = build_parser()
args = parser.parse_args()
dry_run = bool(getattr(args, "dry_run", False))
input_path = getattr(args, "input", None)
continue_on_error = _to_bool_text(getattr(args, "continue_on_error", "true"))
host = getattr(args, "host", None)
port = getattr(args, "port", None)
return execute_command(
command=args.command,
config_path=args.config,
dry_run=dry_run,
input_path=input_path,
continue_on_error=continue_on_error,
host=host,
port=port,
)
if __name__ == "__main__":
sys.exit(run())