use std::io::ErrorKind;
use std::io::{Read, Write};
use std::net::IpAddr;
use std::net::Ipv4Addr;
use std::net::Ipv6Addr;
use std::net::SocketAddr;
use std::net::TcpStream;
use std::net::ToSocketAddrs;
use std::time::Duration;

enum Command {
    Connect,
    Bind,
    UdpAssociate,
    Invalid,
}

enum AddressType {
    IPv4,
    DomainName,
    IPv6,
    Invalid,
}

pub enum ConnectTarget {
    IPAddress(IpAddr, u16),
    Domain(String, u16),
}

#[allow(dead_code)]
enum Reply {
    Succeeded,
    GeneralServerFailure,
    ConnectionNotAllowed,
    NetworkUnreachable,
    HostUnreachable,
    ConnectionRefused,
    TtlExpired,
    CommandNotSupported,
    AddressTypeNotSupported,
}

pub struct Socks<'a> {
    stream: &'a mut TcpStream,
}

impl<'a> Socks<'a> {
    pub fn new(stream: &'a mut TcpStream) -> Self {
        Self { stream }
    }
    fn read_version(&mut self) -> std::io::Result<()> {
        let mut buf = [0u8; 1];
        self.stream.read_exact(&mut buf)?;
        let version = buf[0];
        if version != 0x05 {
            return Err(std::io::Error::new(
                ErrorKind::InvalidInput,
                "expect socks version 5",
            ));
        }
        Ok(())
    }

    fn read_client_methods(&mut self) -> std::io::Result<()> {
        let mut buf = [0u8; 255];
        self.stream.read_exact(&mut buf[..1])?;
        let num_methods = buf[0] as usize;
        self.stream.read_exact(&mut buf[..num_methods])?;
        Ok(())
    }

    fn respond_no_authentication_method(&mut self) -> std::io::Result<()> {
        let version = 0x05;
        let method_no_authentication_required = 0x00;
        self.stream
            .write_all(&[version, method_no_authentication_required])?;
        Ok(())
    }

    fn read_command(&mut self) -> std::io::Result<Command> {
        let mut buf = [0u8; 1];
        self.stream.read_exact(&mut buf)?;
        let command = buf[0];
        Ok(match command {
            0x01 => Command::Connect,
            0x02 => Command::Bind,
            0x03 => Command::UdpAssociate,
            _ => Command::Invalid,
        })
    }

    fn read_reserved(&mut self) -> std::io::Result<()> {
        let mut buf = [0u8; 1];
        self.stream.read_exact(&mut buf)?;
        Ok(())
    }

    fn read_ipv4_address(&mut self) -> std::io::Result<Ipv4Addr> {
        let mut buf = [0u8; 4];
        self.stream.read_exact(&mut buf)?;
        Ok(Ipv4Addr::from(buf))
    }

    fn read_ipv6_address(&mut self) -> std::io::Result<Ipv6Addr> {
        let mut buf = [0u8; 16];
        self.stream.read_exact(&mut buf)?;
        Ok(Ipv6Addr::from(buf))
    }

    fn read_port(&mut self) -> std::io::Result<u16> {
        let mut buf = [0u8; 2];
        self.stream.read_exact(&mut buf)?;
        Ok(u16::from_be_bytes(buf))
    }

    fn read_domain_name(&mut self) -> std::io::Result<String> {
        let mut buf = [0u8; 255];
        self.stream.read_exact(&mut buf[..1])?;
        let domain_name_size = buf[0] as usize;
        self.stream.read_exact(&mut buf[..domain_name_size])?;
        let domain_name = std::str::from_utf8(&buf[..domain_name_size]).map_err(|_| {
            std::io::Error::new(
                ErrorKind::InvalidData,
                "domain name contained non-utf8 characters",
            )
        })?;
        Ok(domain_name.to_owned())
    }

    fn resolve_domain_name(&mut self, domain_name: &str) -> std::io::Result<IpAddr> {
        let mut socket_addrs = (domain_name, 0).to_socket_addrs()?;
        match socket_addrs.next() {
            Some(socket_addr) => Ok(socket_addr.ip()),
            None => Err(std::io::Error::new(
                ErrorKind::NotFound,
                format!("Resolution of address '{domain_name} failed."),
            )),
        }
    }

    fn read_address_type(&mut self) -> std::io::Result<AddressType> {
        let mut buf = [0u8; 1];
        self.stream.read_exact(&mut buf)?;
        let address_type = buf[0];
        Ok(match address_type {
            0x01 => AddressType::IPv4,
            0x03 => AddressType::DomainName,
            0x04 => AddressType::IPv6,
            _ => AddressType::Invalid,
        })
    }

