servo_media_audio/
wave_shaper_node.rs1use crate::block::{Chunk, FRAMES_PER_BLOCK_USIZE};
2use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
3use speexdsp_resampler::State as SpeexResamplerState;
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum OverSampleType {
7 None,
8 Double,
9 Quadruple,
10}
11
12#[derive(Clone, Debug, PartialEq)]
13enum TailtimeBlocks {
14 Zero,
15 One,
16 Two,
17}
18
19const OVERSAMPLING_QUALITY: usize = 0;
20
21impl OverSampleType {
22 fn value(&self) -> usize {
23 match self {
24 OverSampleType::None => 1,
25 OverSampleType::Double => 2,
26 OverSampleType::Quadruple => 4,
27 }
28 }
29}
30
31type WaveShaperCurve = Option<Vec<f32>>;
32
33#[derive(Clone, Debug)]
34pub struct WaveShaperNodeOptions {
35 pub curve: WaveShaperCurve,
36 pub oversample: OverSampleType,
37}
38
39impl Default for WaveShaperNodeOptions {
40 fn default() -> Self {
41 WaveShaperNodeOptions {
42 curve: None,
43 oversample: OverSampleType::None,
44 }
45 }
46}
47
48#[derive(Clone, Debug)]
49pub enum WaveShaperNodeMessage {
50 SetCurve(WaveShaperCurve),
51}
52
53#[derive(AudioNodeCommon)]
54pub(crate) struct WaveShaperNode {
55 curve_set: bool,
56 curve: WaveShaperCurve,
57 #[allow(dead_code)]
58 oversample: OverSampleType,
59 channel_info: ChannelInfo,
60 upsampler: Option<SpeexResamplerState>,
61 downsampler: Option<SpeexResamplerState>,
62 tailtime_blocks_left: TailtimeBlocks,
63}
64
65impl WaveShaperNode {
66 pub fn new(options: WaveShaperNodeOptions, channel_info: ChannelInfo) -> Self {
67 if let Some(vec) = &options.curve {
68 assert!(
69 vec.len() > 1,
70 "WaveShaperNode curve must have length of 2 or more"
71 )
72 }
73
74 Self {
75 curve_set: options.curve.is_some(),
76 curve: options.curve,
77 oversample: options.oversample,
78 channel_info,
79 upsampler: None,
80 downsampler: None,
81 tailtime_blocks_left: TailtimeBlocks::Zero,
82 }
83 }
84
85 fn handle_waveshaper_message(&mut self, message: WaveShaperNodeMessage, _sample_rate: f32) {
86 match message {
87 WaveShaperNodeMessage::SetCurve(new_curve) => {
88 if self.curve_set && new_curve.is_some() {
89 panic!("InvalidStateError: cant set curve if it was already set");
90 }
91 self.curve_set = new_curve.is_some();
92 self.curve = new_curve;
93 }
94 }
95 }
96}
97
98impl AudioNodeEngine for WaveShaperNode {
99 fn node_type(&self) -> AudioNodeType {
100 AudioNodeType::WaveShaperNode
101 }
102
103 fn process(&mut self, mut inputs: Chunk, info: &BlockInfo) -> Chunk {
104 debug_assert!(inputs.len() == 1);
105
106 if self.curve.is_none() {
107 return inputs;
108 }
109
110 let curve = &self.curve.as_ref().expect("Just checked for is_none()");
111
112 if inputs.blocks[0].is_silence() {
113 if WaveShaperNode::silence_produces_nonsilent_output(curve) {
114 inputs.blocks[0].explicit_silence();
115 self.tailtime_blocks_left = TailtimeBlocks::Two;
116 } else if self.tailtime_blocks_left != TailtimeBlocks::Zero {
117 inputs.blocks[0].explicit_silence();
118
119 self.tailtime_blocks_left = match self.tailtime_blocks_left {
120 TailtimeBlocks::Zero => TailtimeBlocks::Zero,
121 TailtimeBlocks::One => TailtimeBlocks::Zero,
122 TailtimeBlocks::Two => TailtimeBlocks::One,
123 }
124 } else {
125 return inputs;
126 }
127 } else {
128 self.tailtime_blocks_left = TailtimeBlocks::Two;
129 }
130
131 let block = &mut inputs.blocks[0];
132 let channels = block.chan_count();
133
134 if self.oversample != OverSampleType::None {
135 let rate: usize = info.sample_rate as usize;
136 let sampling_factor = self.oversample.value();
137
138 if self.upsampler.is_none() {
139 self.upsampler = Some(
140 SpeexResamplerState::new(
141 channels as usize,
142 rate,
143 rate * sampling_factor,
144 OVERSAMPLING_QUALITY,
145 )
146 .expect("Couldnt create upsampler"),
147 );
148 };
149
150 if self.downsampler.is_none() {
151 self.downsampler = Some(
152 SpeexResamplerState::new(
153 channels as usize,
154 rate * sampling_factor,
155 rate,
156 OVERSAMPLING_QUALITY,
157 )
158 .expect("Couldnt create downsampler"),
159 );
160 };
161
162 let mut upsampler = self.upsampler.as_mut().unwrap();
163 let mut downsampler = self.downsampler.as_mut().unwrap();
164
165 let mut oversampled_buffer: Vec<f32> =
166 vec![0.; FRAMES_PER_BLOCK_USIZE * sampling_factor];
167
168 for chan in 0..channels {
169 let out_len = WaveShaperNode::resample(
170 &mut upsampler,
171 chan,
172 block.data_chan(chan),
173 &mut oversampled_buffer,
174 );
175
176 debug_assert!(
177 out_len == 128 * sampling_factor,
178 "Expected {} samples in output after upsampling, got: {}",
179 128 * sampling_factor,
180 out_len
181 );
182
183 WaveShaperNode::apply_curve(&mut oversampled_buffer, &curve);
184
185 let out_len = WaveShaperNode::resample(
186 &mut downsampler,
187 chan,
188 &oversampled_buffer,
189 &mut block.data_chan_mut(chan),
190 );
191
192 debug_assert!(
193 out_len == 128,
194 "Expected 128 samples in output after downsampling, got {}",
195 out_len
196 );
197 }
198 } else {
199 WaveShaperNode::apply_curve(block.data_mut(), &curve);
200 }
201
202 inputs
203 }
204
205 make_message_handler!(WaveShaperNode: handle_waveshaper_message);
206}
207
208impl WaveShaperNode {
209 fn silence_produces_nonsilent_output(curve: &Vec<f32>) -> bool {
210 let len = curve.len();
211 let len_halved = ((len - 1) as f32) / 2.;
212 let curve_index: f32 = len_halved;
213 let index_lo = curve_index as usize;
214 let index_hi = index_lo + 1;
215 let interp_factor: f32 = curve_index - index_lo as f32;
216 let shaped_val = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
217 shaped_val == 0.0
218 }
219
220 fn apply_curve(buf: &mut [f32], curve: &Vec<f32>) {
221 let len = curve.len();
222 let len_halved = ((len - 1) as f32) / 2.;
223 buf.iter_mut().for_each(|sample| {
224 let curve_index: f32 = len_halved * (*sample + 1.);
225
226 if curve_index <= 0. {
227 *sample = curve[0];
228 } else if curve_index >= (len - 1) as f32 {
229 *sample = curve[len - 1];
230 } else {
231 let index_lo = curve_index as usize;
232 let index_hi = index_lo + 1;
233 let interp_factor: f32 = curve_index - index_lo as f32;
234 *sample = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
235 }
236 });
237 }
238
239 fn resample(
240 st: &mut SpeexResamplerState,
241 chan: u8,
242 input: &[f32],
243 output: &mut [f32],
244 ) -> usize {
245 let (_in_len, out_len) = st
246 .process_float(chan as usize, input, output)
247 .expect("Resampling failed");
248 out_len
249 }
250}