1
- use std:: { fmt:: Debug , marker :: PhantomData , net :: SocketAddr , ops:: Deref } ;
1
+ use std:: { fmt:: Debug , io , marker :: PhantomData , ops:: Deref } ;
2
2
3
3
use channel:: none:: NoReceiver ;
4
4
use serde:: { de:: DeserializeOwned , Serialize } ;
@@ -33,6 +33,20 @@ pub trait Channels<S: Service> {
33
33
type Rx : Receiver ;
34
34
}
35
35
36
+ mod wasm_browser {
37
+ #![ allow( dead_code) ]
38
+ pub type BoxedFuture < ' a , T > = std:: pin:: Pin < Box < dyn std:: future:: Future < Output = T > + ' a > > ;
39
+ }
40
+ mod multithreaded {
41
+ #![ allow( dead_code) ]
42
+ pub type BoxedFuture < ' a , T > =
43
+ std:: pin:: Pin < Box < dyn std:: future:: Future < Output = T > + Send + ' a > > ;
44
+ }
45
+ #[ cfg( not( feature = "wasm-browser" ) ) ]
46
+ pub use multithreaded:: * ;
47
+ #[ cfg( feature = "wasm-browser" ) ]
48
+ pub use wasm_browser:: * ;
49
+
36
50
/// Channels that abstract over local or remote sending
37
51
pub mod channel {
38
52
/// Oneshot channel, similar to tokio's oneshot channel
@@ -48,7 +62,7 @@ pub mod channel {
48
62
Tokio ( tokio:: sync:: oneshot:: Sender < T > ) ,
49
63
Boxed (
50
64
Box <
51
- dyn FnOnce ( T ) -> n0_future :: future :: Boxed < io:: Result < ( ) > >
65
+ dyn FnOnce ( T ) -> crate :: BoxedFuture < ' static , io:: Result < ( ) > >
52
66
+ Send
53
67
+ Sync
54
68
+ ' static ,
@@ -71,12 +85,27 @@ pub mod channel {
71
85
}
72
86
}
73
87
74
- #[ derive( Debug , derive_more :: From , derive_more :: Display ) ]
88
+ #[ derive( Debug ) ]
75
89
pub enum SendError {
76
90
ReceiverClosed ,
77
91
Io ( std:: io:: Error ) ,
78
92
}
79
93
94
+ impl From < std:: io:: Error > for SendError {
95
+ fn from ( e : std:: io:: Error ) -> Self {
96
+ Self :: Io ( e)
97
+ }
98
+ }
99
+
100
+ impl std:: fmt:: Display for SendError {
101
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
102
+ match self {
103
+ SendError :: ReceiverClosed => write ! ( f, "receiver closed" ) ,
104
+ SendError :: Io ( e) => write ! ( f, "io error: {}" , e) ,
105
+ }
106
+ }
107
+ }
108
+
80
109
impl std:: error:: Error for SendError {
81
110
fn source ( & self ) -> Option < & ( dyn std:: error:: Error + ' static ) > {
82
111
match self {
@@ -99,7 +128,7 @@ pub mod channel {
99
128
100
129
pub enum Receiver < T > {
101
130
Tokio ( tokio:: sync:: oneshot:: Receiver < T > ) ,
102
- Boxed ( n0_future :: future :: Boxed < std:: io:: Result < T > > ) ,
131
+ Boxed ( crate :: BoxedFuture < ' static , std:: io:: Result < T > > ) ,
103
132
}
104
133
105
134
impl < T > Future for Receiver < T > {
@@ -159,12 +188,27 @@ pub mod channel {
159
188
( tx. into ( ) , rx. into ( ) )
160
189
}
161
190
162
- #[ derive( Debug , derive_more :: From , derive_more :: Display ) ]
191
+ #[ derive( Debug ) ]
163
192
pub enum SendError {
164
193
ReceiverClosed ,
165
194
Io ( std:: io:: Error ) ,
166
195
}
167
196
197
+ impl From < std:: io:: Error > for SendError {
198
+ fn from ( e : std:: io:: Error ) -> Self {
199
+ Self :: Io ( e)
200
+ }
201
+ }
202
+
203
+ impl std:: fmt:: Display for SendError {
204
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
205
+ match self {
206
+ SendError :: ReceiverClosed => write ! ( f, "receiver closed" ) ,
207
+ SendError :: Io ( e) => write ! ( f, "io error: {}" , e) ,
208
+ }
209
+ }
210
+ }
211
+
168
212
impl std:: error:: Error for SendError {
169
213
fn source ( & self ) -> Option < & ( dyn std:: error:: Error + ' static ) > {
170
214
match self {
@@ -329,7 +373,7 @@ impl<I: Channels<S>, S: Service> Deref for WithChannels<I, S> {
329
373
pub enum ServiceSender < M , R , S > {
330
374
Local ( LocalMpscChannel < M , S > , PhantomData < R > ) ,
331
375
#[ cfg( feature = "rpc" ) ]
332
- Remote ( quinn:: Endpoint , SocketAddr , PhantomData < ( R , S ) > ) ,
376
+ Remote ( quinn:: Endpoint , std :: net :: SocketAddr , PhantomData < ( R , S ) > ) ,
333
377
}
334
378
335
379
impl < M , R , S > From < LocalMpscChannel < M , S > > for ServiceSender < M , R , S > {
@@ -339,12 +383,15 @@ impl<M, R, S> From<LocalMpscChannel<M, S>> for ServiceSender<M, R, S> {
339
383
}
340
384
341
385
impl < M : Send + Sync + ' static , R , S : Service > ServiceSender < M , R , S > {
342
- pub async fn request ( & self ) -> anyhow :: Result < ServiceRequest < M , R , S > > {
386
+ pub async fn request ( & self ) -> io :: Result < ServiceRequest < M , R , S > > {
343
387
match self {
344
388
Self :: Local ( tx, _) => Ok ( ServiceRequest :: from ( tx. clone ( ) ) ) ,
345
389
#[ cfg( feature = "rpc" ) ]
346
390
Self :: Remote ( endpoint, addr, _) => {
347
- let connection = endpoint. connect ( * addr, "localhost" ) ?. await ?;
391
+ let connection = endpoint
392
+ . connect ( * addr, "localhost" )
393
+ . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: Other , e) ) ?
394
+ . await ?;
348
395
let ( send, recv) = connection. open_bi ( ) . await ?;
349
396
Ok ( ServiceRequest :: Remote ( rpc:: RemoteRequest :: new ( send, recv) ) )
350
397
}
@@ -371,7 +418,6 @@ impl<M, S> Clone for LocalMpscChannel<M, S> {
371
418
pub mod rpc {
372
419
use std:: { fmt:: Debug , future:: Future , io, marker:: PhantomData , pin:: Pin , sync:: Arc } ;
373
420
374
- use n0_future:: task:: AbortOnDropHandle ;
375
421
use serde:: { de:: DeserializeOwned , Serialize } ;
376
422
use smallvec:: SmallVec ;
377
423
use tokio:: task:: JoinSet ;
@@ -399,9 +445,15 @@ pub mod rpc {
399
445
}
400
446
}
401
447
402
- #[ derive( Debug , derive_more :: From ) ]
448
+ #[ derive( Debug ) ]
403
449
pub struct RemoteRead ( quinn:: RecvStream ) ;
404
450
451
+ impl RemoteRead {
452
+ pub ( crate ) fn new ( recv : quinn:: RecvStream ) -> Self {
453
+ Self ( recv)
454
+ }
455
+ }
456
+
405
457
impl < T : DeserializeOwned > From < RemoteRead > for oneshot:: Receiver < T > {
406
458
fn from ( read : RemoteRead ) -> Self {
407
459
let fut = async move {
@@ -432,9 +484,15 @@ pub mod rpc {
432
484
}
433
485
}
434
486
435
- #[ derive( Debug , derive_more :: From ) ]
487
+ #[ derive( Debug ) ]
436
488
pub struct RemoteWrite ( quinn:: SendStream ) ;
437
489
490
+ impl RemoteWrite {
491
+ pub ( crate ) fn new ( send : quinn:: SendStream ) -> Self {
492
+ Self ( send)
493
+ }
494
+ }
495
+
438
496
impl < T : RpcMessage > From < RemoteWrite > for oneshot:: Sender < T > {
439
497
fn from ( write : RemoteWrite ) -> Self {
440
498
let mut writer = write. 0 ;
@@ -526,7 +584,7 @@ pub mod rpc {
526
584
}
527
585
528
586
impl < R : Serialize , S > RemoteRequest < R , S > {
529
- pub async fn write ( self , msg : impl Into < R > ) -> anyhow :: Result < ( RemoteRead , RemoteWrite ) > {
587
+ pub async fn write ( self , msg : impl Into < R > ) -> io :: Result < ( RemoteRead , RemoteWrite ) > {
530
588
let RemoteRequest ( mut send, recv, _) = self ;
531
589
let msg = msg. into ( ) ;
532
590
let mut buf = SmallVec :: < [ u8 ; 128 ] > :: new ( ) ;
@@ -538,51 +596,51 @@ pub mod rpc {
538
596
539
597
/// Type alias for a handler fn for remote requests
540
598
pub type Handler < R > = Arc <
541
- dyn Fn ( R , RemoteRead , RemoteWrite ) -> n0_future :: future :: Boxed < anyhow :: Result < ( ) > >
599
+ dyn Fn ( R , RemoteRead , RemoteWrite ) -> crate :: BoxedFuture < ' static , io :: Result < ( ) > >
542
600
+ Send
543
601
+ Sync
544
602
+ ' static ,
545
603
> ;
546
604
547
605
/// Utility function to listen for incoming connections and handle them with the provided handler
548
- pub fn listen < R : DeserializeOwned + ' static > (
606
+ pub async fn listen < R : DeserializeOwned + ' static > (
549
607
endpoint : quinn:: Endpoint ,
550
608
handler : Handler < R > ,
551
- ) -> AbortOnDropHandle < ( ) > {
552
- let task = tokio:: spawn ( async move {
553
- let mut tasks = JoinSet :: new ( ) ;
554
- while let Some ( incoming) = endpoint. accept ( ) . await {
555
- let handler = handler. clone ( ) ;
556
- tasks. spawn ( async move {
557
- let connection = match incoming. await {
558
- Ok ( connection) => connection,
609
+ ) {
610
+ let mut tasks = JoinSet :: new ( ) ;
611
+ while let Some ( incoming) = endpoint. accept ( ) . await {
612
+ let handler = handler. clone ( ) ;
613
+ tasks. spawn ( async move {
614
+ let connection = match incoming. await {
615
+ Ok ( connection) => connection,
616
+ Err ( cause) => {
617
+ warn ! ( "failed to accept connection {cause:?}" ) ;
618
+ return io:: Result :: Ok ( ( ) ) ;
619
+ }
620
+ } ;
621
+ loop {
622
+ let ( send, mut recv) = match connection. accept_bi ( ) . await {
623
+ Ok ( ( s, r) ) => ( s, r) ,
559
624
Err ( cause) => {
560
- warn ! ( "failed to accept connection {cause:?}" ) ;
561
- return anyhow :: Ok ( ( ) ) ;
625
+ warn ! ( "failed to accept bi stream {cause:?}" ) ;
626
+ return Ok ( ( ) ) ;
562
627
}
563
628
} ;
564
- loop {
565
- let ( send, mut recv) = match connection. accept_bi ( ) . await {
566
- Ok ( ( s, r) ) => ( s, r) ,
567
- Err ( cause) => {
568
- warn ! ( "failed to accept bi stream {cause:?}" ) ;
569
- return anyhow:: Ok ( ( ) ) ;
570
- }
571
- } ;
572
- let size = recv. read_varint_u64 ( ) . await ?. ok_or_else ( || {
573
- io:: Error :: new ( io:: ErrorKind :: UnexpectedEof , "failed to read size" )
574
- } ) ?;
575
- let mut buf = vec ! [ 0 ; size as usize ] ;
576
- recv. read_exact ( & mut buf) . await ?;
577
- let msg: R = postcard:: from_bytes ( & buf) ?;
578
- let rx = RemoteRead :: from ( recv) ;
579
- let tx = RemoteWrite :: from ( send) ;
580
- handler ( msg, rx, tx) . await ?;
581
- }
582
- } ) ;
583
- }
584
- } ) ;
585
- AbortOnDropHandle :: new ( task)
629
+ let size = recv. read_varint_u64 ( ) . await ?. ok_or_else ( || {
630
+ io:: Error :: new ( io:: ErrorKind :: UnexpectedEof , "failed to read size" )
631
+ } ) ?;
632
+ let mut buf = vec ! [ 0 ; size as usize ] ;
633
+ recv. read_exact ( & mut buf)
634
+ . await
635
+ . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: UnexpectedEof , e) ) ?;
636
+ let msg: R = postcard:: from_bytes ( & buf)
637
+ . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: InvalidData , e) ) ?;
638
+ let rx = RemoteRead :: new ( recv) ;
639
+ let tx = RemoteWrite :: new ( send) ;
640
+ handler ( msg, rx, tx) . await ?;
641
+ }
642
+ } ) ;
643
+ }
586
644
}
587
645
}
588
646
@@ -600,12 +658,15 @@ impl<M, R, S> From<LocalMpscChannel<M, S>> for ServiceRequest<M, R, S> {
600
658
}
601
659
602
660
impl < M : Send + Sync + ' static , S : Service > LocalMpscChannel < M , S > {
603
- pub async fn send < T > ( & self , value : impl Into < WithChannels < T , S > > ) -> anyhow :: Result < ( ) >
661
+ pub async fn send < T > ( & self , value : impl Into < WithChannels < T , S > > ) -> io :: Result < ( ) >
604
662
where
605
663
T : Channels < S > ,
606
664
M : From < WithChannels < T , S > > ,
607
665
{
608
- self . 0 . send ( value. into ( ) . into ( ) ) . await ?;
666
+ self . 0
667
+ . send ( value. into ( ) . into ( ) )
668
+ . await
669
+ . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: Other , e) ) ?;
609
670
Ok ( ( ) )
610
671
}
611
672
}
0 commit comments