servo_media_audio/
wave_shaper_node.rs1use speexdsp_resampler::State as SpeexResamplerState;
6
7use crate::block::{Chunk, FRAMES_PER_BLOCK_USIZE};
8use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
9
10#[derive(Clone, Debug, PartialEq)]
11pub enum OverSampleType {
12 None,
13 Double,
14 Quadruple,
15}
16
17#[derive(Clone, Debug, PartialEq)]
18enum TailtimeBlocks {
19 Zero,
20 One,
21 Two,
22}
23
24const OVERSAMPLING_QUALITY: usize = 0;
25
26impl OverSampleType {
27 fn value(&self) -> usize {
28 match self {
29 OverSampleType::None => 1,
30 OverSampleType::Double => 2,
31 OverSampleType::Quadruple => 4,
32 }
33 }
34}
35
36type WaveShaperCurve = Option<Vec<f32>>;
37
38#[derive(Clone, Debug)]
39pub struct WaveShaperNodeOptions {
40 pub curve: WaveShaperCurve,
41 pub oversample: OverSampleType,
42}
43
44impl Default for WaveShaperNodeOptions {
45 fn default() -> Self {
46 WaveShaperNodeOptions {
47 curve: None,
48 oversample: OverSampleType::None,
49 }
50 }
51}
52
53#[derive(Clone, Debug)]
54pub enum WaveShaperNodeMessage {
55 SetCurve(WaveShaperCurve),
56}
57
58#[derive(AudioNodeCommon)]
59pub(crate) struct WaveShaperNode {
60 curve_set: bool,
61 curve: WaveShaperCurve,
62 #[allow(dead_code)]
63 oversample: OverSampleType,
64 channel_info: ChannelInfo,
65 upsampler: Option<SpeexResamplerState>,
66 downsampler: Option<SpeexResamplerState>,
67 tailtime_blocks_left: TailtimeBlocks,
68}
69
70impl WaveShaperNode {
71 pub fn new(options: WaveShaperNodeOptions, channel_info: ChannelInfo) -> Self {
72 if let Some(vec) = &options.curve {
73 assert!(
74 vec.len() > 1,
75 "WaveShaperNode curve must have length of 2 or more"
76 )
77 }
78
79 Self {
80 curve_set: options.curve.is_some(),
81 curve: options.curve,
82 oversample: options.oversample,
83 channel_info,
84 upsampler: None,
85 downsampler: None,
86 tailtime_blocks_left: TailtimeBlocks::Zero,
87 }
88 }
89
90 fn handle_waveshaper_message(&mut self, message: WaveShaperNodeMessage, _sample_rate: f32) {
91 match message {
92 WaveShaperNodeMessage::SetCurve(new_curve) => {
93 if self.curve_set && new_curve.is_some() {
94 panic!("InvalidStateError: cant set curve if it was already set");
95 }
96 self.curve_set = new_curve.is_some();
97 self.curve = new_curve;
98 },
99 }
100 }
101}
102
103impl AudioNodeEngine for WaveShaperNode {
104 fn node_type(&self) -> AudioNodeType {
105 AudioNodeType::WaveShaperNode
106 }
107
108 fn process(&mut self, mut inputs: Chunk, info: &BlockInfo) -> Chunk {
109 debug_assert!(inputs.len() == 1);
110
111 if self.curve.is_none() {
112 return inputs;
113 }
114
115 let curve = &self.curve.as_ref().expect("Just checked for is_none()");
116
117 if inputs.blocks[0].is_silence() {
118 if WaveShaperNode::silence_produces_nonsilent_output(curve) {
119 inputs.blocks[0].explicit_silence();
120 self.tailtime_blocks_left = TailtimeBlocks::Two;
121 } else if self.tailtime_blocks_left != TailtimeBlocks::Zero {
122 inputs.blocks[0].explicit_silence();
123
124 self.tailtime_blocks_left = match self.tailtime_blocks_left {
125 TailtimeBlocks::Zero => TailtimeBlocks::Zero,
126 TailtimeBlocks::One => TailtimeBlocks::Zero,
127 TailtimeBlocks::Two => TailtimeBlocks::One,
128 }
129 } else {
130 return inputs;
131 }
132 } else {
133 self.tailtime_blocks_left = TailtimeBlocks::Two;
134 }
135
136 let block = &mut inputs.blocks[0];
137 let channels = block.chan_count();
138
139 if self.oversample != OverSampleType::None {
140 let rate: usize = info.sample_rate as usize;
141 let sampling_factor = self.oversample.value();
142
143 if self.upsampler.is_none() {
144 self.upsampler = Some(
145 SpeexResamplerState::new(
146 channels as usize,
147 rate,
148 rate * sampling_factor,
149 OVERSAMPLING_QUALITY,
150 )
151 .expect("Couldnt create upsampler"),
152 );
153 };
154
155 if self.downsampler.is_none() {
156 self.downsampler = Some(
157 SpeexResamplerState::new(
158 channels as usize,
159 rate * sampling_factor,
160 rate,
161 OVERSAMPLING_QUALITY,
162 )
163 .expect("Couldnt create downsampler"),
164 );
165 };
166
167 let upsampler = self.upsampler.as_mut().unwrap();
168 let downsampler = self.downsampler.as_mut().unwrap();
169
170 let mut oversampled_buffer: Vec<f32> =
171 vec![0.; FRAMES_PER_BLOCK_USIZE * sampling_factor];
172
173 for chan in 0..channels {
174 let out_len = WaveShaperNode::resample(
175 upsampler,
176 chan,
177 block.data_chan(chan),
178 &mut oversampled_buffer,
179 );
180
181 debug_assert!(
182 out_len == 128 * sampling_factor,
183 "Expected {} samples in output after upsampling, got: {}",
184 128 * sampling_factor,
185 out_len
186 );
187
188 WaveShaperNode::apply_curve(&mut oversampled_buffer, curve);
189
190 let out_len = WaveShaperNode::resample(
191 downsampler,
192 chan,
193 &oversampled_buffer,
194 block.data_chan_mut(chan),
195 );
196
197 debug_assert!(
198 out_len == 128,
199 "Expected 128 samples in output after downsampling, got {}",
200 out_len
201 );
202 }
203 } else {
204 WaveShaperNode::apply_curve(block.data_mut(), curve);
205 }
206
207 inputs
208 }
209
210 make_message_handler!(WaveShaperNode: handle_waveshaper_message);
211}
212
213impl WaveShaperNode {
214 fn silence_produces_nonsilent_output(curve: &[f32]) -> bool {
215 let len = curve.len();
216 let len_halved = ((len - 1) as f32) / 2.;
217 let curve_index: f32 = len_halved;
218 let index_lo = curve_index as usize;
219 let index_hi = index_lo + 1;
220 let interp_factor: f32 = curve_index - index_lo as f32;
221 let shaped_val = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
222 shaped_val == 0.0
223 }
224
225 fn apply_curve(buf: &mut [f32], curve: &[f32]) {
226 let len = curve.len();
227 let len_halved = ((len - 1) as f32) / 2.;
228 buf.iter_mut().for_each(|sample| {
229 let curve_index: f32 = len_halved * (*sample + 1.);
230
231 if curve_index <= 0. {
232 *sample = curve[0];
233 } else if curve_index >= (len - 1) as f32 {
234 *sample = curve[len - 1];
235 } else {
236 let index_lo = curve_index as usize;
237 let index_hi = index_lo + 1;
238 let interp_factor: f32 = curve_index - index_lo as f32;
239 *sample = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
240 }
241 });
242 }
243
244 fn resample(
245 st: &mut SpeexResamplerState,
246 chan: u8,
247 input: &[f32],
248 output: &mut [f32],
249 ) -> usize {
250 let (_in_len, out_len) = st
251 .process_float(chan as usize, input, output)
252 .expect("Resampling failed");
253 out_len
254 }
255}