32 lines
1.1 KiB
Python
32 lines
1.1 KiB
Python
from __future__ import annotations
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
|
|
from app.application.users.ports.unit_of_work import UnitOfWork
|
|
from app.infrastructure.db.sqlalchemy.repositories.user_repository import (
|
|
SqlAlchemyUserRepository,
|
|
)
|
|
|
|
|
|
class SqlAlchemyUnitOfWork(UnitOfWork):
|
|
def __init__(self, sessionmaker: async_sessionmaker[AsyncSession]) -> None:
|
|
self._sessionmaker = sessionmaker
|
|
self._session: AsyncSession | None = None
|
|
self.users = None
|
|
|
|
async def __aenter__(self) -> "SqlAlchemyUnitOfWork":
|
|
self._session = self._sessionmaker()
|
|
self.users = SqlAlchemyUserRepository(self._session)
|
|
return self
|
|
|
|
async def __aexit__(self, exc_type, exc, tb) -> None:
|
|
if exc:
|
|
await self.rollback()
|
|
await self._session.close() # type: ignore[union-attr]
|
|
|
|
async def commit(self) -> None:
|
|
await self._session.commit() # type: ignore[union-attr]
|
|
|
|
async def rollback(self) -> None:
|
|
await self._session.rollback() # type: ignore[union-attr]
|