1use std::cmp;
2use std::collections::BinaryHeap;
3use std::iter::FromIterator;
4
5use crate::raw::Output;
6use crate::stream::{IntoStreamer, Streamer};
7
8type BoxedStream<'f> =
10 Box<dyn for<'a> Streamer<'a, Item = (&'a [u8], Output)> + 'f>;
11
12#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
20pub struct IndexedValue {
21 pub index: usize,
23 pub value: u64,
25}
26
27pub struct OpBuilder<'f> {
45 streams: Vec<BoxedStream<'f>>,
46}
47
48impl<'f> OpBuilder<'f> {
49 #[inline]
51 pub fn new() -> OpBuilder<'f> {
52 OpBuilder { streams: vec![] }
53 }
54
55 pub fn add<I, S>(mut self, stream: I) -> OpBuilder<'f>
63 where
64 I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
65 S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
66 {
67 self.push(stream);
68 self
69 }
70
71 pub fn push<I, S>(&mut self, stream: I)
76 where
77 I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
78 S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
79 {
80 self.streams.push(Box::new(stream.into_stream()));
81 }
82
83 #[inline]
93 pub fn union(self) -> Union<'f> {
94 Union {
95 heap: StreamHeap::new(self.streams),
96 outs: vec![],
97 cur_slot: None,
98 }
99 }
100
101 #[inline]
111 pub fn intersection(self) -> Intersection<'f> {
112 Intersection {
113 heap: StreamHeap::new(self.streams),
114 outs: vec![],
115 cur_slot: None,
116 }
117 }
118
119 #[inline]
135 pub fn difference(mut self) -> Difference<'f> {
136 let first = self.streams.swap_remove(0);
137 Difference {
138 set: first,
139 key: vec![],
140 heap: StreamHeap::new(self.streams),
141 outs: vec![],
142 }
143 }
144
145 #[inline]
162 pub fn symmetric_difference(self) -> SymmetricDifference<'f> {
163 SymmetricDifference {
164 heap: StreamHeap::new(self.streams),
165 outs: vec![],
166 cur_slot: None,
167 }
168 }
169}
170
171impl<'f, I, S> Extend<I> for OpBuilder<'f>
172where
173 I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
174 S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
175{
176 fn extend<T>(&mut self, it: T)
177 where
178 T: IntoIterator<Item = I>,
179 {
180 for stream in it {
181 self.push(stream);
182 }
183 }
184}
185
186impl<'f, I, S> FromIterator<I> for OpBuilder<'f>
187where
188 I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
189 S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
190{
191 fn from_iter<T>(it: T) -> OpBuilder<'f>
192 where
193 T: IntoIterator<Item = I>,
194 {
195 let mut op = OpBuilder::new();
196 op.extend(it);
197 op
198 }
199}
200
201pub struct Union<'f> {
205 heap: StreamHeap<'f>,
206 outs: Vec<IndexedValue>,
207 cur_slot: Option<Slot>,
208}
209
210impl<'a, 'f> Streamer<'a> for Union<'f> {
211 type Item = (&'a [u8], &'a [IndexedValue]);
212
213 fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
214 if let Some(slot) = self.cur_slot.take() {
215 self.heap.refill(slot);
216 }
217 let slot = match self.heap.pop() {
218 None => return None,
219 Some(slot) => {
220 self.cur_slot = Some(slot);
221 self.cur_slot.as_ref().unwrap()
222 }
223 };
224 self.outs.clear();
225 self.outs.push(slot.indexed_value());
226 while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
227 self.outs.push(slot2.indexed_value());
228 self.heap.refill(slot2);
229 }
230 Some((slot.input(), &self.outs))
231 }
232}
233
234pub struct Intersection<'f> {
239 heap: StreamHeap<'f>,
240 outs: Vec<IndexedValue>,
241 cur_slot: Option<Slot>,
242}
243
244impl<'a, 'f> Streamer<'a> for Intersection<'f> {
245 type Item = (&'a [u8], &'a [IndexedValue]);
246
247 fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
248 if let Some(slot) = self.cur_slot.take() {
249 self.heap.refill(slot);
250 }
251 loop {
252 let slot = match self.heap.pop() {
253 None => return None,
254 Some(slot) => slot,
255 };
256 self.outs.clear();
257 self.outs.push(slot.indexed_value());
258 let mut popped: usize = 1;
259 while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
260 self.outs.push(slot2.indexed_value());
261 self.heap.refill(slot2);
262 popped += 1;
263 }
264 if popped < self.heap.num_slots() {
265 self.heap.refill(slot);
266 } else {
267 self.cur_slot = Some(slot);
268 let key = self.cur_slot.as_ref().unwrap().input();
269 return Some((key, &self.outs));
270 }
271 }
272 }
273}
274
275pub struct Difference<'f> {
284 set: BoxedStream<'f>,
285 key: Vec<u8>,
286 heap: StreamHeap<'f>,
287 outs: Vec<IndexedValue>,
288}
289
290impl<'a, 'f> Streamer<'a> for Difference<'f> {
291 type Item = (&'a [u8], &'a [IndexedValue]);
292
293 fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
294 loop {
295 match self.set.next() {
296 None => return None,
297 Some((key, out)) => {
298 self.key.clear();
299 self.key.extend(key);
300 self.outs.clear();
301 self.outs
302 .push(IndexedValue { index: 0, value: out.value() });
303 }
304 };
305 let mut unique = true;
306 while let Some(slot) = self.heap.pop_if_le(&self.key) {
307 if slot.input() == &*self.key {
308 unique = false;
309 }
310 self.heap.refill(slot);
311 }
312 if unique {
313 return Some((&self.key, &self.outs));
314 }
315 }
316 }
317}
318
319pub struct SymmetricDifference<'f> {
324 heap: StreamHeap<'f>,
325 outs: Vec<IndexedValue>,
326 cur_slot: Option<Slot>,
327}
328
329impl<'a, 'f> Streamer<'a> for SymmetricDifference<'f> {
330 type Item = (&'a [u8], &'a [IndexedValue]);
331
332 fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
333 if let Some(slot) = self.cur_slot.take() {
334 self.heap.refill(slot);
335 }
336 loop {
337 let slot = match self.heap.pop() {
338 None => return None,
339 Some(slot) => slot,
340 };
341 self.outs.clear();
342 self.outs.push(slot.indexed_value());
343 let mut popped: usize = 1;
344 while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
345 self.outs.push(slot2.indexed_value());
346 self.heap.refill(slot2);
347 popped += 1;
348 }
349 if popped % 2 == 0 {
352 self.heap.refill(slot);
353 } else {
354 self.cur_slot = Some(slot);
355 let key = self.cur_slot.as_ref().unwrap().input();
356 return Some((key, &self.outs));
357 }
358 }
359 }
360}
361
362struct StreamHeap<'f> {
363 rdrs: Vec<BoxedStream<'f>>,
364 heap: BinaryHeap<Slot>,
365}
366
367impl<'f> StreamHeap<'f> {
368 fn new(streams: Vec<BoxedStream<'f>>) -> StreamHeap<'f> {
369 let mut u = StreamHeap { rdrs: streams, heap: BinaryHeap::new() };
370 for i in 0..u.rdrs.len() {
371 u.refill(Slot::new(i));
372 }
373 u
374 }
375
376 fn pop(&mut self) -> Option<Slot> {
377 self.heap.pop()
378 }
379
380 fn peek_is_duplicate(&self, key: &[u8]) -> bool {
381 self.heap.peek().map(|s| s.input() == key).unwrap_or(false)
382 }
383
384 fn pop_if_equal(&mut self, key: &[u8]) -> Option<Slot> {
385 if self.peek_is_duplicate(key) {
386 self.pop()
387 } else {
388 None
389 }
390 }
391
392 fn pop_if_le(&mut self, key: &[u8]) -> Option<Slot> {
393 if self.heap.peek().map(|s| s.input() <= key).unwrap_or(false) {
394 self.pop()
395 } else {
396 None
397 }
398 }
399
400 fn num_slots(&self) -> usize {
401 self.rdrs.len()
402 }
403
404 fn refill(&mut self, mut slot: Slot) {
405 if let Some((input, output)) = self.rdrs[slot.idx].next() {
406 slot.set_input(input);
407 slot.set_output(output);
408 self.heap.push(slot);
409 }
410 }
411}
412
413#[derive(Debug, Eq, PartialEq)]
414struct Slot {
415 idx: usize,
416 input: Vec<u8>,
417 output: Output,
418}
419
420impl Slot {
421 fn new(rdr_idx: usize) -> Slot {
422 Slot {
423 idx: rdr_idx,
424 input: Vec::with_capacity(64),
425 output: Output::zero(),
426 }
427 }
428
429 fn indexed_value(&self) -> IndexedValue {
430 IndexedValue { index: self.idx, value: self.output.value() }
431 }
432
433 fn input(&self) -> &[u8] {
434 &self.input
435 }
436
437 fn set_input(&mut self, input: &[u8]) {
438 self.input.clear();
439 self.input.extend(input);
440 }
441
442 fn set_output(&mut self, output: Output) {
443 self.output = output;
444 }
445}
446
447impl PartialOrd for Slot {
448 fn partial_cmp(&self, other: &Slot) -> Option<cmp::Ordering> {
449 (&self.input, self.output)
450 .partial_cmp(&(&other.input, other.output))
451 .map(|ord| ord.reverse())
452 }
453}
454
455impl Ord for Slot {
456 fn cmp(&self, other: &Slot) -> cmp::Ordering {
457 self.partial_cmp(other).unwrap()
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use crate::raw::tests::{fst_map, fst_set};
464 use crate::raw::Fst;
465 use crate::stream::{IntoStreamer, Streamer};
466
467 use super::OpBuilder;
468
469 fn s(string: &str) -> String {
470 string.to_owned()
471 }
472
473 macro_rules! create_set_op {
474 ($name:ident, $op:ident) => {
475 fn $name(sets: Vec<Vec<&str>>) -> Vec<String> {
476 let fsts: Vec<Fst<_>> =
477 sets.into_iter().map(fst_set).collect();
478 let op: OpBuilder = fsts.iter().collect();
479 let mut stream = op.$op().into_stream();
480 let mut keys = vec![];
481 while let Some((key, _)) = stream.next() {
482 keys.push(String::from_utf8(key.to_vec()).unwrap());
483 }
484 keys
485 }
486 };
487 }
488
489 macro_rules! create_map_op {
490 ($name:ident, $op:ident) => {
491 fn $name(sets: Vec<Vec<(&str, u64)>>) -> Vec<(String, u64)> {
492 let fsts: Vec<Fst<_>> =
493 sets.into_iter().map(fst_map).collect();
494 let op: OpBuilder = fsts.iter().collect();
495 let mut stream = op.$op().into_stream();
496 let mut keys = vec![];
497 while let Some((key, outs)) = stream.next() {
498 let merged = outs.iter().fold(0, |a, b| a + b.value);
499 let s = String::from_utf8(key.to_vec()).unwrap();
500 keys.push((s, merged));
501 }
502 keys
503 }
504 };
505 }
506
507 create_set_op!(fst_union, union);
508 create_set_op!(fst_intersection, intersection);
509 create_set_op!(fst_symmetric_difference, symmetric_difference);
510 create_set_op!(fst_difference, difference);
511 create_map_op!(fst_union_map, union);
512 create_map_op!(fst_intersection_map, intersection);
513 create_map_op!(fst_symmetric_difference_map, symmetric_difference);
514 create_map_op!(fst_difference_map, difference);
515
516 #[test]
517 fn union_set() {
518 let v = fst_union(vec![vec!["a", "b", "c"], vec!["x", "y", "z"]]);
519 assert_eq!(v, vec!["a", "b", "c", "x", "y", "z"]);
520 }
521
522 #[test]
523 fn union_set_dupes() {
524 let v = fst_union(vec![vec!["aa", "b", "cc"], vec!["b", "cc", "z"]]);
525 assert_eq!(v, vec!["aa", "b", "cc", "z"]);
526 }
527
528 #[test]
529 fn union_map() {
530 let v = fst_union_map(vec![
531 vec![("a", 1), ("b", 2), ("c", 3)],
532 vec![("x", 1), ("y", 2), ("z", 3)],
533 ]);
534 assert_eq!(
535 v,
536 vec![
537 (s("a"), 1),
538 (s("b"), 2),
539 (s("c"), 3),
540 (s("x"), 1),
541 (s("y"), 2),
542 (s("z"), 3),
543 ]
544 );
545 }
546
547 #[test]
548 fn union_map_dupes() {
549 let v = fst_union_map(vec![
550 vec![("aa", 1), ("b", 2), ("cc", 3)],
551 vec![("b", 1), ("cc", 2), ("z", 3)],
552 vec![("b", 1)],
553 ]);
554 assert_eq!(
555 v,
556 vec![(s("aa"), 1), (s("b"), 4), (s("cc"), 5), (s("z"), 3),]
557 );
558 }
559
560 #[test]
561 fn intersect_set() {
562 let v =
563 fst_intersection(vec![vec!["a", "b", "c"], vec!["x", "y", "z"]]);
564 assert_eq!(v, Vec::<String>::new());
565 }
566
567 #[test]
568 fn intersect_set_dupes() {
569 let v = fst_intersection(vec![
570 vec!["aa", "b", "cc"],
571 vec!["b", "cc", "z"],
572 ]);
573 assert_eq!(v, vec!["b", "cc"]);
574 }
575
576 #[test]
577 fn intersect_map() {
578 let v = fst_intersection_map(vec![
579 vec![("a", 1), ("b", 2), ("c", 3)],
580 vec![("x", 1), ("y", 2), ("z", 3)],
581 ]);
582 assert_eq!(v, Vec::<(String, u64)>::new());
583 }
584
585 #[test]
586 fn intersect_map_dupes() {
587 let v = fst_intersection_map(vec![
588 vec![("aa", 1), ("b", 2), ("cc", 3)],
589 vec![("b", 1), ("cc", 2), ("z", 3)],
590 vec![("b", 1)],
591 ]);
592 assert_eq!(v, vec![(s("b"), 4)]);
593 }
594
595 #[test]
596 fn symmetric_difference() {
597 let v = fst_symmetric_difference(vec![
598 vec!["a", "b", "c"],
599 vec!["a", "b"],
600 vec!["a"],
601 ]);
602 assert_eq!(v, vec!["a", "c"]);
603 }
604
605 #[test]
606 fn symmetric_difference_map() {
607 let v = fst_symmetric_difference_map(vec![
608 vec![("a", 1), ("b", 2), ("c", 3)],
609 vec![("a", 1), ("b", 2)],
610 vec![("a", 1)],
611 ]);
612 assert_eq!(v, vec![(s("a"), 3), (s("c"), 3)]);
613 }
614
615 #[test]
616 fn difference() {
617 let v = fst_difference(vec![
618 vec!["a", "b", "c"],
619 vec!["a", "b"],
620 vec!["a"],
621 ]);
622 assert_eq!(v, vec!["c"]);
623 }
624
625 #[test]
626 fn difference2() {
627 let v = fst_difference(vec![vec!["a", "c"], vec!["b", "c"]]);
629 assert_eq!(v, vec!["a"]);
630 let v = fst_difference(vec![vec!["bar", "foo"], vec!["baz", "foo"]]);
631 assert_eq!(v, vec!["bar"]);
632 }
633
634 #[test]
635 fn difference_map() {
636 let v = fst_difference_map(vec![
637 vec![("a", 1), ("b", 2), ("c", 3)],
638 vec![("a", 1), ("b", 2)],
639 vec![("a", 1)],
640 ]);
641 assert_eq!(v, vec![(s("c"), 3)]);
642 }
643}