servo_media_audio/
wave_shaper_node.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5use speexdsp_resampler::State as SpeexResamplerState;
6
7use crate::block::{Chunk, FRAMES_PER_BLOCK_USIZE};
8use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
9
10#[derive(Clone, Debug, PartialEq)]
11pub enum OverSampleType {
12    None,
13    Double,
14    Quadruple,
15}
16
17#[derive(Clone, Debug, PartialEq)]
18enum TailtimeBlocks {
19    Zero,
20    One,
21    Two,
22}
23
24const OVERSAMPLING_QUALITY: usize = 0;
25
26impl OverSampleType {
27    fn value(&self) -> usize {
28        match self {
29            OverSampleType::None => 1,
30            OverSampleType::Double => 2,
31            OverSampleType::Quadruple => 4,
32        }
33    }
34}
35
36type WaveShaperCurve = Option<Vec<f32>>;
37
38#[derive(Clone, Debug)]
39pub struct WaveShaperNodeOptions {
40    pub curve: WaveShaperCurve,
41    pub oversample: OverSampleType,
42}
43
44impl Default for WaveShaperNodeOptions {
45    fn default() -> Self {
46        WaveShaperNodeOptions {
47            curve: None,
48            oversample: OverSampleType::None,
49        }
50    }
51}
52
53#[derive(Clone, Debug)]
54pub enum WaveShaperNodeMessage {
55    SetCurve(WaveShaperCurve),
56}
57
58#[derive(AudioNodeCommon)]
59pub(crate) struct WaveShaperNode {
60    curve_set: bool,
61    curve: WaveShaperCurve,
62    #[allow(dead_code)]
63    oversample: OverSampleType,
64    channel_info: ChannelInfo,
65    upsampler: Option<SpeexResamplerState>,
66    downsampler: Option<SpeexResamplerState>,
67    tailtime_blocks_left: TailtimeBlocks,
68}
69
70impl WaveShaperNode {
71    pub fn new(options: WaveShaperNodeOptions, channel_info: ChannelInfo) -> Self {
72        if let Some(vec) = &options.curve {
73            assert!(
74                vec.len() > 1,
75                "WaveShaperNode curve must have length of 2 or more"
76            )
77        }
78
79        Self {
80            curve_set: options.curve.is_some(),
81            curve: options.curve,
82            oversample: options.oversample,
83            channel_info,
84            upsampler: None,
85            downsampler: None,
86            tailtime_blocks_left: TailtimeBlocks::Zero,
87        }
88    }
89
90    fn handle_waveshaper_message(&mut self, message: WaveShaperNodeMessage, _sample_rate: f32) {
91        match message {
92            WaveShaperNodeMessage::SetCurve(new_curve) => {
93                if self.curve_set && new_curve.is_some() {
94                    panic!("InvalidStateError: cant set curve if it was already set");
95                }
96                self.curve_set = new_curve.is_some();
97                self.curve = new_curve;
98            },
99        }
100    }
101}
102
103impl AudioNodeEngine for WaveShaperNode {
104    fn node_type(&self) -> AudioNodeType {
105        AudioNodeType::WaveShaperNode
106    }
107
108    fn process(&mut self, mut inputs: Chunk, info: &BlockInfo) -> Chunk {
109        debug_assert!(inputs.len() == 1);
110
111        if self.curve.is_none() {
112            return inputs;
113        }
114
115        let curve = &self.curve.as_ref().expect("Just checked for is_none()");
116
117        if inputs.blocks[0].is_silence() {
118            if WaveShaperNode::silence_produces_nonsilent_output(curve) {
119                inputs.blocks[0].explicit_silence();
120                self.tailtime_blocks_left = TailtimeBlocks::Two;
121            } else if self.tailtime_blocks_left != TailtimeBlocks::Zero {
122                inputs.blocks[0].explicit_silence();
123
124                self.tailtime_blocks_left = match self.tailtime_blocks_left {
125                    TailtimeBlocks::Zero => TailtimeBlocks::Zero,
126                    TailtimeBlocks::One => TailtimeBlocks::Zero,
127                    TailtimeBlocks::Two => TailtimeBlocks::One,
128                }
129            } else {
130                return inputs;
131            }
132        } else {
133            self.tailtime_blocks_left = TailtimeBlocks::Two;
134        }
135
136        let block = &mut inputs.blocks[0];
137        let channels = block.chan_count();
138
139        if self.oversample != OverSampleType::None {
140            let rate: usize = info.sample_rate as usize;
141            let sampling_factor = self.oversample.value();
142
143            if self.upsampler.is_none() {
144                self.upsampler = Some(
145                    SpeexResamplerState::new(
146                        channels as usize,
147                        rate,
148                        rate * sampling_factor,
149                        OVERSAMPLING_QUALITY,
150                    )
151                    .expect("Couldnt create upsampler"),
152                );
153            };
154
155            if self.downsampler.is_none() {
156                self.downsampler = Some(
157                    SpeexResamplerState::new(
158                        channels as usize,
159                        rate * sampling_factor,
160                        rate,
161                        OVERSAMPLING_QUALITY,
162                    )
163                    .expect("Couldnt create downsampler"),
164                );
165            };
166
167            let upsampler = self.upsampler.as_mut().unwrap();
168            let downsampler = self.downsampler.as_mut().unwrap();
169
170            let mut oversampled_buffer: Vec<f32> =
171                vec![0.; FRAMES_PER_BLOCK_USIZE * sampling_factor];
172
173            for chan in 0..channels {
174                let out_len = WaveShaperNode::resample(
175                    upsampler,
176                    chan,
177                    block.data_chan(chan),
178                    &mut oversampled_buffer,
179                );
180
181                debug_assert!(
182                    out_len == 128 * sampling_factor,
183                    "Expected {} samples in output after upsampling, got: {}",
184                    128 * sampling_factor,
185                    out_len
186                );
187
188                WaveShaperNode::apply_curve(&mut oversampled_buffer, curve);
189
190                let out_len = WaveShaperNode::resample(
191                    downsampler,
192                    chan,
193                    &oversampled_buffer,
194                    block.data_chan_mut(chan),
195                );
196
197                debug_assert!(
198                    out_len == 128,
199                    "Expected 128 samples in output after downsampling, got {}",
200                    out_len
201                );
202            }
203        } else {
204            WaveShaperNode::apply_curve(block.data_mut(), curve);
205        }
206
207        inputs
208    }
209
210    make_message_handler!(WaveShaperNode: handle_waveshaper_message);
211}
212
213impl WaveShaperNode {
214    fn silence_produces_nonsilent_output(curve: &[f32]) -> bool {
215        let len = curve.len();
216        let len_halved = ((len - 1) as f32) / 2.;
217        let curve_index: f32 = len_halved;
218        let index_lo = curve_index as usize;
219        let index_hi = index_lo + 1;
220        let interp_factor: f32 = curve_index - index_lo as f32;
221        let shaped_val = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
222        shaped_val == 0.0
223    }
224
225    fn apply_curve(buf: &mut [f32], curve: &[f32]) {
226        let len = curve.len();
227        let len_halved = ((len - 1) as f32) / 2.;
228        buf.iter_mut().for_each(|sample| {
229            let curve_index: f32 = len_halved * (*sample + 1.);
230
231            if curve_index <= 0. {
232                *sample = curve[0];
233            } else if curve_index >= (len - 1) as f32 {
234                *sample = curve[len - 1];
235            } else {
236                let index_lo = curve_index as usize;
237                let index_hi = index_lo + 1;
238                let interp_factor: f32 = curve_index - index_lo as f32;
239                *sample = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
240            }
241        });
242    }
243
244    fn resample(
245        st: &mut SpeexResamplerState,
246        chan: u8,
247        input: &[f32],
248        output: &mut [f32],
249    ) -> usize {
250        let (_in_len, out_len) = st
251            .process_float(chan as usize, input, output)
252            .expect("Resampling failed");
253        out_len
254    }
255}