from __future__ import annotations from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from app.application.users.ports.unit_of_work import UnitOfWork from app.infrastructure.db.sqlalchemy.repositories.user_repository import ( SqlAlchemyUserRepository, ) class SqlAlchemyUnitOfWork(UnitOfWork): def __init__(self, sessionmaker: async_sessionmaker[AsyncSession]) -> None: self._sessionmaker = sessionmaker self._session: AsyncSession | None = None self.users = None async def __aenter__(self) -> "SqlAlchemyUnitOfWork": self._session = self._sessionmaker() self.users = SqlAlchemyUserRepository(self._session) return self async def __aexit__(self, exc_type, exc, tb) -> None: if exc: await self.rollback() await self._session.close() # type: ignore[union-attr] async def commit(self) -> None: await self._session.commit() # type: ignore[union-attr] async def rollback(self) -> None: await self._session.rollback() # type: ignore[union-attr]