1#![cfg(feature = "std")]
2
3use alloc::{Allocator, SliceWrapper};
4use core::marker::PhantomData;
5use core::mem;
6use std;
7use std::sync::RwLock;
9use std::thread::JoinHandle;
10
11use crate::enc::backward_references::UnionHasher;
12use crate::enc::threading::{
13 AnyBoxConstructor, BatchSpawnable, BatchSpawnableLite, BrotliEncoderThreadError, CompressMulti,
14 CompressionThreadResult, InternalOwned, InternalSendAlloc, Joinable, Owned, OwnedRetriever,
15 PoisonedThreadError, SendAlloc,
16};
17use crate::enc::{BrotliAlloc, BrotliEncoderParams};
18
19pub struct MultiThreadedJoinable<T: Send + 'static, U: Send + 'static>(
20 Option<JoinHandle<T>>,
21 PhantomData<U>,
22);
23
24impl<T: Send + 'static, U: Send + 'static + AnyBoxConstructor> Joinable<T, U>
25 for MultiThreadedJoinable<T, U>
26{
27 fn join(mut self) -> Result<T, U> {
28 match self.0.take().unwrap().join() {
29 Ok(t) => Ok(t),
30 Err(e) => Err(<U as AnyBoxConstructor>::new(e)),
31 }
32 }
33}
34
35impl<T: Send + 'static, U: Send + 'static> Drop for MultiThreadedJoinable<T, U> {
36 fn drop(&mut self) {
37 if let Some(join_handle) = self.0.take() {
38 let _ = join_handle.join();
39 }
40 }
41}
42
43pub struct MultiThreadedOwnedRetriever<U: Send + 'static>(RwLock<U>);
44
45impl<U: Send + 'static> OwnedRetriever<U> for MultiThreadedOwnedRetriever<U> {
46 fn view<T, F: FnOnce(&U) -> T>(&self, f: F) -> Result<T, PoisonedThreadError> {
47 match self.0.read() {
48 Ok(u) => Ok(f(&*u)),
49 Err(_) => Err(PoisonedThreadError::default()),
50 }
51 }
52 fn unwrap(self) -> Result<U, PoisonedThreadError> {
53 match self.0.into_inner() {
54 Ok(u) => Ok(u),
55 Err(_) => Err(PoisonedThreadError::default()),
56 }
57 }
58}
59
60#[derive(Default)]
61pub struct MultiThreadedSpawner {}
62
63fn spawn_work<
64 ReturnValue: Send + 'static,
65 ExtraInput: Send + 'static,
66 F: Fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue + Send + 'static,
67 Alloc: BrotliAlloc + Send + 'static,
68 U: Send + 'static + Sync,
69>(
70 extra_input: ExtraInput,
71 index: usize,
72 num_threads: usize,
73 locked_input: std::sync::Arc<RwLock<U>>,
74 alloc: Alloc,
75 f: F,
76) -> std::thread::JoinHandle<ReturnValue>
77where
78 <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
79{
80 std::thread::spawn(move || {
81 let t: ReturnValue = locked_input
82 .view(move |guard: &U| -> ReturnValue {
83 f(extra_input, index, num_threads, guard, alloc)
84 })
85 .unwrap();
86 t
87 })
88}
89
90impl<
91 ReturnValue: Send + 'static,
92 ExtraInput: Send + 'static,
93 Alloc: BrotliAlloc + Send + 'static,
94 U: Send + 'static + Sync,
95 > BatchSpawnable<ReturnValue, ExtraInput, Alloc, U> for MultiThreadedSpawner
96where
97 <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
98{
99 type JoinHandle = MultiThreadedJoinable<ReturnValue, BrotliEncoderThreadError>;
100 type FinalJoinHandle = std::sync::Arc<RwLock<U>>;
101 fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
102 std::sync::Arc::<RwLock<U>>::new(RwLock::new(
103 mem::replace(input, Owned(InternalOwned::Borrowed)).unwrap(),
104 ))
105 }
106 fn spawn<F: Fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue + Send + 'static + Copy>(
107 &mut self,
108 input: &mut Self::FinalJoinHandle,
109 work: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
110 index: usize,
111 num_threads: usize,
112 f: F,
113 ) {
114 let (alloc, extra_input) = work.replace_with_default();
115 let ret = spawn_work(extra_input, index, num_threads, input.clone(), alloc, f);
116 *work = SendAlloc(InternalSendAlloc::Join(MultiThreadedJoinable(
117 Some(ret),
118 PhantomData,
119 )));
120 }
121}
122impl<
123 ReturnValue: Send + 'static,
124 ExtraInput: Send + 'static,
125 Alloc: BrotliAlloc + Send + 'static,
126 U: Send + 'static + Sync,
127 > BatchSpawnableLite<ReturnValue, ExtraInput, Alloc, U> for MultiThreadedSpawner
128where
129 <Alloc as Allocator<u8>>::AllocatedMemory: Send + 'static,
130 <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
131 <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
132{
133 type JoinHandle =
134 <MultiThreadedSpawner as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::JoinHandle;
135 type FinalJoinHandle = <MultiThreadedSpawner as BatchSpawnable<
136 ReturnValue,
137 ExtraInput,
138 Alloc,
139 U,
140 >>::FinalJoinHandle;
141 fn make_spawner(&mut self, input: &mut Owned<U>) -> Self::FinalJoinHandle {
142 <Self as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::make_spawner(self, input)
143 }
144 fn spawn(
145 &mut self,
146 handle: &mut Self::FinalJoinHandle,
147 alloc_per_thread: &mut SendAlloc<ReturnValue, ExtraInput, Alloc, Self::JoinHandle>,
148 index: usize,
149 num_threads: usize,
150 f: fn(ExtraInput, usize, usize, &U, Alloc) -> ReturnValue,
151 ) {
152 <Self as BatchSpawnable<ReturnValue, ExtraInput, Alloc, U>>::spawn(
153 self,
154 handle,
155 alloc_per_thread,
156 index,
157 num_threads,
158 f,
159 )
160 }
161}
162
163pub fn compress_multi<
164 Alloc: BrotliAlloc + Send + 'static,
165 SliceW: SliceWrapper<u8> + Send + 'static + Sync,
166>(
167 params: &BrotliEncoderParams,
168 owned_input: &mut Owned<SliceW>,
169 output: &mut [u8],
170 alloc_per_thread: &mut [SendAlloc<
171 CompressionThreadResult<Alloc>,
172 UnionHasher<Alloc>,
173 Alloc,
174 <MultiThreadedSpawner as BatchSpawnable<
175 CompressionThreadResult<Alloc>,
176 UnionHasher<Alloc>,
177 Alloc,
178 SliceW,
179 >>::JoinHandle,
180 >],
181) -> Result<usize, BrotliEncoderThreadError>
182where
183 <Alloc as Allocator<u8>>::AllocatedMemory: Send,
184 <Alloc as Allocator<u16>>::AllocatedMemory: Send + Sync,
185 <Alloc as Allocator<u32>>::AllocatedMemory: Send + Sync,
186{
187 CompressMulti(
188 params,
189 owned_input,
190 output,
191 alloc_per_thread,
192 &mut MultiThreadedSpawner::default(),
193 )
194}