// You probably want https://lib.rs/crates/domain, this is just me having fun
// exploring DNS. References:
// * https://www.ietf.org/rfc/rfc1035.txt
// * https://www.ietf.org/rfc/rfc3596.txt
// * https://www.ietf.org/rfc/rfc4035.txt

use std::convert::TryInto;
use std::io::Read;
use std::net::{Ipv4Addr, Ipv6Addr};

const QCLASS_IN: u16 = 1;

#[derive(Clone, Copy, Debug, PartialEq)]
enum QR {
    Query,
    Reply,
}

#[derive(Clone, Copy, Debug, PartialEq)]
enum ResponseCode {
    NoError,
    FormatErr,
    ServerFailure,
    NameError,
    NotImplemented,
    Refused,
}

#[derive(Clone, Copy, Debug, PartialEq)]
struct Flags {
    qr: QR,
    authoritative: bool,
    truncated: bool,
    recursion_desired: bool,
    recursion_available: bool,
    response_code: ResponseCode,
}

#[derive(Clone, Copy, Debug, PartialEq)]
struct Header {
    id: u16,
    flags: Flags,
    question_count: u16,
    answer_count: u16,
}

#[derive(Clone, Copy, Debug, PartialEq)]
enum Type {
    A,
    QuadA,
}

#[derive(Clone, Debug, PartialEq)]
struct Question {
    qname: String,
    qtype: Type,
}

#[derive(Clone, Copy, Debug, PartialEq)]
enum ResourceRecordData {
    A(Ipv4Addr),
    QuadA(Ipv6Addr),
}

#[derive(Clone, Debug, PartialEq)]
struct ResourceRecord {
    name: String,
    ttl: u32,
    rdata: ResourceRecordData,
}

#[derive(Clone, Debug, PartialEq)]
pub struct DnsPacket {
    header: Header,
    questions: Vec<Question>,
    answers: Vec<ResourceRecord>,
}

impl Flags {
    fn parse(flags_int: u16) -> std::io::Result<Flags> {
        if flags_int & 0x70 > 0 {
            return Err(std::io::Error::new(
                std::io::ErrorKind::Other,
                "Reserved bits are set in flags.",
            ));
        }
        if flags_int & 0x7800 > 0 {
            return Err(std::io::Error::new(
                std::io::ErrorKind::Other,
                "Only Query Opcode is supported.",
            ));
        }
        Ok(Flags {
            qr: if flags_int & 0x8000 > 0 {
                QR::Reply
            } else {
                QR::Query
            },
            authoritative: flags_int & 0x400 > 0,
            truncated: flags_int & 0x200 > 0,
            recursion_desired: flags_int & 0x100 > 0,
            recursion_available: flags_int & 0x80 > 0,
            response_code: match flags_int & 0xf {
                0 => ResponseCode::NoError,
                1 => ResponseCode::FormatErr,
                2 => ResponseCode::ServerFailure,
                3 => ResponseCode::NameError,
                4 => ResponseCode::NotImplemented,
                5 => ResponseCode::Refused,
                _ => {
                    return Err(std::io::Error::new(
                        std::io::ErrorKind::Other,
                        "Received reserved Responsecode value",
                    ))
                }
            },
        })
    }

    fn serialise(&self) -> u16 {
        let mut flags_int = 0u16;
        flags_int |= match self.qr {
            QR::Query => 0u16,
            QR::Reply => 0x8000,
        };
        if self.authoritative {
            flags_int |= 0x400;
        }
        if self.truncated {
            flags_int |= 0x200;
        }
        if self.recursion_desired {
            flags_int |= 0x100;
        }
        if self.recursion_available {
            flags_int |= 0x80;
        }
        flags_int |= match self.response_code {
            ResponseCode::NoError => 0,
            ResponseCode::FormatErr => 1,
            ResponseCode::ServerFailure => 2,
            ResponseCode::NameError => 3,
            ResponseCode::NotImplemented => 4,
            ResponseCode::Refused => 5,
        };
        flags_int
    }
}

