153 lines
6.5 KiB
Python
153 lines
6.5 KiB
Python
from __future__ import annotations
|
||
|
||
from contextlib import asynccontextmanager
|
||
from dataclasses import asdict
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List
|
||
|
||
from fastapi import FastAPI, Request
|
||
from fastapi.responses import HTMLResponse, JSONResponse
|
||
|
||
from ad_user_creator.config import update_ui_options
|
||
from ad_user_creator.exceptions import AppError, InputValidationError
|
||
from ad_user_creator.ldap_client import LdapClient
|
||
from ad_user_creator.models import AppConfig, UserInputRecord, UIOptionsConfig
|
||
from ad_user_creator.persistence import StateStore
|
||
from ad_user_creator.user_service import UserService
|
||
|
||
|
||
def _split_optional_groups(text: str) -> List[str]:
|
||
if not text.strip():
|
||
return []
|
||
items = [item.strip() for item in text.replace(",", ",").split(",")]
|
||
return [item for item in items if item]
|
||
|
||
|
||
def _body_to_record(body: Dict[str, Any], default_base_group: str) -> UserInputRecord:
|
||
display_name = (body.get("display_name") or "").strip()
|
||
sam_account_name = (body.get("sam_account_name") or "").strip()
|
||
email = (body.get("email") or "").strip()
|
||
dept_ou = (body.get("dept_ou") or "").strip()
|
||
base_group_input = (body.get("base_group") or "").strip()
|
||
base_group = base_group_input or default_base_group
|
||
project_groups = _split_optional_groups(body.get("project_groups") or "")
|
||
resource_groups = _split_optional_groups(body.get("resource_groups") or "")
|
||
return UserInputRecord(
|
||
display_name=display_name,
|
||
sam_account_name=sam_account_name,
|
||
email=email,
|
||
dept_ou=dept_ou,
|
||
base_group=base_group,
|
||
project_groups=project_groups,
|
||
resource_groups=resource_groups,
|
||
)
|
||
|
||
|
||
def create_app(config: AppConfig, state_store: StateStore) -> FastAPI:
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI):
|
||
ldap_client = LdapClient(config.ldap)
|
||
if not config.behavior.dry_run:
|
||
ldap_client.connect()
|
||
service = UserService(config=config, state_store=state_store, ldap_client=ldap_client)
|
||
app.state.config = config
|
||
app.state.service = service
|
||
yield
|
||
ldap_client.close()
|
||
|
||
app = FastAPI(lifespan=lifespan)
|
||
|
||
@app.get("/", response_class=HTMLResponse)
|
||
async def index() -> HTMLResponse:
|
||
path = Path(__file__).parent / "templates" / "index.html"
|
||
html = path.read_text(encoding="utf-8")
|
||
return HTMLResponse(html)
|
||
|
||
@app.post("/api/preview")
|
||
async def api_preview(request: Request) -> JSONResponse:
|
||
body = await request.json()
|
||
config: AppConfig = request.app.state.config
|
||
record = _body_to_record(body, config.defaults.base_group)
|
||
if not record.display_name or not record.sam_account_name or not record.email or not record.dept_ou:
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content={"detail": "显示名称、用户名、邮箱、部门 OU 为必填项"},
|
||
)
|
||
service: UserService = request.app.state.service
|
||
plan = service.preview_plan(record)
|
||
|
||
# 检查组是否存在于 LDAP
|
||
groups_check = {}
|
||
if not config.behavior.dry_run:
|
||
groups_check[record.base_group] = service.ldap_client.group_exists(record.base_group)
|
||
for g in record.project_groups + record.resource_groups:
|
||
groups_check[g] = service.ldap_client.group_exists(g)
|
||
|
||
result_dict = asdict(plan)
|
||
if groups_check:
|
||
result_dict["groups_exist_in_ldap"] = groups_check
|
||
|
||
return JSONResponse(content=result_dict)
|
||
|
||
@app.get("/api/config/ui-options")
|
||
async def api_get_ui_options(request: Request) -> JSONResponse:
|
||
config: AppConfig = request.app.state.config
|
||
return JSONResponse(content=asdict(config.ui_options))
|
||
|
||
@app.put("/api/config/ui-options")
|
||
async def api_update_ui_options(request: Request) -> JSONResponse:
|
||
body = await request.json()
|
||
config: AppConfig = request.app.state.config
|
||
|
||
# 验证输入格式
|
||
def _parse_list(val: Any) -> List[str]:
|
||
if not isinstance(val, list):
|
||
return []
|
||
return [str(v).strip() for v in val if str(v).strip()]
|
||
|
||
new_options = {
|
||
"ou_list": _parse_list(body.get("ou_list")),
|
||
"base_group_list": _parse_list(body.get("base_group_list")),
|
||
"project_group_list": _parse_list(body.get("project_group_list")),
|
||
"resource_group_list": _parse_list(body.get("resource_group_list")),
|
||
}
|
||
|
||
# 由于我们不在 web.py 里直接知道 config_path,我们需要从启动的地方或者约定一个环境变量/默认值
|
||
# ad_user_creator/main.py 或者 cli 里设置的配置路径
|
||
config_path_str = getattr(config, "_config_path", "config/config.yaml")
|
||
try:
|
||
update_ui_options(config_path_str, new_options)
|
||
# 热更新当前内存中的配置
|
||
config.ui_options = UIOptionsConfig(**new_options)
|
||
return JSONResponse(content={"detail": "配置已更新"})
|
||
except Exception as exc:
|
||
return JSONResponse(status_code=500, content={"detail": f"保存配置失败: {exc}"})
|
||
|
||
@app.post("/api/create")
|
||
async def api_create(request: Request) -> JSONResponse:
|
||
body = await request.json()
|
||
config: AppConfig = request.app.state.config
|
||
record = _body_to_record(body, config.defaults.base_group)
|
||
if not record.display_name or not record.sam_account_name or not record.email or not record.dept_ou:
|
||
return JSONResponse(
|
||
status_code=400,
|
||
content={"detail": "显示名称、用户名、邮箱、部门 OU 为必填项"},
|
||
)
|
||
service: UserService = request.app.state.service
|
||
result = service.process_user(record, dry_run=config.behavior.dry_run)
|
||
return JSONResponse(content=asdict(result))
|
||
|
||
@app.exception_handler(InputValidationError)
|
||
async def handle_input_validation_error(request: Request, exc: InputValidationError) -> JSONResponse:
|
||
return JSONResponse(status_code=400, content={"detail": str(exc)})
|
||
|
||
@app.exception_handler(AppError)
|
||
async def handle_app_error(request: Request, exc: AppError) -> JSONResponse:
|
||
return JSONResponse(status_code=500, content={"detail": str(exc)})
|
||
|
||
@app.exception_handler(Exception)
|
||
async def handle_exception(request: Request, exc: Exception) -> JSONResponse:
|
||
return JSONResponse(status_code=500, content={"detail": "服务器内部错误"})
|
||
|
||
return app
|