sessionmiddleware.py
from uuid import uuid4
from starlette.types import ASGIApp, Receive, Scope, Send
from app.core.database.session import (
reset_session_context,
set_session_context,
)
class SQLSessionMiddleware:
def __init__(self, app: ASGIApp) -> None:
self.app = app
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
session_id = str(uuid4())
db_session_token = set_session_context(session_id=session_id)
try:
await self.app(scope, receive, send)
except Exception as exception:
raise exception
finally:
reset_session_context(db_session_token)맨 처음 클라이언트가 들어올때, 클라이언트마다 session_context를 할당. async_scoped_session은 해당 session_context를 기반으로 이벤트루프마다 생성됨
session.py
import functools
from contextvars import ContextVar, Token
import traceback
from typing import Awaitable, Callable, ParamSpec, TypeVar
import logfire
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_scoped_session,
async_sessionmaker,
create_async_engine,
)
from app.core.config import get_settings
from app.core.exceptions import CustomBaseCommitException, IntegrityErrorException
session_context: ContextVar[str] = ContextVar("session_context")
def get_session_context() -> str:
return session_context.get()
def set_session_context(session_id: str) -> Token:
return session_context.set(session_id)
def reset_session_context(context: Token) -> None:
session_context.reset(context)
config = get_settings()
engine = create_async_engine(
"{DB_URL}://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}".format(
DB_URL="postgresql+asyncpg",
DB_USER=config.POSTGRES_USER,
DB_PASSWORD=config.POSTGRES_PASSWORD,
DB_HOST=config.DB_HOST,
DB_PORT=config.DB_PORT,
DB_NAME=config.POSTGRES_DB,
),
echo=True if config.IS_TEST else False,
pool_recycle=3600,
)
logfire.instrument_sqlalchemy(engine=engine.sync_engine)
session_factory = async_sessionmaker(
bind=engine, expire_on_commit=False, class_=AsyncSession
)
session = async_scoped_session(
session_factory=session_factory, scopefunc=get_session_context
)
async def dispose_engine():
await engine.dispose()
async def get_session():
try:
yield session
finally:
await session.close()
await session.remove()
T = TypeVar("T")
P = ParamSpec("P")
def transactional(func: Callable[P, Awaitable[T]]) -> Callable[P, Awaitable[T]]:
@functools.wraps(func)
async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
try:
result = await func(*args, **kwargs)
await session.commit()
except IntegrityError as e:
await session.rollback()
print(traceback.format_exc())
raise IntegrityErrorException() from e
except CustomBaseCommitException as e:
await session.commit()
print(traceback.format_exc())
raise e
except Exception as e:
await session.rollback()
print(traceback.format_exc())
raise e
return result
return _wrapper
솔직히 좋은 구현이라고 볼 수는 없을것같다. 다만 session이 async_scoped_session이라 짜피 단일스레드의 여러개의 이벤트루프에서도 구분된 session을 가질 수 있으므로 요래 구현해놨음.