impl Header {
    fn parse(buf: &[u8]) -> std::io::Result<Header> {
        if buf.len() < 12 {
            return Err(std::io::Error::new(
                std::io::ErrorKind::Other,
                "Buffer too small for header",
            ));
        }
        let id = u16::from_be_bytes([buf[0], buf[1]]);
        let flags = Flags::parse(u16::from_be_bytes([buf[2], buf[3]]))?;
        let question_count = u16::from_be_bytes([buf[4], buf[5]]);
        let answer_count = u16::from_be_bytes([buf[6], buf[7]]);
        let _name_server_count = u16::from_be_bytes([buf[8], buf[9]]);
        let _additional_records_count = u16::from_be_bytes([buf[10], buf[11]]);

        Ok(Header {
            id,
            flags,
            question_count,
            answer_count,
        })
    }
    fn serialise(&self) -> Vec<u8> {
        let mut serialised = Vec::with_capacity(12);
        serialised.extend(self.id.to_be_bytes());
        serialised.extend(self.flags.serialise().to_be_bytes());
        serialised.extend(self.question_count.to_be_bytes());
        serialised.extend(self.answer_count.to_be_bytes());
        serialised.extend(0u16.to_be_bytes());
        serialised.extend(0u16.to_be_bytes());

        serialised
    }
}

fn read_u8(input: &mut std::io::Cursor<&[u8]>) -> std::io::Result<u8> {
    let mut buf = [0u8; 1];
    input.read_exact(&mut buf)?;
    Ok(u8::from_be_bytes(buf))
}
fn read_u16(input: &mut std::io::Cursor<&[u8]>) -> std::io::Result<u16> {
    let mut buf = [0u8; 2];
    input.read_exact(&mut buf)?;
    Ok(u16::from_be_bytes(buf))
}
fn read_u32(input: &mut std::io::Cursor<&[u8]>) -> std::io::Result<u32> {
    let mut buf = [0u8; 4];
    input.read_exact(&mut buf)?;
    Ok(u32::from_be_bytes(buf))
}

fn read_labels(input: &mut std::io::Cursor<&[u8]>, message: &[u8]) -> std::io::Result<String> {
    let mut buf = [0u8; 63];
    let mut labels: Vec<Vec<u8>> = Vec::new();
    loop {
        let label_length = read_u8(input)? as usize;
        if label_length > 63 {
            let offset = read_u8(input)? as usize;
            let mut pointer_input = std::io::Cursor::new(&message[offset..]);
            let mut string_labels = labels
                .iter()
                .map(|l| String::from_utf8_lossy(l).to_string())
                .collect::<Vec<String>>();
            string_labels.push(read_labels(&mut pointer_input, message)?);
            return Ok(string_labels.join("."));
        }
        if label_length == 0 {
            break;
        }
        input.read_exact(&mut buf[..label_length])?;
        let label = Vec::from(&buf[..label_length]);
        labels.push(label);
    }
    Ok(labels
        .iter()
        .map(|l| String::from_utf8_lossy(l).to_string())
        .collect::<Vec<String>>()
        .join("."))
}

fn write_labels(domain: &str) -> Vec<u8> {
    let mut buf = Vec::with_capacity(255);
    for part in domain.split('.') {
        buf.push(part.len() as u8);
        buf.extend(part.as_bytes());
    }
    buf.push(0);
    buf
}

impl Type {
    fn parse(input: &mut std::io::Cursor<&[u8]>) -> std::io::Result<Type> {
        let qtype_int = read_u16(input)?;
        Ok(match qtype_int {
            1 => Type::A,
            28 => Type::QuadA,
            _ => {
                return Err(std::io::Error::new(
                    std::io::ErrorKind::Other,
                    format!("Received unsupported Type value {qtype_int}"),
                ))
            }
        })
    }
    fn serialise(&self) -> u16 {
        match self {
            Type::A => 1,
            Type::QuadA => 28,
        }
    }
}

fn read_qclass(input: &mut std::io::Cursor<&[u8]>) -> std::io::Result<()> {
    let qclass_int = read_u16(input)?;
    if qclass_int != QCLASS_IN {
        // 255 ANY should also be fine.
        return Err(std::io::Error::new(
            std::io::ErrorKind::Other,
            "Only IN QClass value is supported.",
        ));
    };
    Ok(())
}

