from __future__ import annotations
import numpy as np
import random
import torch
from torchvision.transforms import v2
import math
from einops import rearrange
from loguru import logger
import torch.nn.functional as F
from typing import Any, Dict, Tuple, List
class CustomGaussianBlur(object):
"""
Applies a Gaussian blur to the input image using torchvision's v2.GaussianBlur.
This wrapper allows the transform to be applied to both PIL Images and NumPy arrays.
If the input is a NumPy array, it is converted to a PyTorch tensor, the Gaussian blur is applied, and then it is converted back to a NumPy array.
"""
def __init__(self, kernel_size, sigma):
self.transform = v2.GaussianBlur(kernel_size=kernel_size, sigma=sigma)
def __call__(self, img):
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
return self.transform(img).numpy()
else:
return self.transform(img)
def custom_median_filter(input_tensor: torch.Tensor, kernel_size: int = 3, padding: str = 'reflect') -> torch.Tensor:
"""
Applies a median filter to a 4D input tensor (batch, channels, height, width).
Args:
input_tensor (torch.Tensor): Input tensor of shape (B, C, H, W)
kernel_size (int): Size of the kernel (must be odd, e.g., 3, 5, 7)
padding (str): Padding mode ('reflect', 'replicate', or 'constant')
Returns:
torch.Tensor: Filtered tensor of the same shape as input
"""
if input_tensor.ndim == 3:
input_tensor = input_tensor.unsqueeze(0) # Add batch dimension if missing
SHAPE_ADDED = True
else:
SHAPE_ADDED = False
# Ensure kernel_size is odd
assert kernel_size % 2 == 1, "Kernel size must be odd"
# Calculate padding
pad = kernel_size // 2
# Pad the input tensor
padded = F.pad(input_tensor, (pad, pad, pad, pad), mode=padding)
# Unfold the tensor to get all patches of size kernel_size x kernel_size
# Shape after unfold: (B, C, H * W, kernel_size * kernel_size)
patches = padded.unfold(2, kernel_size, 1).unfold(3, kernel_size, 1)
B, C, H, W, _, _ = patches.shape
patches = patches.reshape(B, C, H, W, kernel_size * kernel_size)
# Compute median along the last dimension (across the kernel)
# Shape after median: (B, C, H, W)
filtered = torch.median(patches, dim=-1).values
if SHAPE_ADDED:
filtered = filtered.squeeze(0) # Remove batch dimension if it was added
return filtered
[docs]
class FilterFactory:
"""
A factory class to create and apply a sequence of filters to an input tensor.
"""
def __init__(
self,
filters_to_apply: List[str],
filter_params: Dict[str, Dict[str, Any]],
):
"""
Initializes the FilterFactory with the specified filters and their parameters.
Args:
filters_to_apply (List[str]): A list of filter names to apply in sequence. Supported filters include "gaussian_blur" and "median_filter".
filter_params (Dict[str, Dict[str, Any]]): A dictionary mapping filter names to their respective parameters. For example:
{
"gaussian_blur": {"kernel_size": 5, "sigma": 1.0},
"median_filter": {"kernel_size": 3, "padding": "reflect"}
}
"""
self.filters_to_apply = filters_to_apply
self.filter_params = filter_params
for filter_name in self.filters_to_apply:
if filter_name == "gaussian_blur":
params = self.filter_params.get(filter_name, {})
kernel_size = params.get("kernel_size", 3)
sigma = params.get("sigma", 1.0)
setattr(self, filter_name, CustomGaussianBlur(kernel_size=kernel_size, sigma=sigma))
elif filter_name == "median_filter":
params = self.filter_params.get(filter_name, {})
kernel_size = params.get("kernel_size", 3)
padding = params.get("padding", 'reflect')
setattr(self, filter_name, lambda x, k=kernel_size, p=padding: custom_median_filter(x, kernel_size=k, padding=p))
else:
raise ValueError(f"Unsupported filter {filter_name} provided to FilterFactory.")
def _ensure_tensor(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
if isinstance(x, np.ndarray):
return torch.from_numpy(x)
elif isinstance(x, torch.Tensor):
return x
else:
raise ValueError(f"Input must be a numpy array or a torch tensor, but got {type(x)}.")
def apply_filters(self, x: np.ndarray | torch.Tensor) -> torch.Tensor:
x_t = self._ensure_tensor(x)
for filter_name in self.filters_to_apply:
filter_fn = getattr(self, filter_name)
x_t = filter_fn(x_t)
return x_t
def print_filters(self):
print(f"Filters to apply: {self.filters_to_apply}")