servo_media_audio/
iir_filter_node.rs1use std::collections::VecDeque;
6use std::sync::Arc;
7
8use log::warn;
9use num_complex::Complex64;
10
11use crate::block::Chunk;
12use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
13
14const MAX_COEFFS: usize = 20;
15
16#[derive(Debug)]
17pub struct IIRFilterNodeOptions {
18 pub feedforward: Arc<Vec<f64>>,
19 pub feedback: Arc<Vec<f64>>,
20}
21
22#[derive(Clone)]
23struct IIRFilter {
24 feedforward: Arc<Vec<f64>>,
25 feedback: Arc<Vec<f64>>,
26 inputs: VecDeque<f64>,
27 outputs: VecDeque<f64>,
28}
29
30impl IIRFilter {
31 fn new(feedforward: Arc<Vec<f64>>, feedback: Arc<Vec<f64>>) -> Self {
32 Self {
33 feedforward,
34 feedback,
35 inputs: VecDeque::with_capacity(MAX_COEFFS),
36 outputs: VecDeque::with_capacity(MAX_COEFFS),
37 }
38 }
39
40 fn calculate_output(&mut self, input: f32) -> f32 {
41 self.inputs.push_front(input as f64);
42
43 if self.inputs.len() > MAX_COEFFS {
44 self.inputs.pop_back();
45 }
46
47 let inputs_sum = self
48 .feedforward
49 .iter()
50 .zip(self.inputs.iter())
51 .fold(0.0, |acc, (c, v)| acc + c * v);
52
53 let outputs_sum = self
54 .feedback
55 .iter()
56 .skip(1)
57 .zip(self.outputs.iter())
58 .fold(0.0, |acc, (c, v)| acc + c * v);
59
60 let output = (inputs_sum - outputs_sum) / self.feedback[0];
61
62 if output.is_nan() {
63 warn!("NaN in IIRFilter state");
69 }
70
71 self.outputs.push_front(output);
72
73 if self.outputs.len() > MAX_COEFFS {
74 self.outputs.pop_back();
75 }
76
77 output as f32
78 }
79}
80
81#[derive(AudioNodeCommon)]
82pub struct IIRFilterNode {
83 channel_info: ChannelInfo,
84 filters: Vec<IIRFilter>,
85}
86
87impl IIRFilterNode {
88 pub fn new(options: IIRFilterNodeOptions, channel_info: ChannelInfo) -> Self {
89 debug_assert!(
90 !options.feedforward.is_empty(),
91 "NotSupportedError: feedforward must have at least one coeff"
92 );
93
94 debug_assert!(
95 options.feedforward.len() <= MAX_COEFFS,
96 "NotSupportedError: feedforward max length is {}",
97 MAX_COEFFS
98 );
99
100 debug_assert!(
101 options.feedforward.iter().any(|&v| v != 0.0_f64),
102 "InvalidStateError: all coeffs are zero"
103 );
104
105 debug_assert!(
106 !options.feedback.is_empty(),
107 "NotSupportedError: feedback must have at least one coeff"
108 );
109
110 debug_assert!(
111 options.feedback.len() <= MAX_COEFFS,
112 "NotSupportedError: feedback max length is {}",
113 MAX_COEFFS
114 );
115
116 debug_assert!(
117 options.feedback[0] != 0.0,
118 "InvalidStateError: first feedback coeff must not be zero"
119 );
120
121 let filter = IIRFilter::new(options.feedforward.clone(), options.feedback.clone());
122
123 Self {
124 filters: vec![filter; channel_info.computed_number_of_channels() as usize],
125 channel_info,
126 }
127 }
128
129 pub fn get_frequency_response(
130 feedforward: &[f64],
131 feedback: &[f64],
132 frequency_hz: &[f32],
133 mag_response: &mut [f32],
134 phase_response: &mut [f32],
135 ) {
136 debug_assert!(
137 frequency_hz.len() == mag_response.len() && frequency_hz.len() == phase_response.len(),
138 "get_frequency_response params are of different length"
139 );
140
141 frequency_hz.iter().enumerate().for_each(|(idx, &f)| {
142 if !(0.0..1.0).contains(&f) {
143 mag_response[idx] = f32::NAN;
144 phase_response[idx] = f32::NAN;
145 } else {
146 let f = (-f as f64) * std::f64::consts::PI;
147 let z = Complex64::new(f64::cos(f), f64::sin(f));
148 let numerator = Self::sum(feedforward, z);
149 let denominator = Self::sum(feedback, z);
150
151 let response = numerator / denominator;
152 mag_response[idx] = response.norm() as f32;
153 phase_response[idx] = response.arg() as f32;
154 }
155 });
156 }
157
158 fn sum(coeffs: &[f64], z: Complex64) -> Complex64 {
159 coeffs.iter().fold(Complex64::new(0.0, 0.0), |acc, &coeff| {
160 acc * z + Complex64::new(coeff, 0.0)
161 })
162 }
163}
164
165impl AudioNodeEngine for IIRFilterNode {
166 fn node_type(&self) -> AudioNodeType {
167 AudioNodeType::IIRFilterNode
168 }
169
170 fn process(&mut self, inputs: Chunk, _info: &BlockInfo) -> Chunk {
171 debug_assert!(inputs.len() == 1);
172
173 let mut inputs = if inputs.blocks[0].is_silence() {
174 Chunk::explicit_silence()
175 } else {
176 inputs
177 };
178
179 let mut iter = inputs.blocks[0].iter();
180
181 while let Some(mut frame) = iter.next() {
182 frame.mutate_with(|sample, chan_idx| {
183 *sample = self.filters[chan_idx as usize].calculate_output(*sample);
184 });
185 }
186 inputs
187 }
188}