mod dns;
mod dns_lookup;
mod stream_copy;
mod unshare;

use dns_lookup::DnsLookupTable;
use unshare::CloneFlags;
use unshare::UidGid;
use unshare::Unshare;

use passfd::FdPassingExt;
use std::io::Write;
use std::net::{Ipv4Addr, UdpSocket};
use std::net::{TcpListener, TcpStream};
use std::os::unix::io::AsRawFd;
use std::os::unix::io::FromRawFd;
use std::os::unix::net::UnixStream;
use std::os::unix::process::CommandExt;
use std::path::Path;
use std::path::PathBuf;
use std::sync::Arc;

fn handle_proxy(stream: TcpStream, dns_lookup_table: Arc<DnsLookupTable>) -> std::io::Result<()> {
    let local_ip = match stream.local_addr().unwrap().ip() {
        std::net::IpAddr::V4(ip) => ip,
        std::net::IpAddr::V6(_) => todo!(),
    };
    if let Some(hostname) = dns_lookup_table.get_hostname_for_ip(local_ip) {
        let outgoing = TcpStream::connect((hostname, 80))?;
        stream_copy::BidirectionalStreamCopy::new(stream, outgoing)?.copy_streams()?;
    } else {
        return Err(std::io::Error::new(
            std::io::ErrorKind::Other,
            "No hostname for inbound IP.",
        ));
    }
    Ok(())
}

fn main() -> std::io::Result<()> {
    let (parent_sock, netns_sock) = UnixStream::pair().unwrap();
    // let parent_uid_gid = UidGid::from_current();
    let temp_resolv_conf = unshare::TempFile::new(|f| f.write_all(b"nameserver 127.0.0.1\n"))?;
    let temp_resolv_conf_file_name = PathBuf::from(temp_resolv_conf.file_name());
    let temp_nsswitch_conf = unshare::TempFile::new(|f| {
        let mut nsswitch_content = std::fs::read_to_string("/etc/nsswitch.conf")?;
        let needle = "resolve [!UNAVAIL=return] ";
        if let Some(pos) = nsswitch_content.find(needle) {
            nsswitch_content.drain(pos..pos + needle.len());
        }
        f.write_all(nsswitch_content.as_bytes())
    })?;
    let temp_nsswitch_conf_file_name = PathBuf::from(temp_nsswitch_conf.file_name());

    let mut command = std::process::Command::new("bash");
    let command = unsafe {
        command.pre_exec(move || {
            // Create new namespace mapped to root and bind mount resolv.conf
            let unshare = Unshare::new(&[CloneFlags::Newuser, CloneFlags::Newns])?;
            unshare.map_uid_gid(UidGid::root())?;
            unshare.bind_mount_file(&temp_resolv_conf_file_name, Path::new("/etc/resolv.conf"))?;
            unshare.bind_mount_file(
                &temp_nsswitch_conf_file_name,
                Path::new("/etc/nsswitch.conf"),
            )?;

            // Create new network namespace mapped to parent user.
            let _unshare = Unshare::new(&[CloneFlags::Newuser, CloneFlags::Newnet])?;
            // unshare.map_uid_gid(parent_uid_gid)?;

            let proxy_listener = TcpListener::bind((Ipv4Addr::UNSPECIFIED, 80)).map_err(|e| {
                std::io::Error::new(e.kind(), format!("binding to TCP port failed: {}", e))
            })?;
            netns_sock.send_fd(proxy_listener.as_raw_fd())?;

            let dns_listener = UdpSocket::bind((Ipv4Addr::new(127, 0, 0, 1), 53)).map_err(|e| {
                std::io::Error::new(e.kind(), format!("binding to UDP port failed: {}", e))
            })?;
            netns_sock.send_fd(dns_listener.as_raw_fd())
        })
    };
    let mut child = command.spawn().unwrap();

    let proxy_listener_fd = parent_sock.recv_fd()?;
    let proxy_listener = unsafe { TcpListener::from_raw_fd(proxy_listener_fd) };

    let dns_listener_fd = parent_sock.recv_fd()?;
    let dns_listener = unsafe { UdpSocket::from_raw_fd(dns_listener_fd) };

    let dns_lookup_table = std::sync::Arc::new(DnsLookupTable::new(Ipv4Addr::new(127, 1, 0, 1)));
    let dns_lookup_table_proxy = std::sync::Arc::clone(&dns_lookup_table);

    std::thread::spawn(move || -> std::io::Result<()> {
        for stream in proxy_listener.incoming() {
            let dns_lookup_table = Arc::clone(&dns_lookup_table_proxy);
            if let Err(e) = handle_proxy(stream?, dns_lookup_table) {
                eprintln!("{}", e);
            }
        }
        Ok(())
    });
    std::thread::spawn(move || -> std::io::Result<()> {
        let mut buf = [0u8; 512];
        loop {
            let (size, requester) = dns_listener.recv_from(&mut buf)?;
            let request = dns::DnsPacket::parse(&buf[..size])?;
            let hostname = request.query_target().unwrap();
            if hostname != "www.digiges.ch" && hostname != "www.winterkongress.ch" {
                eprintln!("Proxy: Verbindung zu {} nicht erlaubt", hostname);
                continue;
            }
            let target = dns_lookup_table.get_or_insert(&hostname);
            let reply = request.convert_into_reply(target)?;
            dns_listener.send_to(&reply.serialise(), requester)?;
        }
    });
    match child.wait() {
        Ok(code) => {
            drop(temp_resolv_conf);
            drop(temp_nsswitch_conf);
            std::process::exit(code.code().unwrap())
        }
        Err(e) => Err(e),
    }
}
