use mio::{Events, Interest, Poll, Token};
use std::io::ErrorKind;
use std::io::{Read, Write};
use std::net::TcpStream;

enum ConnectionState {
    Open,
    Closed,
}

pub struct BidirectionalStreamCopy {
    poll: Poll,
    incoming: mio::net::TcpStream,
    outgoing: mio::net::TcpStream,
}

impl BidirectionalStreamCopy {
    const INCOMING: Token = Token(0);
    const OUTGOING: Token = Token(1);

    pub fn new(incoming: TcpStream, outgoing: TcpStream) -> std::io::Result<Self> {
        let poll = Poll::new()?;
        let mut incoming = Self::transform_tcp_stream(incoming)?;
        let mut outgoing = Self::transform_tcp_stream(outgoing)?;

        poll.registry()
            .register(&mut incoming, Self::INCOMING, Interest::READABLE)?;
        poll.registry()
            .register(&mut outgoing, Self::OUTGOING, Interest::READABLE)?;

        Ok(Self {
            poll,
            incoming,
            outgoing,
        })
    }

    fn transform_tcp_stream(stream: TcpStream) -> std::io::Result<mio::net::TcpStream> {
        stream.set_nonblocking(true)?;
        Ok(mio::net::TcpStream::from_std(stream))
        // Ok((mio_stream, token.into()))
    }

    pub fn copy_streams(&mut self) -> std::io::Result<()> {
        let mut events = Events::with_capacity(2);
        let mut terminate = false;
        while !terminate {
            self.poll.poll(&mut events, None)?;

            for event in events.iter() {
                let result = match event.token() {
                    Self::INCOMING => Self::stream_copy(&mut self.incoming, &mut self.outgoing)?,
                    Self::OUTGOING => Self::stream_copy(&mut self.outgoing, &mut self.incoming)?,
                    _ => unreachable!(),
                };
                if matches!(result, ConnectionState::Closed) {
                    terminate = true;
                }
            }

            if terminate {
                break;
            }
        }
        Ok(())
    }
    fn stream_copy(
        source: &mut mio::net::TcpStream,
        target: &mut mio::net::TcpStream,
    ) -> std::io::Result<ConnectionState> {
        let mut buf = [0u8; 4096];
        loop {
            match source.read(&mut buf) {
                Ok(0) => {
                    return Ok(ConnectionState::Closed);
                }
                Ok(len) => target.write_all(&buf[..len])?,
                Err(e) if e.kind() == ErrorKind::WouldBlock => {
                    return Ok(ConnectionState::Open);
                }
                Err(e) => return Err(e),
            }
        }
    }
}
