Vastai-ConnectHub/app/security/ldap_sync.py

68 lines
2.1 KiB
Python

from __future__ import annotations
from sqlalchemy import select
from app.db.engine import get_session
from app.db.models import LdapGroup, LdapGroupRole, Role, User
from app.security.ldap_client import LdapClient
def _ensure_groups(session, groups: list[dict[str, str]]) -> list[LdapGroup]:
existing = {
g.dn: g for g in session.scalars(select(LdapGroup).where(LdapGroup.dn.in_([g["dn"] for g in groups])))
}
results: list[LdapGroup] = []
for g in groups:
if g["dn"] in existing:
obj = existing[g["dn"]]
if g["name"] and obj.name != g["name"]:
obj.name = g["name"]
session.add(obj)
results.append(obj)
else:
obj = LdapGroup(dn=g["dn"], name=g["name"])
session.add(obj)
session.flush()
results.append(obj)
return results
def sync_user_ldap_roles(*, session, user: User, username: str, user_dn: str) -> None:
client = LdapClient()
groups = client.get_user_groups(user_dn=user_dn, username=username)
group_objs = _ensure_groups(session, groups)
if not group_objs:
user.roles = []
session.add(user)
return
group_ids = [g.id for g in group_objs]
role_ids = list(
session.scalars(select(LdapGroupRole.role_id).where(LdapGroupRole.ldap_group_id.in_(group_ids)))
)
if not role_ids:
user.roles = []
session.add(user)
return
roles = list(session.scalars(select(Role).where(Role.id.in_(role_ids))))
user.roles = roles
session.add(user)
def sync_all_ldap_users() -> int:
db = get_session()
client = LdapClient()
updated = 0
try:
users = list(db.scalars(select(User).where(User.is_ldap.is_(True))))
for user in users:
user_dn = client.find_user_dn(user.username)
if not user_dn:
continue
sync_user_ldap_roles(session=db, user=user, username=user.username, user_dn=user_dn)
updated += 1
db.commit()
return updated
finally:
db.close()