use std::collections::HashMap; use std::future::Future; use std::io::{self, ErrorKind}; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::sync::atomic::{AtomicU64, Ordering}; use std::task::{Context, Poll}; use std::time::Duration; use bytes::BytesMut; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf, ReadHalf, WriteHalf}; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tokio::task::JoinHandle; use tokio::time::sleep; #[tokio::main] async fn main() { let channel_id = AtomicU64::new(1); let (left_to_right_sender, left_to_right_receiver) = channel::(512); let (right_to_left_sender, right_to_left_receiver) = channel::(512); let right_consumer_stream_map = Arc::new(Mutex::new(HashMap::>::new())); let left_consumer_stream_map = Arc::new(Mutex::new(HashMap::>::new())); loop_consume(left_to_right_receiver, right_consumer_stream_map.clone()); loop_consume(right_to_left_receiver, left_consumer_stream_map.clone()); let mut handles = vec![]; for _ in 0..3 { let new_channel_id = channel_id.fetch_add(1, Ordering::SeqCst); let right_stream = create_stream(new_channel_id, right_to_left_sender.clone(), right_consumer_stream_map.clone()); let left_stream = create_stream(new_channel_id, left_to_right_sender.clone(), left_consumer_stream_map.clone()); let (right_stream_reader, right_stream_writer) = tokio::io::split(right_stream); let (left_stream_reader, left_stream_writer) = tokio::io::split(left_stream); let a = loop_read(right_stream_reader, format!("#{} right reader", new_channel_id)); let b = loop_read(left_stream_reader, format!("#{} left reader", new_channel_id)); let c = loop_send(new_channel_id, right_stream_writer, format!("right sender")); let d = loop_send(new_channel_id, left_stream_writer, format!("left sender")); handles.extend([a, b, c, d]); } for h in handles { let _ = tokio::join!(h); } } fn create_stream(channel_id: u64, sender: Sender, consumer_stream_map: Arc>>>) -> Stream { let (new_sender, receiver) = channel::(16); { let stream_map_left_locked = &mut consumer_stream_map.lock().unwrap(); stream_map_left_locked.insert(channel_id, new_sender); } Stream { channel_id, sender, receiver, } } fn loop_send(channel_id: u64, mut writer: WriteHalf, tag: String) -> JoinHandle<()> { tokio::spawn(async move { for i in 0..3 { sleep(Duration::from_millis(100)).await; let _ = writer.write_all(format!("Send message: [{}] {} - {}", &tag, channel_id, i).as_bytes()).await; } let _ = writer.shutdown().await; }) } fn loop_read(mut reader: ReadHalf, tag: String) -> JoinHandle<()> { tokio::spawn(async move { let mut buf = BytesMut::with_capacity(4096); // let mut buf = Vec::with_capacity(4096); loop { buf.clear(); match reader.read_buf(&mut buf).await { Ok(len) => println!("Tag: {}, Len: {}, Buf: [[ {} ]]", &tag, len, String::from_utf8_lossy(&buf[..len])), Err(e) => { println!("Tag: {}, Err: {:?}", &tag, e); if e.to_string().contains("broken pipe") { break; } } } } }) } fn loop_consume(mut receiver: Receiver, stream_map: Arc>>>) { tokio::spawn(async move { loop { match receiver.recv().await { None => { println!("Receiver none"); break; } Some(package) => { let channel_id = package.channel_id; let sender = { let stream_map_locked = &stream_map.lock().unwrap(); stream_map_locked.get(&channel_id).map(|s| s.clone()) }; match sender { None => println!("Channel id not found: {}", channel_id), Some(sender) => { // TODO process result let _ = sender.send(package).await; } } } } } }); } #[derive(Debug)] struct StreamPackage { channel_id: u64, message: StreamPackageMessage, } #[derive(Debug)] enum StreamPackageMessage { Data(Vec), Flush, Close, } struct Stream { channel_id: u64, sender: Sender, receiver: Receiver, } impl AsyncRead for Stream { fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { let m = self.receiver.poll_recv(cx); match m { Poll::Pending => Poll::Pending, Poll::Ready(m_opt) => match m_opt { None => Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, "broken pipe"))), Some(package) => match package.message { StreamPackageMessage::Close => { self.receiver.close(); // Poll::Ready(Ok(())) Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, "broken pipe"))) } StreamPackageMessage::Flush => Poll::Ready(Ok(())), StreamPackageMessage::Data(mm) => { buf.put_slice(mm.as_slice()); Poll::Ready(Ok(())) } } } } } } impl AsyncWrite for Stream { fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { let stream_package = StreamPackage { channel_id: self.channel_id, message: StreamPackageMessage::Data(Vec::from(buf)) }; pool_send(self, cx, stream_package, buf.len()) } fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let stream_package = StreamPackage { channel_id: self.channel_id, message: StreamPackageMessage::Flush }; pool_send(self, cx, stream_package, ()) } fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let stream_package = StreamPackage { channel_id: self.channel_id, message: StreamPackageMessage::Close }; pool_send(self, cx, stream_package, ()) } } fn pool_send(s: Pin<&mut Stream>, cx: &mut Context<'_>, stream_package: StreamPackage, r: R) -> Poll> { let fut = s.sender.send(stream_package); tokio::pin!( fut); match fut.poll(cx) { Poll::Pending => Poll::Pending, Poll::Ready(Ok(_)) => Poll::Ready(Ok(r)), Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e.to_string()))), } }