Vastai-ConnectHub/app/security/session.py

98 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from __future__ import annotations
from datetime import datetime, timedelta, timezone
from uuid import uuid4
from fastapi import Request, Response
from sqlalchemy import delete, select
from app.core.config import settings
from app.db.engine import get_session
from app.db.models import Session, User
SESSION_COOKIE_NAME = "session_id"
def _now_utc() -> datetime:
return datetime.now(timezone.utc)
def _as_utc(dt: datetime) -> datetime:
# 兼容历史脏数据:若为 naive按 UTC 解释;若为 aware统一转换到 UTC。
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt.astimezone(timezone.utc)
def create_session(user_id: int, request: Request) -> str:
session_id = uuid4().hex
expires_at = _now_utc() + timedelta(minutes=settings.session_expiry_minutes)
db = get_session()
try:
record = Session(
id=session_id,
user_id=user_id,
expires_at=expires_at,
ip=request.client.host if request.client else "",
user_agent=request.headers.get("User-Agent", ""),
)
db.add(record)
db.commit()
finally:
db.close()
return session_id
def delete_session(session_id: str) -> None:
db = get_session()
try:
db.execute(delete(Session).where(Session.id == session_id))
db.commit()
finally:
db.close()
def set_session_cookie(response: Response, session_id: str) -> None:
response.set_cookie(
SESSION_COOKIE_NAME,
session_id,
httponly=True,
samesite="lax",
max_age=settings.session_expiry_minutes * 60,
)
def clear_session_cookie(response: Response) -> None:
response.delete_cookie(SESSION_COOKIE_NAME)
def get_current_user(request: Request) -> User | None:
if hasattr(request.state, "user"):
return request.state.user
session_id = request.cookies.get(SESSION_COOKIE_NAME)
if not session_id:
request.state.user = None
return None
db = get_session()
try:
record = db.scalar(select(Session).where(Session.id == session_id))
if not record:
request.state.user = None
return None
if _as_utc(record.expires_at) <= _now_utc():
db.execute(delete(Session).where(Session.id == session_id))
db.commit()
request.state.user = None
return None
user = db.get(User, record.user_id)
if not user or not user.is_active:
request.state.user = None
return None
request.state.user = user
return user
finally:
db.close()