impl Question {
    fn parse(input: &mut std::io::Cursor<&[u8]>, message: &[u8]) -> std::io::Result<Question> {
        let qname = read_labels(input, message)?;
        let qtype = Type::parse(input)?;
        read_qclass(input)?;

        Ok(Question { qname, qtype })
    }
    fn serialise(&self) -> Vec<u8> {
        let mut serialised = Vec::with_capacity(128);
        serialised.extend(write_labels(&self.qname));
        serialised.extend(self.qtype.serialise().to_be_bytes());
        serialised.extend(QCLASS_IN.to_be_bytes());
        serialised
    }
}

impl ResourceRecord {
    fn parse(
        input: &mut std::io::Cursor<&[u8]>,
        message: &[u8],
    ) -> std::io::Result<ResourceRecord> {
        let mut buf = [0u8; 512];
        let name = read_labels(input, message)?;
        let type_result = Type::parse(input);
        let qclass_result = read_qclass(input);
        let ttl = read_u32(input)?;
        let rdlength = read_u16(input)? as usize;
        input.read_exact(&mut buf[..rdlength])?;

        // Fail with invalid values only after reading the full record.
        let rtype = type_result?;
        qclass_result?;

        let rdata = match rtype {
            Type::A => ResourceRecordData::A(Ipv4Addr::from([buf[0], buf[1], buf[2], buf[3]])),
            Type::QuadA => ResourceRecordData::QuadA(Ipv6Addr::from(u128::from_be_bytes(
                buf[..16].try_into().unwrap(),
            ))),
        };

        Ok(ResourceRecord { name, ttl, rdata })
    }

    fn serialise(&self) -> Vec<u8> {
        let mut serialised = Vec::with_capacity(128);
        serialised.extend(write_labels(&self.name));
        serialised.extend(
            match self.rdata {
                ResourceRecordData::A(_) => Type::A,
                ResourceRecordData::QuadA(_) => Type::QuadA,
            }
            .serialise()
            .to_be_bytes(),
        );
        serialised.extend(QCLASS_IN.to_be_bytes());
        serialised.extend(self.ttl.to_be_bytes());
        serialised.extend(
            match self.rdata {
                ResourceRecordData::A(_) => 4u16,
                ResourceRecordData::QuadA(_) => 16,
            }
            .to_be_bytes(),
        );
        match self.rdata {
            ResourceRecordData::A(ip) => serialised.extend(u32::from(ip).to_be_bytes()),
            ResourceRecordData::QuadA(ip) => serialised.extend(u128::from(ip).to_be_bytes()),
        };
        serialised
    }
}

impl DnsPacket {
    pub fn parse(buf: &[u8]) -> std::io::Result<DnsPacket> {
        let mut header = Header::parse(buf)?;
        let mut input = std::io::Cursor::new(&buf[12..]);
        let questions: Vec<Question> = (0..header.question_count)
            .filter_map(|_| Question::parse(&mut input, buf).ok())
            .collect();
        let answers: Vec<ResourceRecord> = (0..header.answer_count)
            .filter_map(|_| ResourceRecord::parse(&mut input, buf).ok())
            .collect();

        header.question_count = questions.len() as u16;
        header.answer_count = answers.len() as u16;

        Ok(DnsPacket {
            header,
            questions,
            answers,
        })
    }
    pub fn query_target(&self) -> Option<String> {
        self.questions.first().map(|f| f.qname.to_string())
    }
    pub fn convert_into_reply(&self, target: Ipv4Addr) -> std::io::Result<DnsPacket> {
        let mut reply = self.clone();
        reply.header.flags.qr = QR::Reply;
        reply.header.answer_count += 1;
        let name = self.query_target().ok_or(std::io::Error::new(
            std::io::ErrorKind::Other,
            "Tried to convert DnsPacket without query target",
        ))?;
        let ttl: u32 = 300;
        let rdata = ResourceRecordData::A(target);
        reply.answers.push(ResourceRecord { name, ttl, rdata });
        Ok(reply)
    }
    pub fn serialise(&self) -> Vec<u8> {
        let mut serialised = Vec::with_capacity(512);
        serialised.extend(self.header.serialise());
        for question in &self.questions {
            serialised.extend(question.serialise());
        }
        for answer in &self.answers {
            serialised.extend(answer.serialise())
        }

        serialised
    }
}

