h2/proto/streams/
store.rs

1use super::*;
2
3use indexmap::{self, IndexMap};
4
5use std::convert::Infallible;
6use std::fmt;
7use std::marker::PhantomData;
8use std::ops;
9
10/// Storage for streams
11#[derive(Debug)]
12pub(super) struct Store {
13    slab: slab::Slab<Stream>,
14    ids: IndexMap<StreamId, SlabIndex>,
15}
16
17/// "Pointer" to an entry in the store
18pub(super) struct Ptr<'a> {
19    key: Key,
20    store: &'a mut Store,
21}
22
23/// References an entry in the store.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub(crate) struct Key {
26    index: SlabIndex,
27    /// Keep the stream ID in the key as an ABA guard, since slab indices
28    /// could be re-used with a new stream.
29    stream_id: StreamId,
30}
31
32// We can never have more than `StreamId::MAX` streams in the store,
33// so we can save a smaller index (u32 vs usize).
34#[derive(Debug, Clone, Copy, PartialEq, Eq)]
35struct SlabIndex(u32);
36
37pub(super) struct Queue<N> {
38    indices: Option<store::Indices>,
39    _p: PhantomData<N>,
40}
41
42pub(super) trait Next {
43    fn next(stream: &Stream) -> Option<Key>;
44
45    fn set_next(stream: &mut Stream, key: Option<Key>);
46
47    fn take_next(stream: &mut Stream) -> Option<Key>;
48
49    fn is_queued(stream: &Stream) -> bool;
50
51    fn set_queued(stream: &mut Stream, val: bool);
52}
53
54/// A linked list
55#[derive(Debug, Clone, Copy)]
56struct Indices {
57    pub head: Key,
58    pub tail: Key,
59}
60
61pub(super) enum Entry<'a> {
62    Occupied(OccupiedEntry<'a>),
63    Vacant(VacantEntry<'a>),
64}
65
66pub(super) struct OccupiedEntry<'a> {
67    ids: indexmap::map::OccupiedEntry<'a, StreamId, SlabIndex>,
68}
69
70pub(super) struct VacantEntry<'a> {
71    ids: indexmap::map::VacantEntry<'a, StreamId, SlabIndex>,
72    slab: &'a mut slab::Slab<Stream>,
73}
74
75pub(super) trait Resolve {
76    fn resolve(&mut self, key: Key) -> Ptr<'_>;
77}
78
79// ===== impl Store =====
80
81impl Store {
82    pub fn new() -> Self {
83        Store {
84            slab: slab::Slab::new(),
85            ids: IndexMap::new(),
86        }
87    }
88
89    pub fn find_mut(&mut self, id: &StreamId) -> Option<Ptr<'_>> {
90        let index = match self.ids.get(id) {
91            Some(key) => *key,
92            None => return None,
93        };
94
95        Some(Ptr {
96            key: Key {
97                index,
98                stream_id: *id,
99            },
100            store: self,
101        })
102    }
103
104    pub fn insert(&mut self, id: StreamId, val: Stream) -> Ptr<'_> {
105        let index = SlabIndex(self.slab.insert(val) as u32);
106        assert!(self.ids.insert(id, index).is_none());
107
108        Ptr {
109            key: Key {
110                index,
111                stream_id: id,
112            },
113            store: self,
114        }
115    }
116
117    pub fn find_entry(&mut self, id: StreamId) -> Entry<'_> {
118        use self::indexmap::map::Entry::*;
119
120        match self.ids.entry(id) {
121            Occupied(e) => Entry::Occupied(OccupiedEntry { ids: e }),
122            Vacant(e) => Entry::Vacant(VacantEntry {
123                ids: e,
124                slab: &mut self.slab,
125            }),
126        }
127    }
128
129    #[allow(clippy::blocks_in_conditions)]
130    pub(crate) fn for_each<F>(&mut self, mut f: F)
131    where
132        F: FnMut(Ptr),
133    {
134        match self.try_for_each(|ptr| {
135            f(ptr);
136            Ok::<_, Infallible>(())
137        }) {
138            Ok(()) => (),
139            #[allow(unused)]
140            Err(infallible) => match infallible {},
141        }
142    }
143
144    pub fn try_for_each<F, E>(&mut self, mut f: F) -> Result<(), E>
145    where
146        F: FnMut(Ptr) -> Result<(), E>,
147    {
148        let mut len = self.ids.len();
149        let mut i = 0;
150
151        while i < len {
152            // Get the key by index, this makes the borrow checker happy
153            let (stream_id, index) = {
154                let entry = self.ids.get_index(i).unwrap();
155                (*entry.0, *entry.1)
156            };
157
158            f(Ptr {
159                key: Key { index, stream_id },
160                store: self,
161            })?;
162
163            // TODO: This logic probably could be better...
164            let new_len = self.ids.len();
165
166            if new_len < len {
167                debug_assert!(new_len == len - 1);
168                len -= 1;
169            } else {
170                i += 1;
171            }
172        }
173
174        Ok(())
175    }
176}
177
178impl Resolve for Store {
179    fn resolve(&mut self, key: Key) -> Ptr<'_> {
180        Ptr { key, store: self }
181    }
182}
183
184impl ops::Index<Key> for Store {
185    type Output = Stream;
186
187    fn index(&self, key: Key) -> &Self::Output {
188        self.slab
189            .get(key.index.0 as usize)
190            .filter(|s| s.id == key.stream_id)
191            .unwrap_or_else(|| {
192                panic!("dangling store key for stream_id={:?}", key.stream_id);
193            })
194    }
195}
196
197impl ops::IndexMut<Key> for Store {
198    fn index_mut(&mut self, key: Key) -> &mut Self::Output {
199        self.slab
200            .get_mut(key.index.0 as usize)
201            .filter(|s| s.id == key.stream_id)
202            .unwrap_or_else(|| {
203                panic!("dangling store key for stream_id={:?}", key.stream_id);
204            })
205    }
206}
207
208impl Store {
209    #[cfg(feature = "unstable")]
210    pub fn num_active_streams(&self) -> usize {
211        self.ids.len()
212    }
213
214    #[cfg(feature = "unstable")]
215    pub fn num_wired_streams(&self) -> usize {
216        self.slab.len()
217    }
218}
219
220// While running h2 unit/integration tests, enable this debug assertion.
221//
222// In practice, we don't need to ensure this. But the integration tests
223// help to make sure we've cleaned up in cases where we could (like, the
224// runtime isn't suddenly dropping the task for unknown reasons).
225#[cfg(feature = "unstable")]
226impl Drop for Store {
227    fn drop(&mut self) {
228        use std::thread;
229
230        if !thread::panicking() {
231            debug_assert!(self.slab.is_empty());
232        }
233    }
234}
235
236// ===== impl Queue =====
237
238impl<N> Queue<N>
239where
240    N: Next,
241{
242    pub fn new() -> Self {
243        Queue {
244            indices: None,
245            _p: PhantomData,
246        }
247    }
248
249    pub fn take(&mut self) -> Self {
250        Queue {
251            indices: self.indices.take(),
252            _p: PhantomData,
253        }
254    }
255
256    /// Queue the stream.
257    ///
258    /// If the stream is already contained by the list, return `false`.
259    pub fn push(&mut self, stream: &mut store::Ptr) -> bool {
260        tracing::trace!("Queue::push_back");
261
262        if N::is_queued(stream) {
263            tracing::trace!(" -> already queued");
264            return false;
265        }
266
267        N::set_queued(stream, true);
268
269        // The next pointer shouldn't be set
270        debug_assert!(N::next(stream).is_none());
271
272        // Queue the stream
273        match self.indices {
274            Some(ref mut idxs) => {
275                tracing::trace!(" -> existing entries");
276
277                // Update the current tail node to point to `stream`
278                let key = stream.key();
279                N::set_next(&mut stream.resolve(idxs.tail), Some(key));
280
281                // Update the tail pointer
282                idxs.tail = stream.key();
283            }
284            None => {
285                tracing::trace!(" -> first entry");
286                self.indices = Some(store::Indices {
287                    head: stream.key(),
288                    tail: stream.key(),
289                });
290            }
291        }
292
293        true
294    }
295
296    /// Queue the stream
297    ///
298    /// If the stream is already contained by the list, return `false`.
299    pub fn push_front(&mut self, stream: &mut store::Ptr) -> bool {
300        tracing::trace!("Queue::push_front");
301
302        if N::is_queued(stream) {
303            tracing::trace!(" -> already queued");
304            return false;
305        }
306
307        N::set_queued(stream, true);
308
309        // The next pointer shouldn't be set
310        debug_assert!(N::next(stream).is_none());
311
312        // Queue the stream
313        match self.indices {
314            Some(ref mut idxs) => {
315                tracing::trace!(" -> existing entries");
316
317                // Update the provided stream to point to the head node
318                let head_key = stream.resolve(idxs.head).key();
319                N::set_next(stream, Some(head_key));
320
321                // Update the head pointer
322                idxs.head = stream.key();
323            }
324            None => {
325                tracing::trace!(" -> first entry");
326                self.indices = Some(store::Indices {
327                    head: stream.key(),
328                    tail: stream.key(),
329                });
330            }
331        }
332
333        true
334    }
335
336    pub fn pop<'a, R>(&mut self, store: &'a mut R) -> Option<store::Ptr<'a>>
337    where
338        R: Resolve,
339    {
340        if let Some(mut idxs) = self.indices {
341            let mut stream = store.resolve(idxs.head);
342
343            if idxs.head == idxs.tail {
344                assert!(N::next(&stream).is_none());
345                self.indices = None;
346            } else {
347                idxs.head = N::take_next(&mut stream).unwrap();
348                self.indices = Some(idxs);
349            }
350
351            debug_assert!(N::is_queued(&stream));
352            N::set_queued(&mut stream, false);
353
354            return Some(stream);
355        }
356
357        None
358    }
359
360    pub fn is_empty(&self) -> bool {
361        self.indices.is_none()
362    }
363
364    pub fn pop_if<'a, R, F>(&mut self, store: &'a mut R, f: F) -> Option<store::Ptr<'a>>
365    where
366        R: Resolve,
367        F: Fn(&Stream) -> bool,
368    {
369        if let Some(idxs) = self.indices {
370            let should_pop = f(&store.resolve(idxs.head));
371            if should_pop {
372                return self.pop(store);
373            }
374        }
375
376        None
377    }
378}
379
380impl<N> fmt::Debug for Queue<N> {
381    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
382        f.debug_struct("Queue")
383            .field("indices", &self.indices)
384            // skip phantom data
385            .finish()
386    }
387}
388
389// ===== impl Ptr =====
390
391impl<'a> Ptr<'a> {
392    /// Returns the Key associated with the stream
393    pub fn key(&self) -> Key {
394        self.key
395    }
396
397    pub fn store_mut(&mut self) -> &mut Store {
398        self.store
399    }
400
401    /// Remove the stream from the store
402    pub fn remove(self) -> StreamId {
403        // The stream must have been unlinked before this point
404        debug_assert!(!self.store.ids.contains_key(&self.key.stream_id));
405
406        // Remove the stream state
407        let stream = self.store.slab.remove(self.key.index.0 as usize);
408        assert_eq!(stream.id, self.key.stream_id);
409        stream.id
410    }
411
412    /// Remove the StreamId -> stream state association.
413    ///
414    /// This will effectively remove the stream as far as the H2 protocol is
415    /// concerned.
416    pub fn unlink(&mut self) {
417        let id = self.key.stream_id;
418        self.store.ids.swap_remove(&id);
419    }
420}
421
422impl<'a> Resolve for Ptr<'a> {
423    fn resolve(&mut self, key: Key) -> Ptr<'_> {
424        Ptr {
425            key,
426            store: &mut *self.store,
427        }
428    }
429}
430
431impl<'a> ops::Deref for Ptr<'a> {
432    type Target = Stream;
433
434    fn deref(&self) -> &Stream {
435        &self.store[self.key]
436    }
437}
438
439impl<'a> ops::DerefMut for Ptr<'a> {
440    fn deref_mut(&mut self) -> &mut Stream {
441        &mut self.store[self.key]
442    }
443}
444
445impl<'a> fmt::Debug for Ptr<'a> {
446    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
447        (**self).fmt(fmt)
448    }
449}
450
451// ===== impl OccupiedEntry =====
452
453impl<'a> OccupiedEntry<'a> {
454    pub fn key(&self) -> Key {
455        let stream_id = *self.ids.key();
456        let index = *self.ids.get();
457        Key { index, stream_id }
458    }
459}
460
461// ===== impl VacantEntry =====
462
463impl<'a> VacantEntry<'a> {
464    pub fn insert(self, value: Stream) -> Key {
465        // Insert the value in the slab
466        let stream_id = value.id;
467        let index = SlabIndex(self.slab.insert(value) as u32);
468
469        // Insert the handle in the ID map
470        self.ids.insert(index);
471
472        Key { index, stream_id }
473    }
474}