diff --git a/src/main.rs b/src/main.rs index 1bb6200..973fb89 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ extern crate tokio_io; use abstract_ns::Resolver; use futures::{Future, Stream}; +use futures::future; use getopts::Options; use ns_dns_tokio::DnsResolver; use std::env; @@ -91,18 +92,24 @@ fn forward(bind_ip: &str, local_port: i32, remote_host: &str, remote_port: i32) .expect(&format!("Unable to bind to {}", &bind_addr)); println!("Listening on {}", listener.local_addr().unwrap()); - //create an async dns resolver - let resolver = DnsResolver::system_config(&handle).unwrap(); + let handle2 = handle.clone(); + //we have either been provided an IP address or a host name + //instead of trying to check its format, just trying creating a SocketAddr from it + let parse_result = format!("{}:{}", remote_host, remote_port).parse::(); + let server = future::result(parse_result) + .or_else(|_| { + //it's a hostname; we're going to need to resolve it + //create an async dns resolver + let resolver = DnsResolver::system_config(&handle2).unwrap(); - let server = resolver.resolve(&format!("{}:{}", remote_host, remote_port)) - .map_err(|err| { - println!("{:?}", err); - () + resolver.resolve(&format!("{}:{}", remote_host, remote_port)) + .map(move |resolved| { + resolved.pick_one() + .expect(&format!("No valid IP addresses for target {}", remote_host)) + }) + .map_err(|err| println!("{:?}", err)) }) - .and_then(move |resolved| { - let remote_addr = resolved.pick_one() - .expect(&format!("No valid IP addresses for host {}!", remote_host)); - + .and_then(move |remote_addr| { println!("Resolved {}:{} to {}", remote_host, remote_port, @@ -134,17 +141,19 @@ fn forward(bind_ip: &str, local_port: i32, remote_host: &str, remote_port: i32) let client_addr_clone = client_addr.clone(); let async1 = remote_bytes_copied.map(move |(count, _, _)| { - debug(format!("Transferred {} bytes from upstream server to remote \ - client {}", - count, client_addr_clone)) + debug(format!("Transferred {} bytes from upstream server to \ + remote client {}", + count, + client_addr_clone)) }) .map_err(move |err| error_handler(err, client_addr_clone)); let client_addr_clone = client_addr; let async2 = client_bytes_copied.map(move |(count, _, _)| { - debug(format!("Transferred {} bytes from remote client {} to upstream \ - server", - count, client_addr_clone)) + debug(format!("Transferred {} bytes from remote client {} to \ + upstream server", + count, + client_addr_clone)) }) .map_err(move |err| error_handler(err, client_addr_clone));