use crate::definitions::Image;
use image::{GenericImageView, Pixel};
use std::cmp::{max, min};
#[must_use = "the function does not modify the original image"]
pub fn median_filter<P>(image: &Image<P>, x_radius: u32, y_radius: u32) -> Image<P>
where
P: Pixel<Subpixel = u8>,
{
let (width, height) = image.dimensions();
if width == 0 || height == 0 {
return image.clone();
}
let mut out = Image::<P>::new(width, height);
let rx = x_radius as i32;
let ry = y_radius as i32;
let mut hist = initialise_histogram_for_top_left_pixel(image, x_radius, y_radius);
slide_down_column(&mut hist, image, &mut out, 0, rx, ry);
for x in 1..width {
if x % 2 == 0 {
slide_right(&mut hist, image, x, 0, rx, ry);
slide_down_column(&mut hist, image, &mut out, x, rx, ry);
} else {
slide_right(&mut hist, image, x, height - 1, rx, ry);
slide_up_column(&mut hist, image, &mut out, x, rx, ry);
}
}
out
}
fn initialise_histogram_for_top_left_pixel<P>(
image: &Image<P>,
x_radius: u32,
y_radius: u32,
) -> HistSet
where
P: Pixel<Subpixel = u8>,
{
let (width, height) = image.dimensions();
let kernel_size = (2 * x_radius + 1) * (2 * y_radius + 1);
let num_channels = P::CHANNEL_COUNT;
let mut hist = HistSet::new(num_channels, kernel_size);
let rx = x_radius as i32;
let ry = y_radius as i32;
for dy in -ry..(ry + 1) {
let py = min(max(0, dy), height as i32 - 1) as u32;
for dx in -rx..(rx + 1) {
let px = min(max(0, dx), width as i32 - 1) as u32;
hist.incr(image, px, py);
}
}
hist
}
fn slide_right<P>(hist: &mut HistSet, image: &Image<P>, x: u32, y: u32, rx: i32, ry: i32)
where
P: Pixel<Subpixel = u8>,
{
let (width, height) = image.dimensions();
let prev_x = max(0, x as i32 - rx - 1) as u32;
let next_x = min(x as i32 + rx, width as i32 - 1) as u32;
for dy in -ry..(ry + 1) {
let py = min(max(0, y as i32 + dy), (height - 1) as i32) as u32;
hist.decr(image, prev_x, py);
hist.incr(image, next_x, py);
}
}
fn slide_down_column<P>(
hist: &mut HistSet,
image: &Image<P>,
out: &mut Image<P>,
x: u32,
rx: i32,
ry: i32,
) where
P: Pixel<Subpixel = u8>,
{
let (width, height) = image.dimensions();
hist.set_to_median(out, x, 0);
for y in 1..height {
let prev_y = max(0, y as i32 - ry - 1) as u32;
let next_y = min(y as i32 + ry, height as i32 - 1) as u32;
for dx in -rx..(rx + 1) {
let px = min(max(0, x as i32 + dx), (width - 1) as i32) as u32;
hist.decr(image, px, prev_y);
hist.incr(image, px, next_y);
}
hist.set_to_median(out, x, y);
}
}
fn slide_up_column<P>(
hist: &mut HistSet,
image: &Image<P>,
out: &mut Image<P>,
x: u32,
rx: i32,
ry: i32,
) where
P: Pixel<Subpixel = u8>,
{
let (width, height) = image.dimensions();
hist.set_to_median(out, x, height - 1);
for y in (0..(height - 1)).rev() {
let prev_y = min(y as i32 + ry + 1, height as i32 - 1) as u32;
let next_y = max(0, y as i32 - ry) as u32;
for dx in -rx..(rx + 1) {
let px = min(max(0, x as i32 + dx), (width - 1) as i32) as u32;
hist.decr(image, px, prev_y);
hist.incr(image, px, next_y);
}
hist.set_to_median(out, x, y);
}
}
struct HistSet {
data: Vec<[u32; 256]>,
expected_count: u32,
}
impl HistSet {
fn new(num_channels: u8, expected_count: u32) -> HistSet {
let mut data = vec![];
for _ in 0..num_channels {
data.push([0u32; 256]);
}
HistSet {
data,
expected_count,
}
}
fn incr<P>(&mut self, image: &Image<P>, x: u32, y: u32)
where
P: Pixel<Subpixel = u8>,
{
unsafe {
let pixel = image.unsafe_get_pixel(x, y);
let channels = pixel.channels();
for c in 0..channels.len() {
let p = *channels.get_unchecked(c) as usize;
let hist = self.data.get_unchecked_mut(c);
*hist.get_unchecked_mut(p) += 1;
}
}
}
fn decr<P>(&mut self, image: &Image<P>, x: u32, y: u32)
where
P: Pixel<Subpixel = u8>,
{
unsafe {
let pixel = image.unsafe_get_pixel(x, y);
let channels = pixel.channels();
for c in 0..channels.len() {
let p = *channels.get_unchecked(c) as usize;
let hist = self.data.get_unchecked_mut(c);
*hist.get_unchecked_mut(p) -= 1;
}
}
}
fn set_to_median<P>(&self, image: &mut Image<P>, x: u32, y: u32)
where
P: Pixel<Subpixel = u8>,
{
unsafe {
let target = image.get_pixel_mut(x, y);
let channels = target.channels_mut();
for c in 0..channels.len() {
*channels.get_unchecked_mut(c) = self.channel_median(c as u8);
}
}
}
fn channel_median(&self, c: u8) -> u8 {
let hist = unsafe { self.data.get_unchecked(c as usize) };
let mut count = 0;
for i in 0..256 {
unsafe {
count += *hist.get_unchecked(i);
}
if 2 * count >= self.expected_count {
return i as u8;
}
}
255
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::property_testing::GrayTestImage;
use crate::utils::gray_bench_image;
use crate::utils::pixel_diff_summary;
use image::{GrayImage, Luma};
use quickcheck::{quickcheck, TestResult};
use std::cmp::{max, min};
use test::{black_box, Bencher};
macro_rules! bench_median_filter {
($name:ident, side: $s:expr, x_radius: $rx:expr, y_radius: $ry:expr) => {
#[bench]
fn $name(b: &mut Bencher) {
let image = gray_bench_image($s, $s);
b.iter(|| {
let filtered = median_filter(&image, $rx, $ry);
black_box(filtered);
})
}
};
}
bench_median_filter!(bench_median_filter_s100_r1, side: 100, x_radius: 1,y_radius: 1);
bench_median_filter!(bench_median_filter_s100_r4, side: 100, x_radius: 4,y_radius: 4);
bench_median_filter!(bench_median_filter_s100_r8, side: 100, x_radius: 8,y_radius: 8);
bench_median_filter!(bench_median_filter_s100_rx1_ry4, side: 100, x_radius: 1,y_radius: 4);
bench_median_filter!(bench_median_filter_s100_rx1_ry8, side: 100, x_radius: 1,y_radius: 8);
bench_median_filter!(bench_median_filter_s100_rx4_ry8, side: 100, x_radius: 4,y_radius: 1);
bench_median_filter!(bench_median_filter_s100_rx8_ry1, side: 100, x_radius: 8,y_radius: 1);
fn reference_median_filter(image: &GrayImage, x_radius: u32, y_radius: u32) -> GrayImage {
let (width, height) = image.dimensions();
if width == 0 || height == 0 {
return image.clone();
}
let mut out = GrayImage::new(width, height);
let x_filter_side = (2 * x_radius + 1) as usize;
let y_filter_side = (2 * y_radius + 1) as usize;
let mut neighbors = vec![0u8; x_filter_side * y_filter_side];
let rx = x_radius as i32;
let ry = y_radius as i32;
for y in 0..height {
for x in 0..width {
let mut idx = 0;
for dy in -ry..(ry + 1) {
for dx in -rx..(rx + 1) {
let px = min(max(0, x as i32 + dx), (width - 1) as i32) as u32;
let py = min(max(0, y as i32 + dy), (height - 1) as i32) as u32;
neighbors[idx] = image.get_pixel(px, py)[0] as u8;
idx += 1;
}
}
neighbors.sort();
let m = median(&neighbors);
out.put_pixel(x, y, Luma([m]));
}
}
out
}
fn median(sorted: &[u8]) -> u8 {
let mid = sorted.len() / 2;
sorted[mid]
}
#[test]
fn test_median_filter_matches_reference_implementation() {
fn prop(image: GrayTestImage, x_radius: u32, y_radius: u32) -> TestResult {
let x_radius = x_radius % 5;
let y_radius = y_radius % 5;
let expected = reference_median_filter(&image.0, x_radius, y_radius);
let actual = median_filter(&image.0, x_radius, y_radius);
match pixel_diff_summary(&actual, &expected) {
None => TestResult::passed(),
Some(err) => TestResult::error(err),
}
}
quickcheck(prop as fn(GrayTestImage, u32, u32) -> TestResult);
}
}