servo_media_audio/
iir_filter_node.rs1use std::collections::VecDeque;
6use std::sync::Arc;
7
8use log::warn;
9use malloc_size_of_derive::MallocSizeOf;
10use num_complex::Complex64;
11
12use crate::block::Chunk;
13use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
14
15const MAX_COEFFS: usize = 20;
16
17#[derive(Debug, MallocSizeOf)]
18pub struct IIRFilterNodeOptions {
19 #[conditional_malloc_size_of]
20 pub feedforward: Arc<Vec<f64>>,
21 #[conditional_malloc_size_of]
22 pub feedback: Arc<Vec<f64>>,
23}
24
25#[derive(Clone)]
26struct IIRFilter {
27 feedforward: Arc<Vec<f64>>,
28 feedback: Arc<Vec<f64>>,
29 inputs: VecDeque<f64>,
30 outputs: VecDeque<f64>,
31}
32
33impl IIRFilter {
34 fn new(feedforward: Arc<Vec<f64>>, feedback: Arc<Vec<f64>>) -> Self {
35 Self {
36 feedforward,
37 feedback,
38 inputs: VecDeque::with_capacity(MAX_COEFFS),
39 outputs: VecDeque::with_capacity(MAX_COEFFS),
40 }
41 }
42
43 fn calculate_output(&mut self, input: f32) -> f32 {
44 self.inputs.push_front(input as f64);
45
46 if self.inputs.len() > MAX_COEFFS {
47 self.inputs.pop_back();
48 }
49
50 let inputs_sum = self
51 .feedforward
52 .iter()
53 .zip(self.inputs.iter())
54 .fold(0.0, |acc, (c, v)| acc + c * v);
55
56 let outputs_sum = self
57 .feedback
58 .iter()
59 .skip(1)
60 .zip(self.outputs.iter())
61 .fold(0.0, |acc, (c, v)| acc + c * v);
62
63 let output = (inputs_sum - outputs_sum) / self.feedback[0];
64
65 if output.is_nan() {
66 warn!("NaN in IIRFilter state");
72 }
73
74 self.outputs.push_front(output);
75
76 if self.outputs.len() > MAX_COEFFS {
77 self.outputs.pop_back();
78 }
79
80 output as f32
81 }
82}
83
84#[derive(AudioNodeCommon)]
85pub struct IIRFilterNode {
86 channel_info: ChannelInfo,
87 filters: Vec<IIRFilter>,
88}
89
90impl IIRFilterNode {
91 pub fn new(options: IIRFilterNodeOptions, channel_info: ChannelInfo) -> Self {
92 debug_assert!(
93 !options.feedforward.is_empty(),
94 "NotSupportedError: feedforward must have at least one coeff"
95 );
96
97 debug_assert!(
98 options.feedforward.len() <= MAX_COEFFS,
99 "NotSupportedError: feedforward max length is {}",
100 MAX_COEFFS
101 );
102
103 debug_assert!(
104 options.feedforward.iter().any(|&v| v != 0.0_f64),
105 "InvalidStateError: all coeffs are zero"
106 );
107
108 debug_assert!(
109 !options.feedback.is_empty(),
110 "NotSupportedError: feedback must have at least one coeff"
111 );
112
113 debug_assert!(
114 options.feedback.len() <= MAX_COEFFS,
115 "NotSupportedError: feedback max length is {}",
116 MAX_COEFFS
117 );
118
119 debug_assert!(
120 options.feedback[0] != 0.0,
121 "InvalidStateError: first feedback coeff must not be zero"
122 );
123
124 let filter = IIRFilter::new(options.feedforward.clone(), options.feedback.clone());
125
126 Self {
127 filters: vec![filter; channel_info.computed_number_of_channels() as usize],
128 channel_info,
129 }
130 }
131
132 pub fn get_frequency_response(
133 feedforward: &[f64],
134 feedback: &[f64],
135 frequency_hz: &[f32],
136 mag_response: &mut [f32],
137 phase_response: &mut [f32],
138 ) {
139 debug_assert!(
140 frequency_hz.len() == mag_response.len() && frequency_hz.len() == phase_response.len(),
141 "get_frequency_response params are of different length"
142 );
143
144 frequency_hz.iter().enumerate().for_each(|(idx, &f)| {
145 if !(0.0..1.0).contains(&f) {
146 mag_response[idx] = f32::NAN;
147 phase_response[idx] = f32::NAN;
148 } else {
149 let f = (-f as f64) * std::f64::consts::PI;
150 let z = Complex64::new(f64::cos(f), f64::sin(f));
151 let numerator = Self::sum(feedforward, z);
152 let denominator = Self::sum(feedback, z);
153
154 let response = numerator / denominator;
155 mag_response[idx] = response.norm() as f32;
156 phase_response[idx] = response.arg() as f32;
157 }
158 });
159 }
160
161 fn sum(coeffs: &[f64], z: Complex64) -> Complex64 {
162 coeffs.iter().fold(Complex64::new(0.0, 0.0), |acc, &coeff| {
163 acc * z + Complex64::new(coeff, 0.0)
164 })
165 }
166}
167
168impl AudioNodeEngine for IIRFilterNode {
169 fn node_type(&self) -> AudioNodeType {
170 AudioNodeType::IIRFilterNode
171 }
172
173 fn process(&mut self, inputs: Chunk, _info: &BlockInfo) -> Chunk {
174 debug_assert!(inputs.len() == 1);
175
176 let mut inputs = if inputs.blocks[0].is_silence() {
177 Chunk::explicit_silence()
178 } else {
179 inputs
180 };
181
182 let mut iter = inputs.blocks[0].iter();
183
184 while let Some(mut frame) = iter.next() {
185 frame.mutate_with(|sample, chan_idx| {
186 *sample = self.filters[chan_idx as usize].calculate_output(*sample);
187 });
188 }
189 inputs
190 }
191}