#[cfg(test)]
mod tests {
    #[test]
    fn test_flags() {
        for i in 0..u16::MAX {
            let parse_result = super::Flags::parse(i);
            match i {
                i if i & 0x7870 == 0 && i & 0xf <= 5 => {
                    assert!(
                        parse_result.is_ok(),
                        "Expected ok for input value 0x{i:02x}"
                    );
                    let flags = parse_result.unwrap();
                    assert_eq!(i & 0x8000 > 0, flags.qr == super::QR::Reply);
                    assert_eq!(i & 0x400 > 0, flags.authoritative);
                    assert_eq!(i & 0x200 > 0, flags.truncated);
                    assert_eq!(i & 0x100 > 0, flags.recursion_desired);
                    assert_eq!(i & 0x80 > 0, flags.recursion_available);

                    assert_eq!(i, flags.serialise());
                }
                _ => {
                    assert!(
                        parse_result.is_err(),
                        "Expected error for input value 0x{i:02x}"
                    );
                }
            }
        }
    }
    #[test]
    fn test_extract_query() {
        let query = b"\x86\xcb\x01\x00\x00\x01\x00\x00\x00\x00\x00\x00\x03\x77\x77\x77\x0e\x77\x69\x6e\x74\x65\x72\x6b\x6f\x6e\x67\x72\x65\x73\x73\x02\x63\x68\x00\x00\x01\x00\x01";
        let dns_packet = super::DnsPacket::parse(query);
        assert!(dns_packet.is_ok());
        let query = dns_packet.unwrap().query_target();
        assert!(query.is_some());
        assert_eq!(query.unwrap(), "www.winterkongress.ch");
    }

    #[test]
    fn test_extract_response() {
        let query = b"\x86\xcb\x81\x80\x00\x01\x00\x02\x00\x00\x00\x00\x03\x77\x77\x77\x0e\x77\x69\x6e\x74\x65\x72\x6b\x6f\x6e\x67\x72\x65\x73\x73\x02\x63\x68\x00\x00\x01\x00\x01\xc0\x0c\x00\x05\x00\x01\x00\x00\x01\x2c\x00\x02\xc0\x10\xc0\x10\x00\x01\x00\x01\x00\x00\x01\x2c\x00\x04\x59\x2d\xe3\x35";
        let dns_packet = super::DnsPacket::parse(query);
        assert!(dns_packet.is_ok());
        let dns_packet = dns_packet.unwrap();
        let query = dns_packet.query_target();
        assert!(query.is_some());
        assert_eq!(query.unwrap(), "www.winterkongress.ch");
        let a_record = dns_packet.answers.iter().find_map(|r| match r.rdata {
            super::ResourceRecordData::A(ip) => Some(ip),
            super::ResourceRecordData::QuadA(_) => None,
        });
        assert!(a_record.is_some());
        assert_eq!(a_record.unwrap(), std::net::Ipv4Addr::new(89, 45, 227, 53));
    }

    #[test]
    fn test_serialise_parse() {
        let query = b"\x86\xcb\x81\x80\x00\x01\x00\x02\x00\x00\x00\x00\x03\x77\x77\x77\x0e\x77\x69\x6e\x74\x65\x72\x6b\x6f\x6e\x67\x72\x65\x73\x73\x02\x63\x68\x00\x00\x01\x00\x01\xc0\x0c\x00\x05\x00\x01\x00\x00\x01\x2c\x00\x02\xc0\x10\xc0\x10\x00\x01\x00\x01\x00\x00\x01\x2c\x00\x04\x59\x2d\xe3\x35";
        let dns_packet = super::DnsPacket::parse(query).unwrap();

        assert_eq!(
            dns_packet,
            super::DnsPacket::parse(&dns_packet.serialise()).unwrap()
        );
    }
}
