1use crate::block::{Block, Chunk};
2use crate::destination_node::DestinationNode;
3use crate::listener::AudioListenerNode;
4use crate::node::{AudioNodeEngine, BlockInfo, ChannelCountMode, ChannelInterpretation};
5use crate::param::ParamType;
6use petgraph::graph::DefaultIx;
7use petgraph::stable_graph::NodeIndex;
8use petgraph::stable_graph::StableGraph;
9use petgraph::visit::{DfsPostOrder, EdgeRef, Reversed};
10use petgraph::Direction;
11use smallvec::SmallVec;
12use std::cell::{RefCell, RefMut};
13use std::{cmp, fmt, hash};
14
15#[derive(Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash, Debug)]
16pub struct NodeId(NodeIndex<DefaultIx>);
19
20impl NodeId {
21 pub fn input(self, port: u32) -> PortId<InputPort> {
22 PortId(self, PortIndex::Port(port))
23 }
24 pub fn param(self, param: ParamType) -> PortId<InputPort> {
25 PortId(self, PortIndex::Param(param))
26 }
27 pub fn output(self, port: u32) -> PortId<OutputPort> {
28 PortId(self, PortIndex::Port(port))
29 }
30 pub(crate) fn listener(self) -> PortId<InputPort> {
31 PortId(self, PortIndex::Listener(()))
32 }
33}
34
35#[derive(Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash, Debug)]
45pub enum PortIndex<Kind: PortKind> {
46 Port(u32),
47 Param(Kind::ParamId),
48 Listener(Kind::Listener),
51}
52
53impl<Kind: PortKind> PortId<Kind> {
54 pub fn node(&self) -> NodeId {
55 self.0
56 }
57}
58
59pub trait PortKind {
60 type ParamId: Copy + Eq + PartialEq + Ord + PartialOrd + hash::Hash + fmt::Debug;
61 type Listener: Copy + Eq + PartialEq + Ord + PartialOrd + hash::Hash + fmt::Debug;
62}
63
64#[derive(Clone, Copy, PartialEq, Eq, Ord, PartialOrd, Hash, Debug)]
66pub struct PortId<Kind: PortKind>(NodeId, PortIndex<Kind>);
67
68#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
69pub struct InputPort;
72#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Debug)]
73pub struct OutputPort;
76
77impl PortKind for InputPort {
78 type ParamId = ParamType;
79 type Listener = ();
80}
81
82#[derive(Debug, Hash, PartialOrd, Ord, PartialEq, Eq, Copy, Clone)]
83pub enum Void {}
84
85impl PortKind for OutputPort {
86 type ParamId = Void;
91 type Listener = Void;
92}
93
94pub struct AudioGraph {
95 graph: StableGraph<Node, Edge>,
96 dest_id: NodeId,
97 dests: Vec<NodeId>,
98 listener_id: NodeId,
99}
100
101pub(crate) struct Node {
102 node: RefCell<Box<dyn AudioNodeEngine>>,
103}
104
105pub(crate) struct Edge {
113 connections: SmallVec<[Connection; 1]>,
114}
115
116impl Edge {
117 fn has_between(
119 &self,
120 output_idx: PortIndex<OutputPort>,
121 input_idx: PortIndex<InputPort>,
122 ) -> bool {
123 self.connections
124 .iter()
125 .find(|e| e.input_idx == input_idx && e.output_idx == output_idx)
126 .is_some()
127 }
128
129 fn remove_by_output(&mut self, output_idx: PortIndex<OutputPort>) {
130 self.connections.retain(|i| i.output_idx != output_idx)
131 }
132
133 fn remove_by_input(&mut self, input_idx: PortIndex<InputPort>) {
134 self.connections.retain(|i| i.input_idx != input_idx)
135 }
136
137 fn remove_by_pair(
138 &mut self,
139 output_idx: PortIndex<OutputPort>,
140 input_idx: PortIndex<InputPort>,
141 ) {
142 self.connections
143 .retain(|i| i.output_idx != output_idx || i.input_idx != input_idx)
144 }
145}
146
147struct Connection {
149 input_idx: PortIndex<InputPort>,
152 output_idx: PortIndex<OutputPort>,
155 cache: RefCell<Option<Block>>,
158}
159
160impl AudioGraph {
161 pub fn new(channel_count: u8) -> Self {
162 let mut graph = StableGraph::new();
163 let dest_id =
164 NodeId(graph.add_node(Node::new(Box::new(DestinationNode::new(channel_count)))));
165 let listener_id = NodeId(graph.add_node(Node::new(Box::new(AudioListenerNode::new()))));
166 AudioGraph {
167 graph,
168 dest_id,
169 dests: vec![dest_id],
170 listener_id,
171 }
172 }
173
174 pub(crate) fn add_node(&mut self, node: Box<dyn AudioNodeEngine>) -> NodeId {
176 NodeId(self.graph.add_node(Node::new(node)))
177 }
178
179 pub fn add_edge(&mut self, out: PortId<OutputPort>, inp: PortId<InputPort>) {
183 let edge = self
184 .graph
185 .edges(out.node().0)
186 .find(|e| e.target() == inp.node().0)
187 .map(|e| e.id());
188 if let Some(e) = edge {
189 let w = self
191 .graph
192 .edge_weight_mut(e)
193 .expect("This edge is known to exist");
194 if w.has_between(out.1, inp.1) {
195 return;
196 }
197 w.connections.push(Connection::new(inp.1, out.1))
198 } else {
199 self.graph
201 .add_edge(out.node().0, inp.node().0, Edge::new(inp.1, out.1));
202 }
203 }
204
205 pub fn disconnect_all_from(&mut self, node: NodeId) {
209 let edges = self.graph.edges(node.0).map(|e| e.id()).collect::<Vec<_>>();
210 for edge in edges {
211 self.graph.remove_edge(edge);
212 }
213 }
214
215 pub fn disconnect_output(&mut self, out: PortId<OutputPort>) {
219 let candidates: Vec<_> = self
220 .graph
221 .edges(out.node().0)
222 .map(|e| (e.id(), e.target()))
223 .collect();
224 for (edge, to) in candidates {
225 let mut e = self
226 .graph
227 .remove_edge(edge)
228 .expect("Edge index is known to exist");
229 e.remove_by_output(out.1);
230 if !e.connections.is_empty() {
231 self.graph.add_edge(out.node().0, to, e);
232 }
233 }
234 }
235
236 pub fn disconnect_between(&mut self, from: NodeId, to: NodeId) {
240 let edge = self
241 .graph
242 .edges(from.0)
243 .find(|e| e.target() == to.0)
244 .map(|e| e.id());
245 if let Some(i) = edge {
246 self.graph.remove_edge(i);
247 }
248 }
249
250 pub fn disconnect_output_between(&mut self, out: PortId<OutputPort>, to: NodeId) {
254 let edge = self
255 .graph
256 .edges(out.node().0)
257 .find(|e| e.target() == to.0)
258 .map(|e| e.id());
259 if let Some(edge) = edge {
260 let mut e = self
261 .graph
262 .remove_edge(edge)
263 .expect("Edge index is known to exist");
264 e.remove_by_output(out.1);
265 if !e.connections.is_empty() {
266 self.graph.add_edge(out.node().0, to.0, e);
267 }
268 }
269 }
270
271 pub fn disconnect_to(&mut self, node: NodeId, inp: PortId<InputPort>) {
277 let edge = self
278 .graph
279 .edges(node.0)
280 .find(|e| e.target() == inp.node().0)
281 .map(|e| e.id());
282 if let Some(edge) = edge {
283 let mut e = self
284 .graph
285 .remove_edge(edge)
286 .expect("Edge index is known to exist");
287 e.remove_by_input(inp.1);
288 if !e.connections.is_empty() {
289 self.graph.add_edge(node.0, inp.node().0, e);
290 }
291 }
292 }
293
294 pub fn disconnect_output_between_to(
299 &mut self,
300 out: PortId<OutputPort>,
301 inp: PortId<InputPort>,
302 ) {
303 let edge = self
304 .graph
305 .edges(out.node().0)
306 .find(|e| e.target() == inp.node().0)
307 .map(|e| e.id());
308 if let Some(edge) = edge {
309 let mut e = self
310 .graph
311 .remove_edge(edge)
312 .expect("Edge index is known to exist");
313 e.remove_by_pair(out.1, inp.1);
314 if !e.connections.is_empty() {
315 self.graph.add_edge(out.node().0, inp.node().0, e);
316 }
317 }
318 }
319
320 pub fn dest_id(&self) -> NodeId {
324 self.dest_id
325 }
326
327 pub fn add_extra_dest(&mut self, dest: NodeId) {
329 self.dests.push(dest);
330 }
331
332 pub fn listener_id(&self) -> NodeId {
340 self.listener_id
341 }
342
343 pub fn process(&mut self, info: &BlockInfo) -> Chunk {
345 let reversed = Reversed(&self.graph);
351
352 let mut blocks: SmallVec<[SmallVec<[Block; 1]>; 1]> = SmallVec::new();
353 let mut output_counts: SmallVec<[u32; 1]> = SmallVec::new();
354
355 let mut visit = DfsPostOrder::empty(reversed);
356
357 for dest in &self.dests {
358 visit.move_to(dest.0);
359
360 while let Some(ix) = visit.next(reversed) {
361 let mut curr = self.graph[ix].node.borrow_mut();
362
363 let mut chunk = Chunk::default();
364 chunk
365 .blocks
366 .resize(curr.input_count() as usize, Default::default());
367
368 blocks.clear();
373 blocks.resize(curr.input_count() as usize, Default::default());
374
375 let mode = curr.channel_count_mode();
376 let count = curr.channel_count();
377 let interpretation = curr.channel_interpretation();
378
379 for edge in self.graph.edges_directed(ix, Direction::Incoming) {
381 let edge = edge.weight();
382 for connection in &edge.connections {
383 let mut block = connection
384 .cache
385 .borrow_mut()
386 .take()
387 .expect("Cache should have been filled from traversal");
388
389 match connection.input_idx {
390 PortIndex::Port(idx) => {
391 blocks[idx as usize].push(block);
392 }
393 PortIndex::Param(param) => {
394 block.mix(1, ChannelInterpretation::Speakers);
397 curr.get_param(param).add_block(block)
398 }
399 PortIndex::Listener(_) => curr.set_listenerdata(block),
400 }
401 }
402 }
403
404 for (i, mut blocks) in blocks.drain(..).enumerate() {
405 if blocks.len() == 0 {
406 if mode == ChannelCountMode::Explicit {
407 chunk.blocks[i].mix(count, interpretation);
409 }
410 } else if blocks.len() == 1 {
411 chunk.blocks[i] = blocks.pop().expect("`blocks` had length 1");
412 match mode {
413 ChannelCountMode::Explicit => {
414 chunk.blocks[i].mix(count, interpretation);
415 }
416 ChannelCountMode::ClampedMax => {
417 if chunk.blocks[i].chan_count() > count {
418 chunk.blocks[i].mix(count, interpretation);
419 }
420 }
421 ChannelCountMode::Max => (),
423 }
424 } else {
425 let mix_count = match mode {
426 ChannelCountMode::Explicit => count,
427 _ => {
428 let mut max = 0; for block in &blocks {
430 max = cmp::max(max, block.chan_count());
431 }
432 if mode == ChannelCountMode::ClampedMax {
433 max = cmp::min(max, count);
434 }
435 max
436 }
437 };
438 let block = blocks.into_iter().fold(Block::default(), |acc, mut block| {
439 block.mix(mix_count, interpretation);
440 acc.sum(block)
441 });
442 chunk.blocks[i] = block;
443 }
444 }
445
446 let mut out = curr.process(chunk, info);
448
449 assert_eq!(out.len(), curr.output_count() as usize);
450 if curr.output_count() == 0 {
451 continue;
452 }
453
454 output_counts.clear();
460 output_counts.resize(curr.output_count() as usize, 0);
461 for edge in self.graph.edges(ix) {
462 let edge = edge.weight();
463 for conn in &edge.connections {
464 if let PortIndex::Port(idx) = conn.output_idx {
465 output_counts[idx as usize] += 1;
466 } else {
467 unreachable!()
468 }
469 }
470 }
471
472 for edge in self.graph.edges(ix) {
475 let edge = edge.weight();
476 for conn in &edge.connections {
477 if let PortIndex::Port(idx) = conn.output_idx {
478 output_counts[idx as usize] -= 1;
479 let block = if output_counts[idx as usize] == 0 {
481 out[conn.output_idx].take()
482 } else {
483 out[conn.output_idx].clone()
484 };
485 *conn.cache.borrow_mut() = Some(block);
486 } else {
487 unreachable!()
488 }
489 }
490 }
491 }
492 }
493 self.graph[self.dest_id.0]
495 .node
496 .borrow_mut()
497 .destination_data()
498 .expect("Destination node should have data cached")
499 }
500
501 pub(crate) fn node_mut(&self, ix: NodeId) -> RefMut<Box<dyn AudioNodeEngine>> {
503 self.graph[ix.0].node.borrow_mut()
504 }
505}
506
507impl Node {
508 pub fn new(node: Box<dyn AudioNodeEngine>) -> Self {
509 Node {
510 node: RefCell::new(node),
511 }
512 }
513}
514
515impl Edge {
516 pub fn new(input_idx: PortIndex<InputPort>, output_idx: PortIndex<OutputPort>) -> Self {
517 Edge {
518 connections: SmallVec::from_buf([Connection::new(input_idx, output_idx)]),
519 }
520 }
521}
522
523impl Connection {
524 pub fn new(input_idx: PortIndex<InputPort>, output_idx: PortIndex<OutputPort>) -> Self {
525 Connection {
526 input_idx,
527 output_idx,
528 cache: RefCell::new(None),
529 }
530 }
531}