Source code for MEDiml.filters.utils

from typing import List

import numpy as np
from scipy.signal import fftconvolve


[docs] def pad_imgs( images: np.ndarray, padding_length: List, axis: List, mode: str )-> np.ndarray: """Apply padding on a 3d images using a 2D padding pattern. Args: images (ndarray): a numpy array that represent the image. padding_length (List): The padding length that will apply on each side of each axe. axis (List): A list of axes on which the padding will be done. mode (str): The padding mode. Check options here: `numpy.pad <https://numpy.org/doc/stable/reference/generated/numpy.pad.html>`__. Returns: ndarray: A numpy array that represent the padded image. """ pad_tuple = () j = 1 for i in range(np.ndim(images)): if i in axis: pad_tuple += ((padding_length[-j], padding_length[-j]),) j += 1 else: pad_tuple += ((0, 0),) return np.pad(images, pad_tuple, mode=mode)
[docs] def convolve( dim: int, kernel: np.ndarray, images: np.ndarray, orthogonal_rot: bool=False, mode: str = "symmetric" ) -> np.ndarray: """Convolve a given n-dimensional array with the kernel to generate a filtered image. Args: dim (int): The dimension of the images. kernel (ndarray): The kernel to use for the convolution. images (ndarray): A n-dimensional numpy array that represent a batch of images to filter. orthogonal_rot (bool, optional): If true, the 3D images will be rotated over coronal, axial and sagittal axis. mode (str, optional): The padding mode. Check options here: `numpy.pad <https://numpy.org/doc/stable/reference/generated/numpy.pad.html>`__. Returns: ndarray: The filtered image. """ in_size = np.shape(images) # We only handle 2D or 3D images. assert len(in_size) == 3 or len(in_size) == 4, \ "The tensor should have the followed shape (B, H, W) or (B, D, H, W)" if not orthogonal_rot: # If we have a 2D kernel but a 3D images, we squeeze the tensor if dim < len(in_size) - 1: images = images.reshape((in_size[0] * in_size[1], in_size[2], in_size[3])) # We compute the padding size along each dimension padding = [int((kernel.shape[-1] - 1) / 2) for _ in range(dim)] pad_axis_list = [i for i in range(1, dim+1)] # We pad the images and we add the channel axis. padded_imgs = pad_imgs(images, padding, pad_axis_list, mode) new_imgs = np.expand_dims(padded_imgs, axis=1) # Operate the convolution if dim < len(in_size) - 1: # If we have a 2D kernel but a 3D images, we convolve slice by slice result_list = [fftconvolve(np.expand_dims(new_imgs[i], axis=0), kernel, mode='valid') for i in range(len(images))] result = np.squeeze(np.stack(result_list), axis=2) else : result = fftconvolve(new_imgs, kernel, mode='valid') # Reshape the data to retrieve the following format: (B, C, D, H, W) if dim < len(in_size) - 1: result = result.reshape(( in_size[0], in_size[1], result.shape[1], in_size[2], in_size[3]) ).transpose(0, 2, 1, 3, 4) # If we want orthogonal rotation else: coronal_imgs = images axial_imgs, sagittal_imgs = np.rot90(images, 1, (1, 2)), np.rot90(images, 1, (1, 3)) result_coronal = convolve(dim, kernel, coronal_imgs, False, mode) result_axial = convolve(dim, kernel, axial_imgs, False, mode) result_sagittal = convolve(dim, kernel, sagittal_imgs, False, mode) # split and unflip and stack the result on a new axis result_axial = np.rot90(result_axial, 1, (3, 2)) result_sagittal = np.rot90(result_sagittal, 1, (4, 2)) result = np.stack([result_coronal, result_axial, result_sagittal]) return result