volsu-contests/backend/app/routers/problems.py

395 lines
12 KiB
Python

from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from sqlalchemy.orm import selectinload
from app.database import get_db
from app.models.user import User
from app.models.problem import Problem
from app.models.test_case import TestCase
from app.schemas.problem import (
ProblemCreate,
ProblemUpdate,
ProblemResponse,
ProblemListResponse,
SampleTestResponse,
TestCaseCreate,
TestCaseResponse,
BulkTestCaseImport,
BulkImportResult,
)
from app.dependencies import get_current_user, get_current_admin
router = APIRouter()
@router.get("/contest/{contest_id}", response_model=list[ProblemListResponse])
async def get_problems_by_contest(
contest_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = await db.execute(
select(Problem)
.where(Problem.contest_id == contest_id)
.order_by(Problem.order_index)
)
problems = result.scalars().all()
return problems
@router.get("/{problem_id}", response_model=ProblemResponse)
async def get_problem(
problem_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_user),
):
result = await db.execute(
select(Problem)
.options(selectinload(Problem.test_cases))
.where(Problem.id == problem_id)
)
problem = result.scalar_one_or_none()
if not problem:
raise HTTPException(status_code=404, detail="Problem not found")
# Get only sample tests for participants
sample_tests = [
SampleTestResponse(input=tc.input, output=tc.expected_output)
for tc in problem.test_cases
if tc.is_sample
]
return ProblemResponse(
id=problem.id,
contest_id=problem.contest_id,
title=problem.title,
description=problem.description,
input_format=problem.input_format,
output_format=problem.output_format,
constraints=problem.constraints,
time_limit_ms=problem.time_limit_ms,
memory_limit_kb=problem.memory_limit_kb,
total_points=problem.total_points,
order_index=problem.order_index,
created_at=problem.created_at,
sample_tests=sample_tests,
)
@router.post("/", response_model=ProblemResponse, status_code=status.HTTP_201_CREATED)
async def create_problem(
problem_data: ProblemCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
problem = Problem(
contest_id=problem_data.contest_id,
title=problem_data.title,
description=problem_data.description,
input_format=problem_data.input_format,
output_format=problem_data.output_format,
constraints=problem_data.constraints,
time_limit_ms=problem_data.time_limit_ms,
memory_limit_kb=problem_data.memory_limit_kb,
total_points=problem_data.total_points,
order_index=problem_data.order_index,
)
db.add(problem)
await db.flush()
# Add test cases
for tc_data in problem_data.test_cases:
test_case = TestCase(
problem_id=problem.id,
input=tc_data.input,
expected_output=tc_data.expected_output,
is_sample=tc_data.is_sample,
points=tc_data.points,
order_index=tc_data.order_index,
)
db.add(test_case)
await db.commit()
await db.refresh(problem)
# Load test cases for response
result = await db.execute(
select(Problem)
.options(selectinload(Problem.test_cases))
.where(Problem.id == problem.id)
)
problem = result.scalar_one()
sample_tests = [
SampleTestResponse(input=tc.input, output=tc.expected_output)
for tc in problem.test_cases
if tc.is_sample
]
return ProblemResponse(
id=problem.id,
contest_id=problem.contest_id,
title=problem.title,
description=problem.description,
input_format=problem.input_format,
output_format=problem.output_format,
constraints=problem.constraints,
time_limit_ms=problem.time_limit_ms,
memory_limit_kb=problem.memory_limit_kb,
total_points=problem.total_points,
order_index=problem.order_index,
created_at=problem.created_at,
sample_tests=sample_tests,
)
@router.put("/{problem_id}", response_model=ProblemResponse)
async def update_problem(
problem_id: int,
problem_data: ProblemUpdate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
result = await db.execute(
select(Problem)
.options(selectinload(Problem.test_cases))
.where(Problem.id == problem_id)
)
problem = result.scalar_one_or_none()
if not problem:
raise HTTPException(status_code=404, detail="Problem not found")
update_data = problem_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(problem, field, value)
await db.commit()
await db.refresh(problem)
sample_tests = [
SampleTestResponse(input=tc.input, output=tc.expected_output)
for tc in problem.test_cases
if tc.is_sample
]
return ProblemResponse(
id=problem.id,
contest_id=problem.contest_id,
title=problem.title,
description=problem.description,
input_format=problem.input_format,
output_format=problem.output_format,
constraints=problem.constraints,
time_limit_ms=problem.time_limit_ms,
memory_limit_kb=problem.memory_limit_kb,
total_points=problem.total_points,
order_index=problem.order_index,
created_at=problem.created_at,
sample_tests=sample_tests,
)
@router.delete("/{problem_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_problem(
problem_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
result = await db.execute(select(Problem).where(Problem.id == problem_id))
problem = result.scalar_one_or_none()
if not problem:
raise HTTPException(status_code=404, detail="Problem not found")
await db.delete(problem)
await db.commit()
# Test cases endpoints (admin only)
@router.get("/{problem_id}/test-cases", response_model=list[TestCaseResponse])
async def get_test_cases(
problem_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
result = await db.execute(
select(TestCase)
.where(TestCase.problem_id == problem_id)
.order_by(TestCase.order_index)
)
return result.scalars().all()
@router.post("/{problem_id}/test-cases", response_model=TestCaseResponse, status_code=status.HTTP_201_CREATED)
async def add_test_case(
problem_id: int,
test_case_data: TestCaseCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
# Check if problem exists
result = await db.execute(select(Problem).where(Problem.id == problem_id))
if not result.scalar_one_or_none():
raise HTTPException(status_code=404, detail="Problem not found")
test_case = TestCase(
problem_id=problem_id,
input=test_case_data.input,
expected_output=test_case_data.expected_output,
is_sample=test_case_data.is_sample,
points=test_case_data.points,
order_index=test_case_data.order_index,
)
db.add(test_case)
await db.commit()
await db.refresh(test_case)
return test_case
@router.put("/test-cases/{test_case_id}", response_model=TestCaseResponse)
async def update_test_case(
test_case_id: int,
test_case_data: TestCaseCreate,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
result = await db.execute(select(TestCase).where(TestCase.id == test_case_id))
test_case = result.scalar_one_or_none()
if not test_case:
raise HTTPException(status_code=404, detail="Test case not found")
update_data = test_case_data.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(test_case, field, value)
await db.commit()
await db.refresh(test_case)
return test_case
@router.delete("/test-cases/{test_case_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_test_case(
test_case_id: int,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
result = await db.execute(select(TestCase).where(TestCase.id == test_case_id))
test_case = result.scalar_one_or_none()
if not test_case:
raise HTTPException(status_code=404, detail="Test case not found")
await db.delete(test_case)
await db.commit()
@router.post("/{problem_id}/test-cases/bulk", response_model=BulkImportResult)
async def bulk_import_test_cases(
problem_id: int,
import_data: BulkTestCaseImport,
db: AsyncSession = Depends(get_db),
current_user: User = Depends(get_current_admin),
):
"""
Bulk import test cases from text format.
Format example:
```
1 2
---
3
===
5 10
---
15
```
Supports both single-line and multi-line input/output.
"""
# Check if problem exists
result = await db.execute(select(Problem).where(Problem.id == problem_id))
problem = result.scalar_one_or_none()
if not problem:
raise HTTPException(status_code=404, detail="Problem not found")
# Get current max order_index
result = await db.execute(
select(TestCase.order_index)
.where(TestCase.problem_id == problem_id)
.order_by(TestCase.order_index.desc())
.limit(1)
)
max_order = result.scalar() or -1
# Parse the text
text = import_data.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Empty text provided")
# Split by test delimiter
test_blocks = text.split(import_data.test_delimiter)
created_tests = []
for i, block in enumerate(test_blocks):
block = block.strip()
if not block:
continue
# Split by IO delimiter
parts = block.split(import_data.io_delimiter)
if len(parts) != 2:
raise HTTPException(
status_code=400,
detail=f"Test #{i + 1}: Expected exactly one '{import_data.io_delimiter}' delimiter separating input and output, found {len(parts) - 1}"
)
input_data = parts[0].strip()
output_data = parts[1].strip()
if not input_data:
raise HTTPException(
status_code=400,
detail=f"Test #{i + 1}: Empty input"
)
if not output_data:
raise HTTPException(
status_code=400,
detail=f"Test #{i + 1}: Empty output"
)
# Determine if this should be a sample
is_sample = import_data.mark_first_as_sample and i < import_data.sample_count
max_order += 1
test_case = TestCase(
problem_id=problem_id,
input=input_data,
expected_output=output_data,
is_sample=is_sample,
points=0,
order_index=max_order,
)
db.add(test_case)
created_tests.append(test_case)
if not created_tests:
raise HTTPException(status_code=400, detail="No valid test cases found in text")
await db.commit()
# Refresh all to get IDs
for test in created_tests:
await db.refresh(test)
return BulkImportResult(
created_count=len(created_tests),
test_cases=[TestCaseResponse.model_validate(t) for t in created_tests]
)