Files
simple-rust-tests/__concurrent/async_study/src/main.rs
2022-03-19 23:36:45 +08:00

190 lines
7.0 KiB
Rust

use std::collections::HashMap;
use std::future::Future;
use std::io::{self, ErrorKind};
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::{Context, Poll};
use std::time::Duration;
use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf, ReadHalf, WriteHalf};
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::task::JoinHandle;
use tokio::time::sleep;
#[tokio::main]
async fn main() {
let channel_id = AtomicU64::new(1);
let (left_to_right_sender, left_to_right_receiver) = channel::<StreamPackage>(512);
let (right_to_left_sender, right_to_left_receiver) = channel::<StreamPackage>(512);
let right_consumer_stream_map = Arc::new(Mutex::new(HashMap::<u64, Sender<StreamPackage>>::new()));
let left_consumer_stream_map = Arc::new(Mutex::new(HashMap::<u64, Sender<StreamPackage>>::new()));
loop_consume(left_to_right_receiver, right_consumer_stream_map.clone());
loop_consume(right_to_left_receiver, left_consumer_stream_map.clone());
let mut joins = vec![];
for _ in 0..3 {
let new_channel_id = channel_id.fetch_add(1, Ordering::SeqCst);
let right_stream = create_stream(new_channel_id, right_to_left_sender.clone(), right_consumer_stream_map.clone());
let left_stream = create_stream(new_channel_id, left_to_right_sender.clone(), left_consumer_stream_map.clone());
let (right_stream_reader, right_stream_writer) = tokio::io::split(right_stream);
let (left_stream_reader, left_stream_writer) = tokio::io::split(left_stream);
let a = loop_read(right_stream_reader, format!("#{} right reader", new_channel_id));
let b = loop_read(left_stream_reader, format!("#{} left reader", new_channel_id));
let c = loop_send(new_channel_id, right_stream_writer, format!("right sender"));
let d = loop_send(new_channel_id, left_stream_writer, format!("left sender"));
joins.extend([a, b, c, d]);
}
for j in joins {
let _ = tokio::join!(j);
}
}
fn create_stream(channel_id: u64, sender: Sender<StreamPackage>, consumer_stream_map: Arc<Mutex<HashMap<u64, Sender<StreamPackage>>>>) -> Stream {
let (new_sender, receiver) = channel::<StreamPackage>(16);
{
let stream_map_left_locked = &mut consumer_stream_map.lock().unwrap();
stream_map_left_locked.insert(channel_id, new_sender);
}
Stream {
channel_id,
sender,
receiver,
}
}
fn loop_send(channel_id: u64, mut writer: WriteHalf<Stream>, tag: String) -> JoinHandle<()> {
tokio::spawn(async move {
for i in 0..3 {
sleep(Duration::from_millis(100)).await;
let _ = writer.write_all(format!("Send message: [{}] {} - {}", &tag, channel_id, i).as_bytes()).await;
}
let _ = writer.shutdown().await;
})
}
fn loop_read(mut reader: ReadHalf<Stream>, tag: String) -> JoinHandle<()> {
tokio::spawn(async move {
let mut buf = BytesMut::with_capacity(4096);
// let mut buf = Vec::with_capacity(4096);
loop {
buf.clear();
match reader.read_buf(&mut buf).await {
Ok(len) => println!("Tag: {}, Len: {}, Buf: [[ {} ]]", &tag, len, String::from_utf8_lossy(&buf[..len])),
Err(e) => {
println!("Tag: {}, Err: {:?}", &tag, e);
if e.to_string().contains("broken pipe") {
break;
}
}
}
}
})
}
fn loop_consume(mut receiver: Receiver<StreamPackage>, stream_map: Arc<Mutex<HashMap<u64, Sender<StreamPackage>>>>) {
tokio::spawn(async move {
loop {
match receiver.recv().await {
None => {
println!("Receiver none");
break;
}
Some(package) => {
let channel_id = package.channel_id;
let sender = {
let stream_map_locked = &stream_map.lock().unwrap();
stream_map_locked.get(&channel_id).map(|s| s.clone())
};
match sender {
None => println!("Channel id not found: {}", channel_id),
Some(sender) => {
// TODO process result
let _ = sender.send(package).await;
}
}
}
}
}
});
}
#[derive(Debug)]
struct StreamPackage {
channel_id: u64,
message: StreamPackageMessage,
}
#[derive(Debug)]
enum StreamPackageMessage {
Data(Vec<u8>),
Flush,
Close,
}
struct Stream {
channel_id: u64,
sender: Sender<StreamPackage>,
receiver: Receiver<StreamPackage>,
}
impl AsyncRead for Stream {
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<io::Result<()>> {
let m = self.receiver.poll_recv(cx);
match m {
Poll::Pending => Poll::Pending,
Poll::Ready(m_opt) => match m_opt {
None => Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, "broken pipe"))),
Some(package) => match package.message {
StreamPackageMessage::Close => {
self.receiver.close();
// Poll::Ready(Ok(()))
Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, "broken pipe")))
}
StreamPackageMessage::Flush => Poll::Ready(Ok(())),
StreamPackageMessage::Data(mm) => {
buf.put_slice(mm.as_slice());
Poll::Ready(Ok(()))
}
}
}
}
}
}
impl AsyncWrite for Stream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let stream_package = StreamPackage { channel_id: self.channel_id, message: StreamPackageMessage::Data(Vec::from(buf)) };
pool_send(self, cx, stream_package, buf.len())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let stream_package = StreamPackage { channel_id: self.channel_id, message: StreamPackageMessage::Flush };
pool_send(self, cx, stream_package, ())
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let stream_package = StreamPackage { channel_id: self.channel_id, message: StreamPackageMessage::Close };
pool_send(self, cx, stream_package, ())
}
}
fn pool_send<R>(s: Pin<&mut Stream>, cx: &mut Context<'_>, stream_package: StreamPackage, r: R) -> Poll<io::Result<R>> {
let fut = s.sender.send(stream_package);
tokio::pin!( fut);
match fut.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(_)) => Poll::Ready(Ok(r)),
Poll::Ready(Err(e)) => Poll::Ready(Err(io::Error::new(ErrorKind::BrokenPipe, e.to_string()))),
}
}