Skip to main content

brotli/enc/
multithreading.rs

1#![cfg(feature = "std")]
2
3use alloc::{Allocator, SliceWrapper};
4use core::marker::PhantomData;
5use core::mem;
6use std;
7// in-place thread create
8use std::sync::RwLock;
9use std::thread::JoinHandle;
10
11use crate::enc::backward_references::UnionHasher;
12use crate::enc::threading::{
13    AnyBoxConstructor, BatchSpawnable, BatchSpawnableLite, BrotliEncoderThreadError, CompressMulti,
14    CompressionThreadResult, InternalOwned, InternalSendAlloc, Joinable, Owned, OwnedRetriever,
15    PoisonedThreadError, SendAlloc,
16};
17use crate::enc::{BrotliAlloc, BrotliEncoderParams};
18
19pub struct MultiThreadedJoinable<T: Send + 'static, U: Send + 'static>(
20    Option<JoinHandle<T>>,
21    PhantomData<U>,
22);
23
24impl<T: Send + 'static, U: Send + 'static + AnyBoxConstructor> Joinable<T, U>
25    for MultiThreadedJoinable<T, U>
26{
27    fn join(mut self) -> Result<T, U> {
28        match self.0.take().unwrap().join() {
29            Ok(t) => Ok(t),
30            Err(e) => Err(<U as AnyBoxConstructor>::new(e)),
31        }
32    }
33}
34
35impl<T: Send + 'static, U: Send + 'static> Drop for MultiThreadedJoinable<T, U> {
36    fn drop(&mut self) {
37        if let Some(join_handle) = self.0.take() {
38            let _ = join_handle.join();
39        }
40    }
41}
42
43pub struct MultiThreadedOwnedRetriever<U: Send + 'static>(RwLock<U>);
44
45impl<U: Send + 'static> OwnedRetriever<U> for MultiThreadedOwnedRetriever<U> {
46    fn view<T, F: FnOnce(&U) -> T>(&self, f: F) -> Result<T, PoisonedThreadError> {
47        match self.0.read() {
48            Ok(u) => Ok(f(&*u)),
49            Err(_) => Err(PoisonedThreadError::default()),
50        }
51    }
52    fn unwrap(self) -> Result<U, PoisonedThreadError> {
53        match self.0.into_inner() {
54            Ok(u) => Ok(u),
55            Err(_) => Err(PoisonedThreadError::default()),
56        }
57    }
58}
59
60#[derive(Default)]
61pub struct MultiThreadedSpawner {}
62
63fn spawn_work<
64    ReturnValue: Send + 'static,
65    ExtraInput: Send + 'static,
66    F: Fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue + Send + 'static,
67    Alloc: BrotliAlloc + Send + 'static,
68    U: Send + 'static + Sync,
69>(
70    extra_input: ExtraInput,
71    index: usize,
72    num_threads: usize,
73    locked_input: std::sync::Arc<RwLock<U>>,
74    alloc: Alloc,
75    f: F,
76) -> std::thread::JoinHandle<ReturnValue>
77where
78    <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
79{
80    std::thread::spawn(move || {
81        let t: ReturnValue = locked_input
82            .view(move |guard: &U| -> ReturnValue {
83                f(extra_input, index, num_threads, guard, alloc)
84            })
85            .unwrap();
86        t
87    })
88}
89
90impl<
91        ReturnValue: Send + 'static,
92        ExtraInput: Send + 'static,
93        Alloc: BrotliAlloc + Send + 'static,
94        U: Send + 'static + Sync,
95    > BatchSpawnable<ReturnValue, ExtraInput, Alloc, U> for MultiThreadedSpawner
96where
97    <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
98{
99    type JoinHandle = MultiThreadedJoinable<ReturnValue, BrotliEncoderThreadError>;
100    type FinalJoinHandle = std::sync::Arc<RwLock<U>>;
101    fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
102        std::sync::Arc::<RwLock<U>>::new(RwLock::new(
103            mem::replace(input, Owned(InternalOwned::Borrowed)).unwrap(),
104        ))
105    }
106    fn spawn<F: Fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue + Send + 'static + Copy>(
107        &mut self,
108        input: &mut Self::FinalJoinHandle,
109        work: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
110        index: usize,
111        num_threads: usize,
112        f: F,
113    ) {
114        let (alloc, extra_input) = work.replace_with_default();
115        let ret = spawn_work(extra_input, index, num_threads, input.clone(), alloc, f);
116        *work = SendAlloc(InternalSendAlloc::Join(MultiThreadedJoinable(
117            Some(ret),
118            PhantomData,
119        )));
120    }
121}
122impl<
123        ReturnValue: Send + 'static,
124        ExtraInput: Send + 'static,
125        Alloc: BrotliAlloc + Send + 'static,
126        U: Send + 'static + Sync,
127    > BatchSpawnableLite<ReturnValue, ExtraInput, Alloc, U> for MultiThreadedSpawner
128where
129    <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
130    <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
131    <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
132{
133    type JoinHandle =
134        <MultiThreadedSpawner as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::JoinHandle;
135    type FinalJoinHandle = <MultiThreadedSpawner as BatchSpawnable<
136        ReturnValue,
137        ExtraInput,
138        Alloc,
139        U,
140    >>::FinalJoinHandle;
141    fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
142        <Self as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::make_spawner(self, input)
143    }
144    fn spawn(
145        &mut self,
146        handle: &mut Self::FinalJoinHandle,
147        alloc_per_thread: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
148        index: usize,
149        num_threads: usize,
150        f: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
151    ) {
152        <Self as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::spawn(
153            self,
154            handle,
155            alloc_per_thread,
156            index,
157            num_threads,
158            f,
159        )
160    }
161}
162
163pub fn compress_multi<
164    Alloc: BrotliAlloc + Send + 'static,
165    SliceW: SliceWrapper<u8> + Send + 'static + Sync,
166>(
167    params: &BrotliEncoderParams,
168    owned_input: &mut Owned<SliceW>,
169    output: &mut [u8],
170    alloc_per_thread: &mut [SendAlloc<
171        CompressionThreadResult<Alloc>,
172        UnionHasher<Alloc>,
173        Alloc,
174        <MultiThreadedSpawner as BatchSpawnable<
175            CompressionThreadResult<Alloc>,
176            UnionHasher<Alloc>,
177            Alloc,
178            SliceW,
179        >>::JoinHandle,
180    >],
181) -> Result<usize, BrotliEncoderThreadError>
182where
183    <Alloc as Allocator<u8>>::AllocatedMemory: Send,
184    <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
185    <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
186{
187    CompressMulti(
188        params,
189        owned_input,
190        output,
191        alloc_per_thread,
192        &mut MultiThreadedSpawner::default(),
193    )
194}