1use core::marker::PhantomData;
4use core::pin::Pin;
5use core::task::{Context, Poll, Waker};
6use std::sync::Arc;
7use std::task::Wake;
8
9use futures_util::task::AtomicWaker;
10use futures_util::{Sink, Stream, ready};
11
12#[derive(Default)]
13struct DualWaker {
14 sink: AtomicWaker,
15 stream: AtomicWaker,
16}
17
18impl DualWaker {
19 fn new() -> (Arc<Self>, Waker) {
20 let dual_waker = Arc::new(Self::default());
21 let waker = Waker::from(dual_waker.clone());
22 (dual_waker, waker)
23 }
24}
25
26impl Wake for DualWaker {
27 fn wake(self: Arc<Self>) {
28 self.wake_by_ref();
29 }
30
31 fn wake_by_ref(self: &Arc<Self>) {
32 self.sink.wake();
33 self.stream.wake();
34 }
35}
36
37enum SharedState<Fut, St, Si, Item> {
38 Uninit {
39 future: Pin<Box<Fut>>,
40 },
41 Thunkulating {
42 future: Pin<Box<Fut>>,
43 item: Option<Item>,
44 dual_waker_state: Arc<DualWaker>,
45 dual_waker_waker: Waker,
46 },
47 Done {
48 stream: Pin<Box<St>>,
49 sink: Pin<Box<Si>>,
50 buf: Option<Item>,
51 },
52 Taken,
53}
54
55pub struct LazySinkSource<Fut, St, Si, Item, Error> {
59 state: SharedState<Fut, St, Si, Item>,
60 _phantom: PhantomData<Error>,
61}
62
63impl<Fut, St, Si, Item, Error> LazySinkSource<Fut, St, Si, Item, Error> {
64 pub fn new(future: Fut) -> Self {
66 Self {
67 state: SharedState::Uninit {
68 future: Box::pin(future),
69 },
70 _phantom: PhantomData,
71 }
72 }
73}
74
75impl<Fut, St, Si, Item, Error> Sink<Item> for LazySinkSource<Fut, St, Si, Item, Error>
76where
77 Self: Unpin,
78 Fut: Future<Output = Result<(St, Si), Error>>,
79 St: Stream,
80 Si: Sink<Item>,
81 Error: From<Si::Error>,
82{
83 type Error = Error;
84
85 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86 let state = &mut self.get_mut().state;
87
88 if let SharedState::Uninit { .. } = &*state {
89 return Poll::Ready(Ok(()));
90 }
91
92 if let SharedState::Thunkulating {
93 future,
94 item,
95 dual_waker_state,
96 dual_waker_waker,
97 } = &mut *state
98 {
99 dual_waker_state.sink.register(cx.waker());
100
101 let mut dual_context = Context::from_waker(dual_waker_waker);
102
103 match future.as_mut().poll(&mut dual_context) {
104 Poll::Ready(Ok((stream, sink))) => {
105 let buf = item.take();
106 *state = SharedState::Done {
107 stream: Box::pin(stream),
108 sink: Box::pin(sink),
109 buf,
110 };
111 }
112 Poll::Ready(Err(e)) => {
113 return Poll::Ready(Err(e));
114 }
115 Poll::Pending => {
116 return Poll::Pending;
117 }
118 }
119 }
120
121 if let SharedState::Done { sink, buf, .. } = &mut *state {
122 if buf.is_some() {
123 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
124 sink.as_mut().start_send(buf.take().unwrap())?;
125 }
126 let result = sink.as_mut().poll_ready(cx).map_err(From::from);
127 return result;
128 }
129
130 panic!("LazySinkSource in invalid state.");
131 }
132
133 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
134 let state = &mut self.get_mut().state;
135
136 if let SharedState::Uninit { .. } = &*state {
137 let old_state = std::mem::replace(&mut *state, SharedState::Taken);
138 if let SharedState::Uninit { future } = old_state {
139 let (dual_waker_state, dual_waker_waker) = DualWaker::new();
140 *state = SharedState::Thunkulating {
141 future,
142 item: Some(item),
143 dual_waker_state,
144 dual_waker_waker,
145 };
146
147 return Ok(());
148 }
149 }
150
151 if let SharedState::Thunkulating { .. } = &mut *state {
152 panic!("LazySinkSource not ready.");
153 }
154
155 if let SharedState::Done { sink, buf, .. } = &mut *state {
156 debug_assert!(buf.is_none());
157 let result = sink.as_mut().start_send(item).map_err(From::from);
158 return result;
159 }
160
161 panic!("LazySinkSource not ready.");
162 }
163
164 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165 let state = &mut self.get_mut().state;
166
167 if let SharedState::Uninit { .. } = &*state {
168 return Poll::Ready(Ok(()));
169 }
170
171 if let SharedState::Thunkulating {
172 future,
173 item,
174 dual_waker_state,
175 dual_waker_waker,
176 } = &mut *state
177 {
178 dual_waker_state.sink.register(cx.waker());
179
180 let mut new_context = Context::from_waker(dual_waker_waker);
181
182 match future.as_mut().poll(&mut new_context) {
183 Poll::Ready(Ok((stream, sink))) => {
184 let buf = item.take();
185 *state = SharedState::Done {
186 stream: Box::pin(stream),
187 sink: Box::pin(sink),
188 buf,
189 };
190 }
191 Poll::Ready(Err(e)) => {
192 return Poll::Ready(Err(e));
193 }
194 Poll::Pending => {
195 return Poll::Pending;
196 }
197 }
198 }
199
200 if let SharedState::Done { sink, buf, .. } = &mut *state {
201 if buf.is_some() {
202 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
203 sink.as_mut().start_send(buf.take().unwrap())?;
204 }
205 let result = sink.as_mut().poll_flush(cx).map_err(From::from);
206 return result;
207 }
208
209 panic!("LazySinkHalf in invalid state.");
210 }
211
212 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
213 let state = &mut self.get_mut().state;
214
215 if let SharedState::Uninit { .. } = &*state {
216 return Poll::Ready(Ok(()));
217 }
218
219 if let SharedState::Thunkulating {
220 future,
221 item,
222 dual_waker_state,
223 dual_waker_waker,
224 } = &mut *state
225 {
226 dual_waker_state.sink.register(cx.waker());
227
228 let mut new_context = Context::from_waker(dual_waker_waker);
229
230 match future.as_mut().poll(&mut new_context) {
231 Poll::Ready(Ok((stream, sink))) => {
232 let buf = item.take();
233 *state = SharedState::Done {
234 stream: Box::pin(stream),
235 sink: Box::pin(sink),
236 buf,
237 };
238 }
239 Poll::Ready(Err(e)) => {
240 return Poll::Ready(Err(e));
241 }
242 Poll::Pending => {
243 return Poll::Pending;
244 }
245 }
246 }
247
248 if let SharedState::Done { sink, buf, .. } = &mut *state {
249 if buf.is_some() {
250 ready!(sink.as_mut().poll_ready(cx).map_err(From::from)?);
251 sink.as_mut().start_send(buf.take().unwrap())?;
252 }
253 let result = sink.as_mut().poll_close(cx).map_err(From::from);
254 return result;
255 }
256
257 panic!("LazySinkHalf in invalid state.");
258 }
259}
260
261impl<Fut, St, Si, Item, Error> Stream for LazySinkSource<Fut, St, Si, Item, Error>
262where
263 Self: Unpin,
264 Fut: Future<Output = Result<(St, Si), Error>>,
265 St: Stream,
266 Si: Sink<Item>,
267{
268 type Item = St::Item;
269
270 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
271 let state = &mut self.get_mut().state;
272
273 if let SharedState::Uninit { .. } = &*state {
274 let old_state = std::mem::replace(&mut *state, SharedState::Taken);
275 if let SharedState::Uninit { future } = old_state {
276 let (dual_waker_state, dual_waker_waker) = DualWaker::new();
277 *state = SharedState::Thunkulating {
278 future,
279 item: None,
280 dual_waker_state,
281 dual_waker_waker,
282 };
283 } else {
284 unreachable!();
285 }
286 }
287
288 if let SharedState::Thunkulating {
289 future,
290 item,
291 dual_waker_state,
292 dual_waker_waker,
293 } = &mut *state
294 {
295 dual_waker_state.stream.register(cx.waker());
296
297 let mut new_context = Context::from_waker(dual_waker_waker);
298
299 match future.as_mut().poll(&mut new_context) {
300 Poll::Ready(Ok((stream, sink))) => {
301 let buf = item.take();
302 *state = SharedState::Done {
303 stream: Box::pin(stream),
304 sink: Box::pin(sink),
305 buf,
306 };
307 }
308
309 Poll::Ready(Err(_)) => {
310 return Poll::Ready(None);
311 }
312
313 Poll::Pending => {
314 return Poll::Pending;
315 }
316 }
317 }
318
319 if let SharedState::Done { stream, .. } = &mut *state {
320 let result = stream.as_mut().poll_next(cx);
321 match &result {
322 Poll::Ready(Some(_)) => {}
323 Poll::Ready(None) => {}
324 Poll::Pending => {}
325 }
326 return result;
327 }
328
329 panic!("LazySinkSource in invalid state.");
330 }
331}
332
333#[cfg(test)]
334mod test {
335 use futures_util::{SinkExt, StreamExt};
336 use tokio_util::sync::PollSendError;
337
338 use super::*;
339
340 #[tokio::test(flavor = "current_thread")]
341 async fn stream_drives_initialization() {
342 let local = tokio::task::LocalSet::new();
343 local
344 .run_until(async {
345 let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
346
347 let sink_source = LazySinkSource::new(async move {
348 let () = init_lazy_recv.await.unwrap();
349 let (send, recv) = tokio::sync::mpsc::channel(1);
350 let sink = tokio_util::sync::PollSender::new(send);
351 let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
352 Ok::<_, PollSendError<_>>((stream, sink))
353 });
354
355 let (mut sink, mut stream) = sink_source.split();
356
357 let (stream_init_send, stream_init_recv) = tokio::sync::oneshot::channel::<()>();
359 let stream_task = tokio::task::spawn_local(async move {
360 stream_init_send.send(()).unwrap();
361 (stream.next().await.unwrap(), stream.next().await.unwrap())
362 });
363 let sink_task = tokio::task::spawn_local(async move {
364 stream_init_recv.await.unwrap();
365 SinkExt::send(&mut sink, "test1").await.unwrap();
366 SinkExt::send(&mut sink, "test2").await.unwrap();
367 });
368
369 init_lazy_send.send(()).unwrap();
371
372 tokio::task::yield_now().await;
373
374 assert!(sink_task.is_finished());
375 assert_eq!(("test1", "test2"), stream_task.await.unwrap());
376 sink_task.await.unwrap();
377 })
378 .await;
379 }
380
381 #[tokio::test(flavor = "current_thread")]
382 async fn sink_drives_initialization() {
383 let local = tokio::task::LocalSet::new();
384 local
385 .run_until(async {
386 let (init_lazy_send, init_lazy_recv) = tokio::sync::oneshot::channel::<()>();
387
388 let sink_source = LazySinkSource::new(async move {
389 let () = init_lazy_recv.await.unwrap();
390 let (send, recv) = tokio::sync::mpsc::channel(1);
391 let sink = tokio_util::sync::PollSender::new(send);
392 let stream = tokio_stream::wrappers::ReceiverStream::new(recv);
393 Ok::<_, PollSendError<_>>((stream, sink))
394 });
395
396 let (mut sink, mut stream) = sink_source.split();
397
398 let (sink_init_send, sink_init_recv) = tokio::sync::oneshot::channel::<()>();
400 let stream_task = tokio::task::spawn_local(async move {
401 sink_init_recv.await.unwrap();
402 (stream.next().await.unwrap(), stream.next().await.unwrap())
403 });
404 let sink_task = tokio::task::spawn_local(async move {
405 sink_init_send.send(()).unwrap();
406 SinkExt::send(&mut sink, "test1").await.unwrap();
407 SinkExt::send(&mut sink, "test2").await.unwrap();
408 });
409
410 init_lazy_send.send(()).unwrap();
412
413 tokio::task::yield_now().await;
414
415 assert!(sink_task.is_finished());
416 assert_eq!(("test1", "test2"), stream_task.await.unwrap());
417 sink_task.await.unwrap();
418 })
419 .await;
420 }
421
422 #[tokio::test(flavor = "current_thread")]
423 async fn tcp_stream_drives_initialization() {
424 use tokio::net::{TcpListener, TcpStream};
425 use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
426
427 let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
428
429 let local = tokio::task::LocalSet::new();
430 local
431 .run_until(async {
432 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
433 let addr = listener.local_addr().unwrap();
434
435 let sink_source = LazySinkSource::new(async move {
436 initialization_tx.send(()).unwrap();
438
439 let (stream, _) = listener.accept().await.unwrap();
440 let (rx, tx) = stream.into_split();
441 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
442 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
443 Ok::<_, std::io::Error>((fr, fw))
444 });
445
446 let (mut sink, mut stream) = sink_source.split();
447
448 let stream_task = tokio::task::spawn_local(async move { stream.next().await });
449
450 initialization_rx.await.unwrap(); let sink_task = tokio::task::spawn_local(async move {
453 SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
454 .await
455 .unwrap();
456 });
457
458 for _ in 0..20 {
460 tokio::task::yield_now().await
461 }
462
463 let mut socket = TcpStream::connect(addr).await.unwrap();
465 let (client_rx, client_tx) = socket.split();
466 let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
467 let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
468
469 for _ in 0..20 {
471 tokio::task::yield_now().await
472 }
473
474 assert!(!stream_task.is_finished()); SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
478 .await
479 .unwrap();
480
481 assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
482 sink_task.await.unwrap();
483
484 assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
485 })
486 .await;
487 }
488
489 #[tokio::test(flavor = "current_thread")]
490 async fn tcp_sink_drives_initialization() {
491 use tokio::net::{TcpListener, TcpStream};
492 use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec};
493
494 let (initialization_tx, initialization_rx) = tokio::sync::oneshot::channel::<()>();
495
496 let local = tokio::task::LocalSet::new();
497 local
498 .run_until(async {
499 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
500 let addr = listener.local_addr().unwrap();
501
502 let sink_source = LazySinkSource::new(async move {
503 initialization_tx.send(()).unwrap();
505
506 let (stream, _) = listener.accept().await.unwrap();
507 let (rx, tx) = stream.into_split();
508 let fr = FramedRead::new(rx, LengthDelimitedCodec::new());
509 let fw = FramedWrite::new(tx, LengthDelimitedCodec::new());
510 Ok::<_, std::io::Error>((fr, fw))
511 });
512
513 let (mut sink, mut stream) = sink_source.split();
514
515 let sink_task = tokio::task::spawn_local(async move {
516 SinkExt::send(&mut sink, bytes::Bytes::from("test2"))
517 .await
518 .unwrap();
519 });
520
521 initialization_rx.await.unwrap(); let stream_task = tokio::task::spawn_local(async move { stream.next().await });
524
525 for _ in 0..20 {
527 tokio::task::yield_now().await
528 }
529
530 assert!(!sink_task.is_finished(), "We haven't sent anything yet, so the sink should definitely not be resolved now.");
531
532 let mut socket = TcpStream::connect(addr).await.unwrap();
534 let (client_rx, client_tx) = socket.split();
535 let mut client_tx = FramedWrite::new(client_tx, LengthDelimitedCodec::new());
536 let mut client_rx = FramedRead::new(client_rx, LengthDelimitedCodec::new());
537
538 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
540
541 assert!(sink_task.is_finished()); assert_eq!(&client_rx.next().await.unwrap().unwrap()[..], b"test2");
544
545 SinkExt::send(&mut client_tx, bytes::Bytes::from("test"))
547 .await
548 .unwrap();
549
550 assert_eq!(&stream_task.await.unwrap().unwrap().unwrap()[..], b"test");
551 sink_task.await.unwrap();
552 })
553 .await;
554 }
555}