volsu-contests/backend/tests/conftest.py
2025-11-30 19:55:50 +03:00

188 lines
5.1 KiB
Python

import pytest
from datetime import datetime, timezone, timedelta
from typing import AsyncGenerator
from httpx import AsyncClient, ASGITransport
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from app.main import app
from app.database import Base, get_db
from app.models.user import User
from app.models.contest import Contest
from app.models.problem import Problem
from app.models.test_case import TestCase
from app.services.auth import get_password_hash, create_access_token
# Use SQLite for testing
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
engine = create_async_engine(
TEST_DATABASE_URL,
echo=False,
)
async_session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
async with async_session_maker() as session:
try:
yield session
finally:
await session.close()
app.dependency_overrides[get_db] = override_get_db
@pytest.fixture(autouse=True)
async def setup_database():
"""Create tables before each test and drop after."""
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest.fixture
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""Get a database session for direct database operations in tests."""
async with async_session_maker() as session:
yield session
@pytest.fixture
async def client() -> AsyncGenerator[AsyncClient, None]:
"""Get an async HTTP client for testing API endpoints."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
@pytest.fixture
async def test_user(db_session: AsyncSession) -> User:
"""Create a test user."""
user = User(
email="test@example.com",
username="testuser",
password_hash=get_password_hash("testpassword"),
role="participant",
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
return user
@pytest.fixture
async def admin_user(db_session: AsyncSession) -> User:
"""Create an admin user."""
user = User(
email="admin@example.com",
username="adminuser",
password_hash=get_password_hash("adminpassword"),
role="admin",
)
db_session.add(user)
await db_session.commit()
await db_session.refresh(user)
return user
@pytest.fixture
def user_token(test_user: User) -> str:
"""Get a JWT token for the test user."""
return create_access_token(data={"sub": str(test_user.id)})
@pytest.fixture
def admin_token(admin_user: User) -> str:
"""Get a JWT token for the admin user."""
return create_access_token(data={"sub": str(admin_user.id)})
@pytest.fixture
def auth_headers(user_token: str) -> dict:
"""Get authorization headers for the test user."""
return {"Authorization": f"Bearer {user_token}"}
@pytest.fixture
def admin_headers(admin_token: str) -> dict:
"""Get authorization headers for the admin user."""
return {"Authorization": f"Bearer {admin_token}"}
@pytest.fixture
async def test_contest(db_session: AsyncSession, admin_user: User) -> Contest:
"""Create a test contest."""
now = datetime.now(timezone.utc)
contest = Contest(
title="Test Contest",
description="A test contest",
start_time=now - timedelta(hours=1),
end_time=now + timedelta(hours=2),
is_active=True,
created_by=admin_user.id,
)
db_session.add(contest)
await db_session.commit()
await db_session.refresh(contest)
return contest
@pytest.fixture
async def test_problem(db_session: AsyncSession, test_contest: Contest) -> Problem:
"""Create a test problem."""
problem = Problem(
contest_id=test_contest.id,
title="Sum of Two Numbers",
description="Given two integers, find their sum.",
input_format="Two integers a and b",
output_format="Their sum",
constraints="1 <= a, b <= 1000",
time_limit_ms=1000,
memory_limit_kb=262144,
total_points=100,
order_index=0,
)
db_session.add(problem)
await db_session.commit()
await db_session.refresh(problem)
return problem
@pytest.fixture
async def test_cases(db_session: AsyncSession, test_problem: Problem) -> list[TestCase]:
"""Create test cases for the test problem."""
cases = [
TestCase(
problem_id=test_problem.id,
input="1 2",
expected_output="3",
is_sample=True,
points=50,
order_index=0,
),
TestCase(
problem_id=test_problem.id,
input="100 200",
expected_output="300",
is_sample=False,
points=50,
order_index=1,
),
]
for tc in cases:
db_session.add(tc)
await db_session.commit()
for tc in cases:
await db_session.refresh(tc)
return cases