Skip to main content

sinktools/
lazy_sink_source.rs

1//! [`LazySinkSource`], and related items.
2
3use core::marker::PhantomData;
4use core::pin::Pin;
5use core::task::{Context, Poll, Waker};
6use std::sync::Arc;
7use std::task::Wake;
8
9use futures_util::task::AtomicWaker;
10use futures_util::{Sink, Stream, ready};
11
12#[derive(Default)]
13struct DualWaker {
14    sink: AtomicWaker,
15    stream: AtomicWaker,
16}
17
18impl DualWaker {
19    fn new() -> (Arc<Self>, Waker) {
20        let dual_waker = Arc::new(Self::default());
21        let waker = Waker::from(dual_waker.clone());
22        (dual_waker, waker)
23    }
24}
25
26impl Wake for DualWaker {
27    fn wake(self: Arc<Self>) {
28        self.wake_by_ref();
29    }
30
31    fn wake_by_ref(self: &Arc<Self>) {
32        self.sink.wake();
33        self.stream.wake();
34    }
35}
36
37enum SharedState<Fut, St, Si, Item> {
38    Uninit {
39        future: Pin<Box<Fut>>,
40    },
41    Thunkulating {
42        future: Pin<Box<Fut>>,
43        item: Option<Item>,
44        dual_waker_state: Arc<DualWaker>,
45        dual_waker_waker: Waker,
46    },
47    Done {
48        stream: Pin<Box<St>>,
49        sink: Pin<Box<Si>>,
50        buf: Option<Item>,
51    },
52    Taken,
53}
54
55/// A lazy sink-source, where the internal state is initialized when the first item is attempted to be pulled from the
56/// source, or when the first item is sent to the sink. To split into separate source and sink halves, use
57/// [`futures_util::StreamExt::split`].
58pub struct LazySinkSource<Fut, St, Si, Item, Error> {
59    state: SharedState<Fut, St, Si, Item>,
60    _phantom: PhantomData<Error>,
61}
62
63impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
64    /// Creates a new `LazySinkSource` with the given initialization future.
65    pub fn new(future: Fut) -> Self {
66        Self {
67            state: SharedState::Uninit {
68                future: Box::pin(future),
69            },
70            _phantom: PhantomData,
71        }
72    }
73}
74
75impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkSource<Fut, St, Si, Item, Error>
76where
77    Self: Unpin,
78    Fut: Future<Output = Result<(St, Si), Error>>,
79    St: Stream,
80    Si: Sink<Item>,
81    Error: From<Si::Error>,
82{
83    type Error = Error;
84
85    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86        let state = &mut self.get_mut().state;
87
88        if let SharedState::Uninit { .. } = &*state {
89            return Poll::Ready(Ok(()));
90        }
91
92        if let SharedState::Thunkulating {
93            future,
94            item,
95            dual_waker_state,
96            dual_waker_waker,
97        } = &mut *state
98        {
99            dual_waker_state.sink.register(cx.waker());
100
101            let mut dual_context = Context::from_waker(dual_waker_waker);
102
103            match future.as_mut().poll(&mut dual_context) {
104                Poll::Ready(Ok((stream, sink))) => {
105                    let buf = item.take();
106                    *state = SharedState::Done {
107                        stream: Box::pin(stream),
108                        sink: Box::pin(sink),
109                        buf,
110                    };
111                }
112                Poll::Ready(Err(e)) => {
113                    return Poll::Ready(Err(e));
114                }
115                Poll::Pending => {
116                    return Poll::Pending;
117                }
118            }
119        }
120
121        if let SharedState::Done { sink, buf, .. } = &mut *state {
122            if buf.is_some() {
123                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
124                sink.as_mut().start_send(buf.take().unwrap())?;
125            }
126            let result = sink.as_mut().poll_ready(cx).map_err(From::from);
127            return result;
128        }
129
130        panic!("LazySinkSource in invalid state.");
131    }
132
133    fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
134        let state = &mut self.get_mut().state;
135
136        if let SharedState::Uninit { .. } = &*state {
137            let old_state = std::mem::replace(&mut *state, SharedState::Taken);
138            if let SharedState::Uninit { future } = old_state {
139                let (dual_waker_state, dual_waker_waker) = DualWaker::new();
140                *state = SharedState::Thunkulating {
141                    future,
142                    item: Some(item),
143                    dual_waker_state,
144                    dual_waker_waker,
145                };
146
147                return Ok(());
148            }
149        }
150
151        if let SharedState::Thunkulating { .. } = &mut *state {
152            panic!("LazySinkSource not ready.");
153        }
154
155        if let SharedState::Done { sink, buf, .. } = &mut *state {
156            debug_assert!(buf.is_none());
157            let result = sink.as_mut().start_send(item).map_err(From::from);
158            return result;
159        }
160
161        panic!("LazySinkSource not ready.");
162    }
163
164    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165        let state = &mut self.get_mut().state;
166
167        if let SharedState::Uninit { .. } = &*state {
168            return Poll::Ready(Ok(()));
169        }
170
171        if let SharedState::Thunkulating {
172            future,
173            item,
174            dual_waker_state,
175            dual_waker_waker,
176        } = &mut *state
177        {
178            dual_waker_state.sink.register(cx.waker());
179
180            let mut new_context = Context::from_waker(dual_waker_waker);
181
182            match future.as_mut().poll(&mut new_context) {
183                Poll::Ready(Ok((stream, sink))) => {
184                    let buf = item.take();
185                    *state = SharedState::Done {
186                        stream: Box::pin(stream),
187                        sink: Box::pin(sink),
188                        buf,
189                    };
190                }
191                Poll::Ready(Err(e)) => {
192                    return Poll::Ready(Err(e));
193                }
194                Poll::Pending => {
195                    return Poll::Pending;
196                }
197            }
198        }
199
200        if let SharedState::Done { sink, buf, .. } = &mut *state {
201            if buf.is_some() {
202                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
203                sink.as_mut().start_send(buf.take().unwrap())?;
204            }
205            let result = sink.as_mut().poll_flush(cx).map_err(From::from);
206            return result;
207        }
208
209        panic!("LazySinkHalf in invalid state.");
210    }
211
212    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
213        let state = &mut self.get_mut().state;
214
215        if let SharedState::Uninit { .. } = &*state {
216            return Poll::Ready(Ok(()));
217        }
218
219        if let SharedState::Thunkulating {
220            future,
221            item,
222            dual_waker_state,
223            dual_waker_waker,
224        } = &mut *state
225        {
226            dual_waker_state.sink.register(cx.waker());
227
228            let mut new_context = Context::from_waker(dual_waker_waker);
229
230            match future.as_mut().poll(&mut new_context) {
231                Poll::Ready(Ok((stream, sink))) => {
232                    let buf = item.take();
233                    *state = SharedState::Done {
234                        stream: Box::pin(stream),
235                        sink: Box::pin(sink),
236                        buf,
237                    };
238                }
239                Poll::Ready(Err(e)) => {
240                    return Poll::Ready(Err(e));
241                }
242                Poll::Pending => {
243                    return Poll::Pending;
244                }
245            }
246        }
247
248        if let SharedState::Done { sink, buf, .. } = &mut *state {
249            if buf.is_some() {
250                ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
251                sink.as_mut().start_send(buf.take().unwrap())?;
252            }
253            let result = sink.as_mut().poll_close(cx).map_err(From::from);
254            return result;
255        }
256
257        panic!("LazySinkHalf in invalid state.");
258    }
259}
260
261impl<Fut, St, Si, Item, Error> Stream for LazySinkSource<Fut, St, Si, Item, Error>
262where
263    Self: Unpin,
264    Fut: Future<Output = Result<(St, Si), Error>>,
265    St: Stream,
266    Si: Sink<Item>,
267{
268    type Item = St::Item;
269
270    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
271        let state = &mut self.get_mut().state;
272
273        if let SharedState::Uninit { .. } = &*state {
274            let old_state = std::mem::replace(&mut *state, SharedState::Taken);
275            if let SharedState::Uninit { future } = old_state {
276                let (dual_waker_state, dual_waker_waker) = DualWaker::new();
277                *state = SharedState::Thunkulating {
278                    future,
279                    item: None,
280                    dual_waker_state,
281                    dual_waker_waker,
282                };
283            } else {
284                unreachable!();
285            }
286        }
287
288        if let SharedState::Thunkulating {
289            future,
290            item,
291            dual_waker_state,
292            dual_waker_waker,
293        } = &mut *state
294        {
295            dual_waker_state.stream.register(cx.waker());
296
297            let mut new_context = Context::from_waker(dual_waker_waker);
298
299            match future.as_mut().poll(&mut new_context) {
300                Poll::Ready(Ok((stream, sink))) => {
301                    let buf = item.take();
302                    *state = SharedState::Done {
303                        stream: Box::pin(stream),
304                        sink: Box::pin(sink),
305                        buf,
306                    };
307                }
308
309                Poll::Ready(Err(_)) => {
310                    return Poll::Ready(None);
311                }
312
313                Poll::Pending => {
314                    return Poll::Pending;
315                }
316            }
317        }
318
319        if let SharedState::Done { stream, .. } = &mut *state {
320            let result = stream.as_mut().poll_next(cx);
321            match &result {
322                Poll::Ready(Some(_)) => {}
323                Poll::Ready(None) => {}
324                Poll::Pending => {}
325            }
326            return result;
327        }
328
329        panic!("LazySinkSource in invalid state.");
330    }
331}
332
333#[cfg(test)]
334mod test {
335    use futures_util::{SinkExt, StreamExt};
336    use tokio_util::sync::PollSendError;
337
338    use super::*;
339
340    #[tokio::test(flavor = "current_thread")]
341    async fn stream_drives_initialization() {
342        let local = tokio::task::LocalSet::new();
343        local
344            .run_until(async {
345                let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
346
347                let sink_source = LazySinkSource::new(async move {
348                    let () = init_lazy_recv.await.unwrap();
349                    let (send, recv) = tokio::sync::mpsc::channel(1);
350                    let sink = tokio_util::sync::PollSender::new(send);
351                    let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
352                    Ok::<_, PollSendError<_>>((stream, sink))
353                });
354
355                let (mut sink, mut stream) = sink_source.split();
356
357                // Ensures stream starts the lazy.
358                let (stream_init_send, stream_init_recv) = tokio::sync::oneshot::channel::<()>();
359                let stream_task = tokio::task::spawn_local(async move {
360                    stream_init_send.send(()).unwrap();
361                    (stream.next().await.unwrap(), stream.next().await.unwrap())
362                });
363                let sink_task = tokio::task::spawn_local(async move {
364                    stream_init_recv.await.unwrap();
365                    SinkExt::send(&mut sink, "test1").await.unwrap();
366                    SinkExt::send(&mut sink, "test2").await.unwrap();
367                });
368
369                // finish the future.
370                init_lazy_send.send(()).unwrap();
371
372                tokio::task::yield_now().await;
373
374                assert!(sink_task.is_finished());
375                assert_eq!(("test1", "test2"), stream_task.await.unwrap());
376                sink_task.await.unwrap();
377            })
378            .await;
379    }
380
381    #[tokio::test(flavor = "current_thread")]
382    async fn sink_drives_initialization() {
383        let local = tokio::task::LocalSet::new();
384        local
385            .run_until(async {
386                let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
387
388                let sink_source = LazySinkSource::new(async move {
389                    let () = init_lazy_recv.await.unwrap();
390                    let (send, recv) = tokio::sync::mpsc::channel(1);
391                    let sink = tokio_util::sync::PollSender::new(send);
392                    let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
393                    Ok::<_, PollSendError<_>>((stream, sink))
394                });
395
396                let (mut sink, mut stream) = sink_source.split();
397
398                // Ensures sink starts the lazy.
399                let (sink_init_send, sink_init_recv) = tokio::sync::oneshot::channel::<()>();
400                let stream_task = tokio::task::spawn_local(async move {
401                    sink_init_recv.await.unwrap();
402                    (stream.next().await.unwrap(), stream.next().await.unwrap())
403                });
404                let sink_task = tokio::task::spawn_local(async move {
405                    sink_init_send.send(()).unwrap();
406                    SinkExt::send(&mut sink, "test1").await.unwrap();
407                    SinkExt::send(&mut sink, "test2").await.unwrap();
408                });
409
410                // finish the future.
411                init_lazy_send.send(()).unwrap();
412
413                tokio::task::yield_now().await;
414
415                assert!(sink_task.is_finished());
416                assert_eq!(("test1", "test2"), stream_task.await.unwrap());
417                sink_task.await.unwrap();
418            })
419            .await;
420    }
421
422    #[tokio::test(flavor = "current_thread")]
423    async fn tcp_stream_drives_initialization() {
424        use tokio::net::{TcpListener, TcpStream};
425        use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
426
427        let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
428
429        let local = tokio::task::LocalSet::new();
430        local
431            .run_until(async {
432                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
433                let addr = listener.local_addr().unwrap();
434
435                let sink_source = LazySinkSource::new(async move {
436                    // initialization is at least partially started now.
437                    initialization_tx.send(()).unwrap();
438
439                    let (stream, _) = listener.accept().await.unwrap();
440                    let (rx, tx) = stream.into_split();
441                    let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
442                    let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
443                    Ok::<_, std::io::Error>((fr, fw))
444                });
445
446                let (mut sink, mut stream) = sink_source.split();
447
448                let stream_task = tokio::task::spawn_local(async move { stream.next().await });
449
450                initialization_rx.await.unwrap(); // ensure that the runtime starts driving initialization via the stream.next() call.
451
452                let sink_task = tokio::task::spawn_local(async move {
453                    SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
454                        .await
455                        .unwrap();
456                });
457
458                // try to be really sure that the above sink_task is waiting on the same future to be resolved.
459                for _ in 0..20 {
460                    tokio::task::yield_now().await
461                }
462
463                // trigger further initialization of the future.
464                let mut socket = TcpStream::connect(addr).await.unwrap();
465                let (client_rx, client_tx) = socket.split();
466                let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
467                let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
468
469                // try to be really sure that the effects of the above initialization completing are propagated.
470                for _ in 0..20 {
471                    tokio::task::yield_now().await
472                }
473
474                assert!(!stream_task.is_finished()); // We haven't sent anything yet, so the stream should definitely not be resolved now.
475
476                // Now actually send an item so that the stream will wake up and have an item ready to pull from it.
477                SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
478                    .await
479                    .unwrap();
480
481                assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
482                sink_task.await.unwrap();
483
484                assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
485            })
486            .await;
487    }
488
489    #[tokio::test(flavor = "current_thread")]
490    async fn tcp_sink_drives_initialization() {
491        use tokio::net::{TcpListener, TcpStream};
492        use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
493
494        let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
495
496        let local = tokio::task::LocalSet::new();
497        local
498            .run_until(async {
499                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
500                let addr = listener.local_addr().unwrap();
501
502                let sink_source = LazySinkSource::new(async move {
503                    // initialization is at least partially started now.
504                    initialization_tx.send(()).unwrap();
505
506                    let (stream, _) = listener.accept().await.unwrap();
507                    let (rx, tx) = stream.into_split();
508                    let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
509                    let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
510                    Ok::<_, std::io::Error>((fr, fw))
511                });
512
513                let (mut sink, mut stream) = sink_source.split();
514
515                let sink_task = tokio::task::spawn_local(async move {
516                    SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
517                        .await
518                        .unwrap();
519                });
520
521                initialization_rx.await.unwrap(); // ensure that the runtime starts driving initialization via the stream.next() call.
522
523                let stream_task = tokio::task::spawn_local(async move { stream.next().await });
524
525                // try to be really sure that the above sink_task is waiting on the same future to be resolved.
526                for _ in 0..20 {
527                    tokio::task::yield_now().await
528                }
529
530                assert!(!sink_task.is_finished(), "We haven't sent anything yet, so the sink should definitely not be resolved now.");
531
532                // trigger further initialization of the future.
533                let mut socket = TcpStream::connect(addr).await.unwrap();
534                let (client_rx, client_tx) = socket.split();
535                let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
536                let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
537
538                // try to be really sure that the effects of the above initialization completing are propagated.
539                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
540
541                assert!(sink_task.is_finished()); // Sink should have sent its item.
542
543                assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
544
545                // Now actually send an item so that the stream will wake up and have an item ready to pull from it.
546                SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
547                    .await
548                    .unwrap();
549
550                assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
551                sink_task.await.unwrap();
552            })
553            .await;
554    }
555}