fst/raw/
ops.rs

1use std::cmp;
2use std::collections::BinaryHeap;
3use std::iter::FromIterator;
4
5use crate::raw::Output;
6use crate::stream::{IntoStreamer, Streamer};
7
8/// Permits stream operations to be hetergeneous with respect to streams.
9type BoxedStream<'f> =
10    Box<dyn for<'a> Streamer<'a, Item = (&'a [u8], Output)> + 'f>;
11
12/// A value indexed by a stream.
13///
14/// Indexed values are used to indicate the presence of a key in multiple
15/// streams during a set operation. Namely, the index corresponds to the stream
16/// (by the order in which it was added to the operation, starting at `0`)
17/// and the value corresponds to the value associated with a particular key
18/// in that stream.
19#[derive(Copy, Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
20pub struct IndexedValue {
21    /// The index of the stream that produced this value (starting at `0`).
22    pub index: usize,
23    /// The value.
24    pub value: u64,
25}
26
27/// A builder for collecting fst streams on which to perform set operations
28/// on the keys of fsts.
29///
30/// Set operations include intersection, union, difference and symmetric
31/// difference. The result of each set operation is itself a stream that emits
32/// pairs of keys and a sequence of each occurrence of that key in the
33/// participating streams. This information allows one to perform set
34/// operations on fsts and customize how conflicting output values are handled.
35///
36/// All set operations work efficiently on an arbitrary number of
37/// streams with memory proportional to the number of streams.
38///
39/// The algorithmic complexity of all set operations is `O(n1 + n2 + n3 + ...)`
40/// where `n1, n2, n3, ...` correspond to the number of elements in each
41/// stream.
42///
43/// The `'f` lifetime parameter refers to the lifetime of the underlying set.
44pub struct OpBuilder<'f> {
45    streams: Vec<BoxedStream<'f>>,
46}
47
48impl<'f> OpBuilder<'f> {
49    /// Create a new set operation builder.
50    #[inline]
51    pub fn new() -> OpBuilder<'f> {
52        OpBuilder { streams: vec![] }
53    }
54
55    /// Add a stream to this set operation.
56    ///
57    /// This is useful for a chaining style pattern, e.g.,
58    /// `builder.add(stream1).add(stream2).union()`.
59    ///
60    /// The stream must emit a lexicographically ordered sequence of key-value
61    /// pairs.
62    pub fn add<I, S>(mut self, stream: I) -> OpBuilder<'f>
63    where
64        I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
65        S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
66    {
67        self.push(stream);
68        self
69    }
70
71    /// Add a stream to this set operation.
72    ///
73    /// The stream must emit a lexicographically ordered sequence of key-value
74    /// pairs.
75    pub fn push<I, S>(&mut self, stream: I)
76    where
77        I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
78        S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
79    {
80        self.streams.push(Box::new(stream.into_stream()));
81    }
82
83    /// Performs a union operation on all streams that have been added.
84    ///
85    /// Note that this returns a stream of `(&[u8], &[IndexedValue])`. The
86    /// first element of the tuple is the byte string key. The second element
87    /// of the tuple is a list of all occurrences of that key in participating
88    /// streams. The `IndexedValue` contains an index and the value associated
89    /// with that key in that stream. The index uniquely identifies each
90    /// stream, which is an integer that is auto-incremented when a stream
91    /// is added to this operation (starting at `0`).
92    #[inline]
93    pub fn union(self) -> Union<'f> {
94        Union {
95            heap: StreamHeap::new(self.streams),
96            outs: vec![],
97            cur_slot: None,
98        }
99    }
100
101    /// Performs an intersection operation on all streams that have been added.
102    ///
103    /// Note that this returns a stream of `(&[u8], &[IndexedValue])`. The
104    /// first element of the tuple is the byte string key. The second element
105    /// of the tuple is a list of all occurrences of that key in participating
106    /// streams. The `IndexedValue` contains an index and the value associated
107    /// with that key in that stream. The index uniquely identifies each
108    /// stream, which is an integer that is auto-incremented when a stream
109    /// is added to this operation (starting at `0`).
110    #[inline]
111    pub fn intersection(self) -> Intersection<'f> {
112        Intersection {
113            heap: StreamHeap::new(self.streams),
114            outs: vec![],
115            cur_slot: None,
116        }
117    }
118
119    /// Performs a difference operation with respect to the first stream added.
120    /// That is, this returns a stream of all elements in the first stream
121    /// that don't exist in any other stream that has been added.
122    ///
123    /// Note that this returns a stream of `(&[u8], &[IndexedValue])`. The
124    /// first element of the tuple is the byte string key. The second element
125    /// of the tuple is a list of all occurrences of that key in participating
126    /// streams. The `IndexedValue` contains an index and the value associated
127    /// with that key in that stream. The index uniquely identifies each
128    /// stream, which is an integer that is auto-incremented when a stream
129    /// is added to this operation (starting at `0`).
130    ///
131    /// The interface is the same for all the operations, but due to the nature
132    /// of `difference`, each yielded key contains exactly one `IndexValue` with
133    /// `index` set to 0.
134    #[inline]
135    pub fn difference(mut self) -> Difference<'f> {
136        let first = self.streams.swap_remove(0);
137        Difference {
138            set: first,
139            key: vec![],
140            heap: StreamHeap::new(self.streams),
141            outs: vec![],
142        }
143    }
144
145    /// Performs a symmetric difference operation on all of the streams that
146    /// have been added.
147    ///
148    /// When there are only two streams, then the keys returned correspond to
149    /// keys that are in either stream but *not* in both streams.
150    ///
151    /// More generally, for any number of streams, keys that occur in an odd
152    /// number of streams are returned.
153    ///
154    /// Note that this returns a stream of `(&[u8], &[IndexedValue])`. The
155    /// first element of the tuple is the byte string key. The second element
156    /// of the tuple is a list of all occurrences of that key in participating
157    /// streams. The `IndexedValue` contains an index and the value associated
158    /// with that key in that stream. The index uniquely identifies each
159    /// stream, which is an integer that is auto-incremented when a stream
160    /// is added to this operation (starting at `0`).
161    #[inline]
162    pub fn symmetric_difference(self) -> SymmetricDifference<'f> {
163        SymmetricDifference {
164            heap: StreamHeap::new(self.streams),
165            outs: vec![],
166            cur_slot: None,
167        }
168    }
169}
170
171impl<'f, I, S> Extend<I> for OpBuilder<'f>
172where
173    I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
174    S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
175{
176    fn extend<T>(&mut self, it: T)
177    where
178        T: IntoIterator<Item = I>,
179    {
180        for stream in it {
181            self.push(stream);
182        }
183    }
184}
185
186impl<'f, I, S> FromIterator<I> for OpBuilder<'f>
187where
188    I: for<'a> IntoStreamer<'a, Into = S, Item = (&'a [u8], Output)>,
189    S: 'f + for<'a> Streamer<'a, Item = (&'a [u8], Output)>,
190{
191    fn from_iter<T>(it: T) -> OpBuilder<'f>
192    where
193        T: IntoIterator<Item = I>,
194    {
195        let mut op = OpBuilder::new();
196        op.extend(it);
197        op
198    }
199}
200
201/// A stream of set union over multiple fst streams in lexicographic order.
202///
203/// The `'f` lifetime parameter refers to the lifetime of the underlying map.
204pub struct Union<'f> {
205    heap: StreamHeap<'f>,
206    outs: Vec<IndexedValue>,
207    cur_slot: Option<Slot>,
208}
209
210impl<'a, 'f> Streamer<'a> for Union<'f> {
211    type Item = (&'a [u8], &'a [IndexedValue]);
212
213    fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
214        if let Some(slot) = self.cur_slot.take() {
215            self.heap.refill(slot);
216        }
217        let slot = match self.heap.pop() {
218            None => return None,
219            Some(slot) => {
220                self.cur_slot = Some(slot);
221                self.cur_slot.as_ref().unwrap()
222            }
223        };
224        self.outs.clear();
225        self.outs.push(slot.indexed_value());
226        while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
227            self.outs.push(slot2.indexed_value());
228            self.heap.refill(slot2);
229        }
230        Some((slot.input(), &self.outs))
231    }
232}
233
234/// A stream of set intersection over multiple fst streams in lexicographic
235/// order.
236///
237/// The `'f` lifetime parameter refers to the lifetime of the underlying fst.
238pub struct Intersection<'f> {
239    heap: StreamHeap<'f>,
240    outs: Vec<IndexedValue>,
241    cur_slot: Option<Slot>,
242}
243
244impl<'a, 'f> Streamer<'a> for Intersection<'f> {
245    type Item = (&'a [u8], &'a [IndexedValue]);
246
247    fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
248        if let Some(slot) = self.cur_slot.take() {
249            self.heap.refill(slot);
250        }
251        loop {
252            let slot = match self.heap.pop() {
253                None => return None,
254                Some(slot) => slot,
255            };
256            self.outs.clear();
257            self.outs.push(slot.indexed_value());
258            let mut popped: usize = 1;
259            while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
260                self.outs.push(slot2.indexed_value());
261                self.heap.refill(slot2);
262                popped += 1;
263            }
264            if popped < self.heap.num_slots() {
265                self.heap.refill(slot);
266            } else {
267                self.cur_slot = Some(slot);
268                let key = self.cur_slot.as_ref().unwrap().input();
269                return Some((key, &self.outs));
270            }
271        }
272    }
273}
274
275/// A stream of set difference over multiple fst streams in lexicographic
276/// order.
277///
278/// The difference operation is taken with respect to the first stream and the
279/// rest of the streams. i.e., All elements in the first stream that do not
280/// appear in any other streams.
281///
282/// The `'f` lifetime parameter refers to the lifetime of the underlying fst.
283pub struct Difference<'f> {
284    set: BoxedStream<'f>,
285    key: Vec<u8>,
286    heap: StreamHeap<'f>,
287    outs: Vec<IndexedValue>,
288}
289
290impl<'a, 'f> Streamer<'a> for Difference<'f> {
291    type Item = (&'a [u8], &'a [IndexedValue]);
292
293    fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
294        loop {
295            match self.set.next() {
296                None => return None,
297                Some((key, out)) => {
298                    self.key.clear();
299                    self.key.extend(key);
300                    self.outs.clear();
301                    self.outs
302                        .push(IndexedValue { index: 0, value: out.value() });
303                }
304            };
305            let mut unique = true;
306            while let Some(slot) = self.heap.pop_if_le(&self.key) {
307                if slot.input() == &*self.key {
308                    unique = false;
309                }
310                self.heap.refill(slot);
311            }
312            if unique {
313                return Some((&self.key, &self.outs));
314            }
315        }
316    }
317}
318
319/// A stream of set symmetric difference over multiple fst streams in
320/// lexicographic order.
321///
322/// The `'f` lifetime parameter refers to the lifetime of the underlying fst.
323pub struct SymmetricDifference<'f> {
324    heap: StreamHeap<'f>,
325    outs: Vec<IndexedValue>,
326    cur_slot: Option<Slot>,
327}
328
329impl<'a, 'f> Streamer<'a> for SymmetricDifference<'f> {
330    type Item = (&'a [u8], &'a [IndexedValue]);
331
332    fn next(&'a mut self) -> Option<(&'a [u8], &'a [IndexedValue])> {
333        if let Some(slot) = self.cur_slot.take() {
334            self.heap.refill(slot);
335        }
336        loop {
337            let slot = match self.heap.pop() {
338                None => return None,
339                Some(slot) => slot,
340            };
341            self.outs.clear();
342            self.outs.push(slot.indexed_value());
343            let mut popped: usize = 1;
344            while let Some(slot2) = self.heap.pop_if_equal(slot.input()) {
345                self.outs.push(slot2.indexed_value());
346                self.heap.refill(slot2);
347                popped += 1;
348            }
349            // This key is in the symmetric difference if and only if it
350            // appears in an odd number of sets.
351            if popped % 2 == 0 {
352                self.heap.refill(slot);
353            } else {
354                self.cur_slot = Some(slot);
355                let key = self.cur_slot.as_ref().unwrap().input();
356                return Some((key, &self.outs));
357            }
358        }
359    }
360}
361
362struct StreamHeap<'f> {
363    rdrs: Vec<BoxedStream<'f>>,
364    heap: BinaryHeap<Slot>,
365}
366
367impl<'f> StreamHeap<'f> {
368    fn new(streams: Vec<BoxedStream<'f>>) -> StreamHeap<'f> {
369        let mut u = StreamHeap { rdrs: streams, heap: BinaryHeap::new() };
370        for i in 0..u.rdrs.len() {
371            u.refill(Slot::new(i));
372        }
373        u
374    }
375
376    fn pop(&mut self) -> Option<Slot> {
377        self.heap.pop()
378    }
379
380    fn peek_is_duplicate(&self, key: &[u8]) -> bool {
381        self.heap.peek().map(|s| s.input() == key).unwrap_or(false)
382    }
383
384    fn pop_if_equal(&mut self, key: &[u8]) -> Option<Slot> {
385        if self.peek_is_duplicate(key) {
386            self.pop()
387        } else {
388            None
389        }
390    }
391
392    fn pop_if_le(&mut self, key: &[u8]) -> Option<Slot> {
393        if self.heap.peek().map(|s| s.input() <= key).unwrap_or(false) {
394            self.pop()
395        } else {
396            None
397        }
398    }
399
400    fn num_slots(&self) -> usize {
401        self.rdrs.len()
402    }
403
404    fn refill(&mut self, mut slot: Slot) {
405        if let Some((input, output)) = self.rdrs[slot.idx].next() {
406            slot.set_input(input);
407            slot.set_output(output);
408            self.heap.push(slot);
409        }
410    }
411}
412
413#[derive(Debug, Eq, PartialEq)]
414struct Slot {
415    idx: usize,
416    input: Vec<u8>,
417    output: Output,
418}
419
420impl Slot {
421    fn new(rdr_idx: usize) -> Slot {
422        Slot {
423            idx: rdr_idx,
424            input: Vec::with_capacity(64),
425            output: Output::zero(),
426        }
427    }
428
429    fn indexed_value(&self) -> IndexedValue {
430        IndexedValue { index: self.idx, value: self.output.value() }
431    }
432
433    fn input(&self) -> &[u8] {
434        &self.input
435    }
436
437    fn set_input(&mut self, input: &[u8]) {
438        self.input.clear();
439        self.input.extend(input);
440    }
441
442    fn set_output(&mut self, output: Output) {
443        self.output = output;
444    }
445}
446
447impl PartialOrd for Slot {
448    fn partial_cmp(&self, other: &Slot) -> Option<cmp::Ordering> {
449        (&self.input, self.output)
450            .partial_cmp(&(&other.input, other.output))
451            .map(|ord| ord.reverse())
452    }
453}
454
455impl Ord for Slot {
456    fn cmp(&self, other: &Slot) -> cmp::Ordering {
457        self.partial_cmp(other).unwrap()
458    }
459}
460
461#[cfg(test)]
462mod tests {
463    use crate::raw::tests::{fst_map, fst_set};
464    use crate::raw::Fst;
465    use crate::stream::{IntoStreamer, Streamer};
466
467    use super::OpBuilder;
468
469    fn s(string: &str) -> String {
470        string.to_owned()
471    }
472
473    macro_rules! create_set_op {
474        ($name:ident, $op:ident) => {
475            fn $name(sets: Vec<Vec<&str>>) -> Vec<String> {
476                let fsts: Vec<Fst<_>> =
477                    sets.into_iter().map(fst_set).collect();
478                let op: OpBuilder = fsts.iter().collect();
479                let mut stream = op.$op().into_stream();
480                let mut keys = vec![];
481                while let Some((key, _)) = stream.next() {
482                    keys.push(String::from_utf8(key.to_vec()).unwrap());
483                }
484                keys
485            }
486        };
487    }
488
489    macro_rules! create_map_op {
490        ($name:ident, $op:ident) => {
491            fn $name(sets: Vec<Vec<(&str, u64)>>) -> Vec<(String, u64)> {
492                let fsts: Vec<Fst<_>> =
493                    sets.into_iter().map(fst_map).collect();
494                let op: OpBuilder = fsts.iter().collect();
495                let mut stream = op.$op().into_stream();
496                let mut keys = vec![];
497                while let Some((key, outs)) = stream.next() {
498                    let merged = outs.iter().fold(0, |a, b| a + b.value);
499                    let s = String::from_utf8(key.to_vec()).unwrap();
500                    keys.push((s, merged));
501                }
502                keys
503            }
504        };
505    }
506
507    create_set_op!(fst_union, union);
508    create_set_op!(fst_intersection, intersection);
509    create_set_op!(fst_symmetric_difference, symmetric_difference);
510    create_set_op!(fst_difference, difference);
511    create_map_op!(fst_union_map, union);
512    create_map_op!(fst_intersection_map, intersection);
513    create_map_op!(fst_symmetric_difference_map, symmetric_difference);
514    create_map_op!(fst_difference_map, difference);
515
516    #[test]
517    fn union_set() {
518        let v = fst_union(vec![vec!["a", "b", "c"], vec!["x", "y", "z"]]);
519        assert_eq!(v, vec!["a", "b", "c", "x", "y", "z"]);
520    }
521
522    #[test]
523    fn union_set_dupes() {
524        let v = fst_union(vec![vec!["aa", "b", "cc"], vec!["b", "cc", "z"]]);
525        assert_eq!(v, vec!["aa", "b", "cc", "z"]);
526    }
527
528    #[test]
529    fn union_map() {
530        let v = fst_union_map(vec![
531            vec![("a", 1), ("b", 2), ("c", 3)],
532            vec![("x", 1), ("y", 2), ("z", 3)],
533        ]);
534        assert_eq!(
535            v,
536            vec![
537                (s("a"), 1),
538                (s("b"), 2),
539                (s("c"), 3),
540                (s("x"), 1),
541                (s("y"), 2),
542                (s("z"), 3),
543            ]
544        );
545    }
546
547    #[test]
548    fn union_map_dupes() {
549        let v = fst_union_map(vec![
550            vec![("aa", 1), ("b", 2), ("cc", 3)],
551            vec![("b", 1), ("cc", 2), ("z", 3)],
552            vec![("b", 1)],
553        ]);
554        assert_eq!(
555            v,
556            vec![(s("aa"), 1), (s("b"), 4), (s("cc"), 5), (s("z"), 3),]
557        );
558    }
559
560    #[test]
561    fn intersect_set() {
562        let v =
563            fst_intersection(vec![vec!["a", "b", "c"], vec!["x", "y", "z"]]);
564        assert_eq!(v, Vec::<String>::new());
565    }
566
567    #[test]
568    fn intersect_set_dupes() {
569        let v = fst_intersection(vec![
570            vec!["aa", "b", "cc"],
571            vec!["b", "cc", "z"],
572        ]);
573        assert_eq!(v, vec!["b", "cc"]);
574    }
575
576    #[test]
577    fn intersect_map() {
578        let v = fst_intersection_map(vec![
579            vec![("a", 1), ("b", 2), ("c", 3)],
580            vec![("x", 1), ("y", 2), ("z", 3)],
581        ]);
582        assert_eq!(v, Vec::<(String, u64)>::new());
583    }
584
585    #[test]
586    fn intersect_map_dupes() {
587        let v = fst_intersection_map(vec![
588            vec![("aa", 1), ("b", 2), ("cc", 3)],
589            vec![("b", 1), ("cc", 2), ("z", 3)],
590            vec![("b", 1)],
591        ]);
592        assert_eq!(v, vec![(s("b"), 4)]);
593    }
594
595    #[test]
596    fn symmetric_difference() {
597        let v = fst_symmetric_difference(vec![
598            vec!["a", "b", "c"],
599            vec!["a", "b"],
600            vec!["a"],
601        ]);
602        assert_eq!(v, vec!["a", "c"]);
603    }
604
605    #[test]
606    fn symmetric_difference_map() {
607        let v = fst_symmetric_difference_map(vec![
608            vec![("a", 1), ("b", 2), ("c", 3)],
609            vec![("a", 1), ("b", 2)],
610            vec![("a", 1)],
611        ]);
612        assert_eq!(v, vec![(s("a"), 3), (s("c"), 3)]);
613    }
614
615    #[test]
616    fn difference() {
617        let v = fst_difference(vec![
618            vec!["a", "b", "c"],
619            vec!["a", "b"],
620            vec!["a"],
621        ]);
622        assert_eq!(v, vec!["c"]);
623    }
624
625    #[test]
626    fn difference2() {
627        // Regression test: https://github.com/BurntSushi/fst/issues/19
628        let v = fst_difference(vec![vec!["a", "c"], vec!["b", "c"]]);
629        assert_eq!(v, vec!["a"]);
630        let v = fst_difference(vec![vec!["bar", "foo"], vec!["baz", "foo"]]);
631        assert_eq!(v, vec!["bar"]);
632    }
633
634    #[test]
635    fn difference_map() {
636        let v = fst_difference_map(vec![
637            vec![("a", 1), ("b", 2), ("c", 3)],
638            vec![("a", 1), ("b", 2)],
639            vec![("a", 1)],
640        ]);
641        assert_eq!(v, vec![(s("c"), 3)]);
642    }
643}