evobench_tools/
join.rs

1//! Joins (intersections) of sorted sequences
2
3use itertools::{EitherOrBoth, Itertools};
4
5#[derive(Debug, PartialEq)]
6pub struct KeyVal<K, V> {
7    pub key: K,
8    pub val: V,
9}
10
11pub fn keyval_inner_join_2<K: Ord, V1, V2>(
12    a: impl IntoIterator<Item = KeyVal<K, V1>>,
13    b: impl IntoIterator<Item = KeyVal<K, V2>>,
14) -> impl Iterator<Item = KeyVal<K, (V1, V2)>> {
15    a.into_iter()
16        .merge_join_by(b.into_iter(), |a, b| a.key.cmp(&b.key))
17        .filter_map(|eob| match eob {
18            EitherOrBoth::Both(a, b) => Some(KeyVal {
19                key: a.key,
20                val: (a.val, b.val),
21            }),
22            EitherOrBoth::Left(_) => None,
23            EitherOrBoth::Right(_) => None,
24        })
25}
26
27/// Join any number of sequences of `KeyVal`, ordered by
28/// `KeyVal.key`. Only keys present in all sequences are
29/// preserved. I.e. the resulting sequence has `Vec`s of the same
30/// number of elements as `sequences`. Returns None if `sequences` is
31/// empty.
32pub fn keyval_inner_join<'s, 'i, 'k, 'v, K: Ord + 'k, V: 'v>(
33    sequences: &'s mut [Option<impl IntoIterator<Item = KeyVal<K, V>> + 'i>],
34) -> Option<Box<dyn Iterator<Item = KeyVal<K, Vec<V>>> + 'i>>
35where
36    's: 'i,
37    'k: 'i,
38    'v: 'i,
39{
40    match sequences.len() {
41        0 => None,
42        1 => Some(Box::new(
43            sequences[0]
44                .take()
45                .expect("checked")
46                .into_iter()
47                .map(|KeyVal { key, val }| KeyVal {
48                    key,
49                    val: vec![val],
50                }),
51        )),
52        2 => Some(Box::new(
53            keyval_inner_join_2(
54                sequences[0].take().expect("checked"),
55                sequences[1].take().expect("checked"),
56            )
57            .map(|KeyVal { key, val: (v1, v2) }| KeyVal {
58                key,
59                val: vec![v1, v2],
60            }),
61        )),
62        n => {
63            let (a, b) = sequences.split_at_mut(n / 2);
64            let ar = keyval_inner_join(a).expect("at least 1 out of 3+");
65            let br = keyval_inner_join(b).expect("at least 1 out of 3+");
66            Some(Box::new(keyval_inner_join_2(ar, br).map(
67                |KeyVal {
68                     key,
69                     val: (mut val1, mut val2),
70                 }| {
71                    val1.append(&mut val2);
72                    KeyVal { key, val: val1 }
73                },
74            )))
75        }
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82
83    fn k<K, V>(k: K, v: V) -> KeyVal<K, V> {
84        KeyVal { key: k, val: v }
85    }
86
87    fn seqs() -> (
88        Vec<KeyVal<&'static str, i32>>,
89        Vec<KeyVal<&'static str, i32>>,
90        Vec<KeyVal<&'static str, i32>>,
91    ) {
92        (
93            vec![k("a", 1), k("a2", 2), k("b", 3), k("t", 4), k("u", 5)],
94            vec![
95                k("a03", 10),
96                k("a2", 20),
97                k("b2", 30),
98                k("t", 40),
99                k("u", 50),
100            ],
101            vec![
102                k("a01", 100),
103                k("a2", 200),
104                k("b3", 300),
105                k("t", 400),
106                k("u", 500),
107                k("v", 600),
108            ],
109        )
110    }
111
112    #[test]
113    fn t_2() {
114        let (a, b, _c) = seqs();
115
116        // Oh my, sharing is involved:
117        let res = keyval_inner_join_2(
118            a.iter().map(|KeyVal { key, val }| KeyVal { key, val }),
119            b.iter().map(|KeyVal { key, val }| KeyVal { key, val }),
120        )
121        .collect::<Vec<_>>();
122        assert_eq!(
123            res,
124            vec![k("a2", (2, 20)), k("t", (4, 40)), k("u", (5, 50))]
125                .iter()
126                .map(
127                    |KeyVal {
128                         key,
129                         val: (val1, val2),
130                     }| KeyVal {
131                        key,
132                        val: (val1, val2)
133                    }
134                )
135                .collect::<Vec<_>>()
136        );
137
138        // Owned is easy:
139        assert_eq!(
140            keyval_inner_join_2(a, b).collect::<Vec<_>>(),
141            vec![k("a2", (2, 20)), k("t", (4, 40)), k("u", (5, 50))]
142        );
143    }
144
145    #[test]
146    fn t_3() {
147        let (a, b, c) = seqs();
148        let r = keyval_inner_join(&mut [Some(a), Some(b), Some(c)])
149            .expect("given inputs")
150            .collect::<Vec<_>>();
151        assert_eq!(
152            r,
153            vec![
154                k("a2", vec![2, 20, 200]),
155                k("t", vec![4, 40, 400]),
156                k("u", vec![5, 50, 500])
157            ]
158        );
159        let (a, b, c) = seqs();
160        let d = vec![k("a", 1), k("a2", 2), k("b", 3), k("u", 5)];
161        let mut v = [a, b, c, d].map(Some);
162        let r = keyval_inner_join(&mut v)
163            .expect("given inputs")
164            .collect::<Vec<_>>();
165        assert_eq!(
166            r,
167            vec![k("a2", vec![2, 20, 200, 2]), k("u", vec![5, 50, 500, 5])]
168        );
169    }
170}