servo_media_audio/
wave_shaper_node.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5use malloc_size_of_derive::MallocSizeOf;
6use speexdsp_resampler::State as SpeexResamplerState;
7
8use crate::block::{Chunk, FRAMES_PER_BLOCK_USIZE};
9use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
10
11#[derive(Clone, Debug, PartialEq, MallocSizeOf)]
12pub enum OverSampleType {
13    None,
14    Double,
15    Quadruple,
16}
17
18#[derive(Clone, Debug, PartialEq, MallocSizeOf)]
19enum TailtimeBlocks {
20    Zero,
21    One,
22    Two,
23}
24
25const OVERSAMPLING_QUALITY: usize = 0;
26
27impl OverSampleType {
28    fn value(&self) -> usize {
29        match self {
30            OverSampleType::None => 1,
31            OverSampleType::Double => 2,
32            OverSampleType::Quadruple => 4,
33        }
34    }
35}
36
37type WaveShaperCurve = Option<Vec<f32>>;
38
39#[derive(Clone, Debug, MallocSizeOf)]
40pub struct WaveShaperNodeOptions {
41    pub curve: WaveShaperCurve,
42    pub oversample: OverSampleType,
43}
44
45impl Default for WaveShaperNodeOptions {
46    fn default() -> Self {
47        WaveShaperNodeOptions {
48            curve: None,
49            oversample: OverSampleType::None,
50        }
51    }
52}
53
54#[derive(Clone, Debug, MallocSizeOf)]
55pub enum WaveShaperNodeMessage {
56    SetCurve(WaveShaperCurve),
57}
58
59#[derive(AudioNodeCommon, MallocSizeOf)]
60pub(crate) struct WaveShaperNode {
61    curve_set: bool,
62    curve: WaveShaperCurve,
63    #[allow(dead_code)]
64    oversample: OverSampleType,
65    channel_info: ChannelInfo,
66    #[ignore_malloc_size_of = "External Type"]
67    upsampler: Option<SpeexResamplerState>,
68    #[ignore_malloc_size_of = "External Type"]
69    downsampler: Option<SpeexResamplerState>,
70    tailtime_blocks_left: TailtimeBlocks,
71}
72
73impl WaveShaperNode {
74    pub fn new(options: WaveShaperNodeOptions, channel_info: ChannelInfo) -> Self {
75        if let Some(vec) = &options.curve {
76            assert!(
77                vec.len() > 1,
78                "WaveShaperNode curve must have length of 2 or more"
79            )
80        }
81
82        Self {
83            curve_set: options.curve.is_some(),
84            curve: options.curve,
85            oversample: options.oversample,
86            channel_info,
87            upsampler: None,
88            downsampler: None,
89            tailtime_blocks_left: TailtimeBlocks::Zero,
90        }
91    }
92
93    fn handle_waveshaper_message(&mut self, message: WaveShaperNodeMessage, _sample_rate: f32) {
94        match message {
95            WaveShaperNodeMessage::SetCurve(new_curve) => {
96                if self.curve_set && new_curve.is_some() {
97                    panic!("InvalidStateError: cant set curve if it was already set");
98                }
99                self.curve_set = new_curve.is_some();
100                self.curve = new_curve;
101            },
102        }
103    }
104}
105
106impl AudioNodeEngine for WaveShaperNode {
107    fn node_type(&self) -> AudioNodeType {
108        AudioNodeType::WaveShaperNode
109    }
110
111    fn process(&mut self, mut inputs: Chunk, info: &BlockInfo) -> Chunk {
112        debug_assert!(inputs.len() == 1);
113
114        if self.curve.is_none() {
115            return inputs;
116        }
117
118        let curve = &self.curve.as_ref().expect("Just checked for is_none()");
119
120        if inputs.blocks[0].is_silence() {
121            if WaveShaperNode::silence_produces_nonsilent_output(curve) {
122                inputs.blocks[0].explicit_silence();
123                self.tailtime_blocks_left = TailtimeBlocks::Two;
124            } else if self.tailtime_blocks_left != TailtimeBlocks::Zero {
125                inputs.blocks[0].explicit_silence();
126
127                self.tailtime_blocks_left = match self.tailtime_blocks_left {
128                    TailtimeBlocks::Zero => TailtimeBlocks::Zero,
129                    TailtimeBlocks::One => TailtimeBlocks::Zero,
130                    TailtimeBlocks::Two => TailtimeBlocks::One,
131                }
132            } else {
133                return inputs;
134            }
135        } else {
136            self.tailtime_blocks_left = TailtimeBlocks::Two;
137        }
138
139        let block = &mut inputs.blocks[0];
140        let channels = block.chan_count();
141
142        if self.oversample != OverSampleType::None {
143            let rate: usize = info.sample_rate as usize;
144            let sampling_factor = self.oversample.value();
145
146            if self.upsampler.is_none() {
147                self.upsampler = Some(
148                    SpeexResamplerState::new(
149                        channels as usize,
150                        rate,
151                        rate * sampling_factor,
152                        OVERSAMPLING_QUALITY,
153                    )
154                    .expect("Couldnt create upsampler"),
155                );
156            };
157
158            if self.downsampler.is_none() {
159                self.downsampler = Some(
160                    SpeexResamplerState::new(
161                        channels as usize,
162                        rate * sampling_factor,
163                        rate,
164                        OVERSAMPLING_QUALITY,
165                    )
166                    .expect("Couldnt create downsampler"),
167                );
168            };
169
170            let upsampler = self.upsampler.as_mut().unwrap();
171            let downsampler = self.downsampler.as_mut().unwrap();
172
173            let mut oversampled_buffer: Vec<f32> =
174                vec![0.; FRAMES_PER_BLOCK_USIZE * sampling_factor];
175
176            for chan in 0..channels {
177                let out_len = WaveShaperNode::resample(
178                    upsampler,
179                    chan,
180                    block.data_chan(chan),
181                    &mut oversampled_buffer,
182                );
183
184                debug_assert!(
185                    out_len == 128 * sampling_factor,
186                    "Expected {} samples in output after upsampling, got: {}",
187                    128 * sampling_factor,
188                    out_len
189                );
190
191                WaveShaperNode::apply_curve(&mut oversampled_buffer, curve);
192
193                let out_len = WaveShaperNode::resample(
194                    downsampler,
195                    chan,
196                    &oversampled_buffer,
197                    block.data_chan_mut(chan),
198                );
199
200                debug_assert!(
201                    out_len == 128,
202                    "Expected 128 samples in output after downsampling, got {}",
203                    out_len
204                );
205            }
206        } else {
207            WaveShaperNode::apply_curve(block.data_mut(), curve);
208        }
209
210        inputs
211    }
212
213    make_message_handler!(WaveShaperNode: handle_waveshaper_message);
214}
215
216impl WaveShaperNode {
217    fn silence_produces_nonsilent_output(curve: &[f32]) -> bool {
218        let len = curve.len();
219        let len_halved = ((len - 1) as f32) / 2.;
220        let curve_index: f32 = len_halved;
221        let index_lo = curve_index as usize;
222        let index_hi = index_lo + 1;
223        let interp_factor: f32 = curve_index - index_lo as f32;
224        let shaped_val = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
225        shaped_val == 0.0
226    }
227
228    fn apply_curve(buf: &mut [f32], curve: &[f32]) {
229        let len = curve.len();
230        let len_halved = ((len - 1) as f32) / 2.;
231        buf.iter_mut().for_each(|sample| {
232            let curve_index: f32 = len_halved * (*sample + 1.);
233
234            if curve_index <= 0. {
235                *sample = curve[0];
236            } else if curve_index >= (len - 1) as f32 {
237                *sample = curve[len - 1];
238            } else {
239                let index_lo = curve_index as usize;
240                let index_hi = index_lo + 1;
241                let interp_factor: f32 = curve_index - index_lo as f32;
242                *sample = (1. - interp_factor) * curve[index_lo] + interp_factor * curve[index_hi];
243            }
244        });
245    }
246
247    fn resample(
248        st: &mut SpeexResamplerState,
249        chan: u8,
250        input: &[f32],
251        output: &mut [f32],
252    ) -> usize {
253        let (_in_len, out_len) = st
254            .process_float(chan as usize, input, output)
255            .expect("Resampling failed");
256        out_len
257    }
258}