1use malloc_size_of_derive::MallocSizeOf;
6use speexdsp_resampler::State as SpeexResamplerState;
7
8use crate::block::{Chunk, FRAMES_PER_BLOCK_USIZE};
9use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
10
11#[derive(Clone, Debug, PartialEq, MallocSizeOf)]
12pub enum OverSampleType {
13 None,
14 Double,
15 Quadruple,
16}
17
18#[derive(Clone, Debug, PartialEq, MallocSizeOf)]
19enum TailtimeBlocks {
20 Zero,
21 One,
22 Two,
23}
24
25const OVERSAMPLING_QUALITY: usize = 0;
26
27impl OverSampleType {
28 fn value(&self) -> usize {
29 match self {
30 OverSampleType::None => 1,
31 OverSampleType::Double => 2,
32 OverSampleType::Quadruple => 4,
33 }
34 }
35}
36
37type WaveShaperCurve = Option<Vec<f32>>;
38
39#[derive(Clone, Debug, MallocSizeOf)]
40pub struct WaveShaperNodeOptions {
41 pub curve: WaveShaperCurve,
42 pub oversample: OverSampleType,
43}
44
45impl Default for WaveShaperNodeOptions {
46 fn default() -> Self {
47 WaveShaperNodeOptions {
48 curve: None,
49 oversample: OverSampleType::None,
50 }
51 }
52}
53
54#[derive(Clone, Debug, MallocSizeOf)]
55pub enum WaveShaperNodeMessage {
56 SetCurve(WaveShaperCurve),
57}
58
59#[derive(AudioNodeCommon, MallocSizeOf)]
60pub(crate) struct WaveShaperNode {
61 curve_set: bool,
62 curve: WaveShaperCurve,
63 #[allow(dead_code)]
64 oversample: OverSampleType,
65 channel_info: ChannelInfo,
66 #[ignore_malloc_size_of = "External Type"]
67 upsampler: Option<SpeexResamplerState>,
68 #[ignore_malloc_size_of = "External Type"]
69 downsampler: Option<SpeexResamplerState>,
70 tailtime_blocks_left: TailtimeBlocks,
71}
72
73impl WaveShaperNode {
74 pub fn new(options: WaveShaperNodeOptions, channel_info: ChannelInfo) -> Self {
75 if let Some(vec) = &options.curve {
76 assert!(
77 vec.len() > 1,
78 "WaveShaperNode curve must have length of 2 or more"
79 )
80 }
81
82 Self {
83 curve_set: options.curve.is_some(),
84 curve: options.curve,
85 oversample: options.oversample,
86 channel_info,
87 upsampler: None,
88 downsampler: None,
89 tailtime_blocks_left: TailtimeBlocks::Zero,
90 }
91 }
92
93 fn handle_waveshaper_message(&mut self, message: WaveShaperNodeMessage, _sample_rate: f32) {
94 match message {
95 WaveShaperNodeMessage::SetCurve(new_curve) => {
96 if self.curve_set && new_curve.is_some() {
97 panic!("InvalidStateError: cant set curve if it was already set");
98 }
99 self.curve_set = new_curve.is_some();
100 self.curve = new_curve;
101 },
102 }
103 }
104}
105
106impl AudioNodeEngine for WaveShaperNode {
107 fn node_type(&self) -> AudioNodeType {
108 AudioNodeType::WaveShaperNode
109 }
110
111 fn process(&mut self, mut inputs: Chunk, info: &BlockInfo) -> Chunk {
112 debug_assert!(inputs.len() == 1);
113
114 if self.curve.is_none() {
115 return inputs;
116 }
117
118 let curve = &self.curve.as_ref().expect("Just checked for is_none()");
119
120 if inputs.blocks[0].is_silence() {
121 if WaveShaperNode::silence_produces_nonsilent_output(curve) {
122 inputs.blocks[0].explicit_silence();
123 self.tailtime_blocks_left = TailtimeBlocks::Two;
124 } else if self.tailtime_blocks_left != TailtimeBlocks::Zero {
125 inputs.blocks[0].explicit_silence();
126
127 self.tailtime_blocks_left = match self.tailtime_blocks_left {
128 TailtimeBlocks::Zero => TailtimeBlocks::Zero,
129 TailtimeBlocks::One => TailtimeBlocks::Zero,
130 TailtimeBlocks::Two => TailtimeBlocks::One,
131 }
132 } else {
133 return inputs;
134 }
135 } else {
136 self.tailtime_blocks_left = TailtimeBlocks::Two;
137 }
138
139 let block = &mut inputs.blocks[0];
140 let channels = block.chan_count();
141
142 if self.oversample != OverSampleType::None {
143 let rate: usize = info.sample_rate as usize;
144 let sampling_factor = self.oversample.value();
145
146 if self.upsampler.is_none() {
147 self.upsampler = Some(
148 SpeexResamplerState::new(
149 channels as usize,
150 rate,
151 rate * sampling_factor,
152 OVERSAMPLING_QUALITY,
153 )
154 .expect("Couldnt create upsampler"),
155 );
156 };
157
158 if self.downsampler.is_none() {
159 self.downsampler = Some(
160 SpeexResamplerState::new(
161 channels as usize,
162 rate * sampling_factor,
163 rate,
164 OVERSAMPLING_QUALITY,
165 )
166 .expect("Couldnt create downsampler"),
167 );
168 };
169
170 let upsampler = self.upsampler.as_mut().unwrap();
171 let downsampler = self.downsampler.as_mut().unwrap();
172
173 let mut oversampled_buffer: Vec<f32> =
174 vec![0.; FRAMES_PER_BLOCK_USIZE * sampling_factor];
175
176 for chan in 0..channels {
177 let out_len = WaveShaperNode::resample(
178 upsampler,
179 chan,
180 block.data_chan(chan),
181 &mut oversampled_buffer,
182 );
183
184 debug_assert!(
185 out_len == 128 * sampling_factor,
186 "Expected {} samples in output after upsampling, got: {}",
187 128 * sampling_factor,
188 out_len
189 );
190
191 WaveShaperNode::apply_curve(&mut oversampled_buffer, curve);
192
193 let out_len = WaveShaperNode::resample(
194 downsampler,
195 chan,
196 &oversampled_buffer,
197 block.data_chan_mut(chan),
198 );
199
200 debug_assert!(
201 out_len == 128,
202 "Expected 128 samples in output after downsampling, got {}",
203 out_len
204 );
205 }
206 } else {
207 WaveShaperNode::apply_curve(block.data_mut(), curve);
208 }
209
210 inputs
211 }
212
213 make_message_handler!(WaveShaperNode: handle_waveshaper_message);
214}
215
216impl WaveShaperNode {
217 fn silence_produces_nonsilent_output(curve: &[f32]) -> bool {
218 let len = curve.len();
219 let len_halved = ((len - 1) as f32) / 2.;
220 let curve_index: f32 = len_halved;
221 let index_lo = curve_index as usize;
222 let index_hi = index_lo + 1;
223 let interp_factor: f32 = curve_index - index_lo as f32;
224 let shaped_val = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
225 shaped_val == 0.0
226 }
227
228 fn apply_curve(buf: &mut [f32], curve: &[f32]) {
229 let len = curve.len();
230 let len_halved = ((len - 1) as f32) / 2.;
231 buf.iter_mut().for_each(|sample| {
232 let curve_index: f32 = len_halved * (*sample + 1.);
233
234 if curve_index <= 0. {
235 *sample = curve[0];
236 } else if curve_index >= (len - 1) as f32 {
237 *sample = curve[len - 1];
238 } else {
239 let index_lo = curve_index as usize;
240 let index_hi = index_lo + 1;
241 let interp_factor: f32 = curve_index - index_lo as f32;
242 *sample = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
243 }
244 });
245 }
246
247 fn resample(
248 st: &mut SpeexResamplerState,
249 chan: u8,
250 input: &[f32],
251 output: &mut [f32],
252 ) -> usize {
253 let (_in_len, out_len) = st
254 .process_float(chan as usize, input, output)
255 .expect("Resampling failed");
256 out_len
257 }
258}