188 lines
5.1 KiB
Python
188 lines
5.1 KiB
Python
import pytest
|
|
from datetime import datetime, timezone, timedelta
|
|
from typing import AsyncGenerator
|
|
|
|
from httpx import AsyncClient, ASGITransport
|
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
|
|
|
from app.main import app
|
|
from app.database import Base, get_db
|
|
from app.models.user import User
|
|
from app.models.contest import Contest
|
|
from app.models.problem import Problem
|
|
from app.models.test_case import TestCase
|
|
from app.services.auth import get_password_hash, create_access_token
|
|
|
|
|
|
# Use SQLite for testing
|
|
TEST_DATABASE_URL = "sqlite+aiosqlite:///:memory:"
|
|
|
|
engine = create_async_engine(
|
|
TEST_DATABASE_URL,
|
|
echo=False,
|
|
)
|
|
|
|
async_session_maker = async_sessionmaker(
|
|
engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
|
|
async def override_get_db() -> AsyncGenerator[AsyncSession, None]:
|
|
async with async_session_maker() as session:
|
|
try:
|
|
yield session
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
app.dependency_overrides[get_db] = override_get_db
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
async def setup_database():
|
|
"""Create tables before each test and drop after."""
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
yield
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.drop_all)
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get a database session for direct database operations in tests."""
|
|
async with async_session_maker() as session:
|
|
yield session
|
|
|
|
|
|
@pytest.fixture
|
|
async def client() -> AsyncGenerator[AsyncClient, None]:
|
|
"""Get an async HTTP client for testing API endpoints."""
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
yield ac
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_user(db_session: AsyncSession) -> User:
|
|
"""Create a test user."""
|
|
user = User(
|
|
email="test@example.com",
|
|
username="testuser",
|
|
password_hash=get_password_hash("testpassword"),
|
|
role="participant",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
await db_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest.fixture
|
|
async def admin_user(db_session: AsyncSession) -> User:
|
|
"""Create an admin user."""
|
|
user = User(
|
|
email="admin@example.com",
|
|
username="adminuser",
|
|
password_hash=get_password_hash("adminpassword"),
|
|
role="admin",
|
|
)
|
|
db_session.add(user)
|
|
await db_session.commit()
|
|
await db_session.refresh(user)
|
|
return user
|
|
|
|
|
|
@pytest.fixture
|
|
def user_token(test_user: User) -> str:
|
|
"""Get a JWT token for the test user."""
|
|
return create_access_token(data={"sub": str(test_user.id)})
|
|
|
|
|
|
@pytest.fixture
|
|
def admin_token(admin_user: User) -> str:
|
|
"""Get a JWT token for the admin user."""
|
|
return create_access_token(data={"sub": str(admin_user.id)})
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_headers(user_token: str) -> dict:
|
|
"""Get authorization headers for the test user."""
|
|
return {"Authorization": f"Bearer {user_token}"}
|
|
|
|
|
|
@pytest.fixture
|
|
def admin_headers(admin_token: str) -> dict:
|
|
"""Get authorization headers for the admin user."""
|
|
return {"Authorization": f"Bearer {admin_token}"}
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_contest(db_session: AsyncSession, admin_user: User) -> Contest:
|
|
"""Create a test contest."""
|
|
now = datetime.now(timezone.utc)
|
|
contest = Contest(
|
|
title="Test Contest",
|
|
description="A test contest",
|
|
start_time=now - timedelta(hours=1),
|
|
end_time=now + timedelta(hours=2),
|
|
is_active=True,
|
|
created_by=admin_user.id,
|
|
)
|
|
db_session.add(contest)
|
|
await db_session.commit()
|
|
await db_session.refresh(contest)
|
|
return contest
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_problem(db_session: AsyncSession, test_contest: Contest) -> Problem:
|
|
"""Create a test problem."""
|
|
problem = Problem(
|
|
contest_id=test_contest.id,
|
|
title="Sum of Two Numbers",
|
|
description="Given two integers, find their sum.",
|
|
input_format="Two integers a and b",
|
|
output_format="Their sum",
|
|
constraints="1 <= a, b <= 1000",
|
|
time_limit_ms=1000,
|
|
memory_limit_kb=262144,
|
|
total_points=100,
|
|
order_index=0,
|
|
)
|
|
db_session.add(problem)
|
|
await db_session.commit()
|
|
await db_session.refresh(problem)
|
|
return problem
|
|
|
|
|
|
@pytest.fixture
|
|
async def test_cases(db_session: AsyncSession, test_problem: Problem) -> list[TestCase]:
|
|
"""Create test cases for the test problem."""
|
|
cases = [
|
|
TestCase(
|
|
problem_id=test_problem.id,
|
|
input="1 2",
|
|
expected_output="3",
|
|
is_sample=True,
|
|
points=50,
|
|
order_index=0,
|
|
),
|
|
TestCase(
|
|
problem_id=test_problem.id,
|
|
input="100 200",
|
|
expected_output="300",
|
|
is_sample=False,
|
|
points=50,
|
|
order_index=1,
|
|
),
|
|
]
|
|
for tc in cases:
|
|
db_session.add(tc)
|
|
await db_session.commit()
|
|
for tc in cases:
|
|
await db_session.refresh(tc)
|
|
return cases
|