1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    fmt,
5    hash::{BuildHasherDefault, Hasher},
6};
7
8#[derive(Debug, Default)]
13struct NoOpHasher(u64);
14
15impl Hasher for NoOpHasher {
16    fn write(&mut self, _bytes: &[u8]) {
17        unimplemented!("This NoOpHasher can only handle u64s")
18    }
19
20    fn write_u64(&mut self, i: u64) {
21        self.0 = i;
22    }
23
24    fn finish(&self) -> u64 {
25        self.0
26    }
27}
28
29#[derive(Default)]
33pub struct Extensions {
34    map: HashMap<TypeId, Box<dyn Any>, BuildHasherDefault<NoOpHasher>>,
36}
37
38impl Extensions {
39    #[inline]
41    pub fn new() -> Extensions {
42        Extensions {
43            map: HashMap::default(),
44        }
45    }
46
47    pub fn insert<T: 'static>(&mut self, val: T) -> Option<T> {
60        self.map
61            .insert(TypeId::of::<T>(), Box::new(val))
62            .and_then(downcast_owned)
63    }
64
65    pub fn contains<T: 'static>(&self) -> bool {
76        self.map.contains_key(&TypeId::of::<T>())
77    }
78
79    pub fn get<T: 'static>(&self) -> Option<&T> {
88        self.map
89            .get(&TypeId::of::<T>())
90            .and_then(|boxed| boxed.downcast_ref())
91    }
92
93    pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
102        self.map
103            .get_mut(&TypeId::of::<T>())
104            .and_then(|boxed| boxed.downcast_mut())
105    }
106
107    pub fn get_or_insert<T: 'static>(&mut self, value: T) -> &mut T {
122        self.get_or_insert_with(|| value)
123    }
124
125    pub fn get_or_insert_with<T: 'static, F: FnOnce() -> T>(&mut self, default: F) -> &mut T {
140        self.map
141            .entry(TypeId::of::<T>())
142            .or_insert_with(|| Box::new(default()))
143            .downcast_mut()
144            .expect("extensions map should now contain a T value")
145    }
146
147    pub fn remove<T: 'static>(&mut self) -> Option<T> {
162        self.map.remove(&TypeId::of::<T>()).and_then(downcast_owned)
163    }
164
165    #[inline]
178    pub fn clear(&mut self) {
179        self.map.clear();
180    }
181
182    pub fn extend(&mut self, other: Extensions) {
184        self.map.extend(other.map);
185    }
186}
187
188impl fmt::Debug for Extensions {
189    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
190        f.debug_struct("Extensions").finish()
191    }
192}
193
194fn downcast_owned<T: 'static>(boxed: Box<dyn Any>) -> Option<T> {
195    boxed.downcast().ok().map(|boxed| *boxed)
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_remove() {
204        let mut map = Extensions::new();
205
206        map.insert::<i8>(123);
207        assert!(map.get::<i8>().is_some());
208
209        map.remove::<i8>();
210        assert!(map.get::<i8>().is_none());
211    }
212
213    #[test]
214    fn test_clear() {
215        let mut map = Extensions::new();
216
217        map.insert::<i8>(8);
218        map.insert::<i16>(16);
219        map.insert::<i32>(32);
220
221        assert!(map.contains::<i8>());
222        assert!(map.contains::<i16>());
223        assert!(map.contains::<i32>());
224
225        map.clear();
226
227        assert!(!map.contains::<i8>());
228        assert!(!map.contains::<i16>());
229        assert!(!map.contains::<i32>());
230
231        map.insert::<i8>(10);
232        assert_eq!(*map.get::<i8>().unwrap(), 10);
233    }
234
235    #[test]
236    fn test_integers() {
237        static A: u32 = 8;
238
239        let mut map = Extensions::new();
240
241        map.insert::<i8>(8);
242        map.insert::<i16>(16);
243        map.insert::<i32>(32);
244        map.insert::<i64>(64);
245        map.insert::<i128>(128);
246        map.insert::<u8>(8);
247        map.insert::<u16>(16);
248        map.insert::<u32>(32);
249        map.insert::<u64>(64);
250        map.insert::<u128>(128);
251        map.insert::<&'static u32>(&A);
252        assert!(map.get::<i8>().is_some());
253        assert!(map.get::<i16>().is_some());
254        assert!(map.get::<i32>().is_some());
255        assert!(map.get::<i64>().is_some());
256        assert!(map.get::<i128>().is_some());
257        assert!(map.get::<u8>().is_some());
258        assert!(map.get::<u16>().is_some());
259        assert!(map.get::<u32>().is_some());
260        assert!(map.get::<u64>().is_some());
261        assert!(map.get::<u128>().is_some());
262        assert!(map.get::<&'static u32>().is_some());
263    }
264
265    #[test]
266    fn test_composition() {
267        struct Magi<T>(pub T);
268
269        struct Madoka {
270            pub god: bool,
271        }
272
273        struct Homura {
274            pub attempts: usize,
275        }
276
277        struct Mami {
278            pub guns: usize,
279        }
280
281        let mut map = Extensions::new();
282
283        map.insert(Magi(Madoka { god: false }));
284        map.insert(Magi(Homura { attempts: 0 }));
285        map.insert(Magi(Mami { guns: 999 }));
286
287        assert!(!map.get::<Magi<Madoka>>().unwrap().0.god);
288        assert_eq!(0, map.get::<Magi<Homura>>().unwrap().0.attempts);
289        assert_eq!(999, map.get::<Magi<Mami>>().unwrap().0.guns);
290    }
291
292    #[test]
293    fn test_extensions() {
294        #[derive(Debug, PartialEq)]
295        struct MyType(i32);
296
297        let mut extensions = Extensions::new();
298
299        extensions.insert(5i32);
300        extensions.insert(MyType(10));
301
302        assert_eq!(extensions.get(), Some(&5i32));
303        assert_eq!(extensions.get_mut(), Some(&mut 5i32));
304
305        assert_eq!(extensions.remove::<i32>(), Some(5i32));
306        assert!(extensions.get::<i32>().is_none());
307
308        assert_eq!(extensions.get::<bool>(), None);
309        assert_eq!(extensions.get(), Some(&MyType(10)));
310    }
311
312    #[test]
313    fn test_extend() {
314        #[derive(Debug, PartialEq)]
315        struct MyType(i32);
316
317        let mut extensions = Extensions::new();
318
319        extensions.insert(5i32);
320        extensions.insert(MyType(10));
321
322        let mut other = Extensions::new();
323
324        other.insert(15i32);
325        other.insert(20u8);
326
327        extensions.extend(other);
328
329        assert_eq!(extensions.get(), Some(&15i32));
330        assert_eq!(extensions.get_mut(), Some(&mut 15i32));
331
332        assert_eq!(extensions.remove::<i32>(), Some(15i32));
333        assert!(extensions.get::<i32>().is_none());
334
335        assert_eq!(extensions.get::<bool>(), None);
336        assert_eq!(extensions.get(), Some(&MyType(10)));
337
338        assert_eq!(extensions.get(), Some(&20u8));
339        assert_eq!(extensions.get_mut(), Some(&mut 20u8));
340    }
341}