1pub struct DisjointSetForest {
5 count: usize,
7 parent: Vec<usize>,
11 tree_size: Vec<usize>,
13}
14
15impl DisjointSetForest {
16 pub fn new(count: usize) -> DisjointSetForest {
18 let parent: Vec<usize> = (0..count).collect();
19 let tree_size = vec![1_usize; count];
20 DisjointSetForest {
21 count,
22 parent,
23 tree_size,
24 }
25 }
26
27 pub fn num_trees(&self) -> usize {
29 self.parent
30 .iter()
31 .enumerate()
32 .fold(0, |acc, (i, p)| acc + if i == *p { 1 } else { 0 })
33 }
34
35 pub fn root(&mut self, i: usize) -> usize {
38 assert!(i < self.count);
39 let mut j = i;
40 loop {
41 unsafe {
42 let p = *self.parent.get_unchecked(j);
43 *self.parent.get_unchecked_mut(j) = *self.parent.get_unchecked(p);
44 if j == p {
45 break;
46 }
47 j = p;
48 }
49 }
50 j
51 }
52
53 pub fn find(&mut self, i: usize, j: usize) -> bool {
56 assert!(i < self.count && j < self.count);
57 self.root(i) == self.root(j)
58 }
59
60 pub fn union(&mut self, i: usize, j: usize) {
62 assert!(i < self.count && j < self.count);
63 let p = self.root(i);
64 let q = self.root(j);
65 if p == q {
66 return;
67 }
68 unsafe {
69 let p_size = *self.tree_size.get_unchecked(p);
70 let q_size = *self.tree_size.get_unchecked(q);
71 if p_size < q_size {
72 *self.parent.get_unchecked_mut(p) = q;
73 *self.tree_size.get_unchecked_mut(q) = p_size + q_size;
74 } else {
75 *self.parent.get_unchecked_mut(q) = p;
76 *self.tree_size.get_unchecked_mut(p) = p_size + q_size;
77 }
78 }
79 }
80
81 pub fn trees(&mut self) -> Vec<Vec<usize>> {
83 use std::collections::HashMap;
84
85 let mut root_sets: HashMap<usize, usize> = HashMap::new();
88
89 let mut sets: Vec<Vec<usize>> = vec![];
90 for i in 0..self.count {
91 let root = self.root(i);
92 match root_sets.get(&root).cloned() {
93 Some(set_idx) => {
94 sets[set_idx].push(i);
95 }
96 None => {
97 let idx = sets.len();
98 let set = vec![i];
99 sets.push(set);
100 root_sets.insert(root, idx);
101 }
102 }
103 }
104 sets
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::DisjointSetForest;
111 use ::test;
112 use rand::{rngs::StdRng, SeedableRng};
113 use rand_distr::{Distribution, Uniform};
114
115 #[test]
116 fn test_trees() {
117 #[rustfmt::skip]
123 let mut forest = DisjointSetForest {
124 count: 8,
125 parent: vec![1, 3, 1, 3, 4, 4, 5, 4],
127 tree_size: vec![1, 3, 1, 4, 4, 2, 1, 1],
128 };
129
130 assert_eq!(forest.trees(), vec![vec![0, 1, 2, 3], vec![4, 5, 6, 7]]);
131 }
132
133 #[test]
134 fn test_union_find_sequence() {
135 let mut forest = DisjointSetForest::new(6);
136 assert_eq!(forest.parent, vec![0, 1, 2, 3, 4, 5]);
140 assert_eq!(forest.num_trees(), 6);
141
142 forest.union(0, 4);
143 assert_eq!(forest.parent, vec![0, 1, 2, 3, 0, 5]);
149 assert_eq!(forest.num_trees(), 5);
150
151 forest.union(1, 3);
152 assert_eq!(forest.parent, vec![0, 1, 2, 1, 0, 5]);
158 assert_eq!(forest.num_trees(), 4);
159
160 forest.union(3, 2);
161 assert_eq!(forest.parent, vec![0, 1, 1, 1, 0, 5]);
167 assert_eq!(forest.num_trees(), 3);
168
169 forest.union(2, 4);
170 assert_eq!(forest.parent, vec![1, 1, 1, 1, 0, 5]);
178 assert_eq!(forest.num_trees(), 2);
179 }
180
181 #[bench]
182 fn bench_disjoint_set_forest(b: &mut test::Bencher) {
183 let num_nodes = 500;
184 let num_edges = 20 * num_nodes;
185
186 let mut rng: StdRng = SeedableRng::seed_from_u64(1);
187 let uniform = Uniform::new(0, num_nodes);
188
189 let mut forest = DisjointSetForest::new(num_nodes);
190 b.iter(|| {
191 let mut count = 0;
192 while count < num_edges {
193 let u = uniform.sample(&mut rng);
194 let v = uniform.sample(&mut rng);
195 forest.union(u, v);
196 count += 1;
197 }
198 test::black_box(forest.num_trees());
199 });
200 }
201}