servo_media_audio/
wave_shaper_node.rs

1use crate::block::{Chunk, FRAMES_PER_BLOCK_USIZE};
2use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
3use speexdsp_resampler::State as SpeexResamplerState;
4
5#[derive(Clone, Debug, PartialEq)]
6pub enum OverSampleType {
7    None,
8    Double,
9    Quadruple,
10}
11
12#[derive(Clone, Debug, PartialEq)]
13enum TailtimeBlocks {
14    Zero,
15    One,
16    Two,
17}
18
19const OVERSAMPLING_QUALITY: usize = 0;
20
21impl OverSampleType {
22    fn value(&self) -> usize {
23        match self {
24            OverSampleType::None => 1,
25            OverSampleType::Double => 2,
26            OverSampleType::Quadruple => 4,
27        }
28    }
29}
30
31type WaveShaperCurve = Option<Vec<f32>>;
32
33#[derive(Clone, Debug)]
34pub struct WaveShaperNodeOptions {
35    pub curve: WaveShaperCurve,
36    pub oversample: OverSampleType,
37}
38
39impl Default for WaveShaperNodeOptions {
40    fn default() -> Self {
41        WaveShaperNodeOptions {
42            curve: None,
43            oversample: OverSampleType::None,
44        }
45    }
46}
47
48#[derive(Clone, Debug)]
49pub enum WaveShaperNodeMessage {
50    SetCurve(WaveShaperCurve),
51}
52
53#[derive(AudioNodeCommon)]
54pub(crate) struct WaveShaperNode {
55    curve_set: bool,
56    curve: WaveShaperCurve,
57    #[allow(dead_code)]
58    oversample: OverSampleType,
59    channel_info: ChannelInfo,
60    upsampler: Option<SpeexResamplerState>,
61    downsampler: Option<SpeexResamplerState>,
62    tailtime_blocks_left: TailtimeBlocks,
63}
64
65impl WaveShaperNode {
66    pub fn new(options: WaveShaperNodeOptions, channel_info: ChannelInfo) -> Self {
67        if let Some(vec) = &options.curve {
68            assert!(
69                vec.len() > 1,
70                "WaveShaperNode curve must have length of 2 or more"
71            )
72        }
73
74        Self {
75            curve_set: options.curve.is_some(),
76            curve: options.curve,
77            oversample: options.oversample,
78            channel_info,
79            upsampler: None,
80            downsampler: None,
81            tailtime_blocks_left: TailtimeBlocks::Zero,
82        }
83    }
84
85    fn handle_waveshaper_message(&mut self, message: WaveShaperNodeMessage, _sample_rate: f32) {
86        match message {
87            WaveShaperNodeMessage::SetCurve(new_curve) => {
88                if self.curve_set && new_curve.is_some() {
89                    panic!("InvalidStateError: cant set curve if it was already set");
90                }
91                self.curve_set = new_curve.is_some();
92                self.curve = new_curve;
93            }
94        }
95    }
96}
97
98impl AudioNodeEngine for WaveShaperNode {
99    fn node_type(&self) -> AudioNodeType {
100        AudioNodeType::WaveShaperNode
101    }
102
103    fn process(&mut self, mut inputs: Chunk, info: &BlockInfo) -> Chunk {
104        debug_assert!(inputs.len() == 1);
105
106        if self.curve.is_none() {
107            return inputs;
108        }
109
110        let curve = &self.curve.as_ref().expect("Just checked for is_none()");
111
112        if inputs.blocks[0].is_silence() {
113            if WaveShaperNode::silence_produces_nonsilent_output(curve) {
114                inputs.blocks[0].explicit_silence();
115                self.tailtime_blocks_left = TailtimeBlocks::Two;
116            } else if self.tailtime_blocks_left != TailtimeBlocks::Zero {
117                inputs.blocks[0].explicit_silence();
118
119                self.tailtime_blocks_left = match self.tailtime_blocks_left {
120                    TailtimeBlocks::Zero => TailtimeBlocks::Zero,
121                    TailtimeBlocks::One => TailtimeBlocks::Zero,
122                    TailtimeBlocks::Two => TailtimeBlocks::One,
123                }
124            } else {
125                return inputs;
126            }
127        } else {
128            self.tailtime_blocks_left = TailtimeBlocks::Two;
129        }
130
131        let block = &mut inputs.blocks[0];
132        let channels = block.chan_count();
133
134        if self.oversample != OverSampleType::None {
135            let rate: usize = info.sample_rate as usize;
136            let sampling_factor = self.oversample.value();
137
138            if self.upsampler.is_none() {
139                self.upsampler = Some(
140                    SpeexResamplerState::new(
141                        channels as usize,
142                        rate,
143                        rate * sampling_factor,
144                        OVERSAMPLING_QUALITY,
145                    )
146                    .expect("Couldnt create upsampler"),
147                );
148            };
149
150            if self.downsampler.is_none() {
151                self.downsampler = Some(
152                    SpeexResamplerState::new(
153                        channels as usize,
154                        rate * sampling_factor,
155                        rate,
156                        OVERSAMPLING_QUALITY,
157                    )
158                    .expect("Couldnt create downsampler"),
159                );
160            };
161
162            let mut upsampler = self.upsampler.as_mut().unwrap();
163            let mut downsampler = self.downsampler.as_mut().unwrap();
164
165            let mut oversampled_buffer: Vec<f32> =
166                vec![0.; FRAMES_PER_BLOCK_USIZE * sampling_factor];
167
168            for chan in 0..channels {
169                let out_len = WaveShaperNode::resample(
170                    &mut upsampler,
171                    chan,
172                    block.data_chan(chan),
173                    &mut oversampled_buffer,
174                );
175
176                debug_assert!(
177                    out_len == 128 * sampling_factor,
178                    "Expected {} samples in output after upsampling, got: {}",
179                    128 * sampling_factor,
180                    out_len
181                );
182
183                WaveShaperNode::apply_curve(&mut oversampled_buffer, &curve);
184
185                let out_len = WaveShaperNode::resample(
186                    &mut downsampler,
187                    chan,
188                    &oversampled_buffer,
189                    &mut block.data_chan_mut(chan),
190                );
191
192                debug_assert!(
193                    out_len == 128,
194                    "Expected 128 samples in output after downsampling, got {}",
195                    out_len
196                );
197            }
198        } else {
199            WaveShaperNode::apply_curve(block.data_mut(), &curve);
200        }
201
202        inputs
203    }
204
205    make_message_handler!(WaveShaperNode: handle_waveshaper_message);
206}
207
208impl WaveShaperNode {
209    fn silence_produces_nonsilent_output(curve: &Vec<f32>) -> bool {
210        let len = curve.len();
211        let len_halved = ((len - 1) as f32) / 2.;
212        let curve_index: f32 = len_halved;
213        let index_lo = curve_index as usize;
214        let index_hi = index_lo + 1;
215        let interp_factor: f32 = curve_index - index_lo as f32;
216        let shaped_val = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
217        shaped_val == 0.0
218    }
219
220    fn apply_curve(buf: &mut [f32], curve: &Vec<f32>) {
221        let len = curve.len();
222        let len_halved = ((len - 1) as f32) / 2.;
223        buf.iter_mut().for_each(|sample| {
224            let curve_index: f32 = len_halved * (*sample + 1.);
225
226            if curve_index <= 0. {
227                *sample = curve[0];
228            } else if curve_index >= (len - 1) as f32 {
229                *sample = curve[len - 1];
230            } else {
231                let index_lo = curve_index as usize;
232                let index_hi = index_lo + 1;
233                let interp_factor: f32 = curve_index - index_lo as f32;
234                *sample = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
235            }
236        });
237    }
238
239    fn resample(
240        st: &mut SpeexResamplerState,
241        chan: u8,
242        input: &[f32],
243        output: &mut [f32],
244    ) -> usize {
245        let (_in_len, out_len) = st
246            .process_float(chan as usize, input, output)
247            .expect("Resampling failed");
248        out_len
249    }
250}