68 lines
2.1 KiB
Python
68 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
from sqlalchemy import select
|
|
|
|
from app.db.engine import get_session
|
|
from app.db.models import LdapGroup, LdapGroupRole, Role, User
|
|
from app.security.ldap_client import LdapClient
|
|
|
|
|
|
def _ensure_groups(session, groups: list[dict[str, str]]) -> list[LdapGroup]:
|
|
existing = {
|
|
g.dn: g for g in session.scalars(select(LdapGroup).where(LdapGroup.dn.in_([g["dn"] for g in groups])))
|
|
}
|
|
results: list[LdapGroup] = []
|
|
for g in groups:
|
|
if g["dn"] in existing:
|
|
obj = existing[g["dn"]]
|
|
if g["name"] and obj.name != g["name"]:
|
|
obj.name = g["name"]
|
|
session.add(obj)
|
|
results.append(obj)
|
|
else:
|
|
obj = LdapGroup(dn=g["dn"], name=g["name"])
|
|
session.add(obj)
|
|
session.flush()
|
|
results.append(obj)
|
|
return results
|
|
|
|
|
|
def sync_user_ldap_roles(*, session, user: User, username: str, user_dn: str) -> None:
|
|
client = LdapClient()
|
|
groups = client.get_user_groups(user_dn=user_dn, username=username)
|
|
group_objs = _ensure_groups(session, groups)
|
|
if not group_objs:
|
|
user.roles = []
|
|
session.add(user)
|
|
return
|
|
|
|
group_ids = [g.id for g in group_objs]
|
|
role_ids = list(
|
|
session.scalars(select(LdapGroupRole.role_id).where(LdapGroupRole.ldap_group_id.in_(group_ids)))
|
|
)
|
|
if not role_ids:
|
|
user.roles = []
|
|
session.add(user)
|
|
return
|
|
roles = list(session.scalars(select(Role).where(Role.id.in_(role_ids))))
|
|
user.roles = roles
|
|
session.add(user)
|
|
|
|
|
|
def sync_all_ldap_users() -> int:
|
|
db = get_session()
|
|
client = LdapClient()
|
|
updated = 0
|
|
try:
|
|
users = list(db.scalars(select(User).where(User.is_ldap.is_(True))))
|
|
for user in users:
|
|
user_dn = client.find_user_dn(user.username)
|
|
if not user_dn:
|
|
continue
|
|
sync_user_ldap_roles(session=db, user=user, username=user.username, user_dn=user_dn)
|
|
updated += 1
|
|
db.commit()
|
|
return updated
|
|
finally:
|
|
db.close()
|