289 lines
8.7 KiB
Python
289 lines
8.7 KiB
Python
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from app.database import get_db
|
|
from app.models.user import User
|
|
from app.models.problem import Problem
|
|
from app.models.test_case import TestCase
|
|
from app.schemas.problem import (
|
|
ProblemCreate,
|
|
ProblemUpdate,
|
|
ProblemResponse,
|
|
ProblemListResponse,
|
|
SampleTestResponse,
|
|
TestCaseCreate,
|
|
TestCaseResponse,
|
|
)
|
|
from app.dependencies import get_current_user, get_current_admin
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/contest/{contest_id}", response_model=list[ProblemListResponse])
|
|
async def get_problems_by_contest(
|
|
contest_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
result = await db.execute(
|
|
select(Problem)
|
|
.where(Problem.contest_id == contest_id)
|
|
.order_by(Problem.order_index)
|
|
)
|
|
problems = result.scalars().all()
|
|
return problems
|
|
|
|
|
|
@router.get("/{problem_id}", response_model=ProblemResponse)
|
|
async def get_problem(
|
|
problem_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_user),
|
|
):
|
|
result = await db.execute(
|
|
select(Problem)
|
|
.options(selectinload(Problem.test_cases))
|
|
.where(Problem.id == problem_id)
|
|
)
|
|
problem = result.scalar_one_or_none()
|
|
|
|
if not problem:
|
|
raise HTTPException(status_code=404, detail="Problem not found")
|
|
|
|
# Get only sample tests for participants
|
|
sample_tests = [
|
|
SampleTestResponse(input=tc.input, output=tc.expected_output)
|
|
for tc in problem.test_cases
|
|
if tc.is_sample
|
|
]
|
|
|
|
return ProblemResponse(
|
|
id=problem.id,
|
|
contest_id=problem.contest_id,
|
|
title=problem.title,
|
|
description=problem.description,
|
|
input_format=problem.input_format,
|
|
output_format=problem.output_format,
|
|
constraints=problem.constraints,
|
|
time_limit_ms=problem.time_limit_ms,
|
|
memory_limit_kb=problem.memory_limit_kb,
|
|
total_points=problem.total_points,
|
|
order_index=problem.order_index,
|
|
created_at=problem.created_at,
|
|
sample_tests=sample_tests,
|
|
)
|
|
|
|
|
|
@router.post("/", response_model=ProblemResponse, status_code=status.HTTP_201_CREATED)
|
|
async def create_problem(
|
|
problem_data: ProblemCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
problem = Problem(
|
|
contest_id=problem_data.contest_id,
|
|
title=problem_data.title,
|
|
description=problem_data.description,
|
|
input_format=problem_data.input_format,
|
|
output_format=problem_data.output_format,
|
|
constraints=problem_data.constraints,
|
|
time_limit_ms=problem_data.time_limit_ms,
|
|
memory_limit_kb=problem_data.memory_limit_kb,
|
|
total_points=problem_data.total_points,
|
|
order_index=problem_data.order_index,
|
|
)
|
|
db.add(problem)
|
|
await db.flush()
|
|
|
|
# Add test cases
|
|
for tc_data in problem_data.test_cases:
|
|
test_case = TestCase(
|
|
problem_id=problem.id,
|
|
input=tc_data.input,
|
|
expected_output=tc_data.expected_output,
|
|
is_sample=tc_data.is_sample,
|
|
points=tc_data.points,
|
|
order_index=tc_data.order_index,
|
|
)
|
|
db.add(test_case)
|
|
|
|
await db.commit()
|
|
await db.refresh(problem)
|
|
|
|
# Load test cases for response
|
|
result = await db.execute(
|
|
select(Problem)
|
|
.options(selectinload(Problem.test_cases))
|
|
.where(Problem.id == problem.id)
|
|
)
|
|
problem = result.scalar_one()
|
|
|
|
sample_tests = [
|
|
SampleTestResponse(input=tc.input, output=tc.expected_output)
|
|
for tc in problem.test_cases
|
|
if tc.is_sample
|
|
]
|
|
|
|
return ProblemResponse(
|
|
id=problem.id,
|
|
contest_id=problem.contest_id,
|
|
title=problem.title,
|
|
description=problem.description,
|
|
input_format=problem.input_format,
|
|
output_format=problem.output_format,
|
|
constraints=problem.constraints,
|
|
time_limit_ms=problem.time_limit_ms,
|
|
memory_limit_kb=problem.memory_limit_kb,
|
|
total_points=problem.total_points,
|
|
order_index=problem.order_index,
|
|
created_at=problem.created_at,
|
|
sample_tests=sample_tests,
|
|
)
|
|
|
|
|
|
@router.put("/{problem_id}", response_model=ProblemResponse)
|
|
async def update_problem(
|
|
problem_id: int,
|
|
problem_data: ProblemUpdate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
result = await db.execute(
|
|
select(Problem)
|
|
.options(selectinload(Problem.test_cases))
|
|
.where(Problem.id == problem_id)
|
|
)
|
|
problem = result.scalar_one_or_none()
|
|
|
|
if not problem:
|
|
raise HTTPException(status_code=404, detail="Problem not found")
|
|
|
|
update_data = problem_data.model_dump(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(problem, field, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(problem)
|
|
|
|
sample_tests = [
|
|
SampleTestResponse(input=tc.input, output=tc.expected_output)
|
|
for tc in problem.test_cases
|
|
if tc.is_sample
|
|
]
|
|
|
|
return ProblemResponse(
|
|
id=problem.id,
|
|
contest_id=problem.contest_id,
|
|
title=problem.title,
|
|
description=problem.description,
|
|
input_format=problem.input_format,
|
|
output_format=problem.output_format,
|
|
constraints=problem.constraints,
|
|
time_limit_ms=problem.time_limit_ms,
|
|
memory_limit_kb=problem.memory_limit_kb,
|
|
total_points=problem.total_points,
|
|
order_index=problem.order_index,
|
|
created_at=problem.created_at,
|
|
sample_tests=sample_tests,
|
|
)
|
|
|
|
|
|
@router.delete("/{problem_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_problem(
|
|
problem_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
result = await db.execute(select(Problem).where(Problem.id == problem_id))
|
|
problem = result.scalar_one_or_none()
|
|
|
|
if not problem:
|
|
raise HTTPException(status_code=404, detail="Problem not found")
|
|
|
|
await db.delete(problem)
|
|
await db.commit()
|
|
|
|
|
|
# Test cases endpoints (admin only)
|
|
@router.get("/{problem_id}/test-cases", response_model=list[TestCaseResponse])
|
|
async def get_test_cases(
|
|
problem_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
result = await db.execute(
|
|
select(TestCase)
|
|
.where(TestCase.problem_id == problem_id)
|
|
.order_by(TestCase.order_index)
|
|
)
|
|
return result.scalars().all()
|
|
|
|
|
|
@router.post("/{problem_id}/test-cases", response_model=TestCaseResponse, status_code=status.HTTP_201_CREATED)
|
|
async def add_test_case(
|
|
problem_id: int,
|
|
test_case_data: TestCaseCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
# Check if problem exists
|
|
result = await db.execute(select(Problem).where(Problem.id == problem_id))
|
|
if not result.scalar_one_or_none():
|
|
raise HTTPException(status_code=404, detail="Problem not found")
|
|
|
|
test_case = TestCase(
|
|
problem_id=problem_id,
|
|
input=test_case_data.input,
|
|
expected_output=test_case_data.expected_output,
|
|
is_sample=test_case_data.is_sample,
|
|
points=test_case_data.points,
|
|
order_index=test_case_data.order_index,
|
|
)
|
|
db.add(test_case)
|
|
await db.commit()
|
|
await db.refresh(test_case)
|
|
|
|
return test_case
|
|
|
|
|
|
@router.put("/test-cases/{test_case_id}", response_model=TestCaseResponse)
|
|
async def update_test_case(
|
|
test_case_id: int,
|
|
test_case_data: TestCaseCreate,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
result = await db.execute(select(TestCase).where(TestCase.id == test_case_id))
|
|
test_case = result.scalar_one_or_none()
|
|
|
|
if not test_case:
|
|
raise HTTPException(status_code=404, detail="Test case not found")
|
|
|
|
update_data = test_case_data.model_dump(exclude_unset=True)
|
|
for field, value in update_data.items():
|
|
setattr(test_case, field, value)
|
|
|
|
await db.commit()
|
|
await db.refresh(test_case)
|
|
|
|
return test_case
|
|
|
|
|
|
@router.delete("/test-cases/{test_case_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
async def delete_test_case(
|
|
test_case_id: int,
|
|
db: AsyncSession = Depends(get_db),
|
|
current_user: User = Depends(get_current_admin),
|
|
):
|
|
result = await db.execute(select(TestCase).where(TestCase.id == test_case_id))
|
|
test_case = result.scalar_one_or_none()
|
|
|
|
if not test_case:
|
|
raise HTTPException(status_code=404, detail="Test case not found")
|
|
|
|
await db.delete(test_case)
|
|
await db.commit()
|