servo_media_audio/
iir_filter_node.rs1use crate::block::Chunk;
2use log::warn;
3use crate::node::{AudioNodeEngine, AudioNodeType, BlockInfo, ChannelInfo};
4use num_complex::Complex64;
5use std::collections::VecDeque;
6use std::sync::Arc;
7
8const MAX_COEFFS: usize = 20;
9
10#[derive(Debug)]
11pub struct IIRFilterNodeOptions {
12 pub feedforward: Arc<Vec<f64>>,
13 pub feedback: Arc<Vec<f64>>,
14}
15
16#[derive(Clone)]
17struct IIRFilter {
18 feedforward: Arc<Vec<f64>>,
19 feedback: Arc<Vec<f64>>,
20 inputs: VecDeque<f64>,
21 outputs: VecDeque<f64>,
22}
23
24impl IIRFilter {
25 fn new(feedforward: Arc<Vec<f64>>, feedback: Arc<Vec<f64>>) -> Self {
26 Self {
27 feedforward,
28 feedback,
29 inputs: VecDeque::with_capacity(MAX_COEFFS),
30 outputs: VecDeque::with_capacity(MAX_COEFFS),
31 }
32 }
33
34 fn calculate_output(&mut self, input: f32) -> f32 {
35 self.inputs.push_front(input as f64);
36
37 if self.inputs.len() > MAX_COEFFS {
38 self.inputs.pop_back();
39 }
40
41 let inputs_sum = self
42 .feedforward
43 .iter()
44 .zip(self.inputs.iter())
45 .fold(0.0, |acc, (c, v)| acc + c * v);
46
47 let outputs_sum = self
48 .feedback
49 .iter()
50 .skip(1)
51 .zip(self.outputs.iter())
52 .fold(0.0, |acc, (c, v)| acc + c * v);
53
54 let output = (inputs_sum - outputs_sum) / self.feedback[0];
55
56 if output.is_nan() {
57 warn!("NaN in IIRFilter state");
63 }
64
65 self.outputs.push_front(output);
66
67 if self.outputs.len() > MAX_COEFFS {
68 self.outputs.pop_back();
69 }
70
71 output as f32
72 }
73}
74
75#[derive(AudioNodeCommon)]
76pub struct IIRFilterNode {
77 channel_info: ChannelInfo,
78 filters: Vec<IIRFilter>,
79}
80
81impl IIRFilterNode {
82 pub fn new(options: IIRFilterNodeOptions, channel_info: ChannelInfo) -> Self {
83 debug_assert!(
84 options.feedforward.len() > 0,
85 "NotSupportedError: feedforward must have at least one coeff"
86 );
87
88 debug_assert!(
89 options.feedforward.len() <= MAX_COEFFS,
90 "NotSupportedError: feedforward max length is {}",
91 MAX_COEFFS
92 );
93
94 debug_assert!(
95 options.feedforward.iter().any(|&v| v != 0.0_f64),
96 "InvalidStateError: all coeffs are zero"
97 );
98
99 debug_assert!(
100 options.feedback.len() > 0,
101 "NotSupportedError: feedback must have at least one coeff"
102 );
103
104 debug_assert!(
105 options.feedback.len() <= MAX_COEFFS,
106 "NotSupportedError: feedback max length is {}",
107 MAX_COEFFS
108 );
109
110 debug_assert!(
111 options.feedback[0] != 0.0,
112 "InvalidStateError: first feedback coeff must not be zero"
113 );
114
115 let filter = IIRFilter::new(options.feedforward.clone(), options.feedback.clone());
116
117 Self {
118 filters: vec![filter; channel_info.computed_number_of_channels() as usize],
119 channel_info,
120 }
121 }
122
123 pub fn get_frequency_response(
124 feedforward: &[f64],
125 feedback: &[f64],
126 frequency_hz: &[f32],
127 mag_response: &mut [f32],
128 phase_response: &mut [f32],
129 ) {
130 debug_assert!(
131 frequency_hz.len() == mag_response.len() && frequency_hz.len() == phase_response.len(),
132 "get_frequency_response params are of different length"
133 );
134
135 frequency_hz.iter().enumerate().for_each(|(idx, &f)| {
136 if f < 0.0 || f >= 1.0 {
137 mag_response[idx] = std::f32::NAN;
138 phase_response[idx] = std::f32::NAN;
139 } else {
140 let f = (-f as f64) * std::f64::consts::PI;
141 let z = Complex64::new(f64::cos(f), f64::sin(f));
142 let numerator = Self::sum(feedforward, z);
143 let denominator = Self::sum(feedback, z);
144
145 let response = numerator / denominator;
146 mag_response[idx] = response.norm() as f32;
147 phase_response[idx] = response.arg() as f32;
148 }
149 });
150 }
151
152 fn sum(coeffs: &[f64], z: Complex64) -> Complex64 {
153 coeffs.iter().fold(Complex64::new(0.0, 0.0), |acc, &coeff| {
154 acc * z + Complex64::new(coeff, 0.0)
155 })
156 }
157}
158
159impl AudioNodeEngine for IIRFilterNode {
160 fn node_type(&self) -> AudioNodeType {
161 AudioNodeType::IIRFilterNode
162 }
163
164 fn process(&mut self, inputs: Chunk, _info: &BlockInfo) -> Chunk {
165 debug_assert!(inputs.len() == 1);
166
167 let mut inputs = if inputs.blocks[0].is_silence() {
168 Chunk::explicit_silence()
169 } else {
170 inputs
171 };
172
173 let mut iter = inputs.blocks[0].iter();
174
175 while let Some(mut frame) = iter.next() {
176 frame.mutate_with(|sample, chan_idx| {
177 *sample = self.filters[chan_idx as usize].calculate_output(*sample);
178 });
179 }
180 inputs
181 }
182}