1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
use std::{
    collections::{hash_map::Entry, HashMap},
    hash::Hash,
    sync::{Arc, Weak},
};

use once_cell::sync::OnceCell;

use crate::lock::{rank, Mutex};
use crate::FastHashMap;

type SlotInner<V> = Weak<V>;
type ResourcePoolSlot<V> = Arc<OnceCell<SlotInner<V>>>;

pub struct ResourcePool<K, V> {
    inner: Mutex<FastHashMap<K, ResourcePoolSlot<V>>>,
}

impl<K: Clone + Eq + Hash, V> ResourcePool<K, V> {
    pub fn new() -> Self {
        Self {
            inner: Mutex::new(rank::RESOURCE_POOL_INNER, HashMap::default()),
        }
    }

    /// Get a resource from the pool with the given entry map, or create a new
    /// one if it doesn't exist using the given constructor.
    ///
    /// Behaves such that only one resource will be created for each unique
    /// entry map at any one time.
    pub fn get_or_init<F, E>(&self, key: K, constructor: F) -> Result<Arc<V>, E>
    where
        F: FnOnce(K) -> Result<Arc<V>, E>,
    {
        // We can't prove at compile time that these will only ever be consumed once,
        // so we need to do the check at runtime.
        let mut key = Some(key);
        let mut constructor = Some(constructor);

        'race: loop {
            let mut map_guard = self.inner.lock();

            let entry = match map_guard.entry(key.clone().unwrap()) {
                // An entry exists for this resource.
                //
                // We know that either:
                // - The resource is still alive, and Weak::upgrade will succeed.
                // - The resource is in the process of being dropped, and Weak::upgrade will fail.
                //
                // The entry will never be empty while the BGL is still alive.
                Entry::Occupied(entry) => Arc::clone(entry.get()),
                // No entry exists for this resource.
                //
                // We know that the resource is not alive, so we can create a new entry.
                Entry::Vacant(entry) => Arc::clone(entry.insert(Arc::new(OnceCell::new()))),
            };

            drop(map_guard);

            // Some other thread may beat us to initializing the entry, but OnceCell guarantees that only one thread
            // will actually initialize the entry.
            //
            // We pass the strong reference outside of the closure to keep it alive while we're the only one keeping a reference to it.
            let mut strong = None;
            let weak = entry.get_or_try_init(|| {
                let strong_inner = constructor.take().unwrap()(key.take().unwrap())?;
                let weak = Arc::downgrade(&strong_inner);
                strong = Some(strong_inner);
                Ok(weak)
            })?;

            // If strong is Some, that means we just initialized the entry, so we can just return it.
            if let Some(strong) = strong {
                return Ok(strong);
            }

            // The entry was already initialized by someone else, so we need to try to upgrade it.
            if let Some(strong) = weak.upgrade() {
                // We succeed, the resource is still alive, just return that.
                return Ok(strong);
            }

            // The resource is in the process of being dropped, because upgrade failed.
            // The entry still exists in the map, but it points to nothing.
            //
            // We're in a race with the drop implementation of the resource,
            //  so lets just go around again. When we go around again:
            // - If the entry exists, we might need to go around a few more times.
            // - If the entry doesn't exist, we'll create a new one.
            continue 'race;
        }
    }

    /// Remove the given entry map from the pool.
    ///
    /// Must *only* be called in the Drop impl of [`BindGroupLayout`].
    ///
    /// [`BindGroupLayout`]: crate::binding_model::BindGroupLayout
    pub fn remove(&self, key: &K) {
        let mut map_guard = self.inner.lock();

        // Weak::upgrade will be failing long before this code is called. All threads trying to access the resource will be spinning,
        // waiting for the entry to be removed. It is safe to remove the entry from the map.
        map_guard.remove(key);
    }
}

#[cfg(test)]
mod tests {
    use std::sync::{
        atomic::{AtomicU32, Ordering},
        Barrier,
    };

    use super::*;

    #[test]
    fn deduplication() {
        let pool = ResourcePool::<u32, u32>::new();

        let mut counter = 0_u32;

        let arc1 = pool
            .get_or_init::<_, ()>(0, |key| {
                counter += 1;
                Ok(Arc::new(key))
            })
            .unwrap();

        assert_eq!(*arc1, 0);
        assert_eq!(counter, 1);

        let arc2 = pool
            .get_or_init::<_, ()>(0, |key| {
                counter += 1;
                Ok(Arc::new(key))
            })
            .unwrap();

        assert!(Arc::ptr_eq(&arc1, &arc2));
        assert_eq!(*arc2, 0);
        assert_eq!(counter, 1);

        drop(arc1);
        drop(arc2);
        pool.remove(&0);

        let arc3 = pool
            .get_or_init::<_, ()>(0, |key| {
                counter += 1;
                Ok(Arc::new(key))
            })
            .unwrap();

        assert_eq!(*arc3, 0);
        assert_eq!(counter, 2);
    }

