@@ -13,7 +13,6 @@ use tokio::{
1313 select,
1414 sync:: mpsc:: { channel, Receiver , Sender } ,
1515} ;
16- use tokio_rustls:: server:: TlsStream ;
1716use tokio_rustls:: TlsAcceptor ;
1817use tokio_yamux:: { Session , StreamHandle } ;
1918
@@ -24,7 +23,7 @@ use super::{AgentListener, AgentListenerEvent};
2423
2524pub type TunnelTlsStream = StreamHandle ;
2625pub struct AgentTlsListener < VALIDATE , HANDSHAKE : ClusterRequest > {
27- tls_acceptor : TlsAcceptor ,
26+ tls_acceptor : Arc < TlsAcceptor > ,
2827 validate : Arc < VALIDATE > ,
2928 listener : TcpListener ,
3029 internal_tx : Sender < AgentListenerEvent < HANDSHAKE :: Context , TunnelTlsStream > > ,
@@ -34,12 +33,13 @@ pub struct AgentTlsListener<VALIDATE, HANDSHAKE: ClusterRequest> {
3433
3534impl < VALIDATE , HANDSHAKE : ClusterRequest > AgentTlsListener < VALIDATE , HANDSHAKE > {
3635 pub async fn new ( addr : SocketAddr , validate : VALIDATE , key : PrivatePkcs8KeyDer < ' static > , cert : CertificateDer < ' static > ) -> anyhow:: Result < Self > {
36+ log:: info!( "[AgentTls] starting with addr {addr}" ) ;
3737 let ( internal_tx, internal_rx) = channel ( 10 ) ;
3838 let config = rustls:: ServerConfig :: builder ( ) . with_no_client_auth ( ) . with_single_cert ( vec ! [ cert] , PrivateKeyDer :: Pkcs8 ( key) ) ?;
3939 let tls_acceptor = TlsAcceptor :: from ( Arc :: new ( config) ) ;
4040
4141 Ok ( Self {
42- tls_acceptor,
42+ tls_acceptor : Arc :: new ( tls_acceptor ) ,
4343 listener : TcpListener :: bind ( addr) . await ?,
4444 internal_tx,
4545 internal_rx,
@@ -57,8 +57,14 @@ impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'sta
5757 event = self . internal_rx. recv( ) => break Ok ( event. expect( "should receive event from internal channel" ) ) ,
5858 } ;
5959
60- let tls_stream = self . tls_acceptor . accept ( stream) . await ?;
61- tokio:: spawn ( run_connection ( self . validate . clone ( ) , tls_stream, remote, self . internal_tx . clone ( ) ) ) ;
60+ let tls_acceptor = self . tls_acceptor . clone ( ) ;
61+ let validate = self . validate . clone ( ) ;
62+ let internal_tx = self . internal_tx . clone ( ) ;
63+ tokio:: spawn ( async move {
64+ if let Err ( e) = run_connection ( validate, tls_acceptor, stream, remote, internal_tx) . await {
65+ log:: error!( "[AgentTls] connection {remote} error {e:?}" ) ;
66+ }
67+ } ) ;
6268 }
6369 }
6470
@@ -67,25 +73,28 @@ impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'sta
6773
6874async fn run_connection < VALIDATE : ClusterValidator < REQ > , REQ : ClusterRequest > (
6975 validate : Arc < VALIDATE > ,
70- mut in_stream : TlsStream < TcpStream > ,
76+ tls_acceptor : Arc < TlsAcceptor > ,
77+ stream : TcpStream ,
7178 remote : SocketAddr ,
7279 internal_tx : Sender < AgentListenerEvent < REQ :: Context , TunnelTlsStream > > ,
7380) -> anyhow:: Result < ( ) > {
7481 let started = Instant :: now ( ) ;
75- log:: info!( "[AgentTcp] new connection from {}" , remote) ;
82+ log:: info!( "[AgentTls] new connection from {remote}, handshaking tls" ) ;
83+ let mut in_stream = tls_acceptor. accept ( stream) . await ?;
84+ log:: info!( "[AgentTls] new connection from {remote}, handshake tls success" ) ;
7685
7786 let mut buf = [ 0u8 ; 4096 ] ;
7887 let buf_len = in_stream. read ( & mut buf) . await ?;
7988
80- log:: info!( "[AgentTcp ] new connection got handhsake data {buf_len} bytes" ) ;
89+ log:: info!( "[AgentTls ] new connection from {remote} got handhsake data {buf_len} bytes" ) ;
8190
8291 let req = validate. validate_connect_req ( & buf[ 0 ..buf_len] ) ?;
8392 let domain = validate. generate_domain ( & req) ?;
8493 let agent_id = AgentId :: try_from_domain ( & domain) ?;
8594 let session_id = AgentSessionId :: rand ( ) ;
8695 let agent_ctx = req. context ( ) ;
8796
88- log:: info!( "[AgentTcp ] new connection validated with domain {domain} agent_id: {agent_id}, session uuid: {session_id}" ) ;
97+ log:: info!( "[AgentTls ] new connection from {remote} validated with domain {domain} agent_id: {agent_id}, session uuid: {session_id}" ) ;
8998
9099 let res_buf = validate. sign_response_res ( & req, None ) ;
91100 in_stream. write_all ( & res_buf) . await ?;
@@ -96,7 +105,7 @@ async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ: ClusterRequest>(
96105 . await
97106 . expect ( "should send to main loop" ) ;
98107
99- log:: info!( "[AgentTcp ] new connection {agent_id} {session_id} started loop" ) ;
108+ log:: info!( "[AgentTls ] new connection {agent_id} {session_id} started loop" ) ;
100109 let mut session = Session :: new_client ( in_stream, Default :: default ( ) ) ;
101110 histogram ! ( METRICS_AGENT_HISTOGRAM ) . record ( started. elapsed ( ) . as_millis ( ) as f32 / 1000.0 ) ;
102111
@@ -105,17 +114,17 @@ async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ: ClusterRequest>(
105114 control = control_rx. recv( ) => match control {
106115 Some ( control) => match control {
107116 AgentSessionControl :: CreateStream ( tx) => {
108- log:: info!( "[AgentTcp ] agent {agent_id} {session_id} create stream request" ) ;
117+ log:: info!( "[AgentTls ] agent {agent_id} {session_id} create stream request" ) ;
109118 match session. open_stream( ) {
110119 Ok ( stream) => {
111- log:: info!( "[AgentTcp ] agent {agent_id} {session_id} created stream" ) ;
120+ log:: info!( "[AgentTls ] agent {agent_id} {session_id} created stream" ) ;
112121 if let Err ( _e) = tx. send( Ok ( stream) ) {
113- log:: error!( "[AgentTcp ] agent {agent_id} {session_id} send created stream error" ) ;
122+ log:: error!( "[AgentTls ] agent {agent_id} {session_id} send created stream error" ) ;
114123 }
115124 } ,
116125 Err ( err) => {
117126 if let Err ( _e) = tx. send( Err ( err. into( ) ) ) {
118- log:: error!( "[AgentTcp ] agent {agent_id} {session_id} send create stream's error, may be internal channel failed" ) ;
127+ log:: error!( "[AgentTls ] agent {agent_id} {session_id} send create stream's error, may be internal channel failed" ) ;
119128 }
120129 } ,
121130 }
@@ -134,18 +143,18 @@ async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ: ClusterRequest>(
134143 } ) ;
135144 } ,
136145 Some ( Err ( err) ) => {
137- log:: error!( "[AgentTcp ] agent {agent_id} {session_id} Tcp connection error {err:?}" ) ;
146+ log:: error!( "[AgentTls ] agent {agent_id} {session_id} Tcp connection error {err:?}" ) ;
138147 break ;
139148 } ,
140149 None => {
141- log:: error!( "[AgentTcp ] agent {agent_id} {session_id} Tcp connection broken with None" ) ;
150+ log:: error!( "[AgentTls ] agent {agent_id} {session_id} Tcp connection broken with None" ) ;
142151 break ;
143152 }
144153 }
145154 }
146155 }
147156
148- log:: info!( "[AgentTcp ] agent {agent_id} {session_id} stopped loop" ) ;
157+ log:: info!( "[AgentTls ] agent {agent_id} {session_id} stopped loop" ) ;
149158
150159 internal_tx. send ( AgentListenerEvent :: Disconnected ( agent_id, session_id) ) . await . expect ( "should send to main loop" ) ;
151160
0 commit comments