Abort the other half of a connection regardless of Ok/Err

Previously we were only sending an abort signal on successful completion of one
half of a proxied connection and relying on errors to also affect the other
half.

Now we explicitly send an abort after completing with one half of a connection
regardless of whether we did that successfully or with an error.
This commit is contained in:
Mahmoud Al-Qudsi 2022-06-30 18:18:33 -05:00
parent 0164ef836a
commit d520400413
1 changed files with 7 additions and 6 deletions

View File

@ -1,3 +1,4 @@
use futures::FutureExt;
use getopts::Options;
use std::env;
use std::sync::atomic::{AtomicBool, Ordering};
@ -90,13 +91,12 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
async fn copy_with_abort<R, W>(
read: &mut R,
write: &mut W,
cancel: &broadcast::Sender<()>,
mut abort: broadcast::Receiver<()>,
) -> tokio::io::Result<usize>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
let mut abort = cancel.subscribe();
let mut copied = 0;
let mut buf = [0u8; 1024];
loop {
@ -108,7 +108,7 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
bytes_read = result?;
},
_ = abort.recv() => {
return Ok(copied);
break;
}
}
@ -120,7 +120,6 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
copied += bytes_read;
}
let _ = cancel.send(());
Ok(copied)
}
@ -138,8 +137,10 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
let (cancel, _) = broadcast::channel::<()>(1);
let (remote_copied, client_copied) = tokio::join! {
copy_with_abort(&mut remote_read, &mut client_write, &cancel),
copy_with_abort(&mut client_read, &mut remote_write, &cancel),
copy_with_abort(&mut remote_read, &mut client_write, cancel.subscribe())
.then(|r| { let _ = cancel.send(()); async { r } }),
copy_with_abort(&mut client_read, &mut remote_write, cancel.subscribe())
.then(|r| { let _ = cancel.send(()); async { r } }),
};
match client_copied {