mirror of https://github.com/neosmart/tcpproxy.git
Upgrade to the latest version of tokio
Also simplify the code by dropping our own async resolution, since the newer tokio can do that for us.
This commit is contained in:
parent
677b6cd33a
commit
b2c2876d03
File diff suppressed because it is too large
Load Diff
|
@ -12,8 +12,8 @@ license = "MIT"
|
||||||
edition = "2018"
|
edition = "2018"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
futures = "0.3.6"
|
futures = "0.3.21"
|
||||||
getopts = "0.2.21"
|
getopts = "0.2.21"
|
||||||
rand = "0.7.3"
|
rand = "0.8.5"
|
||||||
tokio = { version = "0.3.1", features = [ "io-util", "net", "rt-multi-thread", "parking_lot", "stream", "macros", ] }
|
tokio = { version = "1.19.2", features = [ "io-util", "net", "rt-multi-thread", "parking_lot", "macros", ] }
|
||||||
trust-dns-resolver = "0.19.5"
|
# trust-dns-resolver = "0.19.5"
|
||||||
|
|
59
src/main.rs
59
src/main.rs
|
@ -10,7 +10,7 @@ static DEBUG: AtomicBool = AtomicBool::new(false);
|
||||||
fn print_usage(program: &str, opts: Options) {
|
fn print_usage(program: &str, opts: Options) {
|
||||||
let program_path = std::path::PathBuf::from(program);
|
let program_path = std::path::PathBuf::from(program);
|
||||||
let program_name = program_path.file_stem().unwrap().to_string_lossy();
|
let program_name = program_path.file_stem().unwrap().to_string_lossy();
|
||||||
let brief = format!("Usage: {} [-b BIND_ADDR] -h REMOTE_HOST -r REMOTE_PORT [-l LOCAL_PORT]",
|
let brief = format!("Usage: {} REMOTE_HOST:PORT [-b BIND_ADDR] [-l LOCAL_PORT]",
|
||||||
program_name);
|
program_name);
|
||||||
print!("{}", opts.usage(&brief));
|
print!("{}", opts.usage(&brief));
|
||||||
}
|
}
|
||||||
|
@ -21,14 +21,6 @@ async fn main() -> Result<(), BoxedError> {
|
||||||
let program = args[0].clone();
|
let program = args[0].clone();
|
||||||
|
|
||||||
let mut opts = Options::new();
|
let mut opts = Options::new();
|
||||||
opts.reqopt("h",
|
|
||||||
"remote-host",
|
|
||||||
"The remote host (ip or host name) to which packets will be forwarded",
|
|
||||||
"REMOTE_HOST");
|
|
||||||
opts.reqopt("r",
|
|
||||||
"remote-port",
|
|
||||||
"The remote port to which TCP packets should be forwarded",
|
|
||||||
"REMOTE_PORT");
|
|
||||||
opts.optopt("b",
|
opts.optopt("b",
|
||||||
"bind",
|
"bind",
|
||||||
"The address on which to listen for incoming requests, defaulting to localhost",
|
"The address on which to listen for incoming requests, defaulting to localhost",
|
||||||
|
@ -47,23 +39,34 @@ async fn main() -> Result<(), BoxedError> {
|
||||||
std::process::exit(-1);
|
std::process::exit(-1);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
let remote = match matches.free.len() {
|
||||||
|
1 => matches.free[0].as_str(),
|
||||||
|
_ => {
|
||||||
|
print_usage(&program, opts);
|
||||||
|
std::process::exit(-1);
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
if !remote.contains(':') {
|
||||||
|
eprintln!("A remote port is required (REMOTE_ADDR:PORT)");
|
||||||
|
std::process::exit(-1);
|
||||||
|
}
|
||||||
|
|
||||||
DEBUG.store(matches.opt_present("d"), Ordering::Relaxed);
|
DEBUG.store(matches.opt_present("d"), Ordering::Relaxed);
|
||||||
// let local_port: i32 = matches.opt_str("l").unwrap_or("0".to_string()).parse()?;
|
// let local_port: i32 = matches.opt_str("l").unwrap_or("0".to_string()).parse()?;
|
||||||
let local_port: i32 = matches.opt_str("l").map(|s| s.parse()).unwrap_or(Ok(0))?;
|
let local_port: i32 = matches.opt_str("l").map(|s| s.parse()).unwrap_or(Ok(0))?;
|
||||||
let remote_port: i32 = matches.opt_str("r").unwrap().parse()?;
|
|
||||||
let remote_host = matches.opt_str("h").unwrap();
|
|
||||||
let bind_addr = match matches.opt_str("b") {
|
let bind_addr = match matches.opt_str("b") {
|
||||||
Some(addr) => addr,
|
Some(addr) => addr,
|
||||||
None => "127.0.0.1".to_owned(),
|
None => "127.0.0.1".to_owned(),
|
||||||
};
|
};
|
||||||
|
|
||||||
forward(&bind_addr, local_port, &remote_host, remote_port).await
|
forward(&bind_addr, local_port, remote).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn forward(bind_ip: &str, local_port: i32, remote_host: &str, remote_port: i32) -> Result<(), BoxedError> {
|
async fn forward(bind_ip: &str, local_port: i32, remote: &str) -> Result<(), BoxedError> {
|
||||||
// Listen on the specified IP and port
|
// Listen on the specified IP and port
|
||||||
let bind_addr = if bind_ip.contains(':') {
|
let bind_addr = if !bind_ip.starts_with('[') && bind_ip.contains(':') {
|
||||||
|
// Correctly format for IPv6 usage
|
||||||
format!("[{}]:{}", bind_ip, local_port)
|
format!("[{}]:{}", bind_ip, local_port)
|
||||||
} else {
|
} else {
|
||||||
format!("{}:{}", bind_ip, local_port)
|
format!("{}:{}", bind_ip, local_port)
|
||||||
|
@ -74,39 +77,23 @@ async fn forward(bind_ip: &str, local_port: i32, remote_host: &str, remote_port:
|
||||||
|
|
||||||
// We have either been provided an IP address or a host name.
|
// We have either been provided an IP address or a host name.
|
||||||
// Instead of trying to check its format, just trying creating a SocketAddr from it.
|
// Instead of trying to check its format, just trying creating a SocketAddr from it.
|
||||||
let parse_result = format!("{}:{}", remote_host, remote_port).parse::<std::net::SocketAddr>();
|
// let parse_result = remote.parse::<SocketAddr>();
|
||||||
let remote_addr = match parse_result {
|
let remote = std::sync::Arc::new(remote.to_string());
|
||||||
Ok(s) => s,
|
|
||||||
Err(_) => {
|
|
||||||
// It's a hostname; we're going to need to resolve it.
|
|
||||||
// Create an async dns resolver
|
|
||||||
|
|
||||||
use trust_dns_resolver::TokioAsyncResolver;
|
|
||||||
use trust_dns_resolver::config::*;
|
|
||||||
|
|
||||||
let resolver = TokioAsyncResolver::tokio(
|
|
||||||
ResolverConfig::default(),
|
|
||||||
ResolverOpts::default())
|
|
||||||
.await.expect("Failed to create DNS resolver");
|
|
||||||
|
|
||||||
let resolutions = resolver.lookup_ip(remote_host).await.expect("Failed to resolve server IP address!");
|
|
||||||
let remote_addr = resolutions.iter().nth(1).expect("Failed to resolve server IP address!");
|
|
||||||
println!("Resolved {} to {}", remote_host, remote_addr);
|
|
||||||
format!("{}:{}", remote_addr, remote_port).parse()?
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
let remote = remote.clone();
|
||||||
let (mut client, client_addr) = listener.accept().await?;
|
let (mut client, client_addr) = listener.accept().await?;
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
println!("New connection from {}", client_addr);
|
println!("New connection from {}", client_addr);
|
||||||
|
|
||||||
// Establish connection to upstream for each incoming client connection
|
// Establish connection to upstream for each incoming client connection
|
||||||
let mut remote = TcpStream::connect(&remote_addr).await?;
|
let mut remote = TcpStream::connect(remote.as_str()).await?;
|
||||||
let (mut client_recv, mut client_send) = client.split();
|
let (mut client_recv, mut client_send) = client.split();
|
||||||
let (mut remote_recv, mut remote_send) = remote.split();
|
let (mut remote_recv, mut remote_send) = remote.split();
|
||||||
|
|
||||||
|
// This version of the join! macro does not require that the futures are fused and
|
||||||
|
// pinned prior to passing to join.
|
||||||
let (remote_bytes_copied, client_bytes_copied) = join!(
|
let (remote_bytes_copied, client_bytes_copied) = join!(
|
||||||
tokio::io::copy(&mut remote_recv, &mut client_send),
|
tokio::io::copy(&mut remote_recv, &mut client_send),
|
||||||
tokio::io::copy(&mut client_recv, &mut remote_send),
|
tokio::io::copy(&mut client_recv, &mut remote_send),
|
||||||
|
|
Loading…
Reference in New Issue