Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor channels #134

Merged
merged 2 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/accounts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
use crate::client::{ResponseContext, SharesChannel, Subscribable, Subscription};
use crate::contracts::Contract;
use crate::messages::{IncomingMessages, OutgoingMessages, RequestMessage, ResponseMessage};
use crate::transport::Response;
use crate::{server_versions, Client, Error};

mod decoders;
Expand Down Expand Up @@ -411,7 +410,7 @@ pub(crate) fn family_codes(client: &Client) -> Result<Vec<FamilyCode>, Error> {
let subscription = client.send_shared_request(OutgoingMessages::RequestFamilyCodes, request)?;

// TODO: enumerate
if let Some(Response::Message(mut message)) = subscription.next() {
if let Some(Ok(mut message)) = subscription.next() {
decoders::decode_family_codes(&mut message)
} else {
Ok(Vec::default())
Expand Down Expand Up @@ -492,15 +491,14 @@ pub fn managed_accounts(client: &Client) -> Result<Vec<String>, Error> {
let subscription = client.send_shared_request(OutgoingMessages::RequestManagedAccounts, request)?;

match subscription.next() {
Some(Response::Message(mut message)) => {
Some(Ok(mut message)) => {
message.skip(); // message type
message.skip(); // message version

let accounts = message.next_string()?;
Ok(accounts.split(",").map(String::from).collect())
}
Some(Response::Cancelled) => Err(Error::Cancelled),
Some(Response::Disconnected) => Err(Error::ConnectionFailed),
Some(Err(e)) => Err(e),
None => Ok(Vec::default()),
}
}
Expand Down
12 changes: 6 additions & 6 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::market_data::realtime::{self, Bar, BarSize, MidPoint, WhatToShow};
use crate::messages::{IncomingMessages, OutgoingMessages};
use crate::messages::{RequestMessage, ResponseMessage};
use crate::orders::{Order, OrderDataResult, OrderNotification};
use crate::transport::{Connection, ConnectionMetadata, InternalSubscription, MessageBus, Response, TcpMessageBus};
use crate::transport::{Connection, ConnectionMetadata, InternalSubscription, MessageBus, TcpMessageBus};
use crate::{accounts, contracts, orders};

// Client
Expand Down Expand Up @@ -1162,7 +1162,7 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {
pub fn next(&self) -> Option<T> {
loop {
match self.subscription.next() {
Some(Response::Message(mut message)) => {
Some(Ok(mut message)) => {
if T::RESPONSE_MESSAGE_IDS.contains(&message.message_type()) {
match T::decode(self.client.server_version(), &mut message) {
Ok(val) => return Some(val),
Expand All @@ -1178,11 +1178,11 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {
info!("subscription iterator unexpected message: {message:?}");
}
}
Some(Response::Cancelled) => {
Some(Err(Error::Cancelled)) => {
debug!("subscription cancelled");
return None;
}
Some(Response::Disconnected) => {
Some(Err(Error::Shutdown)) => {
debug!("server disconnected");
return None;
}
Expand All @@ -1205,7 +1205,7 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {
/// //}
/// ```
pub fn try_next(&self) -> Option<T> {
if let Some(Response::Message(mut message)) = self.subscription.try_next() {
if let Some(Ok(mut message)) = self.subscription.try_next() {
if message.message_type() == IncomingMessages::Error {
error!("{}", message.peek_string(4));
return None;
Expand Down Expand Up @@ -1235,7 +1235,7 @@ impl<'a, T: Subscribable<T>> Subscription<'a, T> {
/// //}
/// ```
pub fn next_timeout(&self, timeout: Duration) -> Option<T> {
if let Some(Response::Message(mut message)) = self.subscription.next_timeout(timeout) {
if let Some(Ok(mut message)) = self.subscription.next_timeout(timeout) {
if message.message_type() == IncomingMessages::Error {
error!("{}", message.peek_string(4));
return None;
Expand Down
10 changes: 4 additions & 6 deletions src/contracts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ use crate::messages::IncomingMessages;
use crate::messages::OutgoingMessages;
use crate::messages::RequestMessage;
use crate::messages::ResponseMessage;
use crate::transport::Response;
use crate::Client;
use crate::{server_versions, Error, ToField};

Expand Down Expand Up @@ -458,7 +457,7 @@ pub(crate) fn contract_details(client: &Client, contract: &Contract) -> Result<V
let mut contract_details: Vec<ContractDetails> = Vec::default();

// TODO create iterator
while let Some(Response::Message(mut message)) = responses.next() {
while let Some(Ok(mut message)) = responses.next() {
match message.message_type() {
IncomingMessages::ContractData => {
let decoded = decoders::decode_contract_details(client.server_version(), &mut message)?;
Expand Down Expand Up @@ -531,7 +530,7 @@ pub(crate) fn matching_symbols(client: &Client, pattern: &str) -> Result<Vec<Con
let request = encoders::encode_request_matching_symbols(request_id, pattern)?;
let subscription = client.send_request(request_id, request)?;

if let Some(Response::Message(mut message)) = subscription.next() {
if let Some(Ok(mut message)) = subscription.next() {
match message.message_type() {
IncomingMessages::SymbolSamples => {
return decoders::decode_contract_descriptions(client.server_version(), &mut message);
Expand Down Expand Up @@ -574,9 +573,8 @@ pub(crate) fn market_rule(client: &Client, market_rule_id: i32) -> Result<Market
let subscription = client.send_shared_request(OutgoingMessages::RequestMarketRule, request)?;

match subscription.next() {
Some(Response::Message(mut message)) => Ok(decoders::decode_market_rule(&mut message)?),
Some(Response::Cancelled) => Err(Error::Simple("subscription cancelled".into())),
Some(Response::Disconnected) => Err(Error::Simple("server gone".into())),
Some(Ok(mut message)) => Ok(decoders::decode_market_rule(&mut message)?),
Some(Err(e)) => Err(e),
None => Err(Error::Simple("no market rule found".into())),
}
}
Expand Down
15 changes: 10 additions & 5 deletions src/errors.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::{num::ParseIntError, string::FromUtf8Error};
use std::{num::ParseIntError, string::FromUtf8Error, sync::Arc};

#[derive(Debug)]
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum Error {
// Errors from external libraries
Io(std::io::Error),
Io(Arc<std::io::Error>),
ParseInt(ParseIntError),
FromUtf8(FromUtf8Error),
ParseTime(time::error::Parse),
Expand All @@ -17,6 +17,7 @@ pub enum Error {
Simple(String),
ConnectionFailed,
Cancelled,
Shutdown,
}

impl std::error::Error for Error {}
Expand All @@ -35,6 +36,7 @@ impl std::fmt::Display for Error {
Error::ServerVersion(wanted, have, message) => write!(f, "server version {wanted} required, got {have}: {message}"),
Error::ConnectionFailed => write!(f, "ConnectionFailed"),
Error::Cancelled => write!(f, "Cancelled"),
Error::Shutdown => write!(f, "Shutdown"),

Error::Simple(ref err) => write!(f, "error occurred: {err}"),
}
Expand All @@ -43,7 +45,7 @@ impl std::fmt::Display for Error {

impl From<std::io::Error> for Error {
fn from(err: std::io::Error) -> Error {
Error::Io(err)
Error::Io(Arc::new(err))
}
}

Expand Down Expand Up @@ -89,7 +91,10 @@ mod tests {
#[test]
fn test_error_display() {
let cases = vec![
(Error::Io(io::Error::new(io::ErrorKind::NotFound, "file not found")), "file not found"),
(
Error::Io(Arc::new(io::Error::new(io::ErrorKind::NotFound, "file not found"))),
"file not found",
),
(Error::ParseInt("123x".parse::<i32>().unwrap_err()), "invalid digit found in string"),
(
Error::FromUtf8(String::from_utf8(vec![0, 159, 146, 150]).unwrap_err()),
Expand Down
10 changes: 5 additions & 5 deletions src/market_data/historical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use time::{Date, OffsetDateTime};

use crate::contracts::Contract;
use crate::messages::{IncomingMessages, RequestMessage, ResponseMessage};
use crate::transport::{InternalSubscription, Response};
use crate::transport::InternalSubscription;
use crate::{server_versions, Client, Error, ToField};

mod decoders;
Expand Down Expand Up @@ -304,7 +304,7 @@ pub(crate) fn head_timestamp(client: &Client, contract: &Contract, what_to_show:

let subscription = client.send_request(request_id, request)?;

if let Some(Response::Message(mut message)) = subscription.next() {
if let Some(Ok(mut message)) = subscription.next() {
decoders::decode_head_timestamp(&mut message)
} else {
Err(Error::Simple("did not receive head timestamp message".into()))
Expand Down Expand Up @@ -359,7 +359,7 @@ pub(crate) fn historical_data(

let subscription = client.send_request(request_id, request)?;

if let Some(Response::Message(mut message)) = subscription.next() {
if let Some(Ok(mut message)) = subscription.next() {
let time_zone = if let Some(tz) = client.time_zone {
tz
} else {
Expand Down Expand Up @@ -410,7 +410,7 @@ pub(crate) fn historical_schedule(

let subscription = client.send_request(request_id, request)?;

if let Some(Response::Message(mut message)) = subscription.next() {
if let Some(Ok(mut message)) = subscription.next() {
match message.message_type() {
IncomingMessages::HistoricalSchedule => decoders::decode_historical_schedule(&mut message),
IncomingMessages::Error => Err(Error::Simple(message.peek_string(4))),
Expand Down Expand Up @@ -547,7 +547,7 @@ impl<T: TickDecoder<T> + Debug> Iterator for TickIterator<T> {

loop {
match self.messages.next() {
Some(Response::Message(mut message)) => {
Some(Ok(mut message)) => {
if message.message_type() == Self::Item::message_type() {
let (ticks, done) = Self::Item::decode(&mut message).unwrap();

Expand Down
7 changes: 3 additions & 4 deletions src/market_data/realtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::contracts::Contract;
use crate::messages::{IncomingMessages, RequestMessage, ResponseMessage};
use crate::orders::TagValue;
use crate::server_versions;
use crate::transport::{InternalSubscription, Response};
use crate::transport::InternalSubscription;
use crate::ToField;
use crate::{Client, Error};

Expand Down Expand Up @@ -320,7 +320,7 @@ impl<'a> Iterator for TradeIterator<'a> {
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.responses.next() {
Some(Response::Message(mut message)) => match message.message_type() {
Some(Ok(mut message)) => match message.message_type() {
IncomingMessages::TickByTick => match decoders::decode_trade_tick(&mut message) {
Ok(tick) => return Some(tick),
Err(e) => error!("unexpected message {message:?}: {e:?}"),
Expand Down Expand Up @@ -363,14 +363,13 @@ impl<'a> Iterator for BidAskIterator<'a> {
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.responses.next() {
Some(Response::Message(mut message)) => match message.message_type() {
Some(Ok(mut message)) => match message.message_type() {
IncomingMessages::TickByTick => match decoders::bid_ask_tick(&mut message) {
Ok(tick) => return Some(tick),
Err(e) => error!("unexpected message {message:?}: {e:?}"),
},
_ => error!("unexpected message {message:?}"),
},
// TODO enumerate
_ => return None,
}
}
Expand Down
12 changes: 6 additions & 6 deletions src/orders.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use log::{error, info};
use crate::contracts::{ComboLeg, ComboLegOpenClose, Contract, DeltaNeutralContract, SecurityType};
use crate::messages::{IncomingMessages, OutgoingMessages};
use crate::messages::{RequestMessage, ResponseMessage};
use crate::transport::{InternalSubscription, Response};
use crate::transport::InternalSubscription;
use crate::Client;
use crate::{encode_option_field, ToField};
use crate::{server_versions, Error};
Expand Down Expand Up @@ -1055,7 +1055,7 @@ impl Iterator for OrderNotificationIterator {
}

loop {
if let Some(Response::Message(mut message)) = self.messages.next() {
if let Some(Ok(mut message)) = self.messages.next() {
match message.message_type() {
IncomingMessages::OpenOrder => {
let open_order = decoders::decode_open_order(self.server_version, message);
Expand Down Expand Up @@ -1332,7 +1332,7 @@ impl Iterator for CancelOrderResultIterator {
/// Returns the next [CancelOrderResult]. Waits up to x seconds for next [CancelOrderResult].
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(Response::Message(mut message)) = self.messages.next() {
if let Some(Ok(mut message)) = self.messages.next() {
match message.message_type() {
IncomingMessages::OrderStatus => match decoders::decode_order_status(self.server_version, &mut message) {
Ok(val) => return Some(CancelOrderResult::OrderStatus(val)),
Expand Down Expand Up @@ -1373,7 +1373,7 @@ pub(crate) fn next_valid_order_id(client: &Client) -> Result<i32, Error> {

let subscription = client.send_shared_request(OutgoingMessages::RequestIds, message)?;

if let Some(Response::Message(message)) = subscription.next() {
if let Some(Ok(message)) = subscription.next() {
let order_id_index = 2;
let next_order_id = message.peek_int(order_id_index)?;

Expand Down Expand Up @@ -1418,7 +1418,7 @@ impl Iterator for OrderDataIterator {
/// Returns the next [OrderDataResult]. Waits up to x seconds for next [OrderDataResult].
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(Response::Message(mut message)) = self.messages.next() {
if let Some(Ok(mut message)) = self.messages.next() {
match message.message_type() {
IncomingMessages::CompletedOrder => match decoders::decode_completed_order(self.server_version, message) {
Ok(val) => return Some(OrderDataResult::OrderData(Box::new(val))),
Expand Down Expand Up @@ -1554,7 +1554,7 @@ impl Iterator for ExecutionDataIterator {
/// Returns the next [OrderDataResult]. Waits up to x seconds for next [OrderDataResult].
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(Response::Message(mut message)) = self.messages.next() {
if let Some(Ok(mut message)) = self.messages.next() {
match message.message_type() {
IncomingMessages::ExecutionData => match decoders::decode_execution_data(self.server_version, &mut message) {
Ok(val) => return Some(ExecutionDataResult::ExecutionData(Box::new(val))),
Expand Down
4 changes: 2 additions & 2 deletions src/stubs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock};
use crossbeam::channel;

use crate::messages::{OutgoingMessages, RequestMessage, ResponseMessage};
use crate::transport::{InternalSubscription, MessageBus, Response, SubscriptionBuilder};
use crate::transport::{InternalSubscription, MessageBus, SubscriptionBuilder};
use crate::Error;

pub(crate) struct MessageBusStub {
Expand Down Expand Up @@ -66,7 +66,7 @@ fn mock_request(

for message in &stub.response_messages {
let message = ResponseMessage::from(&message.replace('|', "\0"));
sender.send(Response::from(message)).unwrap();
sender.send(Ok(message)).unwrap();
}

let mut subscription = SubscriptionBuilder::new().shared_receiver(Arc::new(receiver)).signaler(s1);
Expand Down
Loading
Loading