mirror of https://github.com/neosmart/tcpproxy.git
Aggressively terminate half-closed connections
Previously, if the client closed after receiving a reply, a keepalive connection to the server would stick around until the timeout, even though we will never reuse it.
This commit is contained in:
parent
b2c2876d03
commit
0164ef836a
|
@ -15,5 +15,5 @@ edition = "2018"
|
|||
futures = "0.3.21"
|
||||
getopts = "0.2.21"
|
||||
rand = "0.8.5"
|
||||
tokio = { version = "1.19.2", features = [ "io-util", "net", "rt-multi-thread", "parking_lot", "macros", ] }
|
||||
tokio = { version = "1.19.2", features = [ "io-util", "net", "rt-multi-thread", "parking_lot", "macros", "sync" ] }
|
||||
# trust-dns-resolver = "0.19.5"
|
||||
|
|
114
src/main.rs
114
src/main.rs
|
@ -1,8 +1,9 @@
|
|||
use getopts::Options;
|
||||
use std::env;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use tokio::join;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
use tokio::sync::broadcast;
|
||||
|
||||
type BoxedError = Box<dyn std::error::Error + Sync + Send + 'static>;
|
||||
static DEBUG: AtomicBool = AtomicBool::new(false);
|
||||
|
@ -10,8 +11,10 @@ static DEBUG: AtomicBool = AtomicBool::new(false);
|
|||
fn print_usage(program: &str, opts: Options) {
|
||||
let program_path = std::path::PathBuf::from(program);
|
||||
let program_name = program_path.file_stem().unwrap().to_string_lossy();
|
||||
let brief = format!("Usage: {} REMOTE_HOST:PORT [-b BIND_ADDR] [-l LOCAL_PORT]",
|
||||
program_name);
|
||||
let brief = format!(
|
||||
"Usage: {} REMOTE_HOST:PORT [-b BIND_ADDR] [-l LOCAL_PORT]",
|
||||
program_name
|
||||
);
|
||||
print!("{}", opts.usage(&brief));
|
||||
}
|
||||
|
||||
|
@ -21,14 +24,18 @@ async fn main() -> Result<(), BoxedError> {
|
|||
let program = args[0].clone();
|
||||
|
||||
let mut opts = Options::new();
|
||||
opts.optopt("b",
|
||||
opts.optopt(
|
||||
"b",
|
||||
"bind",
|
||||
"The address on which to listen for incoming requests, defaulting to localhost",
|
||||
"BIND_ADDR");
|
||||
opts.optopt("l",
|
||||
"BIND_ADDR",
|
||||
);
|
||||
opts.optopt(
|
||||
"l",
|
||||
"local-port",
|
||||
"The local port to which tcpproxy should bind to, randomly chosen otherwise",
|
||||
"LOCAL_PORT");
|
||||
"LOCAL_PORT",
|
||||
);
|
||||
opts.optflag("d", "debug", "Enable debug mode");
|
||||
|
||||
let matches = match opts.parse(&args[1..]) {
|
||||
|
@ -44,7 +51,7 @@ async fn main() -> Result<(), BoxedError> {
|
|||
_ => {
|
||||
print_usage(&program, opts);
|
||||
std::process::exit(-1);
|
||||
},
|
||||
}
|
||||
};
|
||||
|
||||
if !remote.contains(':') {
|
||||
|
@ -71,15 +78,52 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
|
|||
} else {
|
||||
format!("{}:{}", bind_ip, local_port)
|
||||
};
|
||||
let bind_sock = bind_addr.parse::<std::net::SocketAddr>().expect("Failed to parse bind address");
|
||||
let bind_sock = bind_addr
|
||||
.parse::<std::net::SocketAddr>()
|
||||
.expect("Failed to parse bind address");
|
||||
let listener = TcpListener::bind(&bind_sock).await?;
|
||||
println!("Listening on {}", listener.local_addr().unwrap());
|
||||
|
||||
// We have either been provided an IP address or a host name.
|
||||
// Instead of trying to check its format, just trying creating a SocketAddr from it.
|
||||
// let parse_result = remote.parse::<SocketAddr>();
|
||||
let remote = std::sync::Arc::new(remote.to_string());
|
||||
|
||||
async fn copy_with_abort<R, W>(
|
||||
read: &mut R,
|
||||
write: &mut W,
|
||||
cancel: &broadcast::Sender<()>,
|
||||
) -> tokio::io::Result<usize>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
W: tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let mut abort = cancel.subscribe();
|
||||
let mut copied = 0;
|
||||
let mut buf = [0u8; 1024];
|
||||
loop {
|
||||
let bytes_read;
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
result = read.read(&mut buf) => {
|
||||
bytes_read = result?;
|
||||
},
|
||||
_ = abort.recv() => {
|
||||
return Ok(copied);
|
||||
}
|
||||
}
|
||||
|
||||
if bytes_read == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
write.write_all(&buf[0..bytes_read]).await?;
|
||||
copied += bytes_read;
|
||||
}
|
||||
|
||||
let _ = cancel.send(());
|
||||
Ok(copied)
|
||||
}
|
||||
|
||||
loop {
|
||||
let remote = remote.clone();
|
||||
let (mut client, client_addr) = listener.accept().await?;
|
||||
|
@ -89,42 +133,48 @@ async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), Box
|
|||
|
||||
// Establish connection to upstream for each incoming client connection
|
||||
let mut remote = TcpStream::connect(remote.as_str()).await?;
|
||||
let (mut client_recv, mut client_send) = client.split();
|
||||
let (mut remote_recv, mut remote_send) = remote.split();
|
||||
let (mut client_read, mut client_write) = client.split();
|
||||
let (mut remote_read, mut remote_write) = remote.split();
|
||||
|
||||
// This version of the join! macro does not require that the futures are fused and
|
||||
// pinned prior to passing to join.
|
||||
let (remote_bytes_copied, client_bytes_copied) = join!(
|
||||
tokio::io::copy(&mut remote_recv, &mut client_send),
|
||||
tokio::io::copy(&mut client_recv, &mut remote_send),
|
||||
);
|
||||
let (cancel, _) = broadcast::channel::<()>(1);
|
||||
let (remote_copied, client_copied) = tokio::join! {
|
||||
copy_with_abort(&mut remote_read, &mut client_write, &cancel),
|
||||
copy_with_abort(&mut client_read, &mut remote_write, &cancel),
|
||||
};
|
||||
|
||||
match remote_bytes_copied {
|
||||
match client_copied {
|
||||
Ok(count) => {
|
||||
if DEBUG.load(Ordering::Relaxed) {
|
||||
eprintln!("Transferred {} bytes from remote client {} to upstream server",
|
||||
count, client_addr);
|
||||
eprintln!(
|
||||
"Transferred {} bytes from remote client {} to upstream server",
|
||||
count, client_addr
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("Error writing from remote client {} to upstream server!",
|
||||
client_addr);
|
||||
eprintln!("{:?}", err);
|
||||
eprintln!(
|
||||
"Error writing bytes from remote client {} to upstream server",
|
||||
client_addr
|
||||
);
|
||||
eprintln!("{}", err);
|
||||
}
|
||||
};
|
||||
|
||||
match client_bytes_copied {
|
||||
match remote_copied {
|
||||
Ok(count) => {
|
||||
if DEBUG.load(Ordering::Relaxed) {
|
||||
eprintln!("Transferred {} bytes from upstream server to remote client {}",
|
||||
count, client_addr);
|
||||
eprintln!(
|
||||
"Transferred {} bytes from upstream server to remote client {}",
|
||||
count, client_addr
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("Error writing bytes from upstream server to remote client {}",
|
||||
client_addr);
|
||||
eprintln!("{:?}", err);
|
||||
eprintln!(
|
||||
"Error writing from upstream server to remote client {}!",
|
||||
client_addr
|
||||
);
|
||||
eprintln!("{}", err);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
Loading…
Reference in New Issue