diff --git a/__concurrent/async_study/Cargo.lock b/__concurrent/async_study/Cargo.lock index 0453b09..784fb33 100644 --- a/__concurrent/async_study/Cargo.lock +++ b/__concurrent/async_study/Cargo.lock @@ -6,6 +6,7 @@ version = 3 name = "async_study" version = "0.1.0" dependencies = [ + "bytes", "futures-util", "tokio", ] diff --git a/__concurrent/async_study/Cargo.toml b/__concurrent/async_study/Cargo.toml index de8e4e3..f8e3525 100644 --- a/__concurrent/async_study/Cargo.toml +++ b/__concurrent/async_study/Cargo.toml @@ -8,3 +8,4 @@ edition = "2021" [dependencies] tokio = { version = "1.0", features = ["full"] } futures-util = { version = "0.3", default-features = false } +bytes = "1.1" diff --git a/__concurrent/async_study/src/main.rs b/__concurrent/async_study/src/main.rs index e68090e..06d7767 100644 --- a/__concurrent/async_study/src/main.rs +++ b/__concurrent/async_study/src/main.rs @@ -1,19 +1,194 @@ +use std::collections::HashMap; use std::future::Future; use std::io::{self, ErrorKind}; use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::sync::atomic::{AtomicU64, Ordering}; use std::task::{Context, Poll}; use std::time::Duration; -use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf}; +use bytes::BytesMut; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf, ReadHalf}; use tokio::sync::mpsc::{channel, Receiver, Sender}; +use tokio::task::JoinHandle; use tokio::time::sleep; #[tokio::main] -async fn main() {} +async fn main() { + let stream_id = AtomicU64::new(1); + + let (left_sender, left_receiver) = channel::(512); + let (right_sender, right_receiver) = channel::(512); + + let left_stream_map = Arc::new(Mutex::new(HashMap::>::new())); + let right_stream_map = Arc::new(Mutex::new(HashMap::>::new())); + + + let cloned_left_stream_map = left_stream_map.clone(); + let cloned_right_stream_map = right_stream_map.clone(); + loop_consume(left_receiver, cloned_left_stream_map); + loop_consume(right_receiver, cloned_right_stream_map); + + let new_stream_id = stream_id.fetch_add(1, Ordering::SeqCst); + + let (left_new_sender, left_new_receiver) = channel::(16); + { + let stream_map_left_locked = &mut left_stream_map.lock().unwrap(); + stream_map_left_locked.insert(new_stream_id, left_new_sender); + } + let left_stream = Stream { + channel_id: new_stream_id, + sender: right_sender.clone(), + receiver: left_new_receiver, + }; + let (right_new_sender, right_new_receiver) = channel::(16); + { + let stream_map_right_locked = &mut right_stream_map.lock().unwrap(); + stream_map_right_locked.insert(new_stream_id, right_new_sender); + } + let right_stream = Stream { + channel_id: new_stream_id, + sender: left_sender.clone(), + receiver: right_new_receiver, + }; + + let (left_stream_reader, mut left_stream_writer) = tokio::io::split(left_stream); + let (right_stream_reader, mut right_stream_writer) = tokio::io::split(right_stream); + + let a = loop_read(left_stream_reader, format!("#{} left reader ", new_stream_id)); + let b = loop_read(right_stream_reader, format!("#{} right reader", new_stream_id)); + + let c = tokio::spawn(async move { + for i in 0..10 { + sleep(Duration::from_millis(100)).await; + let _ = left_stream_writer.write_all(format!("Left message: {}", i).as_bytes()).await; + } + let _ = left_stream_writer.shutdown().await; + }); + let d = tokio::spawn(async move { + for i in 0..10 { + sleep(Duration::from_millis(100)).await; + let _ = right_stream_writer.write_all(format!("Right message: {}", i).as_bytes()).await; + } + let _ = right_stream_writer.shutdown().await; + }); + + let _ = tokio::join!(a, b, c, d); +} + +fn loop_read(mut reader: ReadHalf, tag: String) -> JoinHandle<()> { + tokio::spawn(async move { + let mut buf = BytesMut::with_capacity(4096); + // let mut buf = Vec::with_capacity(4096); + loop { + buf.clear(); + match reader.read_buf(&mut buf).await { + Ok(len) => println!("Tag: {}, Len: {}, Buf: [[ {} ]]", &tag, len, String::from_utf8_lossy(&buf[..len])), + Err(e) => { + println!("Tag: {}, Err: {:?}", &tag, e); + if e.to_string().contains("broken pipe") { + break; + } + } + } + } + }) +} + +fn loop_consume(mut receiver: Receiver, stream_map: Arc>>>) { + tokio::spawn(async move { + loop { + match receiver.recv().await { + None => { + println!("Receiver none"); + break; + } + Some(package) => { + let channel_id = package.channel_id; + let sender = { + let stream_map_locked = &stream_map.lock().unwrap(); + stream_map_locked.get(&channel_id).map(|s| s.clone()) + }; + match sender { + None => println!("Channel id not found: {}", channel_id), + Some(sender) => { + // TODO process result + let _ = sender.send(package).await; + } + } + } + } + } + }); +} + #[derive(Debug)] -enum StreamPackage { - Data(u64, Vec), - Flush(u64), - Close(u64), +struct StreamPackage { + channel_id: u64, + message: StreamPackageMessage, } + +#[derive(Debug)] +enum StreamPackageMessage { + Data(Vec), + Flush, + Close, +} + +struct Stream { + channel_id: u64, + sender: Sender, + receiver: Receiver, +} + +impl AsyncRead for Stream { + fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll> { + let m = self.receiver.poll_recv(cx); + match m { + Poll::Pending => Poll::Pending, + Poll::Ready(m_opt) => match m_opt { + None => Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, "broken pipe"))), + Some(package) => match package.message { + StreamPackageMessage::Close => { + self.receiver.close(); + // Poll::Ready(Ok(())) + Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, "broken pipe"))) + } + StreamPackageMessage::Flush => Poll::Ready(Ok(())), + StreamPackageMessage::Data(mm) => { + buf.put_slice(mm.as_slice()); + Poll::Ready(Ok(())) + } + } + } + } + } +} + +impl AsyncWrite for Stream { + fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + let stream_package = StreamPackage { channel_id: self.channel_id, message: StreamPackageMessage::Data(Vec::from(buf)) }; + pool_send(self, cx, stream_package, buf.len()) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let stream_package = StreamPackage { channel_id: self.channel_id, message: StreamPackageMessage::Flush }; + pool_send(self, cx, stream_package, ()) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let stream_package = StreamPackage { channel_id: self.channel_id, message: StreamPackageMessage::Close }; + pool_send(self, cx, stream_package, ()) + } +} + +fn pool_send(s: Pin<&mut Stream>, cx: &mut Context<'_>, stream_package: StreamPackage, r: R) -> Poll> { + let fut = s.sender.send(stream_package); + tokio::pin!( fut); + match fut.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(Ok(_)) => Poll::Ready(Ok(r)), + Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e.to_string()))), + } +} \ No newline at end of file