Upgrade to the latest version of tokio

Also simplify the code by dropping our own async resolution, since the newer
tokio can do that for us.
This commit is contained in:
Mahmoud Al-Qudsi 2022-06-30 16:51:49 -05:00
parent 677b6cd33a
commit b2c2876d03
3 changed files with 162 additions and 638 deletions

733
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -12,8 +12,8 @@ license = "MIT"
edition = "2018" edition = "2018"
[dependencies] [dependencies]
futures = "0.3.6" futures = "0.3.21"
getopts = "0.2.21" getopts = "0.2.21"
rand = "0.7.3" rand = "0.8.5"
tokio = { version = "0.3.1", features = [ "io-util", "net", "rt-multi-thread", "parking_lot", "stream", "macros", ] } tokio = { version = "1.19.2", features = [ "io-util", "net", "rt-multi-thread", "parking_lot", "macros", ] }
trust-dns-resolver = "0.19.5" # trust-dns-resolver = "0.19.5"

View File

@ -10,7 +10,7 @@ static DEBUG: AtomicBool = AtomicBool::new(false);
fn print_usage(program: &str, opts: Options) { fn print_usage(program: &str, opts: Options) {
let program_path = std::path::PathBuf::from(program); let program_path = std::path::PathBuf::from(program);
let program_name = program_path.file_stem().unwrap().to_string_lossy(); let program_name = program_path.file_stem().unwrap().to_string_lossy();
let brief = format!("Usage: {} [-b BIND_ADDR] -h REMOTE_HOST -r REMOTE_PORT [-l LOCAL_PORT]", let brief = format!("Usage: {} REMOTE_HOST:PORT [-b BIND_ADDR] [-l LOCAL_PORT]",
program_name); program_name);
print!("{}", opts.usage(&brief)); print!("{}", opts.usage(&brief));
} }
@ -21,14 +21,6 @@ async fn main() -> Result<(), BoxedError> {
let program = args[0].clone(); let program = args[0].clone();
let mut opts = Options::new(); let mut opts = Options::new();
opts.reqopt("h",
"remote-host",
"The remote host (ip or host name) to which packets will be forwarded",
"REMOTE_HOST");
opts.reqopt("r",
"remote-port",
"The remote port to which TCP packets should be forwarded",
"REMOTE_PORT");
opts.optopt("b", opts.optopt("b",
"bind", "bind",
"The address on which to listen for incoming requests, defaulting to localhost", "The address on which to listen for incoming requests, defaulting to localhost",
@ -47,23 +39,34 @@ async fn main() -> Result<(), BoxedError> {
std::process::exit(-1); std::process::exit(-1);
} }
}; };
let remote = match matches.free.len() {
1 => matches.free[0].as_str(),
_ => {
print_usage(&program, opts);
std::process::exit(-1);
},
};
if !remote.contains(':') {
eprintln!("A remote port is required (REMOTE_ADDR:PORT)");
std::process::exit(-1);
}
DEBUG.store(matches.opt_present("d"), Ordering::Relaxed); DEBUG.store(matches.opt_present("d"), Ordering::Relaxed);
// let local_port: i32 = matches.opt_str("l").unwrap_or("0".to_string()).parse()?; // let local_port: i32 = matches.opt_str("l").unwrap_or("0".to_string()).parse()?;
let local_port: i32 = matches.opt_str("l").map(|s| s.parse()).unwrap_or(Ok(0))?; let local_port: i32 = matches.opt_str("l").map(|s| s.parse()).unwrap_or(Ok(0))?;
let remote_port: i32 = matches.opt_str("r").unwrap().parse()?;
let remote_host = matches.opt_str("h").unwrap();
let bind_addr = match matches.opt_str("b") { let bind_addr = match matches.opt_str("b") {
Some(addr) => addr, Some(addr) => addr,
None => "127.0.0.1".to_owned(), None => "127.0.0.1".to_owned(),
}; };
forward(&bind_addr, local_port, &remote_host, remote_port).await forward(&bind_addr, local_port, remote).await
} }
async fn forward(bind_ip: &str, local_port: i32, remote_host: &str, remote_port: i32) -> Result<(), BoxedError> { async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), BoxedError> {
// Listen on the specified IP and port // Listen on the specified IP and port
let bind_addr = if bind_ip.contains(':') { let bind_addr = if !bind_ip.starts_with('[') && bind_ip.contains(':') {
// Correctly format for IPv6 usage
format!("[{}]:{}", bind_ip, local_port) format!("[{}]:{}", bind_ip, local_port)
} else { } else {
format!("{}:{}", bind_ip, local_port) format!("{}:{}", bind_ip, local_port)
@ -74,39 +77,23 @@ async fn forward(bind_ip: &str, local_port: i32, remote_host: &str, remote_port:
// We have either been provided an IP address or a host name. // We have either been provided an IP address or a host name.
// Instead of trying to check its format, just trying creating a SocketAddr from it. // Instead of trying to check its format, just trying creating a SocketAddr from it.
let parse_result = format!("{}:{}", remote_host, remote_port).parse::<std::net::SocketAddr>(); // let parse_result = remote.parse::<SocketAddr>();
let remote_addr = match parse_result { let remote = std::sync::Arc::new(remote.to_string());
Ok(s) => s,
Err(_) => {
// It's a hostname; we're going to need to resolve it.
// Create an async dns resolver
use trust_dns_resolver::TokioAsyncResolver;
use trust_dns_resolver::config::*;
let resolver = TokioAsyncResolver::tokio(
ResolverConfig::default(),
ResolverOpts::default())
.await.expect("Failed to create DNS resolver");
let resolutions = resolver.lookup_ip(remote_host).await.expect("Failed to resolve server IP address!");
let remote_addr = resolutions.iter().nth(1).expect("Failed to resolve server IP address!");
println!("Resolved {} to {}", remote_host, remote_addr);
format!("{}:{}", remote_addr, remote_port).parse()?
},
};
loop { loop {
let remote = remote.clone();
let (mut client, client_addr) = listener.accept().await?; let (mut client, client_addr) = listener.accept().await?;
tokio::spawn(async move { tokio::spawn(async move {
println!("New connection from {}", client_addr); println!("New connection from {}", client_addr);
// Establish connection to upstream for each incoming client connection // Establish connection to upstream for each incoming client connection
let mut remote = TcpStream::connect(&remote_addr).await?; let mut remote = TcpStream::connect(remote.as_str()).await?;
let (mut client_recv, mut client_send) = client.split(); let (mut client_recv, mut client_send) = client.split();
let (mut remote_recv, mut remote_send) = remote.split(); let (mut remote_recv, mut remote_send) = remote.split();
// This version of the join! macro does not require that the futures are fused and
// pinned prior to passing to join.
let (remote_bytes_copied, client_bytes_copied) = join!( let (remote_bytes_copied, client_bytes_copied) = join!(
tokio::io::copy(&mut remote_recv, &mut client_send), tokio::io::copy(&mut remote_recv, &mut client_send),
tokio::io::copy(&mut client_recv, &mut remote_send), tokio::io::copy(&mut client_recv, &mut remote_send),