snowstorm/net/protocols/
tunnel.rs

1use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
2
3/// Creates a tunnel between the given [`Stream`] and [`TcpStream`].
4///
5/// On success, the total number of bytes sent and received is returned.
6pub async fn tcp_tunnel<D, U>(stream: D, upstream: U) -> Result<(u64, u64), std::io::Error>
7where
8    U: AsyncRead + AsyncWrite + Unpin + Send + 'static,
9    D: AsyncRead + AsyncWrite + Unpin + Send + 'static,
10{
11    let (mut upstream_read, mut upstream_write) = tokio::io::split(upstream);
12    let (mut stream_read, mut stream_write) = tokio::io::split(stream);
13
14    // copy data from upstream to stream
15    let upstream_to_stream = tokio::spawn(async move {
16        let result = tokio::io::copy(&mut upstream_read, &mut stream_write).await;
17        let _ = stream_write.shutdown().await;
18
19        match result {
20            Ok(bytes_copied) => Ok(bytes_copied),
21            Err(err) => Err(err),
22        }
23    });
24
25    // copy data from stream to upstream
26    let stream_to_upstream = tokio::spawn(async move {
27        let result = tokio::io::copy(&mut stream_read, &mut upstream_write).await;
28
29        let _ = upstream_write.shutdown().await;
30
31        match result {
32            Ok(bytes_copied) => Ok(bytes_copied),
33            Err(err) => Err(err),
34        }
35    });
36
37    let (upstream_to_stream, stream_to_upstream) =
38        tokio::join!(upstream_to_stream, stream_to_upstream);
39
40    match (upstream_to_stream?, stream_to_upstream?) {
41        (Ok(upstream_bytes), Ok(stream_bytes)) => Ok((upstream_bytes, stream_bytes)),
42        (Err(err), _) | (_, Err(err)) => Err(err),
43    }
44}