    fn read_connect_target(&mut self, address_type: AddressType) -> std::io::Result<ConnectTarget> {
        match address_type {
            AddressType::IPv4 => Ok(ConnectTarget::IPAddress(
                IpAddr::V4(self.read_ipv4_address()?),
                self.read_port()?,
            )),
            AddressType::DomainName => Ok(ConnectTarget::Domain(
                self.read_domain_name()?,
                self.read_port()?,
            )),
            AddressType::IPv6 => Ok(ConnectTarget::IPAddress(
                IpAddr::V6(self.read_ipv6_address()?),
                self.read_port()?,
            )),
            AddressType::Invalid => Err(std::io::Error::new(
                ErrorKind::InvalidInput,
                "Invalid address type",
            )),
        }
    }

    fn respond_reply(&mut self, reply: Reply) -> std::io::Result<()> {
        let version: u8 = 0x05;
        let reply_byte: u8 = match reply {
            Reply::Succeeded => 0x00,
            Reply::GeneralServerFailure => 0x01,
            Reply::ConnectionNotAllowed => 0x02,
            Reply::NetworkUnreachable => 0x03,
            Reply::HostUnreachable => 0x04,
            Reply::ConnectionRefused => 0x05,
            Reply::TtlExpired => 0x06,
            Reply::CommandNotSupported => 0x07,
            Reply::AddressTypeNotSupported => 0x08,
        };
        let reserved: u8 = 0x00;
        // Return dummy values instead of actual proxy IPs, we don't support
        // active FTP or similar protocols.
        let address_type: u8 = 0x01; // IPv4
        let address = [0u8; 4]; // 0.0.0.0
        let port = 1080u16.to_be_bytes();
        self.stream.write_all(&[
            version,
            reply_byte,
            reserved,
            address_type,
            address[0],
            address[1],
            address[2],
            address[3],
            port[0],
            port[1],
        ])?;
        Ok(())
    }

    pub fn handle_authentication(&mut self) -> std::io::Result<()> {
        self.read_version()?;
        self.read_client_methods()?;

        self.respond_no_authentication_method()
    }

    pub fn get_connect_target(&mut self) -> std::io::Result<ConnectTarget> {
        self.read_version()?;
        let command = self.read_command()?;
        if !matches!(command, Command::Connect) {
            self.respond_reply(Reply::CommandNotSupported)?;
            return Err(std::io::Error::new(
                ErrorKind::InvalidInput,
                "unsupported socks command",
            ));
        }
        self.read_reserved()?;
        let address_type = self.read_address_type()?;
        if matches!(address_type, self::AddressType::Invalid) {
            self.respond_reply(self::Reply::AddressTypeNotSupported)?;
            return Err(std::io::Error::new(
                ErrorKind::InvalidInput,
                "unsupported socks address type",
            ));
        }
        self.read_connect_target(address_type)
    }

    pub fn refuse_connection(mut self) -> std::io::Result<()> {
        self.respond_reply(Reply::ConnectionNotAllowed)
    }

    pub fn establish_connection_to_target(
        &mut self,
        target: ConnectTarget,
    ) -> std::io::Result<TcpStream> {
        let socket_addr = match target {
            ConnectTarget::IPAddress(ip, port) => SocketAddr::new(ip, port),
            ConnectTarget::Domain(domain, port) => {
                let ip = self.resolve_domain_name(&domain)?;
                SocketAddr::new(ip, port)
            }
        };

        let connect = TcpStream::connect_timeout(&socket_addr, Duration::new(10, 0));
        let reply = match &connect {
            Ok(_) => Reply::Succeeded,
            Err(e) if e.kind() == ErrorKind::TimedOut => Reply::HostUnreachable,
            Err(e) if e.kind() == ErrorKind::ConnectionRefused => Reply::ConnectionRefused,
            Err(_) => Reply::GeneralServerFailure,
        };
        self.respond_reply(reply)?;
        connect.map_err(|e| std::io::Error::new(e.kind(), format!("{}: {}", socket_addr, e)))
    }

    #[allow(unused)]
    pub fn read_request_and_connect(&mut self) -> std::io::Result<TcpStream> {
        self.handle_authentication()?;

        let connect_target = self.get_connect_target()?;
        self.establish_connection_to_target(connect_target)
    }
}
