Skip to content

Commit 76033d1

Browse files
authored
Make Client: Send + Sync (#116)
1 parent ee926a1 commit 76033d1

File tree

7 files changed

+97
-80
lines changed

7 files changed

+97
-80
lines changed

src/client.rs

+32-19
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
use std::cell::RefCell;
21
use std::fmt::Debug;
32
use std::io::Write;
43
use std::sync::atomic::{AtomicI32, Ordering};
4+
use std::sync::{Arc, Mutex};
55

66
use byteorder::{BigEndian, WriteBytesExt};
77
use log::{debug, error, info};
@@ -40,7 +40,7 @@ pub struct Client {
4040

4141
managed_accounts: String,
4242
client_id: i32, // ID of client.
43-
pub(crate) message_bus: RefCell<Box<dyn MessageBus>>,
43+
pub(crate) message_bus: Arc<Mutex<dyn MessageBus>>,
4444
next_request_id: AtomicI32, // Next available request_id.
4545
order_id: AtomicI32, // Next available order_id. Starts with value returned on connection.
4646
}
@@ -67,11 +67,11 @@ impl Client {
6767
/// println!("next_order_id: {}", client.next_order_id());
6868
/// ```
6969
pub fn connect(address: &str, client_id: i32) -> Result<Client, Error> {
70-
let message_bus = RefCell::new(Box::new(TcpMessageBus::connect(address)?));
70+
let message_bus = Arc::new(Mutex::new(TcpMessageBus::connect(address)?));
7171
Client::do_connect(client_id, message_bus)
7272
}
7373

74-
fn do_connect(client_id: i32, message_bus: RefCell<Box<dyn MessageBus>>) -> Result<Client, Error> {
74+
fn do_connect(client_id: i32, message_bus: Arc<Mutex<dyn MessageBus>>) -> Result<Client, Error> {
7575
let mut client = Client {
7676
server_version: 0,
7777
connection_time: None,
@@ -87,7 +87,11 @@ impl Client {
8787
client.start_api()?;
8888
client.receive_account_info()?;
8989

90-
client.message_bus.borrow_mut().process_messages(client.server_version)?;
90+
client
91+
.message_bus
92+
.lock()
93+
.expect("MessageBus is poisoned")
94+
.process_messages(client.server_version)?;
9195

9296
Ok(client)
9397
}
@@ -98,9 +102,9 @@ impl Client {
98102
let version = format!("v{MIN_SERVER_VERSION}..{MAX_SERVER_VERSION}");
99103

100104
let packet = prefix.to_owned() + &encode_packet(&version);
101-
self.message_bus.borrow_mut().write(&packet)?;
105+
self.message_bus.lock().expect("MessageBus is poisoned").write(&packet)?;
102106

103-
let ack = self.message_bus.borrow_mut().read_message();
107+
let ack = self.message_bus.lock().expect("MessageBus is poisoned").read_message();
104108

105109
match ack {
106110
Ok(mut response_message) => {
@@ -133,7 +137,7 @@ impl Client {
133137
prelude.push_field(&"");
134138
}
135139

136-
self.message_bus.borrow_mut().write_message(prelude)?;
140+
self.message_bus.lock().expect("MessageBus is poisoned").write_message(prelude)?;
137141

138142
Ok(())
139143
}
@@ -146,7 +150,7 @@ impl Client {
146150
let mut attempts = 0;
147151
const MAX_ATTEMPTS: i32 = 100;
148152
loop {
149-
let mut message = self.message_bus.borrow_mut().read_message()?;
153+
let mut message = self.message_bus.lock().expect("MessageBus is poisoned").read_message()?;
150154

151155
match message.message_type() {
152156
IncomingMessages::NextValidId => {
@@ -886,7 +890,7 @@ impl Client {
886890
// == Internal Use ==
887891

888892
#[cfg(test)]
889-
pub(crate) fn stubbed(message_bus: RefCell<Box<dyn MessageBus>>, server_version: i32) -> Client {
893+
pub(crate) fn stubbed(message_bus: Arc<Mutex<dyn MessageBus>>, server_version: i32) -> Client {
890894
Client {
891895
server_version: server_version,
892896
connection_time: None,
@@ -900,47 +904,56 @@ impl Client {
900904
}
901905

902906
pub(crate) fn send_message(&self, packet: RequestMessage) -> Result<(), Error> {
903-
self.message_bus.borrow_mut().write_message(&packet)
907+
self.message_bus.lock().expect("MessageBus is poisoned").write_message(&packet)
904908
}
905909

906910
pub(crate) fn send_request(&self, request_id: i32, message: RequestMessage) -> Result<ResponseIterator, Error> {
907911
debug!("send_message({:?}, {:?})", request_id, message);
908-
self.message_bus.borrow_mut().send_generic_message(request_id, &message)
912+
self.message_bus
913+
.lock()
914+
.expect("MessageBus is poisoned")
915+
.send_generic_message(request_id, &message)
909916
}
910917

911918
pub(crate) fn send_durable_request(&self, request_id: i32, message: RequestMessage) -> Result<ResponseIterator, Error> {
912919
debug!("send_durable_request({:?}, {:?})", request_id, message);
913-
self.message_bus.borrow_mut().send_durable_message(request_id, &message)
920+
self.message_bus
921+
.lock()
922+
.expect("MessageBus is poisoned")
923+
.send_durable_message(request_id, &message)
914924
}
915925

916926
pub(crate) fn send_order(&self, order_id: i32, message: RequestMessage) -> Result<ResponseIterator, Error> {
917927
debug!("send_order({:?}, {:?})", order_id, message);
918-
self.message_bus.borrow_mut().send_order_message(order_id, &message)
928+
self.message_bus
929+
.lock()
930+
.expect("MessageBus is poisoned")
931+
.send_order_message(order_id, &message)
919932
}
920933

921934
/// Sends request for the next valid order id.
922935
pub(crate) fn request_next_order_id(&self, message: RequestMessage) -> Result<GlobalResponseIterator, Error> {
923-
self.message_bus.borrow_mut().request_next_order_id(&message)
936+
self.message_bus.lock().expect("MessageBus is poisoned").request_next_order_id(&message)
924937
}
925938

926939
/// Sends request for open orders.
927940
pub(crate) fn request_order_data(&self, message: RequestMessage) -> Result<GlobalResponseIterator, Error> {
928-
self.message_bus.borrow_mut().request_open_orders(&message)
941+
self.message_bus.lock().expect("MessageBus is poisoned").request_open_orders(&message)
929942
}
930943

931944
/// Sends request for market rule.
932945
pub(crate) fn request_market_rule(&self, message: RequestMessage) -> Result<GlobalResponseIterator, Error> {
933-
self.message_bus.borrow_mut().request_market_rule(&message)
946+
self.message_bus.lock().expect("MessageBus is poisoned").request_market_rule(&message)
934947
}
935948

936949
/// Sends request for positions.
937950
pub(crate) fn request_positions(&self, message: RequestMessage) -> Result<GlobalResponseIterator, Error> {
938-
self.message_bus.borrow_mut().request_positions(&message)
951+
self.message_bus.lock().expect("MessageBus is poisoned").request_positions(&message)
939952
}
940953

941954
/// Sends request for family codes.
942955
pub(crate) fn request_family_codes(&self, message: RequestMessage) -> Result<GlobalResponseIterator, Error> {
943-
self.message_bus.borrow_mut().request_family_codes(&message)
956+
self.message_bus.lock().expect("MessageBus is poisoned").request_family_codes(&message)
944957
}
945958

946959
pub(crate) fn check_server_version(&self, version: i32, message: &str) -> Result<(), Error> {

src/client/transport.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::net::TcpStream;
55
use std::sync::{Arc, RwLock};
66
use std::thread::{self, JoinHandle};
77
use std::time::Duration;
8+
use std::sync::Mutex;
89

910
use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
1011
use crossbeam::channel::{self, Receiver, Sender};
@@ -17,7 +18,7 @@ use recorder::MessageRecorder;
1718

1819
mod recorder;
1920

20-
pub(crate) trait MessageBus {
21+
pub(crate) trait MessageBus: Send + Sync {
2122
fn read_message(&mut self) -> Result<ResponseMessage, Error>;
2223

2324
fn write_message(&mut self, packet: &RequestMessage) -> Result<(), Error>;
@@ -43,7 +44,7 @@ pub(crate) trait MessageBus {
4344
#[derive(Debug)]
4445
pub struct TcpMessageBus {
4546
reader: Arc<TcpStream>,
46-
writer: Box<TcpStream>,
47+
writer: Arc<Mutex<TcpStream>>,
4748
handles: Vec<JoinHandle<i32>>,
4849
requests: Arc<SenderHash<i32, ResponseMessage>>,
4950
orders: Arc<SenderHash<i32, ResponseMessage>>,
@@ -101,7 +102,7 @@ impl TcpMessageBus {
101102
let stream = TcpStream::connect(connection_string)?;
102103

103104
let reader = Arc::new(stream.try_clone()?);
104-
let writer = Box::new(stream);
105+
let writer = Arc::new(Mutex::new(stream));
105106
let requests = Arc::new(SenderHash::new());
106107
let orders = Arc::new(SenderHash::new());
107108

@@ -213,7 +214,7 @@ impl MessageBus for TcpMessageBus {
213214
packet.write_u32::<BigEndian>(data.len() as u32)?;
214215
packet.write_all(data)?;
215216

216-
self.writer.write_all(&packet)?;
217+
self.writer.lock().expect("MessageBus writer is poisoned").write_all(&packet)?;
217218

218219
self.recorder.record_request(message);
219220

@@ -222,7 +223,7 @@ impl MessageBus for TcpMessageBus {
222223

223224
fn write(&mut self, data: &str) -> Result<(), Error> {
224225
debug!("{data:?} ->");
225-
self.writer.write_all(data.as_bytes())?;
226+
self.writer.lock().expect("MessageBus writer is poisoned").write_all(data.as_bytes())?;
226227
Ok(())
227228
}
228229

src/contracts/tests.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
use std::cell::RefCell;
1+
use std::sync::RwLock;
2+
use std::sync::{Arc, Mutex};
23

34
use super::*;
45

56
use crate::stubs::MessageBusStub;
67

78
#[test]
89
fn request_stock_contract_details() {
9-
let message_bus = RefCell::new(Box::new(MessageBusStub{
10-
request_messages: RefCell::new(vec![]),
10+
let message_bus = Arc::new(Mutex::new(MessageBusStub{
11+
request_messages: RwLock::new(vec![]),
1112
response_messages: vec![
1213
"10|9001|TSLA|STK||0||SMART|USD|TSLA|NMS|NMS|76792991|0.01||ACTIVETIM,AD,ADJUST,ALERT,ALGO,ALLOC,AON,AVGCOST,BASKET,BENCHPX,CASHQTY,COND,CONDORDER,DARKONLY,DARKPOLL,DAY,DEACT,DEACTDIS,DEACTEOD,DIS,DUR,GAT,GTC,GTD,GTT,HID,IBKRATS,ICE,IMB,IOC,LIT,LMT,LOC,MIDPX,MIT,MKT,MOC,MTL,NGCOMB,NODARK,NONALGO,OCA,OPG,OPGREROUT,PEGBENCH,PEGMID,POSTATS,POSTONLY,PREOPGRTH,PRICECHK,REL,REL2MID,RELPCTOFS,RPI,RTH,SCALE,SCALEODD,SCALERST,SIZECHK,SNAPMID,SNAPMKT,SNAPREL,STP,STPLMT,SWEEP,TRAIL,TRAILLIT,TRAILLMT,TRAILMIT,WHATIF|SMART,AMEX,NYSE,CBOE,PHLX,ISE,CHX,ARCA,ISLAND,DRCTEDGE,BEX,BATS,EDGEA,CSFBALGO,JEFFALGO,BYX,IEX,EDGX,FOXRIVER,PEARL,NYSENAT,LTSE,MEMX,PSX|1|0|TESLA INC|NASDAQ||Consumer, Cyclical|Auto Manufacturers|Auto-Cars/Light Trucks|US/Eastern|20221229:0400-20221229:2000;20221230:0400-20221230:2000;20221231:CLOSED;20230101:CLOSED;20230102:CLOSED;20230103:0400-20230103:2000|20221229:0930-20221229:1600;20221230:0930-20221230:1600;20221231:CLOSED;20230101:CLOSED;20230102:CLOSED;20230103:0930-20230103:1600|||1|ISIN|US88160R1014|1|||26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26||COMMON|1|1|100||".to_string(),
1314
"10|9001|TSLA|STK||0||AMEX|USD|TSLA|NMS|NMS|76792991|0.01||ACTIVETIM,AD,ADJUST,ALERT,ALLOC,AVGCOST,BASKET,BENCHPX,CASHQTY,COND,CONDORDER,DAY,DEACT,DEACTDIS,DEACTEOD,GAT,GTC,GTD,GTT,HID,IOC,LIT,LMT,MIT,MKT,MTL,NGCOMB,NONALGO,OCA,PEGBENCH,SCALE,SCALERST,SNAPMID,SNAPMKT,SNAPREL,STP,STPLMT,TRAIL,TRAILLIT,TRAILLMT,TRAILMIT,WHATIF|SMART,AMEX,NYSE,CBOE,PHLX,ISE,CHX,ARCA,ISLAND,DRCTEDGE,BEX,BATS,EDGEA,CSFBALGO,JEFFALGO,BYX,IEX,EDGX,FOXRIVER,PEARL,NYSENAT,LTSE,MEMX,PSX|1|0|TESLA INC|NASDAQ||Consumer, Cyclical|Auto Manufacturers|Auto-Cars/Light Trucks|US/Eastern|20221229:0700-20221229:2000;20221230:0700-20221230:2000;20221231:CLOSED;20230101:CLOSED;20230102:CLOSED;20230103:0700-20230103:2000|20221229:0700-20221229:2000;20221230:0700-20221230:2000;20221231:CLOSED;20230101:CLOSED;20230102:CLOSED;20230103:0700-20230103:2000|||1|ISIN|US88160R1014|1|||26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26,26||COMMON|1|1|100||".to_string(),
@@ -21,7 +22,7 @@ fn request_stock_contract_details() {
2122

2223
let results = client.contract_details(&contract);
2324

24-
let request_messages = client.message_bus.borrow().request_messages();
25+
let request_messages = client.message_bus.lock().expect("MessageBus is poisoned").request_messages();
2526

2627
assert_eq!(request_messages[0].encode_simple(), "9|8|9000|0|TSLA|STK||0|||SMART||USD|||0|||");
2728

src/market_data/historical/tests.rs

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use std::cell::RefCell;
1+
use std::sync::RwLock;
2+
use std::sync::{Arc, Mutex};
23

34
use time::macros::datetime;
45

@@ -10,8 +11,8 @@ use super::*;
1011

1112
#[test]
1213
fn test_head_timestamp() {
13-
let message_bus = RefCell::new(Box::new(MessageBusStub {
14-
request_messages: RefCell::new(vec![]),
14+
let message_bus = Arc::new(Mutex::new(MessageBusStub {
15+
request_messages: RwLock::new(vec![]),
1516
response_messages: vec!["9|9000|1678323335|".to_owned()],
1617
}));
1718

@@ -27,7 +28,7 @@ fn test_head_timestamp() {
2728

2829
assert_eq!(head_timestamp, OffsetDateTime::from_unix_timestamp(1678323335).unwrap(), "bar.date");
2930

30-
let request_messages = client.message_bus.borrow().request_messages();
31+
let request_messages = client.message_bus.lock().expect("MessageBus is poisoned").request_messages();
3132

3233
let head_timestamp_request = &request_messages[0];
3334
assert_eq!(
@@ -70,8 +71,8 @@ fn test_histogram_data() {
7071

7172
#[test]
7273
fn test_historical_data() {
73-
let message_bus = RefCell::new(Box::new(MessageBusStub {
74-
request_messages: RefCell::new(vec![]),
74+
let message_bus = Arc::new(Mutex::new(MessageBusStub {
75+
request_messages: RwLock::new(vec![]),
7576
response_messages: vec![
7677
"17\09000\020230413 16:31:22\020230415 16:31:22\02\020230413\0182.9400\0186.5000\0180.9400\0185.9000\0948837.22\0184.869\0324891\020230414\0183.8800\0186.2800\0182.0100\0185.0000\0810998.27\0183.9865\0277547\0".to_owned()
7778
],
@@ -107,7 +108,7 @@ fn test_historical_data() {
107108

108109
// Assert Request
109110

110-
let request_messages = client.message_bus.borrow().request_messages();
111+
let request_messages = client.message_bus.lock().expect("MessageBus is poisoned").request_messages();
111112

112113
let head_timestamp_request = &request_messages[0];
113114
assert_eq!(

src/market_data/realtime/tests.rs

+5-4
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
use std::cell::RefCell;
1+
use std::sync::RwLock;
2+
use std::sync::{Arc, Mutex};
23

34
use time::OffsetDateTime;
45

@@ -11,8 +12,8 @@ use super::*;
1112

1213
#[test]
1314
fn realtime_bars() {
14-
let message_bus = RefCell::new(Box::new(MessageBusStub {
15-
request_messages: RefCell::new(vec![]),
15+
let message_bus = Arc::new(Mutex::new(MessageBusStub {
16+
request_messages: RwLock::new(vec![]),
1617
response_messages: vec!["50|3|9001|1678323335|4028.75|4029.00|4028.25|4028.50|2|4026.75|1|".to_owned()],
1718
}));
1819

@@ -46,7 +47,7 @@ fn realtime_bars() {
4647
// Should trigger cancel realtime bars
4748
drop(bars);
4849

49-
let request_messages = client.message_bus.borrow().request_messages();
50+
let request_messages = client.message_bus.lock().expect("MessageBus is poisoned").request_messages();
5051

5152
// Verify Requests
5253
let realtime_bars_request = &request_messages[0];

0 commit comments

Comments
 (0)