buf_list/cursor/
mod.rs

1// Copyright (c) The buf-list Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(feature = "futures03")]
5mod futures_imp;
6#[cfg(test)]
7mod tests;
8#[cfg(feature = "tokio1")]
9mod tokio_imp;
10
11use crate::{BufList, errors::ReadExactError};
12use bytes::{Buf, Bytes};
13use std::{
14    cmp::Ordering,
15    io::{self, IoSlice, IoSliceMut, SeekFrom},
16};
17
18/// A `Cursor` wraps an in-memory `BufList` and provides it with a [`Seek`] implementation.
19///
20/// `Cursor`s allow `BufList`s to implement [`Read`] and [`BufRead`], allowing a `BufList` to be
21/// used anywhere you might use a reader or writer that does actual I/O.
22///
23/// The cursor may either own or borrow a `BufList`: both `Cursor<BufList>` and `Cursor<&BufList>`
24/// are supported.
25///
26/// # Optional features
27///
28/// * `tokio1`: With this feature enabled, [`Cursor`] implements the `tokio` crate's
29///   [`AsyncSeek`](tokio::io::AsyncSeek), [`AsyncRead`](tokio::io::AsyncRead) and
30///   [`AsyncBufRead`](tokio::io::AsyncBufRead).
31/// * `futures03`: With this feature enabled, [`Cursor`] implements the `futures` crate's
32///   [`AsyncSeek`](futures_io_03::AsyncSeek), [`AsyncRead`](futures_io_03::AsyncRead) and
33///   [`AsyncBufRead`](futures_io_03::AsyncBufRead).
34///
35/// [`Read`]: std::io::Read
36/// [`BufRead`]: std::io::BufRead
37/// [`Seek`]: std::io::Seek
38pub struct Cursor<T> {
39    inner: T,
40
41    /// Data associated with the cursor.
42    data: CursorData,
43}
44
45impl<T: AsRef<BufList>> Cursor<T> {
46    /// Creates a new cursor wrapping the provided `BufList`.
47    ///
48    /// # Examples
49    ///
50    /// ```
51    /// use buf_list::{BufList, Cursor};
52    ///
53    /// let cursor = Cursor::new(BufList::new());
54    /// ```
55    pub fn new(inner: T) -> Cursor<T> {
56        let data = CursorData::new();
57        Cursor { inner, data }
58    }
59
60    /// Consumes this cursor, returning the underlying value.
61    ///
62    /// # Examples
63    ///
64    /// ```
65    /// use buf_list::{BufList, Cursor};
66    ///
67    /// let cursor = Cursor::new(BufList::new());
68    ///
69    /// let vec = cursor.into_inner();
70    /// ```
71    pub fn into_inner(self) -> T {
72        self.inner
73    }
74
75    /// Gets a reference to the underlying value in this cursor.
76    ///
77    /// # Examples
78    ///
79    /// ```
80    /// use buf_list::{BufList, Cursor};
81    ///
82    /// let cursor = Cursor::new(BufList::new());
83    ///
84    /// let reference = cursor.get_ref();
85    /// ```
86    pub const fn get_ref(&self) -> &T {
87        &self.inner
88    }
89
90    /// Returns the current position of this cursor.
91    ///
92    /// # Examples
93    ///
94    /// ```
95    /// use buf_list::{BufList, Cursor};
96    /// use std::io::prelude::*;
97    /// use std::io::SeekFrom;
98    ///
99    /// let mut cursor = Cursor::new(BufList::from(&[1, 2, 3, 4, 5][..]));
100    ///
101    /// assert_eq!(cursor.position(), 0);
102    ///
103    /// cursor.seek(SeekFrom::Current(2)).unwrap();
104    /// assert_eq!(cursor.position(), 2);
105    ///
106    /// cursor.seek(SeekFrom::Current(-1)).unwrap();
107    /// assert_eq!(cursor.position(), 1);
108    /// ```
109    pub const fn position(&self) -> u64 {
110        self.data.pos
111    }
112
113    /// Sets the position of this cursor.
114    ///
115    /// # Examples
116    ///
117    /// ```
118    /// use buf_list::{BufList, Cursor};
119    ///
120    /// let mut cursor = Cursor::new(BufList::from(&[1, 2, 3, 4, 5][..]));
121    ///
122    /// assert_eq!(cursor.position(), 0);
123    ///
124    /// cursor.set_position(2);
125    /// assert_eq!(cursor.position(), 2);
126    ///
127    /// cursor.set_position(4);
128    /// assert_eq!(cursor.position(), 4);
129    /// ```
130    pub fn set_position(&mut self, pos: u64) {
131        self.data.set_pos(self.inner.as_ref(), pos);
132    }
133
134    // ---
135    // Helper methods
136    // ---
137    #[cfg(test)]
138    fn assert_invariants(&self) -> anyhow::Result<()> {
139        self.data.assert_invariants(self.inner.as_ref())
140    }
141}
142
143impl<T> Clone for Cursor<T>
144where
145    T: Clone,
146{
147    #[inline]
148    fn clone(&self) -> Self {
149        Cursor {
150            inner: self.inner.clone(),
151            data: self.data.clone(),
152        }
153    }
154
155    #[inline]
156    fn clone_from(&mut self, other: &Self) {
157        self.inner.clone_from(&other.inner);
158        self.data = other.data.clone();
159    }
160}
161
162impl<T: AsRef<BufList>> io::Seek for Cursor<T> {
163    fn seek(&mut self, style: SeekFrom) -> io::Result<u64> {
164        self.data.seek_impl(self.inner.as_ref(), style)
165    }
166
167    fn stream_position(&mut self) -> io::Result<u64> {
168        Ok(self.data.pos)
169    }
170}
171
172impl<T: AsRef<BufList>> io::Read for Cursor<T> {
173    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
174        Ok(self.data.read_impl(self.inner.as_ref(), buf))
175    }
176
177    fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
178        Ok(self.data.read_vectored_impl(self.inner.as_ref(), bufs))
179    }
180
181    // TODO: is_read_vectored once that's available on stable Rust.
182
183    fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> {
184        self.data.read_exact_impl(self.inner.as_ref(), buf)
185    }
186}
187
188impl<T: AsRef<BufList>> io::BufRead for Cursor<T> {
189    fn fill_buf(&mut self) -> io::Result<&[u8]> {
190        Ok(self.data.fill_buf_impl(self.inner.as_ref()))
191    }
192
193    fn consume(&mut self, amt: usize) {
194        self.data.consume_impl(self.inner.as_ref(), amt);
195    }
196}
197
198impl<T: AsRef<BufList>> Buf for Cursor<T> {
199    fn remaining(&self) -> usize {
200        let total = self.data.num_bytes(self.inner.as_ref());
201        total.saturating_sub(self.data.pos) as usize
202    }
203
204    fn has_remaining(&self) -> bool {
205        self.data.num_bytes(self.inner.as_ref()) > self.data.pos
206    }
207
208    fn chunk(&self) -> &[u8] {
209        self.data.fill_buf_impl(self.inner.as_ref())
210    }
211
212    fn advance(&mut self, amt: usize) {
213        self.data.consume_impl(self.inner.as_ref(), amt);
214    }
215
216    fn chunks_vectored<'iovs>(&'iovs self, iovs: &mut [IoSlice<'iovs>]) -> usize {
217        let list = self.inner.as_ref();
218
219        if iovs.is_empty() || !self.has_remaining() {
220            return 0;
221        }
222
223        let current_chunk = self.data.chunk;
224        let chunk_start_pos = list.get_start_pos()[current_chunk];
225        let offset_in_chunk = (self.data.pos - chunk_start_pos) as usize;
226
227        iovs[0] = IoSlice::new(
228            &list.get_chunk(current_chunk).expect("chunk is in range")[offset_in_chunk..],
229        );
230        // Fill up the remaining iovs with as many slices as possible.
231        let to_fill = (iovs.len()).min(list.num_chunks() - current_chunk);
232        for (i, iov) in iovs.iter_mut().enumerate().take(to_fill).skip(1) {
233            *iov = IoSlice::new(
234                list.get_chunk(current_chunk + i)
235                    .expect("chunk is in range"),
236            );
237        }
238
239        to_fill
240    }
241}
242
243#[derive(Clone, Debug)]
244struct CursorData {
245    /// The chunk number the cursor is pointing to. Kept in sync with pos.
246    ///
247    /// This is within the range [0, self.start_pos.len()). It is self.start_pos.len() - 1 iff pos
248    /// is greater than list.num_bytes().
249    chunk: usize,
250
251    /// The overall position in the stream. Kept in sync with chunk.
252    pos: u64,
253}
254
255impl CursorData {
256    fn new() -> Self {
257        Self { chunk: 0, pos: 0 }
258    }
259
260    #[cfg(test)]
261    fn assert_invariants(&self, list: &BufList) -> anyhow::Result<()> {
262        use anyhow::ensure;
263
264        ensure!(
265            self.pos >= list.get_start_pos()[self.chunk],
266            "invariant failed: current position {} >= start position {} (chunk = {})",
267            self.pos,
268            list.get_start_pos()[self.chunk],
269            self.chunk
270        );
271
272        let next_pos = list.get_start_pos().get(self.chunk + 1).copied().into();
273        ensure!(
274            Offset::Value(self.pos) < next_pos,
275            "invariant failed: next start position {:?} > current position {} (chunk = {})",
276            next_pos,
277            self.pos,
278            self.chunk
279        );
280
281        Ok(())
282    }
283
284    fn seek_impl(&mut self, list: &BufList, style: SeekFrom) -> io::Result<u64> {
285        let (base_pos, offset) = match style {
286            SeekFrom::Start(n) => {
287                self.set_pos(list, n);
288                return Ok(n);
289            }
290            SeekFrom::End(n) => (self.num_bytes(list), n),
291            SeekFrom::Current(n) => (self.pos, n),
292        };
293        // Can't use checked_add_signed since it was only stabilized in Rust 1.66. This is adapted
294        // from
295        // https://github.com/rust-lang/rust/blame/ed937594d3/library/std/src/io/cursor.rs#L295-L299.
296        let new_pos = if offset >= 0 {
297            base_pos.checked_add(offset as u64)
298        } else {
299            base_pos.checked_sub(offset.wrapping_neg() as u64)
300        };
301        match new_pos {
302            Some(n) => {
303                self.set_pos(list, n);
304                Ok(self.pos)
305            }
306            None => Err(io::Error::new(
307                io::ErrorKind::InvalidInput,
308                "invalid seek to a negative or overflowing position",
309            )),
310        }
311    }
312
313    fn read_impl(&mut self, list: &BufList, buf: &mut [u8]) -> usize {
314        // Read as much as possible until we fill up the buffer.
315        let mut buf_pos = 0;
316        while buf_pos < buf.len() {
317            let (chunk, chunk_pos) = match self.get_chunk_and_pos(list) {
318                Some(value) => value,
319                None => break,
320            };
321            // The number of bytes to copy is the smaller of the two:
322            // - the length of the chunk - the position in it.
323            // - the number of bytes remaining, which is buf.len() - buf_pos.
324            let n_to_copy = (chunk.len() - chunk_pos).min(buf.len() - buf_pos);
325            let chunk_bytes = chunk.as_ref();
326
327            let bytes_to_copy = &chunk_bytes[chunk_pos..(chunk_pos + n_to_copy)];
328            let dest = &mut buf[buf_pos..(buf_pos + n_to_copy)];
329            dest.copy_from_slice(bytes_to_copy);
330            buf_pos += n_to_copy;
331
332            // Increment the position.
333            self.pos += n_to_copy as u64;
334            // If we've finished reading through the chunk, move to the next chunk.
335            if n_to_copy == chunk.len() - chunk_pos {
336                self.chunk += 1;
337            }
338        }
339
340        buf_pos
341    }
342
343    fn read_vectored_impl(&mut self, list: &BufList, bufs: &mut [IoSliceMut<'_>]) -> usize {
344        let mut nread = 0;
345        for buf in bufs {
346            // Copy data from the buffer until we run out of bytes to copy.
347            let n = self.read_impl(list, buf);
348            nread += n;
349            if n < buf.len() {
350                break;
351            }
352        }
353        nread
354    }
355
356    fn read_exact_impl(&mut self, list: &BufList, buf: &mut [u8]) -> io::Result<()> {
357        // This is the same as read_impl as long as there's enough space.
358        let total = self.num_bytes(list);
359        let remaining = total.saturating_sub(self.pos);
360        let buf_len = buf.len();
361        if remaining < buf_len as u64 {
362            // Rust 1.80 and above will cause the position to be set to the end
363            // of the buffer, due to (apparently)
364            // https://github.com/rust-lang/rust/pull/125404. Follow that
365            // behavior.
366            self.set_pos(list, total);
367            return Err(io::Error::new(
368                io::ErrorKind::UnexpectedEof,
369                ReadExactError { remaining, buf_len },
370            ));
371        }
372
373        self.read_impl(list, buf);
374        Ok(())
375    }
376
377    fn fill_buf_impl<'a>(&'a self, list: &'a BufList) -> &'a [u8] {
378        const EMPTY_SLICE: &[u8] = &[];
379        match self.get_chunk_and_pos(list) {
380            Some((chunk, chunk_pos)) => &chunk.as_ref()[chunk_pos..],
381            // An empty return value means the end of the buffer has been reached.
382            None => EMPTY_SLICE,
383        }
384    }
385
386    fn consume_impl(&mut self, list: &BufList, amt: usize) {
387        self.set_pos(list, self.pos + amt as u64);
388    }
389
390    fn set_pos(&mut self, list: &BufList, new_pos: u64) {
391        match new_pos.cmp(&self.pos) {
392            Ordering::Greater => {
393                let start_pos = list.get_start_pos();
394                let next_start = start_pos.get(self.chunk + 1).copied().into();
395                if Offset::Value(new_pos) < next_start {
396                    // Within the same chunk.
397                } else {
398                    // The above check ensures that we're not currently pointing to the last index
399                    // (since it would have returned Eof, which is greater than Offset(n) for any
400                    // n).
401                    //
402                    // Do a binary search for this element.
403                    match start_pos[self.chunk + 1..].binary_search(&new_pos) {
404                        // We're starting the search from self.chunk + 1, which means that the value
405                        // returned from binary_search is 1 less than the actual delta.
406                        Ok(delta_minus_one) => {
407                            // Exactly at the start point of a chunk.
408                            self.chunk += 1 + delta_minus_one;
409                        }
410                        // The value returned in the error case (not at the start point of a chunk)
411                        // is (delta - 1) + 1, so just delta.
412                        Err(delta) => {
413                            debug_assert!(
414                                delta > 0,
415                                "delta must be at least 1 since we already \
416                                checked the same chunk (self.chunk = {})",
417                                self.chunk,
418                            );
419                            self.chunk += delta;
420                        }
421                    }
422                }
423            }
424            Ordering::Equal => {}
425            Ordering::Less => {
426                let start_pos = list.get_start_pos();
427                if start_pos.get(self.chunk).copied() <= Some(new_pos) {
428                    // Within the same chunk.
429                } else {
430                    match start_pos[..self.chunk].binary_search(&new_pos) {
431                        Ok(chunk) => {
432                            // Exactly at the start point of a chunk.
433                            self.chunk = chunk;
434                        }
435                        Err(chunk_plus_1) => {
436                            debug_assert!(
437                                chunk_plus_1 > 0,
438                                "chunk_plus_1 must be at least 1 since self.start_pos[0] is 0 \
439                                 (self.chunk = {})",
440                                self.chunk,
441                            );
442                            self.chunk = chunk_plus_1 - 1;
443                        }
444                    }
445                }
446            }
447        }
448        self.pos = new_pos;
449    }
450
451    #[inline]
452    fn get_chunk_and_pos<'b>(&self, list: &'b BufList) -> Option<(&'b Bytes, usize)> {
453        match list.get_chunk(self.chunk) {
454            Some(chunk) => {
455                // This guarantees that pos is not past the end of the list.
456                debug_assert!(
457                    self.pos < self.num_bytes(list),
458                    "self.pos ({}) is less than num_bytes ({})",
459                    self.pos,
460                    self.num_bytes(list)
461                );
462                Some((
463                    chunk,
464                    (self.pos - list.get_start_pos()[self.chunk]) as usize,
465                ))
466            }
467            None => {
468                // pos is past the end of the list.
469                None
470            }
471        }
472    }
473
474    fn num_bytes(&self, list: &BufList) -> u64 {
475        *list
476            .get_start_pos()
477            .last()
478            .expect("start_pos always has at least one element")
479    }
480}
481
482/// This is the same as Option<T> except Offset and Eof are reversed in ordering, i.e. Eof >
483/// Offset(T) for any T.
484#[derive(Copy, Clone, Debug, Eq, PartialEq, PartialOrd, Ord)]
485enum Offset<T> {
486    Value(T),
487    Eof,
488}
489
490impl<T> From<Option<T>> for Offset<T> {
491    fn from(value: Option<T>) -> Self {
492        match value {
493            Some(v) => Self::Value(v),
494            None => Self::Eof,
495        }
496    }
497}