servo_media_audio/
iir_filter_node.rs

1use crate::block::Chunk;
2use log::warn;
3use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
4use num_complex::Complex64;
5use std::collections::VecDeque;
6use std::sync::Arc;
7
8const MAX_COEFFS: usize = 20;
9
10#[derive(Debug)]
11pub struct IIRFilterNodeOptions {
12    pub feedforward: Arc<Vec<f64>>,
13    pub feedback: Arc<Vec<f64>>,
14}
15
16#[derive(Clone)]
17struct IIRFilter {
18    feedforward: Arc<Vec<f64>>,
19    feedback: Arc<Vec<f64>>,
20    inputs: VecDeque<f64>,
21    outputs: VecDeque<f64>,
22}
23
24impl IIRFilter {
25    fn new(feedforward: Arc<Vec<f64>>, feedback: Arc<Vec<f64>>) -> Self {
26        Self {
27            feedforward,
28            feedback,
29            inputs: VecDeque::with_capacity(MAX_COEFFS),
30            outputs: VecDeque::with_capacity(MAX_COEFFS),
31        }
32    }
33
34    fn calculate_output(&mut self, input: f32) -> f32 {
35        self.inputs.push_front(input as f64);
36
37        if self.inputs.len() > MAX_COEFFS {
38            self.inputs.pop_back();
39        }
40
41        let inputs_sum = self
42            .feedforward
43            .iter()
44            .zip(self.inputs.iter())
45            .fold(0.0, |acc, (c, v)| acc + c * v);
46
47        let outputs_sum = self
48            .feedback
49            .iter()
50            .skip(1)
51            .zip(self.outputs.iter())
52            .fold(0.0, |acc, (c, v)| acc + c * v);
53
54        let output = (inputs_sum - outputs_sum) / self.feedback[0];
55
56        if output.is_nan() {
57            // Per spec:
58            // Note: The UA may produce a warning to notify the user that NaN values have occurred in the filter state.
59            // This is usually indicative of an unstable filter.
60            //
61            // But idk how to produce warnings
62            warn!("NaN in IIRFilter state");
63        }
64
65        self.outputs.push_front(output);
66
67        if self.outputs.len() > MAX_COEFFS {
68            self.outputs.pop_back();
69        }
70
71        output as f32
72    }
73}
74
75#[derive(AudioNodeCommon)]
76pub struct IIRFilterNode {
77    channel_info: ChannelInfo,
78    filters: Vec<IIRFilter>,
79}
80
81impl IIRFilterNode {
82    pub fn new(options: IIRFilterNodeOptions, channel_info: ChannelInfo) -> Self {
83        debug_assert!(
84            options.feedforward.len() > 0,
85            "NotSupportedError: feedforward must have at least one coeff"
86        );
87
88        debug_assert!(
89            options.feedforward.len() <= MAX_COEFFS,
90            "NotSupportedError: feedforward max length is {}",
91            MAX_COEFFS
92        );
93
94        debug_assert!(
95            options.feedforward.iter().any(|&v| v != 0.0_f64),
96            "InvalidStateError: all coeffs are zero"
97        );
98
99        debug_assert!(
100            options.feedback.len() > 0,
101            "NotSupportedError: feedback must have at least one coeff"
102        );
103
104        debug_assert!(
105            options.feedback.len() <= MAX_COEFFS,
106            "NotSupportedError: feedback max length is {}",
107            MAX_COEFFS
108        );
109
110        debug_assert!(
111            options.feedback[0] != 0.0,
112            "InvalidStateError: first feedback coeff must not be zero"
113        );
114
115        let filter = IIRFilter::new(options.feedforward.clone(), options.feedback.clone());
116
117        Self {
118            filters: vec![filter; channel_info.computed_number_of_channels() as usize],
119            channel_info,
120        }
121    }
122
123    pub fn get_frequency_response(
124        feedforward: &[f64],
125        feedback: &[f64],
126        frequency_hz: &[f32],
127        mag_response: &mut [f32],
128        phase_response: &mut [f32],
129    ) {
130        debug_assert!(
131            frequency_hz.len() == mag_response.len() && frequency_hz.len() == phase_response.len(),
132            "get_frequency_response params are of different length"
133        );
134
135        frequency_hz.iter().enumerate().for_each(|(idx, &f)| {
136            if f < 0.0 || f >= 1.0 {
137                mag_response[idx] = std::f32::NAN;
138                phase_response[idx] = std::f32::NAN;
139            } else {
140                let f = (-f as f64) * std::f64::consts::PI;
141                let z = Complex64::new(f64::cos(f), f64::sin(f));
142                let numerator = Self::sum(feedforward, z);
143                let denominator = Self::sum(feedback, z);
144
145                let response = numerator / denominator;
146                mag_response[idx] = response.norm() as f32;
147                phase_response[idx] = response.arg() as f32;
148            }
149        });
150    }
151
152    fn sum(coeffs: &[f64], z: Complex64) -> Complex64 {
153        coeffs.iter().fold(Complex64::new(0.0, 0.0), |acc, &coeff| {
154            acc * z + Complex64::new(coeff, 0.0)
155        })
156    }
157}
158
159impl AudioNodeEngine for IIRFilterNode {
160    fn node_type(&self) -> AudioNodeType {
161        AudioNodeType::IIRFilterNode
162    }
163
164    fn process(&mut self, inputs: Chunk, _info: &BlockInfo) -> Chunk {
165        debug_assert!(inputs.len() == 1);
166
167        let mut inputs = if inputs.blocks[0].is_silence() {
168            Chunk::explicit_silence()
169        } else {
170            inputs
171        };
172
173        let mut iter = inputs.blocks[0].iter();
174
175        while let Some(mut frame) = iter.next() {
176            frame.mutate_with(|sample, chan_idx| {
177                *sample = self.filters[chan_idx as usize].calculate_output(*sample);
178            });
179        }
180        inputs
181    }
182}