diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 902de50..d7e94cc 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -21,4 +21,4 @@ jobs: pip install -r requirements.txt - name: Analysing the code with pylint run: | - pylint -d C0415,C0200,C0301,C0114,R0903,C0115,W0246,R0914,C0209,E1121,C0103,C2801,R0801,E1101,E0401,E0611,R0911,C0116,W0212,W0719,W0601,W1203,W0123,W0511,W0621,R0913,R0917 $(git ls-files '*.py') + pylint -d R0912,C0415,C0200,C0301,C0114,R0903,C0115,W0246,R0914,C0209,E1121,C0103,C2801,R0801,E1101,E0401,E0611,R0911,C0116,W0212,W0719,W0601,W1203,W0123,W0511,W0621,R0913,R0917 $(git ls-files '*.py') diff --git a/src/controllers/flight.py b/src/controllers/flight.py index 9b30398..8c45258 100644 --- a/src/controllers/flight.py +++ b/src/controllers/flight.py @@ -1,11 +1,17 @@ +from fastapi import HTTPException, status + from src.controllers.interface import ( ControllerBase, controller_exception_handler, ) -from src.views.flight import FlightSimulation -from src.models.flight import FlightModel +from src.views.flight import FlightSimulation, FlightCreated +from src.models.flight import ( + FlightModel, + FlightWithReferencesRequest, +) from src.models.environment import EnvironmentModel from src.models.rocket import RocketModel +from src.repositories.interface import RepositoryInterface from src.services.flight import FlightService @@ -21,6 +27,56 @@ class FlightController(ControllerBase): def __init__(self): super().__init__(models=[FlightModel]) + async def _load_environment(self, environment_id: str) -> EnvironmentModel: + repo_cls = RepositoryInterface.get_model_repo(EnvironmentModel) + async with repo_cls() as repo: + environment = await repo.read_environment_by_id(environment_id) + if environment is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Environment not found", + ) + return environment + + async def _load_rocket(self, rocket_id: str) -> RocketModel: + repo_cls = RepositoryInterface.get_model_repo(RocketModel) + async with repo_cls() as repo: + rocket = await repo.read_rocket_by_id(rocket_id) + if rocket is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Rocket not found", + ) + return rocket + + @controller_exception_handler + async def create_flight_from_references( + self, payload: FlightWithReferencesRequest + ) -> FlightCreated: + environment = await self._load_environment(payload.environment_id) + rocket = await self._load_rocket(payload.rocket_id) + flight_model = payload.flight.assemble( + environment=environment, + rocket=rocket, + ) + return await self.post_flight(flight_model) + + @controller_exception_handler + async def update_flight_from_references( + self, + flight_id: str, + payload: FlightWithReferencesRequest, + ) -> None: + environment = await self._load_environment(payload.environment_id) + rocket = await self._load_rocket(payload.rocket_id) + flight_model = payload.flight.assemble( + environment=environment, + rocket=rocket, + ) + flight_model.set_id(flight_id) + await self.put_flight_by_id(flight_id, flight_model) + return + @controller_exception_handler async def update_environment_by_flight_id( self, flight_id: str, *, environment: EnvironmentModel diff --git a/src/controllers/rocket.py b/src/controllers/rocket.py index f586ccc..ce1c36c 100644 --- a/src/controllers/rocket.py +++ b/src/controllers/rocket.py @@ -1,9 +1,16 @@ +from fastapi import HTTPException, status + from src.controllers.interface import ( ControllerBase, controller_exception_handler, ) -from src.views.rocket import RocketSimulation -from src.models.rocket import RocketModel +from src.views.rocket import RocketSimulation, RocketCreated +from src.models.motor import MotorModel +from src.models.rocket import ( + RocketModel, + RocketWithMotorReferenceRequest, +) +from src.repositories.interface import RepositoryInterface from src.services.rocket import RocketService @@ -19,6 +26,37 @@ class RocketController(ControllerBase): def __init__(self): super().__init__(models=[RocketModel]) + async def _load_motor(self, motor_id: str) -> MotorModel: + repo_cls = RepositoryInterface.get_model_repo(MotorModel) + async with repo_cls() as repo: + motor = await repo.read_motor_by_id(motor_id) + if motor is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="Motor not found", + ) + return motor + + @controller_exception_handler + async def create_rocket_from_motor_reference( + self, payload: RocketWithMotorReferenceRequest + ) -> RocketCreated: + motor = await self._load_motor(payload.motor_id) + rocket_model = payload.rocket.assemble(motor) + return await self.post_rocket(rocket_model) + + @controller_exception_handler + async def update_rocket_from_motor_reference( + self, + rocket_id: str, + payload: RocketWithMotorReferenceRequest, + ) -> None: + motor = await self._load_motor(payload.motor_id) + rocket_model = payload.rocket.assemble(motor) + rocket_model.set_id(rocket_id) + await self.put_rocket_by_id(rocket_id, rocket_model) + return + @controller_exception_handler async def get_rocketpy_rocket_binary(self, rocket_id: str) -> bytes: """ diff --git a/src/models/flight.py b/src/models/flight.py index 93d36d4..b8957ca 100644 --- a/src/models/flight.py +++ b/src/models/flight.py @@ -1,4 +1,7 @@ +import json from typing import Optional, Self, ClassVar, Literal + +from pydantic import BaseModel, Field, field_validator from src.models.interface import ApiBaseModel from src.models.rocket import RocketModel from src.models.environment import EnvironmentModel @@ -69,3 +72,76 @@ def RETRIEVED(model_instance: type(Self)): **model_instance.model_dump(), ) ) + + @field_validator('environment', mode='before') + @classmethod + def _coerce_environment(cls, value): + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError as exc: + raise ValueError( + 'Invalid JSON for environment payload' + ) from exc + return value + + @field_validator('rocket', mode='before') + @classmethod + def _coerce_rocket(cls, value): + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError as exc: + raise ValueError('Invalid JSON for rocket payload') from exc + return value + + +class FlightPartialModel(BaseModel): + """Flight attributes required when rocket/environment are referenced.""" + + name: str = Field(default="flight") + rail_length: float = 1 + time_overshoot: bool = True + terminate_on_apogee: bool = False + equations_of_motion: Literal['standard', 'solid_propulsion'] = 'standard' + inclination: float = 90.0 + heading: float = 0.0 + max_time: Optional[int] = None + max_time_step: Optional[float] = None + min_time_step: Optional[int] = None + rtol: Optional[float] = None + atol: Optional[float] = None + verbose: Optional[bool] = None + + def assemble( + self, + *, + environment: EnvironmentModel, + rocket: RocketModel, + ) -> FlightModel: + """Compose a full flight model using referenced resources.""" + + flight_data = self.model_dump(exclude_none=True) + return FlightModel( + environment=environment, + rocket=rocket, + **flight_data, + ) + + +class FlightWithReferencesRequest(BaseModel): + """Payload for creating or updating flights via component references.""" + + environment_id: str + rocket_id: str + flight: FlightPartialModel + + @field_validator('flight', mode='before') + @classmethod + def _coerce_flight(cls, value): + if isinstance(value, str): + try: + value = json.loads(value) + except json.JSONDecodeError as exc: + raise ValueError('Invalid JSON for flight payload') from exc + return value diff --git a/src/models/motor.py b/src/models/motor.py index 4d1cdf5..93eb0b4 100644 --- a/src/models/motor.py +++ b/src/models/motor.py @@ -1,6 +1,7 @@ +import json from enum import Enum from typing import Optional, Tuple, List, Union, Self, ClassVar, Literal -from pydantic import model_validator +from pydantic import model_validator, field_validator from src.models.interface import ApiBaseModel from src.models.sub.tanks import MotorTank @@ -57,6 +58,18 @@ class MotorModel(ApiBaseModel): ] = 'nozzle_to_combustion_chamber' reshape_thrust_curve: Union[bool, tuple] = False + @field_validator('tanks', mode='before') + @classmethod + def _coerce_tanks(cls, value): + if isinstance(value, str): + try: + value = json.loads(value) + except json.JSONDecodeError as exc: + raise ValueError('Invalid JSON for tanks payload') from exc + if isinstance(value, dict): + value = [value] + return value + @model_validator(mode='after') # TODO: extend guard to check motor kinds and tank kinds specifics def validate_motor_kind(self): diff --git a/src/models/rocket.py b/src/models/rocket.py index 02a48f0..c1a22cd 100644 --- a/src/models/rocket.py +++ b/src/models/rocket.py @@ -1,4 +1,7 @@ +import json from typing import Optional, Tuple, List, Union, Self, ClassVar, Literal + +from pydantic import BaseModel, Field, field_validator from src.models.interface import ApiBaseModel from src.models.motor import MotorModel from src.models.sub.aerosurfaces import ( @@ -10,6 +13,15 @@ ) +def _maybe_parse_json(value): + if isinstance(value, str): + try: + return json.loads(value) + except json.JSONDecodeError as exc: + raise ValueError('Invalid JSON payload') from exc + return value + + class RocketModel(ApiBaseModel): NAME: ClassVar = "rocket" METHODS: ClassVar = ("POST", "GET", "PUT", "DELETE") @@ -37,6 +49,42 @@ class RocketModel(ApiBaseModel): rail_buttons: Optional[RailButtons] = None tail: Optional[Tail] = None + @field_validator('motor', mode='before') + @classmethod + def _coerce_motor(cls, value): + return _maybe_parse_json(value) + + @field_validator('nose', mode='before') + @classmethod + def _coerce_nose(cls, value): + return _maybe_parse_json(value) + + @field_validator('fins', mode='before') + @classmethod + def _coerce_fins(cls, value): + value = _maybe_parse_json(value) + if isinstance(value, dict): + value = [value] + return value + + @field_validator('parachutes', mode='before') + @classmethod + def _coerce_parachutes(cls, value): + value = _maybe_parse_json(value) + if isinstance(value, dict): + value = [value] + return value + + @field_validator('rail_buttons', mode='before') + @classmethod + def _coerce_rail_buttons(cls, value): + return _maybe_parse_json(value) + + @field_validator('tail', mode='before') + @classmethod + def _coerce_tail(cls, value): + return _maybe_parse_json(value) + @staticmethod def UPDATED(): return @@ -61,3 +109,53 @@ def RETRIEVED(model_instance: type(Self)): **model_instance.model_dump(), ) ) + + +class RocketPartialModel(BaseModel): + """Rocket attributes required when a motor is supplied by reference.""" + + radius: float + mass: float + motor_position: float + center_of_mass_without_motor: float + inertia: Union[ + Tuple[float, float, float], + Tuple[float, float, float, float, float, float], + ] = (0, 0, 0) + power_off_drag: List[Tuple[float, float]] = Field( + default_factory=lambda: [(0, 0)] + ) + power_on_drag: List[Tuple[float, float]] = Field( + default_factory=lambda: [(0, 0)] + ) + coordinate_system_orientation: Literal['tail_to_nose', 'nose_to_tail'] = ( + 'tail_to_nose' + ) + nose: NoseCone + fins: List[Fins] + parachutes: Optional[List[Parachute]] = None + rail_buttons: Optional[RailButtons] = None + tail: Optional[Tail] = None + + def assemble(self, motor: MotorModel) -> RocketModel: + """Compose a full rocket model using the referenced motor.""" + + rocket_data = self.model_dump(exclude_none=True) + return RocketModel(motor=motor, **rocket_data) + + +class RocketWithMotorReferenceRequest(BaseModel): + """Payload for creating or updating rockets via motor reference.""" + + motor_id: str + rocket: RocketPartialModel + + @field_validator('rocket', mode='before') + @classmethod + def _coerce_rocket(cls, value): + if isinstance(value, str): + try: + value = json.loads(value) + except json.JSONDecodeError as exc: + raise ValueError('Invalid JSON for rocket payload') from exc + return value diff --git a/src/routes/flight.py b/src/routes/flight.py index 9aff0db..6afd13a 100644 --- a/src/routes/flight.py +++ b/src/routes/flight.py @@ -11,7 +11,7 @@ FlightRetrieved, ) from src.models.environment import EnvironmentModel -from src.models.flight import FlightModel +from src.models.flight import FlightModel, FlightWithReferencesRequest from src.models.rocket import RocketModel from src.controllers.flight import FlightController @@ -41,6 +41,25 @@ async def create_flight(flight: FlightModel) -> FlightCreated: return await controller.post_flight(flight) +@router.post("/from-references", status_code=201) +async def create_flight_from_references( + payload: FlightWithReferencesRequest, +) -> FlightCreated: + """ + Creates a flight using existing rocket and environment references. + + ## Args + ``` + environment_id: str + rocket_id: str + flight: Flight-only fields JSON + ``` + """ + with tracer.start_as_current_span("create_flight_from_references"): + controller = FlightController() + return await controller.create_flight_from_references(payload) + + @router.get("/{flight_id}") async def read_flight(flight_id: str) -> FlightRetrieved: """ @@ -70,6 +89,29 @@ async def update_flight(flight_id: str, flight: FlightModel) -> None: return await controller.put_flight_by_id(flight_id, flight) +@router.put("/{flight_id}/from-references", status_code=204) +async def update_flight_from_references( + flight_id: str, + payload: FlightWithReferencesRequest, +) -> None: + """ + Updates a flight using existing rocket and environment references. + + ## Args + ``` + flight_id: str + environment_id: str + rocket_id: str + flight: Flight-only fields JSON + ``` + """ + with tracer.start_as_current_span("update_flight_from_references"): + controller = FlightController() + return await controller.update_flight_from_references( + flight_id, payload + ) + + @router.delete("/{flight_id}", status_code=204) async def delete_flight(flight_id: str) -> None: """ diff --git a/src/routes/rocket.py b/src/routes/rocket.py index 5346d9e..d5e4bdf 100644 --- a/src/routes/rocket.py +++ b/src/routes/rocket.py @@ -10,7 +10,10 @@ RocketCreated, RocketRetrieved, ) -from src.models.rocket import RocketModel +from src.models.rocket import ( + RocketModel, + RocketWithMotorReferenceRequest, +) from src.controllers.rocket import RocketController router = APIRouter( @@ -39,6 +42,24 @@ async def create_rocket(rocket: RocketModel) -> RocketCreated: return await controller.post_rocket(rocket) +@router.post("/from-motor-reference", status_code=201) +async def create_rocket_from_motor_reference( + payload: RocketWithMotorReferenceRequest, +) -> RocketCreated: + """ + Creates a rocket using an existing motor reference. + + ## Args + ``` + motor_id: str + rocket: Rocket-only fields JSON + ``` + """ + with tracer.start_as_current_span("create_rocket_from_motor_reference"): + controller = RocketController() + return await controller.create_rocket_from_motor_reference(payload) + + @router.get("/{rocket_id}") async def read_rocket(rocket_id: str) -> RocketRetrieved: """ @@ -68,6 +89,28 @@ async def update_rocket(rocket_id: str, rocket: RocketModel) -> None: return await controller.put_rocket_by_id(rocket_id, rocket) +@router.put("/{rocket_id}/from-motor-reference", status_code=204) +async def update_rocket_from_motor_reference( + rocket_id: str, + payload: RocketWithMotorReferenceRequest, +) -> None: + """ + Updates a rocket using an existing motor reference. + + ## Args + ``` + rocket_id: str + motor_id: str + rocket: Rocket-only fields JSON + ``` + """ + with tracer.start_as_current_span("update_rocket_from_motor_reference"): + controller = RocketController() + return await controller.update_rocket_from_motor_reference( + rocket_id, payload + ) + + @router.delete("/{rocket_id}", status_code=204) async def delete_rocket(rocket_id: str) -> None: """ diff --git a/src/services/motor.py b/src/services/motor.py index cc3ce69..7275920 100644 --- a/src/services/motor.py +++ b/src/services/motor.py @@ -14,6 +14,8 @@ TankGeometry, ) +from fastapi import HTTPException, status + from src.models.sub.tanks import TankKinds from src.models.motor import MotorKinds, MotorModel from src.views.motor import MotorSimulation @@ -35,6 +37,12 @@ def from_motor_model(cls, motor: MotorModel) -> Self: MotorService containing the rocketpy motor object. """ + reshape_thrust_curve = motor.reshape_thrust_curve + if isinstance(reshape_thrust_curve, bool): + reshape_thrust_curve = False + elif isinstance(reshape_thrust_curve, list): + reshape_thrust_curve = tuple(reshape_thrust_curve) + motor_core = { "thrust_source": motor.thrust_source, "burn_time": motor.burn_time, @@ -44,7 +52,7 @@ def from_motor_model(cls, motor: MotorModel) -> Self: "center_of_dry_mass_position": motor.center_of_dry_mass_position, "coordinate_system_orientation": motor.coordinate_system_orientation, "interpolation_method": motor.interpolation_method, - "reshape_thrust_curve": False or motor.reshape_thrust_curve, + "reshape_thrust_curve": reshape_thrust_curve, } match MotorKinds(motor.motor_kind): @@ -63,15 +71,36 @@ def from_motor_model(cls, motor: MotorModel) -> Self: grains_center_of_mass_position=motor.grains_center_of_mass_position, ) case MotorKinds.SOLID: + grain_params = { + 'grain_number': motor.grain_number, + 'grain_density': motor.grain_density, + 'grain_outer_radius': motor.grain_outer_radius, + 'grain_initial_inner_radius': motor.grain_initial_inner_radius, + 'grain_initial_height': motor.grain_initial_height, + 'grain_separation': motor.grain_separation, + 'grains_center_of_mass_position': motor.grains_center_of_mass_position, + } + + missing = [ + key for key, value in grain_params.items() if value is None + ] + if missing: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=( + "Solid motor requires grain configuration: missing " + + ', '.join(missing) + ), + ) + + optional_params = {} + if motor.throat_radius is not None: + optional_params['throat_radius'] = motor.throat_radius + rocketpy_motor = SolidMotor( **motor_core, - grain_number=motor.grain_number, - grain_density=motor.grain_density, - grain_outer_radius=motor.grain_outer_radius, - grain_initial_inner_radius=motor.grain_initial_inner_radius, - grain_initial_height=motor.grain_initial_height, - grains_center_of_mass_position=motor.grains_center_of_mass_position, - grain_separation=motor.grain_separation, + **grain_params, + **optional_params, ) case _: rocketpy_motor = GenericMotor( @@ -84,7 +113,7 @@ def from_motor_model(cls, motor: MotorModel) -> Self: ) if motor.motor_kind not in (MotorKinds.SOLID, MotorKinds.GENERIC): - for tank in motor.tanks: + for tank in motor.tanks or []: tank_core = { "name": tank.name, "geometry": TankGeometry( diff --git a/src/services/rocket.py b/src/services/rocket.py index f67fbde..5864a81 100644 --- a/src/services/rocket.py +++ b/src/services/rocket.py @@ -12,6 +12,8 @@ Tail as RocketPyTail, ) +from fastapi import HTTPException, status + from src import logger from src.models.rocket import RocketModel, Parachute from src.models.sub.aerosurfaces import NoseCone, Tail, Fins @@ -132,13 +134,20 @@ def get_rocketpy_nose(nose: NoseCone) -> RocketPyNoseCone: RocketPyNoseCone """ - rocketpy_nose = RocketPyNoseCone( - name=nose.name, - length=nose.length, - kind=nose.kind, - base_radius=nose.base_radius, - rocket_radius=nose.rocket_radius, - ) + try: + rocketpy_nose = RocketPyNoseCone( + name=nose.name, + length=nose.length, + kind=nose.kind, + base_radius=nose.base_radius, + rocket_radius=nose.rocket_radius, + ) + except ValueError as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=str(exc), + ) from exc + rocketpy_nose.position = nose.position return rocketpy_nose @@ -159,25 +168,40 @@ def get_rocketpy_finset(fins: Fins, kind: str) -> RocketPyFins: RocketPyTrapezoidalFins RocketPyEllipticalFins """ + + if fins.rocket_radius is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Fin definition missing required field 'rocket_radius'", + ) + + base_kwargs = { + 'n': fins.n, + 'name': fins.name, + 'root_chord': fins.root_chord, + 'span': fins.span, + 'rocket_radius': fins.rocket_radius, + } + match kind: case "trapezoidal": - rocketpy_finset = RocketPyTrapezoidalFins( - n=fins.n, - name=fins.name, - root_chord=fins.root_chord, - span=fins.span, - **fins.get_additional_parameters(), - ) + factory = RocketPyTrapezoidalFins case "elliptical": - rocketpy_finset = RocketPyEllipticalFins( - n=fins.n, - name=fins.name, - root_chord=fins.root_chord, - span=fins.span, - **fins.get_additional_parameters(), - ) + factory = RocketPyEllipticalFins case _: raise ValueError(f"Invalid fins kind: {kind}") + + try: + rocketpy_finset = factory( + **base_kwargs, + **fins.get_additional_parameters(), + ) + except (TypeError, ValueError) as exc: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=str(exc), + ) from exc + rocketpy_finset.position = fins.position return rocketpy_finset diff --git a/tests/unit/test_routes/test_flights_route.py b/tests/unit/test_routes/test_flights_route.py index 0b8ac30..5190c45 100644 --- a/tests/unit/test_routes/test_flights_route.py +++ b/tests/unit/test_routes/test_flights_route.py @@ -1,10 +1,11 @@ from unittest.mock import patch, Mock, AsyncMock +import copy import json import pytest from fastapi.testclient import TestClient from fastapi import HTTPException, status from src.models.environment import EnvironmentModel -from src.models.flight import FlightModel +from src.models.flight import FlightModel, FlightWithReferencesRequest from src.models.rocket import RocketModel from src.views.motor import MotorView from src.views.rocket import RocketView @@ -54,9 +55,23 @@ def mock_controller_instance(): mock_controller_instance.get_rocketpy_flight_binary = Mock() mock_controller_instance.update_environment_by_flight_id = Mock() mock_controller_instance.update_rocket_by_flight_id = Mock() + mock_controller_instance.create_flight_from_references = Mock() + mock_controller_instance.update_flight_from_references = Mock() yield mock_controller_instance +@pytest.fixture +def stub_flight_reference_payload(stub_flight_dump): + partial_flight = copy.deepcopy(stub_flight_dump) + partial_flight.pop('environment', None) + partial_flight.pop('rocket', None) + return { + 'environment_id': 'env-123', + 'rocket_id': 'rocket-456', + 'flight': partial_flight, + } + + def test_create_flight(stub_flight_dump, mock_controller_instance): mock_response = AsyncMock(return_value=FlightCreated(flight_id='123')) mock_controller_instance.post_flight = mock_response @@ -71,6 +86,81 @@ def test_create_flight(stub_flight_dump, mock_controller_instance): ) +def test_create_flight_with_string_nested_fields( + stub_flight_dump, mock_controller_instance +): + payload = copy.deepcopy(stub_flight_dump) + payload['environment'] = json.dumps(payload['environment']) + payload['rocket'] = json.dumps(payload['rocket']) + mock_controller_instance.post_flight = AsyncMock( + return_value=FlightCreated(flight_id='123') + ) + response = client.post('/flights/', json=payload) + assert response.status_code == 201 + mock_controller_instance.post_flight.assert_called_once_with( + FlightModel(**payload) + ) + + +def test_create_flight_from_references( + stub_flight_reference_payload, mock_controller_instance +): + mock_response = AsyncMock(return_value=FlightCreated(flight_id='123')) + mock_controller_instance.create_flight_from_references = mock_response + response = client.post( + '/flights/from-references', json=stub_flight_reference_payload + ) + assert response.status_code == 201 + assert response.json() == { + 'flight_id': '123', + 'message': 'Flight successfully created', + } + mock_controller_instance.create_flight_from_references.assert_called_once_with( + FlightWithReferencesRequest(**stub_flight_reference_payload) + ) + + +def test_create_flight_from_references_with_string_payload( + stub_flight_reference_payload, mock_controller_instance +): + payload = copy.deepcopy(stub_flight_reference_payload) + payload['flight'] = json.dumps(payload['flight']) + mock_controller_instance.create_flight_from_references = AsyncMock( + return_value=FlightCreated(flight_id='123') + ) + response = client.post('/flights/from-references', json=payload) + assert response.status_code == 201 + mock_controller_instance.create_flight_from_references.assert_called_once_with( + FlightWithReferencesRequest(**payload) + ) + + +def test_create_flight_from_references_not_found( + stub_flight_reference_payload, mock_controller_instance +): + mock_controller_instance.create_flight_from_references.side_effect = ( + HTTPException(status_code=status.HTTP_404_NOT_FOUND) + ) + response = client.post( + '/flights/from-references', json=stub_flight_reference_payload + ) + assert response.status_code == 404 + assert response.json() == {'detail': 'Not Found'} + + +def test_create_flight_from_references_server_error( + stub_flight_reference_payload, mock_controller_instance +): + mock_controller_instance.create_flight_from_references.side_effect = ( + HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + ) + response = client.post( + '/flights/from-references', json=stub_flight_reference_payload + ) + assert response.status_code == 500 + assert response.json() == {'detail': 'Internal Server Error'} + + def test_create_flight_optional_params( stub_flight_dump, mock_controller_instance ): @@ -173,6 +263,78 @@ def test_update_flight_by_id(stub_flight_dump, mock_controller_instance): ) +def test_update_flight_with_string_nested_fields( + stub_flight_dump, mock_controller_instance +): + payload = copy.deepcopy(stub_flight_dump) + payload['environment'] = json.dumps(payload['environment']) + payload['rocket'] = json.dumps(payload['rocket']) + mock_controller_instance.put_flight_by_id = AsyncMock(return_value=None) + response = client.put('/flights/123', json=payload) + assert response.status_code == 204 + mock_controller_instance.put_flight_by_id.assert_called_once_with( + '123', FlightModel(**payload) + ) + + +def test_update_flight_from_references( + stub_flight_reference_payload, mock_controller_instance +): + mock_response = AsyncMock(return_value=None) + mock_controller_instance.update_flight_from_references = mock_response + response = client.put( + '/flights/123/from-references', + json=stub_flight_reference_payload, + ) + assert response.status_code == 204 + mock_controller_instance.update_flight_from_references.assert_called_once_with( + '123', FlightWithReferencesRequest(**stub_flight_reference_payload) + ) + + +def test_update_flight_from_references_with_string_payload( + stub_flight_reference_payload, mock_controller_instance +): + payload = copy.deepcopy(stub_flight_reference_payload) + payload['flight'] = json.dumps(payload['flight']) + mock_controller_instance.update_flight_from_references = AsyncMock( + return_value=None + ) + response = client.put('/flights/123/from-references', json=payload) + assert response.status_code == 204 + mock_controller_instance.update_flight_from_references.assert_called_once_with( + '123', FlightWithReferencesRequest(**payload) + ) + + +def test_update_flight_from_references_not_found( + stub_flight_reference_payload, mock_controller_instance +): + mock_controller_instance.update_flight_from_references.side_effect = ( + HTTPException(status_code=status.HTTP_404_NOT_FOUND) + ) + response = client.put( + '/flights/123/from-references', + json=stub_flight_reference_payload, + ) + assert response.status_code == 404 + assert response.json() == {'detail': 'Not Found'} + + +def test_update_flight_from_references_server_error( + stub_flight_reference_payload, mock_controller_instance +): + mock_controller_instance.update_flight_from_references.side_effect = ( + HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + ) + response = client.put( + '/flights/123/from-references', + json=stub_flight_reference_payload, + ) + assert response.status_code == 500 + assert response.json() == {'detail': 'Internal Server Error'} + + def test_update_environment_by_flight_id( stub_environment_dump, mock_controller_instance ): diff --git a/tests/unit/test_routes/test_rockets_route.py b/tests/unit/test_routes/test_rockets_route.py index 6b73d8c..a91041d 100644 --- a/tests/unit/test_routes/test_rockets_route.py +++ b/tests/unit/test_routes/test_rockets_route.py @@ -1,4 +1,5 @@ from unittest.mock import patch, Mock, AsyncMock +import copy import json import pytest from fastapi.testclient import TestClient @@ -8,7 +9,10 @@ RailButtons, Parachute, ) -from src.models.rocket import RocketModel +from src.models.rocket import ( + RocketModel, + RocketWithMotorReferenceRequest, +) from src.views.rocket import ( RocketCreated, RocketRetrieved, @@ -78,9 +82,21 @@ def mock_controller_instance(): mock_controller_instance.delete_rocket_by_id = Mock() mock_controller_instance.get_rocket_simulation = Mock() mock_controller_instance.get_rocketpy_rocket_binary = Mock() + mock_controller_instance.create_rocket_from_motor_reference = Mock() + mock_controller_instance.update_rocket_from_motor_reference = Mock() yield mock_controller_instance +@pytest.fixture +def stub_rocket_reference_payload(stub_rocket_dump): + partial_rocket = copy.deepcopy(stub_rocket_dump) + partial_rocket.pop('motor', None) + return { + 'motor_id': 'motor-123', + 'rocket': partial_rocket, + } + + def test_create_rocket(stub_rocket_dump, mock_controller_instance): mock_response = AsyncMock(return_value=RocketCreated(rocket_id='123')) mock_controller_instance.post_rocket = mock_response @@ -95,6 +111,89 @@ def test_create_rocket(stub_rocket_dump, mock_controller_instance): ) +def test_create_rocket_with_string_nested_fields( + stub_rocket_dump, + stub_parachute_dump, + stub_rail_buttons_dump, + stub_tail_dump, + mock_controller_instance, +): + payload = copy.deepcopy(stub_rocket_dump) + payload['motor'] = json.dumps(payload['motor']) + payload['nose'] = json.dumps(payload['nose']) + payload['fins'] = json.dumps(payload['fins']) + payload['parachutes'] = json.dumps([stub_parachute_dump]) + payload['rail_buttons'] = json.dumps(stub_rail_buttons_dump) + payload['tail'] = json.dumps(stub_tail_dump) + + mock_response = AsyncMock(return_value=RocketCreated(rocket_id='123')) + mock_controller_instance.post_rocket = mock_response + response = client.post('/rockets/', json=payload) + assert response.status_code == 201 + mock_controller_instance.post_rocket.assert_called_once_with( + RocketModel(**payload) + ) + + +def test_create_rocket_from_motor_reference( + stub_rocket_reference_payload, mock_controller_instance +): + mock_response = AsyncMock(return_value=RocketCreated(rocket_id='123')) + mock_controller_instance.create_rocket_from_motor_reference = mock_response + response = client.post( + '/rockets/from-motor-reference', json=stub_rocket_reference_payload + ) + assert response.status_code == 201 + assert response.json() == { + 'rocket_id': '123', + 'message': 'Rocket successfully created', + } + mock_controller_instance.create_rocket_from_motor_reference.assert_called_once_with( + RocketWithMotorReferenceRequest(**stub_rocket_reference_payload) + ) + + +def test_create_rocket_from_motor_reference_with_string_payload( + stub_rocket_reference_payload, mock_controller_instance +): + payload = copy.deepcopy(stub_rocket_reference_payload) + payload['rocket'] = json.dumps(payload['rocket']) + mock_controller_instance.create_rocket_from_motor_reference = AsyncMock( + return_value=RocketCreated(rocket_id='123') + ) + response = client.post('/rockets/from-motor-reference', json=payload) + assert response.status_code == 201 + mock_controller_instance.create_rocket_from_motor_reference.assert_called_once_with( + RocketWithMotorReferenceRequest(**payload) + ) + + +def test_create_rocket_from_motor_reference_not_found( + stub_rocket_reference_payload, mock_controller_instance +): + mock_controller_instance.create_rocket_from_motor_reference.side_effect = ( + HTTPException(status_code=status.HTTP_404_NOT_FOUND) + ) + response = client.post( + '/rockets/from-motor-reference', json=stub_rocket_reference_payload + ) + assert response.status_code == 404 + assert response.json() == {'detail': 'Not Found'} + + +def test_create_rocket_from_motor_reference_server_error( + stub_rocket_reference_payload, mock_controller_instance +): + mock_controller_instance.create_rocket_from_motor_reference.side_effect = ( + HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + ) + response = client.post( + '/rockets/from-motor-reference', json=stub_rocket_reference_payload + ) + assert response.status_code == 500 + assert response.json() == {'detail': 'Internal Server Error'} + + def test_create_rocket_optional_params( stub_rocket_dump, stub_tail_dump, @@ -365,6 +464,30 @@ def test_update_rocket(stub_rocket_dump, mock_controller_instance): ) +def test_update_rocket_with_string_nested_fields( + stub_rocket_dump, + stub_parachute_dump, + stub_rail_buttons_dump, + stub_tail_dump, + mock_controller_instance, +): + payload = copy.deepcopy(stub_rocket_dump) + payload['motor']['motor_kind'] = 'SOLID' + payload['motor'] = json.dumps(payload['motor']) + payload['nose'] = json.dumps(payload['nose']) + payload['fins'] = json.dumps(payload['fins']) + payload['parachutes'] = json.dumps([stub_parachute_dump]) + payload['rail_buttons'] = json.dumps(stub_rail_buttons_dump) + payload['tail'] = json.dumps(stub_tail_dump) + + mock_controller_instance.put_rocket_by_id = AsyncMock(return_value=None) + response = client.put('/rockets/123', json=payload) + assert response.status_code == 204 + mock_controller_instance.put_rocket_by_id.assert_called_once_with( + '123', RocketModel(**payload) + ) + + def test_update_rocket_invalid_input(): response = client.put( '/rockets/123', json={'mass': 'foo', 'radius': 'bar'} @@ -372,6 +495,64 @@ def test_update_rocket_invalid_input(): assert response.status_code == 422 +def test_update_rocket_from_motor_reference( + stub_rocket_reference_payload, mock_controller_instance +): + mock_response = AsyncMock(return_value=None) + mock_controller_instance.update_rocket_from_motor_reference = mock_response + response = client.put( + '/rockets/123/from-motor-reference', + json=stub_rocket_reference_payload, + ) + assert response.status_code == 204 + mock_controller_instance.update_rocket_from_motor_reference.assert_called_once_with( + '123', RocketWithMotorReferenceRequest(**stub_rocket_reference_payload) + ) + + +def test_update_rocket_from_motor_reference_with_string_payload( + stub_rocket_reference_payload, mock_controller_instance +): + payload = copy.deepcopy(stub_rocket_reference_payload) + payload['rocket'] = json.dumps(payload['rocket']) + mock_controller_instance.update_rocket_from_motor_reference = AsyncMock( + return_value=None + ) + response = client.put('/rockets/123/from-motor-reference', json=payload) + assert response.status_code == 204 + mock_controller_instance.update_rocket_from_motor_reference.assert_called_once_with( + '123', RocketWithMotorReferenceRequest(**payload) + ) + + +def test_update_rocket_from_motor_reference_not_found( + stub_rocket_reference_payload, mock_controller_instance +): + mock_controller_instance.update_rocket_from_motor_reference.side_effect = ( + HTTPException(status_code=status.HTTP_404_NOT_FOUND) + ) + response = client.put( + '/rockets/123/from-motor-reference', + json=stub_rocket_reference_payload, + ) + assert response.status_code == 404 + assert response.json() == {'detail': 'Not Found'} + + +def test_update_rocket_from_motor_reference_server_error( + stub_rocket_reference_payload, mock_controller_instance +): + mock_controller_instance.update_rocket_from_motor_reference.side_effect = ( + HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR) + ) + response = client.put( + '/rockets/123/from-motor-reference', + json=stub_rocket_reference_payload, + ) + assert response.status_code == 500 + assert response.json() == {'detail': 'Internal Server Error'} + + def test_update_rocket_not_found(stub_rocket_dump, mock_controller_instance): mock_controller_instance.put_rocket_by_id.side_effect = HTTPException( status_code=status.HTTP_404_NOT_FOUND