98 lines
2.6 KiB
Python
98 lines
2.6 KiB
Python
from __future__ import annotations
|
||
|
||
from datetime import datetime, timedelta, timezone
|
||
from uuid import uuid4
|
||
|
||
from fastapi import Request, Response
|
||
from sqlalchemy import delete, select
|
||
|
||
from app.core.config import settings
|
||
from app.db.engine import get_session
|
||
from app.db.models import Session, User
|
||
|
||
|
||
SESSION_COOKIE_NAME = "session_id"
|
||
|
||
|
||
def _now_utc() -> datetime:
|
||
return datetime.now(timezone.utc)
|
||
|
||
|
||
def _as_utc(dt: datetime) -> datetime:
|
||
# 兼容历史脏数据:若为 naive,按 UTC 解释;若为 aware,统一转换到 UTC。
|
||
if dt.tzinfo is None:
|
||
return dt.replace(tzinfo=timezone.utc)
|
||
return dt.astimezone(timezone.utc)
|
||
|
||
|
||
def create_session(user_id: int, request: Request) -> str:
|
||
session_id = uuid4().hex
|
||
expires_at = _now_utc() + timedelta(minutes=settings.session_expiry_minutes)
|
||
db = get_session()
|
||
try:
|
||
record = Session(
|
||
id=session_id,
|
||
user_id=user_id,
|
||
expires_at=expires_at,
|
||
ip=request.client.host if request.client else "",
|
||
user_agent=request.headers.get("User-Agent", ""),
|
||
)
|
||
db.add(record)
|
||
db.commit()
|
||
finally:
|
||
db.close()
|
||
return session_id
|
||
|
||
|
||
def delete_session(session_id: str) -> None:
|
||
db = get_session()
|
||
try:
|
||
db.execute(delete(Session).where(Session.id == session_id))
|
||
db.commit()
|
||
finally:
|
||
db.close()
|
||
|
||
|
||
def set_session_cookie(response: Response, session_id: str) -> None:
|
||
response.set_cookie(
|
||
SESSION_COOKIE_NAME,
|
||
session_id,
|
||
httponly=True,
|
||
samesite="lax",
|
||
max_age=settings.session_expiry_minutes * 60,
|
||
)
|
||
|
||
|
||
def clear_session_cookie(response: Response) -> None:
|
||
response.delete_cookie(SESSION_COOKIE_NAME)
|
||
|
||
|
||
def get_current_user(request: Request) -> User | None:
|
||
if hasattr(request.state, "user"):
|
||
return request.state.user
|
||
|
||
session_id = request.cookies.get(SESSION_COOKIE_NAME)
|
||
if not session_id:
|
||
request.state.user = None
|
||
return None
|
||
|
||
db = get_session()
|
||
try:
|
||
record = db.scalar(select(Session).where(Session.id == session_id))
|
||
if not record:
|
||
request.state.user = None
|
||
return None
|
||
if _as_utc(record.expires_at) <= _now_utc():
|
||
db.execute(delete(Session).where(Session.id == session_id))
|
||
db.commit()
|
||
request.state.user = None
|
||
return None
|
||
user = db.get(User, record.user_id)
|
||
if not user or not user.is_active:
|
||
request.state.user = None
|
||
return None
|
||
request.state.user = user
|
||
return user
|
||
finally:
|
||
db.close()
|