futures_util/
abortable.rs

1use crate::task::AtomicWaker;
2use alloc::sync::Arc;
3use core::fmt;
4use core::pin::Pin;
5use core::sync::atomic::{AtomicBool, Ordering};
6use futures_core::future::Future;
7use futures_core::task::{Context, Poll};
8use futures_core::Stream;
9use pin_project_lite::pin_project;
10
11pin_project! {
12    /// A future/stream which can be remotely short-circuited using an `AbortHandle`.
13    #[derive(Debug, Clone)]
14    #[must_use = "futures/streams do nothing unless you poll them"]
15    pub struct Abortable<T> {
16        #[pin]
17        task: T,
18        inner: Arc<AbortInner>,
19    }
20}
21
22impl<T> Abortable<T> {
23    /// Creates a new `Abortable` future/stream using an existing `AbortRegistration`.
24    /// `AbortRegistration`s can be acquired through `AbortHandle::new`.
25    ///
26    /// When `abort` is called on the handle tied to `reg` or if `abort` has
27    /// already been called, the future/stream will complete immediately without making
28    /// any further progress.
29    ///
30    /// # Examples:
31    ///
32    /// Usage with futures:
33    ///
34    /// ```
35    /// # futures::executor::block_on(async {
36    /// use futures::future::{Abortable, AbortHandle, Aborted};
37    ///
38    /// let (abort_handle, abort_registration) = AbortHandle::new_pair();
39    /// let future = Abortable::new(async { 2 }, abort_registration);
40    /// abort_handle.abort();
41    /// assert_eq!(future.await, Err(Aborted));
42    /// # });
43    /// ```
44    ///
45    /// Usage with streams:
46    ///
47    /// ```
48    /// # futures::executor::block_on(async {
49    /// # use futures::future::{Abortable, AbortHandle};
50    /// # use futures::stream::{self, StreamExt};
51    ///
52    /// let (abort_handle, abort_registration) = AbortHandle::new_pair();
53    /// let mut stream = Abortable::new(stream::iter(vec![1, 2, 3]), abort_registration);
54    /// abort_handle.abort();
55    /// assert_eq!(stream.next().await, None);
56    /// # });
57    /// ```
58    pub fn new(task: T, reg: AbortRegistration) -> Self {
59        Self { task, inner: reg.inner }
60    }
61
62    /// Checks whether the task has been aborted. Note that all this
63    /// method indicates is whether [`AbortHandle::abort`] was *called*.
64    /// This means that it will return `true` even if:
65    /// * `abort` was called after the task had completed.
66    /// * `abort` was called while the task was being polled - the task may still be running and
67    /// will not be stopped until `poll` returns.
68    pub fn is_aborted(&self) -> bool {
69        self.inner.aborted.load(Ordering::Relaxed)
70    }
71}
72
73/// A registration handle for an `Abortable` task.
74/// Values of this type can be acquired from `AbortHandle::new` and are used
75/// in calls to `Abortable::new`.
76#[derive(Debug)]
77pub struct AbortRegistration {
78    pub(crate) inner: Arc<AbortInner>,
79}
80
81/// A handle to an `Abortable` task.
82#[derive(Debug, Clone)]
83pub struct AbortHandle {
84    inner: Arc<AbortInner>,
85}
86
87impl AbortHandle {
88    /// Creates an (`AbortHandle`, `AbortRegistration`) pair which can be used
89    /// to abort a running future or stream.
90    ///
91    /// This function is usually paired with a call to [`Abortable::new`].
92    pub fn new_pair() -> (Self, AbortRegistration) {
93        let inner =
94            Arc::new(AbortInner { waker: AtomicWaker::new(), aborted: AtomicBool::new(false) });
95
96        (Self { inner: inner.clone() }, AbortRegistration { inner })
97    }
98}
99
100// Inner type storing the waker to awaken and a bool indicating that it
101// should be aborted.
102#[derive(Debug)]
103pub(crate) struct AbortInner {
104    pub(crate) waker: AtomicWaker,
105    pub(crate) aborted: AtomicBool,
106}
107
108/// Indicator that the `Abortable` task was aborted.
109#[derive(Copy, Clone, Debug, Eq, PartialEq)]
110pub struct Aborted;
111
112impl fmt::Display for Aborted {
113    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114        write!(f, "`Abortable` future has been aborted")
115    }
116}
117
118#[cfg(feature = "std")]
119impl std::error::Error for Aborted {}
120
121impl<T> Abortable<T> {
122    fn try_poll<I>(
123        mut self: Pin<&mut Self>,
124        cx: &mut Context<'_>,
125        poll: impl Fn(Pin<&mut T>, &mut Context<'_>) -> Poll<I>,
126    ) -> Poll<Result<I, Aborted>> {
127        // Check if the task has been aborted
128        if self.is_aborted() {
129            return Poll::Ready(Err(Aborted));
130        }
131
132        // attempt to complete the task
133        if let Poll::Ready(x) = poll(self.as_mut().project().task, cx) {
134            return Poll::Ready(Ok(x));
135        }
136
137        // Register to receive a wakeup if the task is aborted in the future
138        self.inner.waker.register(cx.waker());
139
140        // Check to see if the task was aborted between the first check and
141        // registration.
142        // Checking with `is_aborted` which uses `Relaxed` is sufficient because
143        // `register` introduces an `AcqRel` barrier.
144        if self.is_aborted() {
145            return Poll::Ready(Err(Aborted));
146        }
147
148        Poll::Pending
149    }
150}
151
152impl<Fut> Future for Abortable<Fut>
153where
154    Fut: Future,
155{
156    type Output = Result<Fut::Output, Aborted>;
157
158    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
159        self.try_poll(cx, |fut, cx| fut.poll(cx))
160    }
161}
162
163impl<St> Stream for Abortable<St>
164where
165    St: Stream,
166{
167    type Item = St::Item;
168
169    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
170        self.try_poll(cx, |stream, cx| stream.poll_next(cx)).map(Result::ok).map(Option::flatten)
171    }
172}
173
174impl AbortHandle {
175    /// Abort the `Abortable` stream/future associated with this handle.
176    ///
177    /// Notifies the Abortable task associated with this handle that it
178    /// should abort. Note that if the task is currently being polled on
179    /// another thread, it will not immediately stop running. Instead, it will
180    /// continue to run until its poll method returns.
181    pub fn abort(&self) {
182        self.inner.aborted.store(true, Ordering::Relaxed);
183        self.inner.waker.wake();
184    }
185}