imageproc/
union_find.rs

1//! An implementation of disjoint set forests for union find.
2
3/// Data structure for efficient union find.
4pub struct DisjointSetForest {
5    /// Number of forest elements.
6    count: usize,
7    /// parent[i] is the index of the parent
8    /// of the element with index i. If parent[i] == i
9    /// then i is a root.
10    parent: Vec<usize>,
11    /// tree_size[i] is the size of the tree rooted at i.
12    tree_size: Vec<usize>,
13}
14
15impl DisjointSetForest {
16    /// Constructs forest of singletons with count elements.
17    pub fn new(count: usize) -> DisjointSetForest {
18        let parent: Vec<usize> = (0..count).collect();
19        let tree_size = vec![1_usize; count];
20        DisjointSetForest {
21            count,
22            parent,
23            tree_size,
24        }
25    }
26
27    /// Returns the number of trees in the forest.
28    pub fn num_trees(&self) -> usize {
29        self.parent
30            .iter()
31            .enumerate()
32            .fold(0, |acc, (i, p)| acc + if i == *p { 1 } else { 0 })
33    }
34
35    /// Returns index of the root of the tree containing i.
36    /// Needs mutable reference to self for path compression.
37    pub fn root(&mut self, i: usize) -> usize {
38        assert!(i < self.count);
39        let mut j = i;
40        loop {
41            unsafe {
42                let p = *self.parent.get_unchecked(j);
43                *self.parent.get_unchecked_mut(j) = *self.parent.get_unchecked(p);
44                if j == p {
45                    break;
46                }
47                j = p;
48            }
49        }
50        j
51    }
52
53    /// Returns true if i and j are in the same tree.
54    /// Need mutable reference to self for path compression.
55    pub fn find(&mut self, i: usize, j: usize) -> bool {
56        assert!(i < self.count && j < self.count);
57        self.root(i) == self.root(j)
58    }
59
60    /// Unions the trees containing i and j.
61    pub fn union(&mut self, i: usize, j: usize) {
62        assert!(i < self.count && j < self.count);
63        let p = self.root(i);
64        let q = self.root(j);
65        if p == q {
66            return;
67        }
68        unsafe {
69            let p_size = *self.tree_size.get_unchecked(p);
70            let q_size = *self.tree_size.get_unchecked(q);
71            if p_size < q_size {
72                *self.parent.get_unchecked_mut(p) = q;
73                *self.tree_size.get_unchecked_mut(q) = p_size + q_size;
74            } else {
75                *self.parent.get_unchecked_mut(q) = p;
76                *self.tree_size.get_unchecked_mut(p) = p_size + q_size;
77            }
78        }
79    }
80
81    /// Returns the elements of each tree.
82    pub fn trees(&mut self) -> Vec<Vec<usize>> {
83        use std::collections::HashMap;
84
85        // Maps a tree root to the index of the set
86        // containing its children
87        let mut root_sets: HashMap<usize, usize> = HashMap::new();
88
89        let mut sets: Vec<Vec<usize>> = vec![];
90        for i in 0..self.count {
91            let root = self.root(i);
92            match root_sets.get(&root).cloned() {
93                Some(set_idx) => {
94                    sets[set_idx].push(i);
95                }
96                None => {
97                    let idx = sets.len();
98                    let set = vec![i];
99                    sets.push(set);
100                    root_sets.insert(root, idx);
101                }
102            }
103        }
104        sets
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::DisjointSetForest;
111    use ::test;
112    use rand::{rngs::StdRng, SeedableRng};
113    use rand_distr::{Distribution, Uniform};
114
115    #[test]
116    fn test_trees() {
117        //    3         4
118        //    |        /  \
119        //    1       5    7
120        //   /  \     |
121        //  0    2    6
122        #[rustfmt::skip]
123        let mut forest = DisjointSetForest {
124            count: 8,
125            // element:     0, 1, 2, 3, 4, 5, 6, 7
126            parent:    vec![1, 3, 1, 3, 4, 4, 5, 4],
127            tree_size: vec![1, 3, 1, 4, 4, 2, 1, 1],
128        };
129
130        assert_eq!(forest.trees(), vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7]]);
131    }
132
133    #[test]
134    fn test_union_find_sequence() {
135        let mut forest = DisjointSetForest::new(6);
136        // 0  1  2  3  4  5
137
138        //                             0, 1, 2, 3, 4, 5
139        assert_eq!(forest.parent, vec![0, 1, 2, 3, 4, 5]);
140        assert_eq!(forest.num_trees(), 6);
141
142        forest.union(0, 4);
143        // 0  1  2  3  5
144        // |
145        // 4
146
147        //                             0, 1, 2, 3, 4, 5
148        assert_eq!(forest.parent, vec![0, 1, 2, 3, 0, 5]);
149        assert_eq!(forest.num_trees(), 5);
150
151        forest.union(1, 3);
152        // 0  1  2  5
153        // |  |
154        // 4  3
155
156        //                             0, 1, 2, 3, 4, 5
157        assert_eq!(forest.parent, vec![0, 1, 2, 1, 0, 5]);
158        assert_eq!(forest.num_trees(), 4);
159
160        forest.union(3, 2);
161        // 0    1     5
162        // |   / \
163        // 4  3   2
164
165        //                             0, 1, 2, 3, 4, 5
166        assert_eq!(forest.parent, vec![0, 1, 1, 1, 0, 5]);
167        assert_eq!(forest.num_trees(), 3);
168
169        forest.union(2, 4);
170        //    1     5
171        //  / | \
172        // 0  3  2
173        // |
174        // 4
175
176        //                             0, 1, 2, 3, 4, 5
177        assert_eq!(forest.parent, vec![1, 1, 1, 1, 0, 5]);
178        assert_eq!(forest.num_trees(), 2);
179    }
180
181    #[bench]
182    fn bench_disjoint_set_forest(b: &mut test::Bencher) {
183        let num_nodes = 500;
184        let num_edges = 20 * num_nodes;
185
186        let mut rng: StdRng = SeedableRng::seed_from_u64(1);
187        let uniform = Uniform::new(0, num_nodes);
188
189        let mut forest = DisjointSetForest::new(num_nodes);
190        b.iter(|| {
191            let mut count = 0;
192            while count < num_edges {
193                let u = uniform.sample(&mut rng);
194                let v = uniform.sample(&mut rng);
195                forest.union(u, v);
196                count += 1;
197            }
198            test::black_box(forest.num_trees());
199        });
200    }
201}