servo_media_audio/
iir_filter_node.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at https://mozilla.org/MPL/2.0/. */
4
5use std::collections::VecDeque;
6use std::sync::Arc;
7
8use log::warn;
9use num_complex::Complex64;
10
11use crate::block::Chunk;
12use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
13
14const MAX_COEFFS: usize = 20;
15
16#[derive(Debug)]
17pub struct IIRFilterNodeOptions {
18    pub feedforward: Arc<Vec<f64>>,
19    pub feedback: Arc<Vec<f64>>,
20}
21
22#[derive(Clone)]
23struct IIRFilter {
24    feedforward: Arc<Vec<f64>>,
25    feedback: Arc<Vec<f64>>,
26    inputs: VecDeque<f64>,
27    outputs: VecDeque<f64>,
28}
29
30impl IIRFilter {
31    fn new(feedforward: Arc<Vec<f64>>, feedback: Arc<Vec<f64>>) -> Self {
32        Self {
33            feedforward,
34            feedback,
35            inputs: VecDeque::with_capacity(MAX_COEFFS),
36            outputs: VecDeque::with_capacity(MAX_COEFFS),
37        }
38    }
39
40    fn calculate_output(&mut self, input: f32) -> f32 {
41        self.inputs.push_front(input as f64);
42
43        if self.inputs.len() > MAX_COEFFS {
44            self.inputs.pop_back();
45        }
46
47        let inputs_sum = self
48            .feedforward
49            .iter()
50            .zip(self.inputs.iter())
51            .fold(0.0, |acc, (c, v)| acc + c * v);
52
53        let outputs_sum = self
54            .feedback
55            .iter()
56            .skip(1)
57            .zip(self.outputs.iter())
58            .fold(0.0, |acc, (c, v)| acc + c * v);
59
60        let output = (inputs_sum - outputs_sum) / self.feedback[0];
61
62        if output.is_nan() {
63            // Per spec:
64            // Note: The UA may produce a warning to notify the user that NaN values have occurred in the filter state.
65            // This is usually indicative of an unstable filter.
66            //
67            // But idk how to produce warnings
68            warn!("NaN in IIRFilter state");
69        }
70
71        self.outputs.push_front(output);
72
73        if self.outputs.len() > MAX_COEFFS {
74            self.outputs.pop_back();
75        }
76
77        output as f32
78    }
79}
80
81#[derive(AudioNodeCommon)]
82pub struct IIRFilterNode {
83    channel_info: ChannelInfo,
84    filters: Vec<IIRFilter>,
85}
86
87impl IIRFilterNode {
88    pub fn new(options: IIRFilterNodeOptions, channel_info: ChannelInfo) -> Self {
89        debug_assert!(
90            !options.feedforward.is_empty(),
91            "NotSupportedError: feedforward must have at least one coeff"
92        );
93
94        debug_assert!(
95            options.feedforward.len() <= MAX_COEFFS,
96            "NotSupportedError: feedforward max length is {}",
97            MAX_COEFFS
98        );
99
100        debug_assert!(
101            options.feedforward.iter().any(|&v| v != 0.0_f64),
102            "InvalidStateError: all coeffs are zero"
103        );
104
105        debug_assert!(
106            !options.feedback.is_empty(),
107            "NotSupportedError: feedback must have at least one coeff"
108        );
109
110        debug_assert!(
111            options.feedback.len() <= MAX_COEFFS,
112            "NotSupportedError: feedback max length is {}",
113            MAX_COEFFS
114        );
115
116        debug_assert!(
117            options.feedback[0] != 0.0,
118            "InvalidStateError: first feedback coeff must not be zero"
119        );
120
121        let filter = IIRFilter::new(options.feedforward.clone(), options.feedback.clone());
122
123        Self {
124            filters: vec![filter; channel_info.computed_number_of_channels() as usize],
125            channel_info,
126        }
127    }
128
129    pub fn get_frequency_response(
130        feedforward: &[f64],
131        feedback: &[f64],
132        frequency_hz: &[f32],
133        mag_response: &mut [f32],
134        phase_response: &mut [f32],
135    ) {
136        debug_assert!(
137            frequency_hz.len() == mag_response.len() && frequency_hz.len() == phase_response.len(),
138            "get_frequency_response params are of different length"
139        );
140
141        frequency_hz.iter().enumerate().for_each(|(idx, &f)| {
142            if !(0.0..1.0).contains(&f) {
143                mag_response[idx] = f32::NAN;
144                phase_response[idx] = f32::NAN;
145            } else {
146                let f = (-f as f64) * std::f64::consts::PI;
147                let z = Complex64::new(f64::cos(f), f64::sin(f));
148                let numerator = Self::sum(feedforward, z);
149                let denominator = Self::sum(feedback, z);
150
151                let response = numerator / denominator;
152                mag_response[idx] = response.norm() as f32;
153                phase_response[idx] = response.arg() as f32;
154            }
155        });
156    }
157
158    fn sum(coeffs: &[f64], z: Complex64) -> Complex64 {
159        coeffs.iter().fold(Complex64::new(0.0, 0.0), |acc, &coeff| {
160            acc * z + Complex64::new(coeff, 0.0)
161        })
162    }
163}
164
165impl AudioNodeEngine for IIRFilterNode {
166    fn node_type(&self) -> AudioNodeType {
167        AudioNodeType::IIRFilterNode
168    }
169
170    fn process(&mut self, inputs: Chunk, _info: &BlockInfo) -> Chunk {
171        debug_assert!(inputs.len() == 1);
172
173        let mut inputs = if inputs.blocks[0].is_silence() {
174            Chunk::explicit_silence()
175        } else {
176            inputs
177        };
178
179        let mut iter = inputs.blocks[0].iter();
180
181        while let Some(mut frame) = iter.next() {
182            frame.mutate_with(|sample, chan_idx| {
183                *sample = self.filters[chan_idx as usize].calculate_output(*sample);
184            });
185        }
186        inputs
187    }
188}