    // Test name has "2_threads" in the name so nextest reserves two threads for it.
    #[test]
    fn concurrent_creation_2_threads() {
        struct Resources {
            pool: ResourcePool<u32, u32>,
            counter: AtomicU32,
            barrier: Barrier,
        }

        let resources = Arc::new(Resources {
            pool: ResourcePool::<u32, u32>::new(),
            counter: AtomicU32::new(0),
            barrier: Barrier::new(2),
        });

        // Like all races, this is not inherently guaranteed to work, but in practice it should work fine.
        //
        // To validate the expected order of events, we've put print statements in the code, indicating when each thread is at a certain point.
        // The output will look something like this if the test is working as expected:
        //
        // ```
        // 0: prewait
        // 1: prewait
        // 1: postwait
        // 0: postwait
        // 1: init
        // 1: postget
        // 0: postget
        // ```
        fn thread_inner(idx: u8, resources: &Resources) -> Arc<u32> {
            eprintln!("{idx}: prewait");

            // Once this returns, both threads should hit get_or_init at about the same time,
            // allowing us to actually test concurrent creation.
            //
            // Like all races, this is not inherently guaranteed to work, but in practice it should work fine.
            resources.barrier.wait();

            eprintln!("{idx}: postwait");

            let ret = resources
                .pool
                .get_or_init::<_, ()>(0, |key| {
                    eprintln!("{idx}: init");

                    // Simulate long running constructor, ensuring that both threads will be in get_or_init.
                    std::thread::sleep(std::time::Duration::from_millis(250));

                    resources.counter.fetch_add(1, Ordering::SeqCst);

                    Ok(Arc::new(key))
                })
                .unwrap();

            eprintln!("{idx}: postget");

            ret
        }

        let thread1 = std::thread::spawn({
            let resource_clone = Arc::clone(&resources);
            move || thread_inner(1, &resource_clone)
        });

        let arc0 = thread_inner(0, &resources);

        assert_eq!(resources.counter.load(Ordering::Acquire), 1);

        let arc1 = thread1.join().unwrap();

        assert!(Arc::ptr_eq(&arc0, &arc1));
    }

    // Test name has "2_threads" in the name so nextest reserves two threads for it.
    #[test]
    fn create_while_drop_2_threads() {
        struct Resources {
            pool: ResourcePool<u32, u32>,
            barrier: Barrier,
        }

        let resources = Arc::new(Resources {
            pool: ResourcePool::<u32, u32>::new(),
            barrier: Barrier::new(2),
        });

        // Like all races, this is not inherently guaranteed to work, but in practice it should work fine.
        //
        // To validate the expected order of events, we've put print statements in the code, indicating when each thread is at a certain point.
        // The output will look something like this if the test is working as expected:
        //
        // ```
        // 0: prewait
        // 1: prewait
        // 1: postwait
        // 0: postwait
        // 1: postsleep
        // 1: removal
        // 0: postget
        // ```
        //
        // The last two _may_ be flipped.

        let existing_entry = resources
            .pool
            .get_or_init::<_, ()>(0, |key| Ok(Arc::new(key)))
            .unwrap();

        // Drop the entry, but do _not_ remove it from the pool.
        // This simulates the situation where the resource arc has been dropped, but the Drop implementation
        // has not yet run, which calls remove.
        drop(existing_entry);

        fn thread0_inner(resources: &Resources) {
            eprintln!("0: prewait");
            resources.barrier.wait();

            eprintln!("0: postwait");
            // We try to create a new entry, but the entry already exists.
            //
            // As Arc::upgrade is failing, we will just keep spinning until remove is called.
            resources
                .pool
                .get_or_init::<_, ()>(0, |key| Ok(Arc::new(key)))
                .unwrap();
            eprintln!("0: postget");
        }

        fn thread1_inner(resources: &Resources) {
            eprintln!("1: prewait");
            resources.barrier.wait();

            eprintln!("1: postwait");
            // We wait a little bit, making sure that thread0_inner has started spinning.
            std::thread::sleep(std::time::Duration::from_millis(250));
            eprintln!("1: postsleep");

            // We remove the entry from the pool, allowing thread0_inner to re-create.
            resources.pool.remove(&0);
            eprintln!("1: removal");
        }

        let thread1 = std::thread::spawn({
            let resource_clone = Arc::clone(&resources);
            move || thread1_inner(&resource_clone)
        });

        thread0_inner(&resources);

        thread1.join().unwrap();
    }
}