Source code for MEDiml.filters.wavelet

import math
from itertools import combinations, permutations
from typing import List, Union

import numpy as np
import pywt

from ..MEDscan import MEDscan
from ..utils.image_volume_obj import image_volume_obj


[docs] class Wavelet(): """ The wavelet filter class. """
[docs] def __init__( self, ndims: int, wavelet_name="haar", padding="symmetric", rot_invariance=False): """The constructor of the wavelet filter Args: ndims (int): The number of dimension of the images that will be filter as int. wavelet_name (str): The name of the wavelet kernel as string. padding (str): The padding type that will be used to produce the convolution rot_invariance (bool): If true, rotation invariance will be done on the images. Returns: None """ self.dim = ndims self.padding = padding self.rot = rot_invariance self.wavelet = None self.kernel_length = None self.create_kernel(wavelet_name)
[docs] def create_kernel(self, wavelet_name: str): """Get the wavelet object and his kernel length. Args: wavelet_name (str): A string that represent the wavelet name that will be use to create the kernel Returns: None """ self.wavelet = pywt.Wavelet(wavelet_name) self.kernel_length = max(self.wavelet.rec_len, self.wavelet.dec_len)
def __unpad(self, images: np.ndarray, padding: List) -> np.ndarray: """Unpad a batch of images Args: images: A numpy nd-array or a list that represent the batch of padded images. The shape should be (B, H, W) or (B, H, W, D) padding: a list of length 2*self.dim that gives the length of padding on each side of each axis. Returns: ndarray: A numpy nd-array or a list that represent the batch of unpadded images """ if self.dim == 2: return images[:, padding[0]:-padding[1], padding[2]:-padding[3]] elif self.dim == 3: return images[:, padding[0]:-padding[1], padding[2]:-padding[3], padding[4]:-padding[5]] else: raise NotImplementedError def __get_pad_length(self, image_shape: List, level: int) -> np.ndarray: """Compute the padding length needed to have a padded image where the length along each axis is a multiple 2^level. Args: image_shape (List): a list of integer that describe the length of the image along each axis. level (int): The level of the wavelet transform Returns: ndarray: An integer list of length 2*self.dim that gives the length of padding on each side of each axis. """ padding = [] ker_length = self.kernel_length*level for l in image_shape: padded_length = math.ceil((l + 2*(ker_length-1)) / 2**level) * 2**level - l padding.extend([math.floor(padded_length/2), math.ceil(padded_length/2)]) return padding
[docs] def _pad_imgs(self, images: np.ndarray, padding, axis: List): """Apply padding on a 3d images using a 2D padding pattern (special for wavelet). Args: images: a numpy array that represent the image. padding: The padding length that will apply on each side of each axe. axis: A list of axes on which the padding will be done. Returns: ndarray: A numpy array that represent the padded image. """ pad_tuple = () j = 0 for i in range(np.ndim(images)): if i in axis: pad_tuple += ((padding[j], padding[j+1]),) j += 2 else: pad_tuple += ((0, 0),) return np.pad(images, pad_tuple, mode=self.padding)
[docs] def convolve(self, images: np.ndarray, _filter="LHL", level=1)-> np.ndarray: """Filter a given batch of images using pywavelet. Args: images (ndarray): A n-dimensional numpy array that represent the images to filter _filter (str): The filter to uses. level (int): The number of decomposition steps to perform. Returns: ndarray: The filtered image as numpy nd-array """ # We pad the images padding = self.__get_pad_length(np.shape(images[0]), level) axis_list = [i for i in range(0, self.dim)] images = np.expand_dims(self._pad_imgs(images[0], padding, axis_list), axis=0) # We generate the to collect the result from pywavelet dictionary _index = str().join(['a' if _filter[i] == 'L' else 'd' for i in range(len(_filter))]) if self.rot: result = [] _index_list = np.unique([str().join(perm) for perm in permutations(_index, self.dim)]) # For each images, we flip each axis. for image in images: axis_rot = [comb for j in range(self.dim+1) for comb in combinations(np.arange(self.dim), j)] images_rot = [np.flip(image, axis) for axis in axis_rot] res_rot = [] for i in range(len(images_rot)): filtered_image = pywt.swtn(images_rot[i], self.wavelet, level=level)[0] res_rot.extend([np.flip(filtered_image[j], axis=axis_rot[i]) for j in _index_list]) result.extend([np.mean(res_rot, axis=0)]) else: result = [] for i in range(len(images)): result.extend([pywt.swtn(images[i], self.wavelet, level=level)[level-1][_index]]) return self.__unpad(np.array(result), padding)
[docs] def apply_wavelet( input_images: Union[np.ndarray, image_volume_obj], medscan: MEDscan = None, ndims: int = 3, wavelet_name: str = "haar", subband: str = "LHL", level: int = 1, padding: str = "symmetric", rot_invariance: bool = False ) -> np.ndarray: """Apply the mean filter to the input image Args: input_images (ndarray): The image to filter. medscan (MEDscan, optional): The MEDscan object that will provide the filter parameters. ndims (int, optional): The number of dimensions of the input image. wavelet_name (str): The name of the wavelet kernel as string. level (List[str], optional): The number of decompositions steps to perform. subband (str, optional): String of the 1D wavelet kernels ("H" for high-pass filter or "L" for low-pass filter). Must have a size of ``ndims``. padding (str, optional): The padding type that will be used to produce the convolution. Check options here: `numpy.pad <https://numpy.org/doc/stable/reference/generated/numpy.pad.html>`__. rot_invariance (bool, optional): If true, rotation invariance will be done on the kernel. Returns: ndarray: The filtered image. """ # Check if the input is a numpy array or a Image volume object spatial_ref = None if type(input_images) == image_volume_obj: spatial_ref = input_images.spatialRef input_images = input_images.data # Convert to shape : (B, W, H, D) input_images = np.expand_dims(input_images.astype(np.float64), axis=0) if medscan: # Initialize filter class instance _filter = Wavelet( ndims=medscan.params.filter.wavelet.ndims, wavelet_name=medscan.params.filter.wavelet.basis_function, rot_invariance=medscan.params.filter.wavelet.rot_invariance, padding=medscan.params.filter.wavelet.padding ) # Run convolution result = _filter.convolve( input_images, _filter=medscan.params.filter.wavelet.subband, level=medscan.params.filter.wavelet.level ) else: # Initialize filter class instance _filter = Wavelet( ndims=ndims, wavelet_name=wavelet_name, rot_invariance=rot_invariance, padding=padding ) # Run convolution result = _filter.convolve( input_images, _filter=subband, level=level ) if spatial_ref: return image_volume_obj(np.squeeze(result), spatial_ref) else: return np.squeeze(result)