volsu-contests/backend/app/routers/problems.py
2025-11-30 19:55:50 +03:00

289 lines
8.7 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from app.database import get_db
from app.models.user import User
from app.models.problem import Problem
from app.models.test_case import TestCase
from app.schemas.problem import (
ProblemCreate,
ProblemUpdate,
ProblemResponse,
ProblemListResponse,
SampleTestResponse,
TestCaseCreate,
TestCaseResponse,
)
from app.dependencies import get_current_user, get_current_admin
router = APIRouter()
@router.get("/contest/{contest_id}", response_model=list[ProblemListResponse])
async def get_problems_by_contest(
contest_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = await db.execute(
select(Problem)
.where(Problem.contest_id == contest_id)
.order_by(Problem.order_index)
)
problems = result.scalars().all()
return problems
@router.get("/{problem_id}", response_model=ProblemResponse)
async def get_problem(
problem_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = await db.execute(
select(Problem)
.options(selectinload(Problem.test_cases))
.where(Problem.id == problem_id)
)
problem = result.scalar_one_or_none()
if not problem:
raise HTTPException(status_code=404, detail="Problem not found")
# Get only sample tests for participants
sample_tests = [
SampleTestResponse(input=tc.input, output=tc.expected_output)
for tc in problem.test_cases
if tc.is_sample
]
return ProblemResponse(
id=problem.id,
contest_id=problem.contest_id,
title=problem.title,
description=problem.description,
input_format=problem.input_format,
output_format=problem.output_format,
constraints=problem.constraints,
time_limit_ms=problem.time_limit_ms,
memory_limit_kb=problem.memory_limit_kb,
total_points=problem.total_points,
order_index=problem.order_index,
created_at=problem.created_at,
sample_tests=sample_tests,
)
@router.post("/", response_model=ProblemResponse, status_code=status.HTTP_201_CREATED)
async def create_problem(
problem_data: ProblemCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
problem = Problem(
contest_id=problem_data.contest_id,
title=problem_data.title,
description=problem_data.description,
input_format=problem_data.input_format,
output_format=problem_data.output_format,
constraints=problem_data.constraints,
time_limit_ms=problem_data.time_limit_ms,
memory_limit_kb=problem_data.memory_limit_kb,
total_points=problem_data.total_points,
order_index=problem_data.order_index,
)
db.add(problem)
await db.flush()
# Add test cases
for tc_data in problem_data.test_cases:
test_case = TestCase(
problem_id=problem.id,
input=tc_data.input,
expected_output=tc_data.expected_output,
is_sample=tc_data.is_sample,
points=tc_data.points,
order_index=tc_data.order_index,
)
db.add(test_case)
await db.commit()
await db.refresh(problem)
# Load test cases for response
result = await db.execute(
select(Problem)
.options(selectinload(Problem.test_cases))
.where(Problem.id == problem.id)
)
problem = result.scalar_one()
sample_tests = [
SampleTestResponse(input=tc.input, output=tc.expected_output)
for tc in problem.test_cases
if tc.is_sample
]
return ProblemResponse(
id=problem.id,
contest_id=problem.contest_id,
title=problem.title,
description=problem.description,
input_format=problem.input_format,
output_format=problem.output_format,
constraints=problem.constraints,
time_limit_ms=problem.time_limit_ms,
memory_limit_kb=problem.memory_limit_kb,
total_points=problem.total_points,
order_index=problem.order_index,
created_at=problem.created_at,
sample_tests=sample_tests,
)
@router.put("/{problem_id}", response_model=ProblemResponse)
async def update_problem(
problem_id: int,
problem_data: ProblemUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
result = await db.execute(
select(Problem)
.options(selectinload(Problem.test_cases))
.where(Problem.id == problem_id)
)
problem = result.scalar_one_or_none()
if not problem:
raise HTTPException(status_code=404, detail="Problem not found")
update_data = problem_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(problem, field, value)
await db.commit()
await db.refresh(problem)
sample_tests = [
SampleTestResponse(input=tc.input, output=tc.expected_output)
for tc in problem.test_cases
if tc.is_sample
]
return ProblemResponse(
id=problem.id,
contest_id=problem.contest_id,
title=problem.title,
description=problem.description,
input_format=problem.input_format,
output_format=problem.output_format,
constraints=problem.constraints,
time_limit_ms=problem.time_limit_ms,
memory_limit_kb=problem.memory_limit_kb,
total_points=problem.total_points,
order_index=problem.order_index,
created_at=problem.created_at,
sample_tests=sample_tests,
)
@router.delete("/{problem_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_problem(
problem_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
result = await db.execute(select(Problem).where(Problem.id == problem_id))
problem = result.scalar_one_or_none()
if not problem:
raise HTTPException(status_code=404, detail="Problem not found")
await db.delete(problem)
await db.commit()
# Test cases endpoints (admin only)
@router.get("/{problem_id}/test-cases", response_model=list[TestCaseResponse])
async def get_test_cases(
problem_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
result = await db.execute(
select(TestCase)
.where(TestCase.problem_id == problem_id)
.order_by(TestCase.order_index)
)
return result.scalars().all()
@router.post("/{problem_id}/test-cases", response_model=TestCaseResponse, status_code=status.HTTP_201_CREATED)
async def add_test_case(
problem_id: int,
test_case_data: TestCaseCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
# Check if problem exists
result = await db.execute(select(Problem).where(Problem.id == problem_id))
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="Problem not found")
test_case = TestCase(
problem_id=problem_id,
input=test_case_data.input,
expected_output=test_case_data.expected_output,
is_sample=test_case_data.is_sample,
points=test_case_data.points,
order_index=test_case_data.order_index,
)
db.add(test_case)
await db.commit()
await db.refresh(test_case)
return test_case
@router.put("/test-cases/{test_case_id}", response_model=TestCaseResponse)
async def update_test_case(
test_case_id: int,
test_case_data: TestCaseCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
result = await db.execute(select(TestCase).where(TestCase.id == test_case_id))
test_case = result.scalar_one_or_none()
if not test_case:
raise HTTPException(status_code=404, detail="Test case not found")
update_data = test_case_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test_case, field, value)
await db.commit()
await db.refresh(test_case)
return test_case
@router.delete("/test-cases/{test_case_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_test_case(
test_case_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
result = await db.execute(select(TestCase).where(TestCase.id == test_case_id))
test_case = result.scalar_one_or_none()
if not test_case:
raise HTTPException(status_code=404, detail="Test case not found")
await db.delete(test_case)
await db.commit()