diff --git a/RELEASE_NOTES.md b/RELEASE_NOTES.md index f590c52..f6d3752 100644 --- a/RELEASE_NOTES.md +++ b/RELEASE_NOTES.md @@ -6,16 +6,10 @@ ## Upgrading - +* Update base client from version `0.6.1` to `0.7.0` and upgrade the `Client` constructor accordingly. ## New Features -* Replace assert statements with proper exception handling -* Implement client instance reuse to avoid redundant TCP connections -* Move documentation and code examples to the documentation website -* Replace the local `PaginationParams` type with the `frequenz-client-common` one -* Remove dependency to `googleapis-common-protos` -* Replace `Energy` with `Power` for the `quantity` representation * Add str function for `DeliveryPeriod` object * Add integration tests for the API * Add an equality function to the Order type diff --git a/pyproject.toml b/pyproject.toml index 27aab4e..2b08665 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ dependencies = [ "frequenz-api-common >= 0.6.3, < 0.7.0", "grpcio >= 1.66.2, < 2", "frequenz-channels >= 1.0.0, < 2", - "frequenz-client-base >= 0.6.1, < 0.7.0", + "frequenz-client-base >= 0.7.0, < 0.8.0", "frequenz-client-common >= 0.1.0, < 0.3.0", "frequenz-api-electricity-trading >= 0.2.4, < 1", "protobuf >= 5.28.0, < 6", @@ -145,6 +145,7 @@ disable = [ "unsubscriptable-object", # Checked by mypy "no-member", + "no-name-in-module", # Checked by flake8 "f-string-without-interpolation", "line-too-long", diff --git a/src/frequenz/client/electricity_trading/_client.py b/src/frequenz/client/electricity_trading/_client.py index ec4976e..24079a2 100644 --- a/src/frequenz/client/electricity_trading/_client.py +++ b/src/frequenz/client/electricity_trading/_client.py @@ -3,10 +3,12 @@ """Module to define the client class.""" +from __future__ import annotations + import logging from datetime import datetime, timezone from decimal import Decimal, InvalidOperation -from typing import Awaitable, cast +from typing import TYPE_CHECKING, Any, Awaitable, cast import grpc @@ -17,9 +19,11 @@ ) from frequenz.channels import Receiver from frequenz.client.base.client import BaseApiClient +from frequenz.client.base.exception import ClientNotConnected from frequenz.client.base.streaming import GrpcStreamBroadcaster from frequenz.client.common.pagination import Params from google.protobuf import field_mask_pb2, struct_pb2 +from typing_extensions import override from ._types import ( DeliveryArea, @@ -41,6 +45,12 @@ UpdateOrder, ) +if TYPE_CHECKING: + from frequenz.api.electricity_trading.v1.electricity_trading_pb2_grpc import ( + ElectricityTradingServiceAsyncStub, + ) + + _logger = logging.getLogger(__name__) @@ -81,7 +91,7 @@ def validate_decimal_places(value: Decimal, decimal_places: int, name: str) -> N ) from exc -class Client(BaseApiClient[ElectricityTradingServiceStub]): +class Client(BaseApiClient): """Electricity trading client.""" _instances: dict[tuple[str, str | None], "Client"] = {} @@ -123,7 +133,10 @@ def __init__( if not hasattr( self, "_initialized" ): # Prevent re-initialization of existing instances - super().__init__(server_url, ElectricityTradingServiceStub, connect=connect) + super().__init__(server_url, connect=connect) + self._stub: ElectricityTradingServiceAsyncStub | None = None + if connect: + self._create_stub() self._initialized = True self._gridpool_orders_streams: dict[ @@ -149,6 +162,41 @@ def __init__( self._metadata = (("key", auth_key),) if auth_key else () + def _create_stub(self) -> None: + """Create a new gRPC stub for the Electricity Trading service.""" + stub: Any = ElectricityTradingServiceStub(self.channel) + self._stub = stub + + @override + def connect(self, server_url: str | None = None) -> None: + """Connect to the server, possibly using a new URL. + + If the client is already connected and the URL is the same as the previous URL, + this method does nothing. If you want to force a reconnection, you can call + [disconnect()][frequenz.client.base.client.BaseApiClient.disconnect] first. + + Args: + server_url: The URL of the server to connect to. If not provided, the + previously used URL is used. + """ + super().connect(server_url) + self._create_stub() + + @property + def stub(self) -> ElectricityTradingServiceAsyncStub: + """ + Get the gRPC stub for the Electricity Trading service. + + Returns: + The gRPC stub. + + Raises: + ClientNotConnected: If the client is not connected to the server. + """ + if self._stub is None: + raise ClientNotConnected(server_url=self.server_url, operation="stub") + return self._stub + async def stream_gridpool_orders( # pylint: disable=too-many-arguments, too-many-positional-arguments self, @@ -192,7 +240,7 @@ async def stream_gridpool_orders( try: self._gridpool_orders_streams[stream_key] = GrpcStreamBroadcaster( f"electricity-trading-{stream_key}", - lambda: self.stub.ReceiveGridpoolOrdersStream( # type: ignore + lambda: self.stub.ReceiveGridpoolOrdersStream( electricity_trading_pb2.ReceiveGridpoolOrdersStreamRequest( gridpool_id=gridpool_id, filter=gridpool_order_filter.to_pb(), @@ -251,7 +299,7 @@ async def stream_gridpool_trades( try: self._gridpool_trades_streams[stream_key] = GrpcStreamBroadcaster( f"electricity-trading-{stream_key}", - lambda: self.stub.ReceiveGridpoolTradesStream( # type: ignore + lambda: self.stub.ReceiveGridpoolTradesStream( electricity_trading_pb2.ReceiveGridpoolTradesStreamRequest( gridpool_id=gridpool_id, filter=gridpool_trade_filter.to_pb(), @@ -303,7 +351,7 @@ async def stream_public_trades( self._public_trades_streams[public_trade_filter] = ( GrpcStreamBroadcaster( f"electricity-trading-{public_trade_filter}", - lambda: self.stub.ReceivePublicTradesStream( # type: ignore + lambda: self.stub.ReceivePublicTradesStream( electricity_trading_pb2.ReceivePublicTradesStreamRequest( filter=public_trade_filter.to_pb(), ), diff --git a/tests/test_client.py b/tests/test_client.py index dbfff79..0a8c88e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,6 +3,7 @@ """Tests for the methods in the client.""" import asyncio +from dataclasses import dataclass from datetime import datetime, timedelta, timezone from decimal import Decimal from unittest.mock import AsyncMock @@ -35,13 +36,31 @@ ) +@dataclass +class SetupParams: # pylint: disable=too-many-instance-attributes + """Parameters for the setup of the test suite.""" + + client: Client + mock_stub: AsyncMock + loop: asyncio.AbstractEventLoop + gridpool_id: int + delivery_area: DeliveryArea + delivery_period: DeliveryPeriod + order_type: OrderType + side: MarketSide + price: Price + quantity: Power + order_execution_option: OrderExecutionOption + valid_until: datetime + + @pytest.fixture def set_up() -> Generator[Any, Any, Any]: """Set up the test suite.""" # Create a mock client and stub - _ = Client("grpc://unknown.host", connect=False) + client = Client("grpc://unknown.host", connect=False) mock_stub = AsyncMock() - _._stub = mock_stub # pylint: disable=protected-access + client._stub = mock_stub # pylint: disable=protected-access # Create a new event loop for each test loop = asyncio.new_event_loop() @@ -65,40 +84,40 @@ def set_up() -> Generator[Any, Any, Any]: order_execution_option = OrderExecutionOption.AON valid_until = delivery_start + timedelta(hours=3) - yield { - "client": _, - "mock_stub": mock_stub, - "loop": loop, - "gridpool_id": gridpool_id, - "delivery_area": delivery_area, - "delivery_period": delivery_period, - "order_type": order_type, - "side": side, - "price": price, - "quantity": quantity, - "order_execution_option": order_execution_option, - "valid_until": valid_until, - } + yield SetupParams( + client=client, + mock_stub=mock_stub, + loop=loop, + gridpool_id=gridpool_id, + delivery_area=delivery_area, + delivery_period=delivery_period, + order_type=order_type, + side=side, + price=price, + quantity=quantity, + order_execution_option=order_execution_option, + valid_until=valid_until, + ) loop.close() # pylint: disable=redefined-outer-name def set_up_order_detail_response( - set_up: dict[str, Any], + set_up: SetupParams, order_id: int = 1, ) -> electricity_trading_pb2.OrderDetail: """Set up an order detail response.""" return OrderDetail( order_id=order_id, order=Order( - delivery_area=set_up["delivery_area"], - delivery_period=set_up["delivery_period"], - type=set_up["order_type"], - side=set_up["side"], - price=set_up["price"], - quantity=set_up["quantity"], - execution_option=set_up["order_execution_option"], + delivery_area=set_up.delivery_area, + delivery_period=set_up.delivery_period, + type=set_up.order_type, + side=set_up.side, + price=set_up.price, + quantity=set_up.quantity, + execution_option=set_up.order_execution_option, ), state_detail=StateDetail( state=OrderState.ACTIVE, @@ -107,77 +126,77 @@ def set_up_order_detail_response( ), open_quantity=Power(mw=Decimal("5.00")), filled_quantity=Power(mw=Decimal("0.00")), - create_time=set_up["delivery_period"].start - timedelta(hours=2), - modification_time=set_up["delivery_period"].start - timedelta(hours=1), + create_time=set_up.delivery_period.start - timedelta(hours=2), + modification_time=set_up.delivery_period.start - timedelta(hours=1), ).to_pb() -def test_stream_gridpool_orders(set_up: dict[str, Any]) -> None: +def test_stream_gridpool_orders(set_up: SetupParams) -> None: """Test the method streaming gridpool orders.""" - set_up["loop"].run_until_complete( - set_up["client"].stream_gridpool_orders(set_up["gridpool_id"]) + set_up.loop.run_until_complete( + set_up.client.stream_gridpool_orders(set_up.gridpool_id) ) - set_up["mock_stub"].ReceiveGridpoolOrdersStream.assert_called_once() - args, _ = set_up["mock_stub"].ReceiveGridpoolOrdersStream.call_args - assert args[0].gridpool_id == set_up["gridpool_id"] + set_up.mock_stub.ReceiveGridpoolOrdersStream.assert_called_once() + args, _ = set_up.mock_stub.ReceiveGridpoolOrdersStream.call_args + assert args[0].gridpool_id == set_up.gridpool_id -def test_stream_gridpool_orders_with_optional_inputs(set_up: dict[str, Any]) -> None: +def test_stream_gridpool_orders_with_optional_inputs(set_up: SetupParams) -> None: """Test the method streaming gridpool orders with some fields to filter for.""" # Fields to filter for order_states = [OrderState.ACTIVE] - set_up["loop"].run_until_complete( - set_up["client"].stream_gridpool_orders( - set_up["gridpool_id"], order_states=order_states + set_up.loop.run_until_complete( + set_up.client.stream_gridpool_orders( + set_up.gridpool_id, order_states=order_states ) ) - set_up["mock_stub"].ReceiveGridpoolOrdersStream.assert_called_once() - args, _ = set_up["mock_stub"].ReceiveGridpoolOrdersStream.call_args - assert args[0].gridpool_id == set_up["gridpool_id"] + set_up.mock_stub.ReceiveGridpoolOrdersStream.assert_called_once() + args, _ = set_up.mock_stub.ReceiveGridpoolOrdersStream.call_args + assert args[0].gridpool_id == set_up.gridpool_id assert args[0].filter.states == [ order_state.to_pb() for order_state in order_states ] def test_stream_gridpool_trades( - set_up: dict[str, Any], + set_up: SetupParams, ) -> None: """Test the method streaming gridpool trades.""" - set_up["loop"].run_until_complete( - set_up["client"].stream_gridpool_trades( - gridpool_id=set_up["gridpool_id"], market_side=set_up["side"] + set_up.loop.run_until_complete( + set_up.client.stream_gridpool_trades( + gridpool_id=set_up.gridpool_id, market_side=set_up.side ) ) - set_up["mock_stub"].ReceiveGridpoolTradesStream.assert_called_once() - args, _ = set_up["mock_stub"].ReceiveGridpoolTradesStream.call_args - assert args[0].gridpool_id == set_up["gridpool_id"] - assert args[0].filter.side == set_up["side"].to_pb() + set_up.mock_stub.ReceiveGridpoolTradesStream.assert_called_once() + args, _ = set_up.mock_stub.ReceiveGridpoolTradesStream.call_args + assert args[0].gridpool_id == set_up.gridpool_id + assert args[0].filter.side == set_up.side.to_pb() def test_stream_public_trades( - set_up: dict[str, Any], + set_up: SetupParams, ) -> None: """Test the method streaming public trades.""" # Fields to filter for trade_states = [TradeState.ACTIVE] - set_up["loop"].run_until_complete( - set_up["client"].stream_public_trades(states=trade_states) + set_up.loop.run_until_complete( + set_up.client.stream_public_trades(states=trade_states) ) - set_up["mock_stub"].ReceivePublicTradesStream.assert_called_once() - args, _ = set_up["mock_stub"].ReceivePublicTradesStream.call_args + set_up.mock_stub.ReceivePublicTradesStream.assert_called_once() + args, _ = set_up.mock_stub.ReceivePublicTradesStream.call_args assert args[0].filter.states == [ trade_state.to_pb() for trade_state in trade_states ] def test_create_gridpool_order( - set_up: dict[str, Any], + set_up: SetupParams, ) -> None: """ Test the method creating a gridpool order. @@ -190,34 +209,34 @@ def test_create_gridpool_order( mock_response = electricity_trading_pb2.CreateGridpoolOrderResponse( order_detail=order_detail_response ) - set_up["mock_stub"].CreateGridpoolOrder.return_value = mock_response - - set_up["loop"].run_until_complete( - set_up["client"].create_gridpool_order( - gridpool_id=set_up["gridpool_id"], - delivery_area=set_up["delivery_area"], - delivery_period=set_up["delivery_period"], - order_type=set_up["order_type"], - side=set_up["side"], - price=set_up["price"], - quantity=set_up["quantity"], - execution_option=set_up["order_execution_option"], # optional field + set_up.mock_stub.CreateGridpoolOrder.return_value = mock_response + + set_up.loop.run_until_complete( + set_up.client.create_gridpool_order( + gridpool_id=set_up.gridpool_id, + delivery_area=set_up.delivery_area, + delivery_period=set_up.delivery_period, + order_type=set_up.order_type, + side=set_up.side, + price=set_up.price, + quantity=set_up.quantity, + execution_option=set_up.order_execution_option, # optional field ) ) - set_up["mock_stub"].CreateGridpoolOrder.assert_called_once() - args, _ = set_up["mock_stub"].CreateGridpoolOrder.call_args - assert args[0].gridpool_id == set_up["gridpool_id"] - assert args[0].order.type == set_up["order_type"].to_pb() - assert args[0].order.quantity == set_up["quantity"].to_pb() - assert args[0].order.price == set_up["price"].to_pb() - assert args[0].order.delivery_period == set_up["delivery_period"].to_pb() - assert args[0].order.delivery_area == set_up["delivery_area"].to_pb() - assert args[0].order.execution_option == set_up["order_execution_option"].to_pb() + set_up.mock_stub.CreateGridpoolOrder.assert_called_once() + args, _ = set_up.mock_stub.CreateGridpoolOrder.call_args + assert args[0].gridpool_id == set_up.gridpool_id + assert args[0].order.type == set_up.order_type.to_pb() + assert args[0].order.quantity == set_up.quantity.to_pb() + assert args[0].order.price == set_up.price.to_pb() + assert args[0].order.delivery_period == set_up.delivery_period.to_pb() + assert args[0].order.delivery_area == set_up.delivery_area.to_pb() + assert args[0].order.execution_option == set_up.order_execution_option.to_pb() def test_update_gridpool_order( - set_up: dict[str, Any], + set_up: SetupParams, ) -> None: """Test the method updating a gridpool order.""" # Setup the expected response with valid values, @@ -226,23 +245,23 @@ def test_update_gridpool_order( mock_response = electricity_trading_pb2.UpdateGridpoolOrderResponse( order_detail=order_detail_response ) - set_up["mock_stub"].UpdateGridpoolOrder.return_value = mock_response + set_up.mock_stub.UpdateGridpoolOrder.return_value = mock_response - set_up["loop"].run_until_complete( - set_up["client"].update_gridpool_order( - gridpool_id=set_up["gridpool_id"], + set_up.loop.run_until_complete( + set_up.client.update_gridpool_order( + gridpool_id=set_up.gridpool_id, order_id=1, - quantity=set_up["quantity"], - valid_until=set_up["valid_until"], + quantity=set_up.quantity, + valid_until=set_up.valid_until, ) ) valid_until_pb = timestamp_pb2.Timestamp() - valid_until_pb.FromDatetime(set_up["valid_until"]) + valid_until_pb.FromDatetime(set_up.valid_until) - set_up["mock_stub"].UpdateGridpoolOrder.assert_called_once() - args, _ = set_up["mock_stub"].UpdateGridpoolOrder.call_args - assert args[0].update_order_fields.quantity == set_up["quantity"].to_pb() + set_up.mock_stub.UpdateGridpoolOrder.assert_called_once() + args, _ = set_up.mock_stub.UpdateGridpoolOrder.call_args + assert args[0].update_order_fields.quantity == set_up.quantity.to_pb() assert args[0].update_order_fields.valid_until == valid_until_pb # Test that other fields e.g. price are not set assert not args[0].update_order_fields.HasField( @@ -251,7 +270,7 @@ def test_update_gridpool_order( def test_cancel_gridpool_order( - set_up: dict[str, Any], + set_up: SetupParams, ) -> None: """Test the method cancelling gridpool orders.""" # Setup the expected response with valid values, @@ -264,22 +283,22 @@ def test_cancel_gridpool_order( # Order to cancel order_id = 1 - set_up["mock_stub"].CancelGridpoolOrder.return_value = mock_response + set_up.mock_stub.CancelGridpoolOrder.return_value = mock_response - set_up["loop"].run_until_complete( - set_up["client"].cancel_gridpool_order( - gridpool_id=set_up["gridpool_id"], order_id=order_id + set_up.loop.run_until_complete( + set_up.client.cancel_gridpool_order( + gridpool_id=set_up.gridpool_id, order_id=order_id ) ) - set_up["mock_stub"].CancelGridpoolOrder.assert_called_once() - args, _ = set_up["mock_stub"].CancelGridpoolOrder.call_args - assert args[0].gridpool_id == set_up["gridpool_id"] + set_up.mock_stub.CancelGridpoolOrder.assert_called_once() + args, _ = set_up.mock_stub.CancelGridpoolOrder.call_args + assert args[0].gridpool_id == set_up.gridpool_id assert args[0].order_id == order_id def test_list_gridpool_orders( - set_up: dict[str, Any], + set_up: SetupParams, ) -> None: """Test the method listing gridpool orders.""" # Setup the expected response with valid values, @@ -288,20 +307,20 @@ def test_list_gridpool_orders( mock_response = electricity_trading_pb2.ListGridpoolOrdersResponse( order_details=[order_detail_response] ) - set_up["mock_stub"].ListGridpoolOrders.return_value = mock_response + set_up.mock_stub.ListGridpoolOrders.return_value = mock_response # Fields to filter for side = MarketSide.BUY order_states = [OrderState.ACTIVE] - set_up["loop"].run_until_complete( - set_up["client"].list_gridpool_orders( - gridpool_id=set_up["gridpool_id"], side=side, order_states=order_states + set_up.loop.run_until_complete( + set_up.client.list_gridpool_orders( + gridpool_id=set_up.gridpool_id, side=side, order_states=order_states ) ) - set_up["mock_stub"].ListGridpoolOrders.assert_called_once() - args, _ = set_up["mock_stub"].ListGridpoolOrders.call_args + set_up.mock_stub.ListGridpoolOrders.assert_called_once() + args, _ = set_up.mock_stub.ListGridpoolOrders.call_args assert args[0].filter.states == [ order_state.to_pb() for order_state in order_states ] @@ -375,7 +394,7 @@ def test_list_gridpool_orders( ) def test_create_gridpool_order_with_invalid_params( # pylint: disable=too-many-arguments, too-many-positional-arguments - set_up: dict[str, Any], + set_up: SetupParams, price: Price, quantity: Power, delivery_period: DeliveryPeriod, @@ -385,10 +404,10 @@ def test_create_gridpool_order_with_invalid_params( ) -> None: """Test creating an order with invalid input parameters.""" with pytest.raises(expected_exception): - set_up["loop"].run_until_complete( - set_up["client"].create_gridpool_order( - gridpool_id=set_up["gridpool_id"], - delivery_area=set_up["delivery_area"], + set_up.loop.run_until_complete( + set_up.client.create_gridpool_order( + gridpool_id=set_up.gridpool_id, + delivery_area=set_up.delivery_area, delivery_period=delivery_period, order_type=OrderType.LIMIT, side=MarketSide.BUY, @@ -427,7 +446,7 @@ def test_create_gridpool_order_with_invalid_params( ], ) def test_update_gridpool_order_with_invalid_params( # pylint: disable=too-many-arguments - set_up: dict[str, Any], + set_up: SetupParams, price: Price, quantity: Power, valid_until: datetime, @@ -435,9 +454,9 @@ def test_update_gridpool_order_with_invalid_params( # pylint: disable=too-many- ) -> None: """Test updating an order with invalid input parameters.""" with pytest.raises(expected_exception): - set_up["loop"].run_until_complete( - set_up["client"].update_gridpool_order( - gridpool_id=set_up["gridpool_id"], + set_up.loop.run_until_complete( + set_up.client.update_gridpool_order( + gridpool_id=set_up.gridpool_id, order_id=1, price=price, quantity=quantity,