from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select from sqlalchemy.orm import selectinload from app.database import get_db from app.models.user import User from app.models.problem import Problem from app.models.test_case import TestCase from app.schemas.problem import ( ProblemCreate, ProblemUpdate, ProblemResponse, ProblemListResponse, SampleTestResponse, TestCaseCreate, TestCaseResponse, ) from app.dependencies import get_current_user, get_current_admin router = APIRouter() @router.get("/contest/{contest_id}", response_model=list[ProblemListResponse]) async def get_problems_by_contest( contest_id: int, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): result = await db.execute( select(Problem) .where(Problem.contest_id == contest_id) .order_by(Problem.order_index) ) problems = result.scalars().all() return problems @router.get("/{problem_id}", response_model=ProblemResponse) async def get_problem( problem_id: int, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_user), ): result = await db.execute( select(Problem) .options(selectinload(Problem.test_cases)) .where(Problem.id == problem_id) ) problem = result.scalar_one_or_none() if not problem: raise HTTPException(status_code=404, detail="Problem not found") # Get only sample tests for participants sample_tests = [ SampleTestResponse(input=tc.input, output=tc.expected_output) for tc in problem.test_cases if tc.is_sample ] return ProblemResponse( id=problem.id, contest_id=problem.contest_id, title=problem.title, description=problem.description, input_format=problem.input_format, output_format=problem.output_format, constraints=problem.constraints, time_limit_ms=problem.time_limit_ms, memory_limit_kb=problem.memory_limit_kb, total_points=problem.total_points, order_index=problem.order_index, created_at=problem.created_at, sample_tests=sample_tests, ) @router.post("/", response_model=ProblemResponse, status_code=status.HTTP_201_CREATED) async def create_problem( problem_data: ProblemCreate, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_admin), ): problem = Problem( contest_id=problem_data.contest_id, title=problem_data.title, description=problem_data.description, input_format=problem_data.input_format, output_format=problem_data.output_format, constraints=problem_data.constraints, time_limit_ms=problem_data.time_limit_ms, memory_limit_kb=problem_data.memory_limit_kb, total_points=problem_data.total_points, order_index=problem_data.order_index, ) db.add(problem) await db.flush() # Add test cases for tc_data in problem_data.test_cases: test_case = TestCase( problem_id=problem.id, input=tc_data.input, expected_output=tc_data.expected_output, is_sample=tc_data.is_sample, points=tc_data.points, order_index=tc_data.order_index, ) db.add(test_case) await db.commit() await db.refresh(problem) # Load test cases for response result = await db.execute( select(Problem) .options(selectinload(Problem.test_cases)) .where(Problem.id == problem.id) ) problem = result.scalar_one() sample_tests = [ SampleTestResponse(input=tc.input, output=tc.expected_output) for tc in problem.test_cases if tc.is_sample ] return ProblemResponse( id=problem.id, contest_id=problem.contest_id, title=problem.title, description=problem.description, input_format=problem.input_format, output_format=problem.output_format, constraints=problem.constraints, time_limit_ms=problem.time_limit_ms, memory_limit_kb=problem.memory_limit_kb, total_points=problem.total_points, order_index=problem.order_index, created_at=problem.created_at, sample_tests=sample_tests, ) @router.put("/{problem_id}", response_model=ProblemResponse) async def update_problem( problem_id: int, problem_data: ProblemUpdate, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_admin), ): result = await db.execute( select(Problem) .options(selectinload(Problem.test_cases)) .where(Problem.id == problem_id) ) problem = result.scalar_one_or_none() if not problem: raise HTTPException(status_code=404, detail="Problem not found") update_data = problem_data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(problem, field, value) await db.commit() await db.refresh(problem) sample_tests = [ SampleTestResponse(input=tc.input, output=tc.expected_output) for tc in problem.test_cases if tc.is_sample ] return ProblemResponse( id=problem.id, contest_id=problem.contest_id, title=problem.title, description=problem.description, input_format=problem.input_format, output_format=problem.output_format, constraints=problem.constraints, time_limit_ms=problem.time_limit_ms, memory_limit_kb=problem.memory_limit_kb, total_points=problem.total_points, order_index=problem.order_index, created_at=problem.created_at, sample_tests=sample_tests, ) @router.delete("/{problem_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_problem( problem_id: int, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_admin), ): result = await db.execute(select(Problem).where(Problem.id == problem_id)) problem = result.scalar_one_or_none() if not problem: raise HTTPException(status_code=404, detail="Problem not found") await db.delete(problem) await db.commit() # Test cases endpoints (admin only) @router.get("/{problem_id}/test-cases", response_model=list[TestCaseResponse]) async def get_test_cases( problem_id: int, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_admin), ): result = await db.execute( select(TestCase) .where(TestCase.problem_id == problem_id) .order_by(TestCase.order_index) ) return result.scalars().all() @router.post("/{problem_id}/test-cases", response_model=TestCaseResponse, status_code=status.HTTP_201_CREATED) async def add_test_case( problem_id: int, test_case_data: TestCaseCreate, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_admin), ): # Check if problem exists result = await db.execute(select(Problem).where(Problem.id == problem_id)) if not result.scalar_one_or_none(): raise HTTPException(status_code=404, detail="Problem not found") test_case = TestCase( problem_id=problem_id, input=test_case_data.input, expected_output=test_case_data.expected_output, is_sample=test_case_data.is_sample, points=test_case_data.points, order_index=test_case_data.order_index, ) db.add(test_case) await db.commit() await db.refresh(test_case) return test_case @router.put("/test-cases/{test_case_id}", response_model=TestCaseResponse) async def update_test_case( test_case_id: int, test_case_data: TestCaseCreate, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_admin), ): result = await db.execute(select(TestCase).where(TestCase.id == test_case_id)) test_case = result.scalar_one_or_none() if not test_case: raise HTTPException(status_code=404, detail="Test case not found") update_data = test_case_data.model_dump(exclude_unset=True) for field, value in update_data.items(): setattr(test_case, field, value) await db.commit() await db.refresh(test_case) return test_case @router.delete("/test-cases/{test_case_id}", status_code=status.HTTP_204_NO_CONTENT) async def delete_test_case( test_case_id: int, db: AsyncSession = Depends(get_db), current_user: User = Depends(get_current_admin), ): result = await db.execute(select(TestCase).where(TestCase.id == test_case_id)) test_case = result.scalar_one_or_none() if not test_case: raise HTTPException(status_code=404, detail="Test case not found") await db.delete(test_case) await db.commit()