Abort the other half of a connection regardless of Ok/Err

Previously we were only sending an abort signal on successful completion of one
half of a proxied connection and relying on errors to also affect the other
half.

Now we explicitly send an abort after completing with one half of a connection
regardless of whether we did that successfully or with an error.
This commit is contained in:
Mahmoud Al-Qudsi 2022-06-30 18:18:33 -05:00
parent 0164ef836a
commit d520400413
1 changed files with 7 additions and 6 deletions

View File

@ -1,3 +1,4 @@
use futures::FutureExt;
use getopts::Options; use getopts::Options;
use std::env; use std::env;
use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::atomic::{AtomicBool, Ordering};
@ -90,13 +91,12 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
async fn copy_with_abort<R, W>( async fn copy_with_abort<R, W>(
read: &mut R, read: &mut R,
write: &mut W, write: &mut W,
cancel: &broadcast::Sender<()>, mut abort: broadcast::Receiver<()>,
) -> tokio::io::Result<usize> ) -> tokio::io::Result<usize>
where where
R: tokio::io::AsyncRead + Unpin, R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin, W: tokio::io::AsyncWrite + Unpin,
{ {
let mut abort = cancel.subscribe();
let mut copied = 0; let mut copied = 0;
let mut buf = [0u8; 1024]; let mut buf = [0u8; 1024];
loop { loop {
@ -108,7 +108,7 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
bytes_read = result?; bytes_read = result?;
}, },
_ = abort.recv() => { _ = abort.recv() => {
return Ok(copied); break;
} }
} }
@ -120,7 +120,6 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
copied += bytes_read; copied += bytes_read;
} }
let _ = cancel.send(());
Ok(copied) Ok(copied)
} }
@ -138,8 +137,10 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
let (cancel, _) = broadcast::channel::<()>(1); let (cancel, _) = broadcast::channel::<()>(1);
let (remote_copied, client_copied) = tokio::join! { let (remote_copied, client_copied) = tokio::join! {
copy_with_abort(&mut remote_read, &mut client_write, &cancel), copy_with_abort(&mut remote_read, &mut client_write, cancel.subscribe())
copy_with_abort(&mut client_read, &mut remote_write, &cancel), .then(|r| { let _ = cancel.send(()); async { r } }),
copy_with_abort(&mut client_read, &mut remote_write, cancel.subscribe())
.then(|r| { let _ = cancel.send(()); async { r } }),
}; };
match client_copied { match client_copied {