395 lines
12 KiB
Python
395 lines
12 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.database import get_db
|
|
from app.models.user import User
|
|
from app.models.problem import Problem
|
|
from app.models.test_case import TestCase
|
|
from app.schemas.problem import (
|
|
ProblemCreate,
|
|
ProblemUpdate,
|
|
ProblemResponse,
|
|
ProblemListResponse,
|
|
SampleTestResponse,
|
|
TestCaseCreate,
|
|
TestCaseResponse,
|
|
BulkTestCaseImport,
|
|
BulkImportResult,
|
|
)
|
|
from app.dependencies import get_current_user, get_current_admin
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/contest/{contest_id}", response_model=list[ProblemListResponse])
|
|
async def get_problems_by_contest(
|
|
contest_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
result = await db.execute(
|
|
select(Problem)
|
|
.where(Problem.contest_id == contest_id)
|
|
.order_by(Problem.order_index)
|
|
)
|
|
problems = result.scalars().all()
|
|
return problems
|
|
|
|
|
|
@router.get("/{problem_id}", response_model=ProblemResponse)
|
|
async def get_problem(
|
|
problem_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
result = await db.execute(
|
|
select(Problem)
|
|
.options(selectinload(Problem.test_cases))
|
|
.where(Problem.id == problem_id)
|
|
)
|
|
problem = result.scalar_one_or_none()
|
|
|
|
if not problem:
|
|
raise HTTPException(status_code=404, detail="Problem not found")
|
|
|
|
# Get only sample tests for participants
|
|
sample_tests = [
|
|
SampleTestResponse(input=tc.input, output=tc.expected_output)
|
|
for tc in problem.test_cases
|
|
if tc.is_sample
|
|
]
|
|
|
|
return ProblemResponse(
|
|
id=problem.id,
|
|
contest_id=problem.contest_id,
|
|
title=problem.title,
|
|
description=problem.description,
|
|
input_format=problem.input_format,
|
|
output_format=problem.output_format,
|
|
constraints=problem.constraints,
|
|
time_limit_ms=problem.time_limit_ms,
|
|
memory_limit_kb=problem.memory_limit_kb,
|
|
total_points=problem.total_points,
|
|
order_index=problem.order_index,
|
|
created_at=problem.created_at,
|
|
sample_tests=sample_tests,
|
|
)
|
|
|
|
|
|
@router.post("/", response_model=ProblemResponse, status_code=status.HTTP_201_CREATED)
|
|
async def create_problem(
|
|
problem_data: ProblemCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
problem = Problem(
|
|
contest_id=problem_data.contest_id,
|
|
title=problem_data.title,
|
|
description=problem_data.description,
|
|
input_format=problem_data.input_format,
|
|
output_format=problem_data.output_format,
|
|
constraints=problem_data.constraints,
|
|
time_limit_ms=problem_data.time_limit_ms,
|
|
memory_limit_kb=problem_data.memory_limit_kb,
|
|
total_points=problem_data.total_points,
|
|
order_index=problem_data.order_index,
|
|
)
|
|
db.add(problem)
|
|
await db.flush()
|
|
|
|
# Add test cases
|
|
for tc_data in problem_data.test_cases:
|
|
test_case = TestCase(
|
|
problem_id=problem.id,
|
|
input=tc_data.input,
|
|
expected_output=tc_data.expected_output,
|
|
is_sample=tc_data.is_sample,
|
|
points=tc_data.points,
|
|
order_index=tc_data.order_index,
|
|
)
|
|
db.add(test_case)
|
|
|
|
await db.commit()
|
|
await db.refresh(problem)
|
|
|
|
# Load test cases for response
|
|
result = await db.execute(
|
|
select(Problem)
|
|
.options(selectinload(Problem.test_cases))
|
|
.where(Problem.id == problem.id)
|
|
)
|
|
problem = result.scalar_one()
|
|
|
|
sample_tests = [
|
|
SampleTestResponse(input=tc.input, output=tc.expected_output)
|
|
for tc in problem.test_cases
|
|
if tc.is_sample
|
|
]
|
|
|
|
return ProblemResponse(
|
|
id=problem.id,
|
|
contest_id=problem.contest_id,
|
|
title=problem.title,
|
|
description=problem.description,
|
|
input_format=problem.input_format,
|
|
output_format=problem.output_format,
|
|
constraints=problem.constraints,
|
|
time_limit_ms=problem.time_limit_ms,
|
|
memory_limit_kb=problem.memory_limit_kb,
|
|
total_points=problem.total_points,
|
|
order_index=problem.order_index,
|
|
created_at=problem.created_at,
|
|
sample_tests=sample_tests,
|
|
)
|
|
|
|
|
|
@router.put("/{problem_id}", response_model=ProblemResponse)
|
|
async def update_problem(
|
|
problem_id: int,
|
|
problem_data: ProblemUpdate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
result = await db.execute(
|
|
select(Problem)
|
|
.options(selectinload(Problem.test_cases))
|
|
.where(Problem.id == problem_id)
|
|
)
|
|
problem = result.scalar_one_or_none()
|
|
|
|
if not problem:
|
|
raise HTTPException(status_code=404, detail="Problem not found")
|
|
|
|
update_data = problem_data.model_dump(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(problem, field, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(problem)
|
|
|
|
sample_tests = [
|
|
SampleTestResponse(input=tc.input, output=tc.expected_output)
|
|
for tc in problem.test_cases
|
|
if tc.is_sample
|
|
]
|
|
|
|
return ProblemResponse(
|
|
id=problem.id,
|
|
contest_id=problem.contest_id,
|
|
title=problem.title,
|
|
description=problem.description,
|
|
input_format=problem.input_format,
|
|
output_format=problem.output_format,
|
|
constraints=problem.constraints,
|
|
time_limit_ms=problem.time_limit_ms,
|
|
memory_limit_kb=problem.memory_limit_kb,
|
|
total_points=problem.total_points,
|
|
order_index=problem.order_index,
|
|
created_at=problem.created_at,
|
|
sample_tests=sample_tests,
|
|
)
|
|
|
|
|
|
@router.delete("/{problem_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_problem(
|
|
problem_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
result = await db.execute(select(Problem).where(Problem.id == problem_id))
|
|
problem = result.scalar_one_or_none()
|
|
|
|
if not problem:
|
|
raise HTTPException(status_code=404, detail="Problem not found")
|
|
|
|
await db.delete(problem)
|
|
await db.commit()
|
|
|
|
|
|
# Test cases endpoints (admin only)
|
|
@router.get("/{problem_id}/test-cases", response_model=list[TestCaseResponse])
|
|
async def get_test_cases(
|
|
problem_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
result = await db.execute(
|
|
select(TestCase)
|
|
.where(TestCase.problem_id == problem_id)
|
|
.order_by(TestCase.order_index)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.post("/{problem_id}/test-cases", response_model=TestCaseResponse, status_code=status.HTTP_201_CREATED)
|
|
async def add_test_case(
|
|
problem_id: int,
|
|
test_case_data: TestCaseCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
# Check if problem exists
|
|
result = await db.execute(select(Problem).where(Problem.id == problem_id))
|
|
if not result.scalar_one_or_none():
|
|
raise HTTPException(status_code=404, detail="Problem not found")
|
|
|
|
test_case = TestCase(
|
|
problem_id=problem_id,
|
|
input=test_case_data.input,
|
|
expected_output=test_case_data.expected_output,
|
|
is_sample=test_case_data.is_sample,
|
|
points=test_case_data.points,
|
|
order_index=test_case_data.order_index,
|
|
)
|
|
db.add(test_case)
|
|
await db.commit()
|
|
await db.refresh(test_case)
|
|
|
|
return test_case
|
|
|
|
|
|
@router.put("/test-cases/{test_case_id}", response_model=TestCaseResponse)
|
|
async def update_test_case(
|
|
test_case_id: int,
|
|
test_case_data: TestCaseCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
result = await db.execute(select(TestCase).where(TestCase.id == test_case_id))
|
|
test_case = result.scalar_one_or_none()
|
|
|
|
if not test_case:
|
|
raise HTTPException(status_code=404, detail="Test case not found")
|
|
|
|
update_data = test_case_data.model_dump(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(test_case, field, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(test_case)
|
|
|
|
return test_case
|
|
|
|
|
|
@router.delete("/test-cases/{test_case_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_test_case(
|
|
test_case_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
result = await db.execute(select(TestCase).where(TestCase.id == test_case_id))
|
|
test_case = result.scalar_one_or_none()
|
|
|
|
if not test_case:
|
|
raise HTTPException(status_code=404, detail="Test case not found")
|
|
|
|
await db.delete(test_case)
|
|
await db.commit()
|
|
|
|
|
|
@router.post("/{problem_id}/test-cases/bulk", response_model=BulkImportResult)
|
|
async def bulk_import_test_cases(
|
|
problem_id: int,
|
|
import_data: BulkTestCaseImport,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
"""
|
|
Bulk import test cases from text format.
|
|
|
|
Format example:
|
|
```
|
|
1 2
|
|
---
|
|
3
|
|
===
|
|
5 10
|
|
---
|
|
15
|
|
```
|
|
|
|
Supports both single-line and multi-line input/output.
|
|
"""
|
|
# Check if problem exists
|
|
result = await db.execute(select(Problem).where(Problem.id == problem_id))
|
|
problem = result.scalar_one_or_none()
|
|
if not problem:
|
|
raise HTTPException(status_code=404, detail="Problem not found")
|
|
|
|
# Get current max order_index
|
|
result = await db.execute(
|
|
select(TestCase.order_index)
|
|
.where(TestCase.problem_id == problem_id)
|
|
.order_by(TestCase.order_index.desc())
|
|
.limit(1)
|
|
)
|
|
max_order = result.scalar() or -1
|
|
|
|
# Parse the text
|
|
text = import_data.text.strip()
|
|
if not text:
|
|
raise HTTPException(status_code=400, detail="Empty text provided")
|
|
|
|
# Split by test delimiter
|
|
test_blocks = text.split(import_data.test_delimiter)
|
|
created_tests = []
|
|
|
|
for i, block in enumerate(test_blocks):
|
|
block = block.strip()
|
|
if not block:
|
|
continue
|
|
|
|
# Split by IO delimiter
|
|
parts = block.split(import_data.io_delimiter)
|
|
if len(parts) != 2:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Test #{i + 1}: Expected exactly one '{import_data.io_delimiter}' delimiter separating input and output, found {len(parts) - 1}"
|
|
)
|
|
|
|
input_data = parts[0].strip()
|
|
output_data = parts[1].strip()
|
|
|
|
if not input_data:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Test #{i + 1}: Empty input"
|
|
)
|
|
if not output_data:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Test #{i + 1}: Empty output"
|
|
)
|
|
|
|
# Determine if this should be a sample
|
|
is_sample = import_data.mark_first_as_sample and i < import_data.sample_count
|
|
|
|
max_order += 1
|
|
test_case = TestCase(
|
|
problem_id=problem_id,
|
|
input=input_data,
|
|
expected_output=output_data,
|
|
is_sample=is_sample,
|
|
points=0,
|
|
order_index=max_order,
|
|
)
|
|
db.add(test_case)
|
|
created_tests.append(test_case)
|
|
|
|
if not created_tests:
|
|
raise HTTPException(status_code=400, detail="No valid test cases found in text")
|
|
|
|
await db.commit()
|
|
|
|
# Refresh all to get IDs
|
|
for test in created_tests:
|
|
await db.refresh(test)
|
|
|
|
return BulkImportResult(
|
|
created_count=len(created_tests),
|
|
test_cases=[TestCaseResponse.model_validate(t) for t in created_tests]
|
|
)
|