Aggressively terminate half-closed connections

Previously, if the client closed after receiving a reply, a keepalive connection
to the server would stick around until the timeout, even though we will never
reuse it.
This commit is contained in:
Mahmoud Al-Qudsi 2022-06-30 18:09:29 -05:00
parent b2c2876d03
commit 0164ef836a
2 changed files with 104 additions and 54 deletions

View File

@ -15,5 +15,5 @@ edition = "2018"
futures = "0.3.21"
getopts = "0.2.21"
rand = "0.8.5"
tokio = { version = "1.19.2", features = [ "io-util", "net", "rt-multi-thread", "parking_lot", "macros", ] }
tokio = { version = "1.19.2", features = [ "io-util", "net", "rt-multi-thread", "parking_lot", "macros", "sync" ] }
# trust-dns-resolver = "0.19.5"

View File

@ -1,8 +1,9 @@
use getopts::Options;
use std::env;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::join;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::broadcast;
type BoxedError = Box<dyn std::error::Error + Sync + Send + 'static>;
static DEBUG: AtomicBool = AtomicBool::new(false);
@ -10,8 +11,10 @@ static DEBUG: AtomicBool = AtomicBool::new(false);
fn print_usage(program: &str, opts: Options) {
let program_path = std::path::PathBuf::from(program);
let program_name = program_path.file_stem().unwrap().to_string_lossy();
let brief = format!("Usage: {} REMOTE_HOST:PORT [-b BIND_ADDR] [-l LOCAL_PORT]",
program_name);
let brief = format!(
"Usage: {} REMOTE_HOST:PORT [-b BIND_ADDR] [-l LOCAL_PORT]",
program_name
);
print!("{}", opts.usage(&brief));
}
@ -21,14 +24,18 @@ async fn main() -> Result<(), BoxedError> {
let program = args[0].clone();
let mut opts = Options::new();
opts.optopt("b",
"bind",
"The address on which to listen for incoming requests, defaulting to localhost",
"BIND_ADDR");
opts.optopt("l",
"local-port",
"The local port to which tcpproxy should bind to, randomly chosen otherwise",
"LOCAL_PORT");
opts.optopt(
"b",
"bind",
"The address on which to listen for incoming requests, defaulting to localhost",
"BIND_ADDR",
);
opts.optopt(
"l",
"local-port",
"The local port to which tcpproxy should bind to, randomly chosen otherwise",
"LOCAL_PORT",
);
opts.optflag("d", "debug", "Enable debug mode");
let matches = match opts.parse(&args[1..]) {
@ -44,7 +51,7 @@ async fn main() -> Result<(), BoxedError> {
_ => {
print_usage(&program, opts);
std::process::exit(-1);
},
}
};
if !remote.contains(':') {
@ -71,65 +78,108 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
} else {
format!("{}:{}", bind_ip, local_port)
};
let bind_sock = bind_addr.parse::<std::net::SocketAddr>().expect("Failed to parse bind address");
let bind_sock = bind_addr
.parse::<std::net::SocketAddr>()
.expect("Failed to parse bind address");
let listener = TcpListener::bind(&bind_sock).await?;
println!("Listening on {}", listener.local_addr().unwrap());
// We have either been provided an IP address or a host name.
// Instead of trying to check its format, just trying creating a SocketAddr from it.
// let parse_result = remote.parse::<SocketAddr>();
let remote = std::sync::Arc::new(remote.to_string());
async fn copy_with_abort<R, W>(
read: &mut R,
write: &mut W,
cancel: &broadcast::Sender<()>,
) -> tokio::io::Result<usize>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
let mut abort = cancel.subscribe();
let mut copied = 0;
let mut buf = [0u8; 1024];
loop {
let bytes_read;
tokio::select! {
biased;
result = read.read(&mut buf) => {
bytes_read = result?;
},
_ = abort.recv() => {
return Ok(copied);
}
}
if bytes_read == 0 {
break;
}
write.write_all(&buf[0..bytes_read]).await?;
copied += bytes_read;
}
let _ = cancel.send(());
Ok(copied)
}
loop {
let remote = remote.clone();
let (mut client, client_addr) = listener.accept().await?;
tokio::spawn(async move {
println!("New connection from {}", client_addr);
println!("New connection from {}", client_addr);
// Establish connection to upstream for each incoming client connection
let mut remote = TcpStream::connect(remote.as_str()).await?;
let (mut client_recv, mut client_send) = client.split();
let (mut remote_recv, mut remote_send) = remote.split();
// Establish connection to upstream for each incoming client connection
let mut remote = TcpStream::connect(remote.as_str()).await?;
let (mut client_read, mut client_write) = client.split();
let (mut remote_read, mut remote_write) = remote.split();
// This version of the join! macro does not require that the futures are fused and
// pinned prior to passing to join.
let (remote_bytes_copied, client_bytes_copied) = join!(
tokio::io::copy(&mut remote_recv, &mut client_send),
tokio::io::copy(&mut client_recv, &mut remote_send),
);
match remote_bytes_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!("Transferred {} bytes from remote client {} to upstream server",
count, client_addr);
}
let (cancel, _) = broadcast::channel::<()>(1);
let (remote_copied, client_copied) = tokio::join! {
copy_with_abort(&mut remote_read, &mut client_write, &cancel),
copy_with_abort(&mut client_read, &mut remote_write, &cancel),
};
match client_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!(
"Transferred {} bytes from remote client {} to upstream server",
count, client_addr
);
}
Err(err) => {
eprintln!("Error writing from remote client {} to upstream server!",
client_addr);
eprintln!("{:?}", err);
}
};
}
Err(err) => {
eprintln!(
"Error writing bytes from remote client {} to upstream server",
client_addr
);
eprintln!("{}", err);
}
};
match client_bytes_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!("Transferred {} bytes from upstream server to remote client {}",
count, client_addr);
}
match remote_copied {
Ok(count) => {
if DEBUG.load(Ordering::Relaxed) {
eprintln!(
"Transferred {} bytes from upstream server to remote client {}",
count, client_addr
);
}
Err(err) => {
eprintln!("Error writing bytes from upstream server to remote client {}",
client_addr);
eprintln!("{:?}", err);
}
};
}
Err(err) => {
eprintln!(
"Error writing from upstream server to remote client {}!",
client_addr
);
eprintln!("{}", err);
}
};
let r: Result<(), BoxedError> = Ok(());
r
let r: Result<(), BoxedError> = Ok(());
r
});
}
}