1use crate::definitions::Image;
2use image::{GenericImageView, Pixel};
3use std::cmp::{max, min};
4
5#[must_use = "the function does not modify the original image"]
109pub fn median_filter<P>(image: &Image<P>, x_radius: u32, y_radius: u32) -> Image<P>
110where
111 P: Pixel<Subpixel = u8>,
112{
113 let (width, height) = image.dimensions();
114
115 if width == 0 || height == 0 {
116 return image.clone();
117 }
118
119 let mut out = Image::<P>::new(width, height);
120 let rx = x_radius as i32;
121 let ry = y_radius as i32;
122
123 let mut hist = initialise_histogram_for_top_left_pixel(image, x_radius, y_radius);
124 slide_down_column(&mut hist, image, &mut out, 0, rx, ry);
125
126 for x in 1..width {
127 if x % 2 == 0 {
128 slide_right(&mut hist, image, x, 0, rx, ry);
129 slide_down_column(&mut hist, image, &mut out, x, rx, ry);
130 } else {
131 slide_right(&mut hist, image, x, height - 1, rx, ry);
132 slide_up_column(&mut hist, image, &mut out, x, rx, ry);
133 }
134 }
135 out
136}
137
138fn initialise_histogram_for_top_left_pixel<P>(
139 image: &Image<P>,
140 x_radius: u32,
141 y_radius: u32,
142) -> HistSet
143where
144 P: Pixel<Subpixel = u8>,
145{
146 let (width, height) = image.dimensions();
147 let kernel_size = (2 * x_radius + 1) * (2 * y_radius + 1);
148 let num_channels = P::CHANNEL_COUNT;
149
150 let mut hist = HistSet::new(num_channels, kernel_size);
151 let rx = x_radius as i32;
152 let ry = y_radius as i32;
153
154 for dy in -ry..(ry + 1) {
155 let py = min(max(0, dy), height as i32 - 1) as u32;
156
157 for dx in -rx..(rx + 1) {
158 let px = min(max(0, dx), width as i32 - 1) as u32;
159
160 hist.incr(image, px, py);
161 }
162 }
163
164 hist
165}
166
167fn slide_right<P>(hist: &mut HistSet, image: &Image<P>, x: u32, y: u32, rx: i32, ry: i32)
168where
169 P: Pixel<Subpixel = u8>,
170{
171 let (width, height) = image.dimensions();
172
173 let prev_x = max(0, x as i32 - rx - 1) as u32;
174 let next_x = min(x as i32 + rx, width as i32 - 1) as u32;
175
176 for dy in -ry..(ry + 1) {
177 let py = min(max(0, y as i32 + dy), (height - 1) as i32) as u32;
178
179 hist.decr(image, prev_x, py);
180 hist.incr(image, next_x, py);
181 }
182}
183
184fn slide_down_column<P>(
185 hist: &mut HistSet,
186 image: &Image<P>,
187 out: &mut Image<P>,
188 x: u32,
189 rx: i32,
190 ry: i32,
191) where
192 P: Pixel<Subpixel = u8>,
193{
194 let (width, height) = image.dimensions();
195 hist.set_to_median(out, x, 0);
196
197 for y in 1..height {
198 let prev_y = max(0, y as i32 - ry - 1) as u32;
199 let next_y = min(y as i32 + ry, height as i32 - 1) as u32;
200
201 for dx in -rx..(rx + 1) {
202 let px = min(max(0, x as i32 + dx), (width - 1) as i32) as u32;
203
204 hist.decr(image, px, prev_y);
205 hist.incr(image, px, next_y);
206 }
207
208 hist.set_to_median(out, x, y);
209 }
210}
211
212fn slide_up_column<P>(
213 hist: &mut HistSet,
214 image: &Image<P>,
215 out: &mut Image<P>,
216 x: u32,
217 rx: i32,
218 ry: i32,
219) where
220 P: Pixel<Subpixel = u8>,
221{
222 let (width, height) = image.dimensions();
223 hist.set_to_median(out, x, height - 1);
224
225 for y in (0..(height - 1)).rev() {
226 let prev_y = min(y as i32 + ry + 1, height as i32 - 1) as u32;
227 let next_y = max(0, y as i32 - ry) as u32;
228
229 for dx in -rx..(rx + 1) {
230 let px = min(max(0, x as i32 + dx), (width - 1) as i32) as u32;
231
232 hist.decr(image, px, prev_y);
233 hist.incr(image, px, next_y);
234 }
235
236 hist.set_to_median(out, x, y);
237 }
238}
239
240struct HistSet {
243 data: Vec<[u32; 256]>,
245 expected_count: u32,
249}
250
251impl HistSet {
252 fn new(num_channels: u8, expected_count: u32) -> HistSet {
253 let mut data = vec![];
256 for _ in 0..num_channels {
257 data.push([0u32; 256]);
258 }
259
260 HistSet {
261 data,
262 expected_count,
263 }
264 }
265
266 fn incr<P>(&mut self, image: &Image<P>, x: u32, y: u32)
267 where
268 P: Pixel<Subpixel = u8>,
269 {
270 unsafe {
271 let pixel = image.unsafe_get_pixel(x, y);
272 let channels = pixel.channels();
273 for c in 0..channels.len() {
274 let p = *channels.get_unchecked(c) as usize;
275 let hist = self.data.get_unchecked_mut(c);
276 *hist.get_unchecked_mut(p) += 1;
277 }
278 }
279 }
280
281 fn decr<P>(&mut self, image: &Image<P>, x: u32, y: u32)
282 where
283 P: Pixel<Subpixel = u8>,
284 {
285 unsafe {
286 let pixel = image.unsafe_get_pixel(x, y);
287 let channels = pixel.channels();
288 for c in 0..channels.len() {
289 let p = *channels.get_unchecked(c) as usize;
290 let hist = self.data.get_unchecked_mut(c);
291 *hist.get_unchecked_mut(p) -= 1;
292 }
293 }
294 }
295
296 fn set_to_median<P>(&self, image: &mut Image<P>, x: u32, y: u32)
297 where
298 P: Pixel<Subpixel = u8>,
299 {
300 unsafe {
301 let target = image.get_pixel_mut(x, y);
302 let channels = target.channels_mut();
303 for c in 0..channels.len() {
304 *channels.get_unchecked_mut(c) = self.channel_median(c as u8);
305 }
306 }
307 }
308
309 fn channel_median(&self, c: u8) -> u8 {
310 let hist = unsafe { self.data.get_unchecked(c as usize) };
311
312 let mut count = 0;
313
314 for i in 0..256 {
315 unsafe {
316 count += *hist.get_unchecked(i);
317 }
318
319 if 2 * count >= self.expected_count {
320 return i as u8;
321 }
322 }
323
324 255
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use crate::property_testing::GrayTestImage;
332 use crate::utils::gray_bench_image;
333 use crate::utils::pixel_diff_summary;
334 use image::{GrayImage, Luma};
335 use quickcheck::{quickcheck, TestResult};
336 use std::cmp::{max, min};
337 use test::{black_box, Bencher};
338
339 macro_rules! bench_median_filter {
340 ($name:ident, side: $s:expr, x_radius: $rx:expr, y_radius: $ry:expr) => {
341 #[bench]
342 fn $name(b: &mut Bencher) {
343 let image = gray_bench_image($s, $s);
344 b.iter(|| {
345 let filtered = median_filter(&image, $rx, $ry);
346 black_box(filtered);
347 })
348 }
349 };
350 }
351
352 bench_median_filter!(bench_median_filter_s100_r1, side: 100, x_radius: 1,y_radius: 1);
353 bench_median_filter!(bench_median_filter_s100_r4, side: 100, x_radius: 4,y_radius: 4);
354 bench_median_filter!(bench_median_filter_s100_r8, side: 100, x_radius: 8,y_radius: 8);
355
356 bench_median_filter!(bench_median_filter_s100_rx1_ry4, side: 100, x_radius: 1,y_radius: 4);
358 bench_median_filter!(bench_median_filter_s100_rx1_ry8, side: 100, x_radius: 1,y_radius: 8);
359 bench_median_filter!(bench_median_filter_s100_rx4_ry8, side: 100, x_radius: 4,y_radius: 1);
360 bench_median_filter!(bench_median_filter_s100_rx8_ry1, side: 100, x_radius: 8,y_radius: 1);
361
362 fn reference_median_filter(image: &GrayImage, x_radius: u32, y_radius: u32) -> GrayImage {
365 let (width, height) = image.dimensions();
366
367 if width == 0 || height == 0 {
368 return image.clone();
369 }
370
371 let mut out = GrayImage::new(width, height);
372 let x_filter_side = (2 * x_radius + 1) as usize;
373 let y_filter_side = (2 * y_radius + 1) as usize;
374 let mut neighbors = vec![0u8; x_filter_side * y_filter_side];
375
376 let rx = x_radius as i32;
377 let ry = y_radius as i32;
378
379 for y in 0..height {
380 for x in 0..width {
381 let mut idx = 0;
382
383 for dy in -ry..(ry + 1) {
384 for dx in -rx..(rx + 1) {
385 let px = min(max(0, x as i32 + dx), (width - 1) as i32) as u32;
386 let py = min(max(0, y as i32 + dy), (height - 1) as i32) as u32;
387
388 neighbors[idx] = image.get_pixel(px, py)[0] as u8;
389
390 idx += 1;
391 }
392 }
393
394 neighbors.sort();
395
396 let m = median(&neighbors);
397 out.put_pixel(x, y, Luma([m]));
398 }
399 }
400
401 out
402 }
403
404 fn median(sorted: &[u8]) -> u8 {
405 let mid = sorted.len() / 2;
406 sorted[mid]
407 }
408
409 #[test]
410 fn test_median_filter_matches_reference_implementation() {
411 fn prop(image: GrayTestImage, x_radius: u32, y_radius: u32) -> TestResult {
412 let x_radius = x_radius % 5;
413 let y_radius = y_radius % 5;
414 let expected = reference_median_filter(&image.0, x_radius, y_radius);
415 let actual = median_filter(&image.0, x_radius, y_radius);
416
417 match pixel_diff_summary(&actual, &expected) {
418 None => TestResult::passed(),
419 Some(err) => TestResult::error(err),
420 }
421 }
422 quickcheck(prop as fn(GrayTestImage, u32, u32) -> TestResult);
423 }
424}