import pytest from datetime import datetime, timezone, timedelta from typing import AsyncGenerator from httpx import AsyncClient, ASGITransport from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from app.main import app from app.database import Base, get_db from app.models.user import User from app.models.contest import Contest from app.models.problem import Problem from app.models.test_case import TestCase from app.services.auth import get_password_hash, create_access_token # Use SQLite for testing TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:" engine = create_async_engine( TEST_DATABASE_URL, echo=False, ) async_session_maker = async_sessionmaker( engine, class_=AsyncSession, expire_on_commit=False, ) async def override_get_db() -> AsyncGenerator[AsyncSession, None]: async with async_session_maker() as session: try: yield session finally: await session.close() app.dependency_overrides[get_db] = override_get_db @pytest.fixture(autouse=True) async def setup_database(): """Create tables before each test and drop after.""" async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield async with engine.begin() as conn: await conn.run_sync(Base.metadata.drop_all) @pytest.fixture async def db_session() -> AsyncGenerator[AsyncSession, None]: """Get a database session for direct database operations in tests.""" async with async_session_maker() as session: yield session @pytest.fixture async def client() -> AsyncGenerator[AsyncClient, None]: """Get an async HTTP client for testing API endpoints.""" transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac @pytest.fixture async def test_user(db_session: AsyncSession) -> User: """Create a test user.""" user = User( email="test@example.com", username="testuser", password_hash=get_password_hash("testpassword"), role="participant", ) db_session.add(user) await db_session.commit() await db_session.refresh(user) return user @pytest.fixture async def admin_user(db_session: AsyncSession) -> User: """Create an admin user.""" user = User( email="admin@example.com", username="adminuser", password_hash=get_password_hash("adminpassword"), role="admin", ) db_session.add(user) await db_session.commit() await db_session.refresh(user) return user @pytest.fixture def user_token(test_user: User) -> str: """Get a JWT token for the test user.""" return create_access_token(data={"sub": str(test_user.id)}) @pytest.fixture def admin_token(admin_user: User) -> str: """Get a JWT token for the admin user.""" return create_access_token(data={"sub": str(admin_user.id)}) @pytest.fixture def auth_headers(user_token: str) -> dict: """Get authorization headers for the test user.""" return {"Authorization": f"Bearer {user_token}"} @pytest.fixture def admin_headers(admin_token: str) -> dict: """Get authorization headers for the admin user.""" return {"Authorization": f"Bearer {admin_token}"} @pytest.fixture async def test_contest(db_session: AsyncSession, admin_user: User) -> Contest: """Create a test contest.""" now = datetime.now(timezone.utc) contest = Contest( title="Test Contest", description="A test contest", start_time=now - timedelta(hours=1), end_time=now + timedelta(hours=2), is_active=True, created_by=admin_user.id, ) db_session.add(contest) await db_session.commit() await db_session.refresh(contest) return contest @pytest.fixture async def test_problem(db_session: AsyncSession, test_contest: Contest) -> Problem: """Create a test problem.""" problem = Problem( contest_id=test_contest.id, title="Sum of Two Numbers", description="Given two integers, find their sum.", input_format="Two integers a and b", output_format="Their sum", constraints="1 <= a, b <= 1000", time_limit_ms=1000, memory_limit_kb=262144, total_points=100, order_index=0, ) db_session.add(problem) await db_session.commit() await db_session.refresh(problem) return problem @pytest.fixture async def test_cases(db_session: AsyncSession, test_problem: Problem) -> list[TestCase]: """Create test cases for the test problem.""" cases = [ TestCase( problem_id=test_problem.id, input="1 2", expected_output="3", is_sample=True, points=50, order_index=0, ), TestCase( problem_id=test_problem.id, input="100 200", expected_output="300", is_sample=False, points=50, order_index=1, ), ] for tc in cases: db_session.add(tc) await db_session.commit() for tc in cases: await db_session.refresh(tc) return cases