Source code for imars3d.backend.corrections.ring_removal

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""iMars3D's ring artifact correction module."""
import logging
import param
from imars3d.backend.util.functions import clamp_max_workers
import scipy
import numpy as np

try:
    import bm3d_streak_removal as bm3dsr
except ImportError:
    bm3dsr = None
from multiprocessing.managers import SharedMemoryManager
from tqdm.contrib.concurrent import process_map
from functools import partial

logger = logging.getLogger(__name__)


[docs] class bm3d_ring_removal(param.ParameterizedFunction): """ Remove ring artifact from sinograms using BM3D method. This method requires BM3D suite, which can be installed by ``pip install bm3d_streak_removal``. ref: `10.1107/S1600577521001910 <http://doi.org/10.1107/S1600577521001910>`_ Parameters ---------- arrays: np.ndarray Input radiograph stack. extreme_streak_iterations: int Number of iterations for extreme streak attenuation. extreme_detect_lambda: float Consider streaks which are stronger than lambda * local_std as extreme. extreme_detect_size: int Half window size for extreme streak detection -- total (2*s + 1). extreme_replace_size: int Half window size for extreme streak replacement -- total (2*s + 1). max_bin_iter_horizontal: int The number of total horizontal scales (counting the full scale). bin_vertical: int The factor of vertical binning, e.g. bin_vertical=32 would perform denoising in 1/32th of the original vertical size. filter_strength: float Strength of BM4D denoising (>0), where 1 is the standard application, >1 is stronger, and <1 is weaker. use_slices: bool If True, the sinograms will be split horizontally across each binning iteration into overlapping. slice_sizes: list A list of horizontal sizes for use of the slicing if use_slices=True. By default, slice size is either 39 pixels or 1/5th of the total width of the current iteration, whichever is larger. slice_step_sizes: list List of number of pixels between slices obtained with use_slices=True, one for each binning iteration. By default 1/4th of the corresponding slice size. denoise_indices: list Indices of sinograms to denoise; by default, denoises the full stack provided. Returns ------- Radiograph stack with ring artifact removed. Notes ----- 1. The parallel processing is handled at the bm3d level, and it is an intrinsic slow correction algorithm running on CPU. 2. The underlying BM3D library uses stdout to print progress instead of a progress bar. """ arrays = param.Array(doc="Input radiograph stack.", default=None) # parameters passed to bm3dsr.extreme_streak_attenuation extreme_streak_iterations = param.Integer(default=3, doc="Number of iterations for extreme streak attenuation.") extreme_detect_lambda = param.Number( default=4.0, doc="Consider streaks which are stronger than lambda * local_std as extreme.", ) extreme_detect_size = param.Integer( default=9, doc="Half window size for extreme streak detection -- total (2*s + 1).", ) extreme_replace_size = param.Integer( default=2, doc="Half window size for extreme streak replacement -- total (2*s + 1).", ) # parameters passed to bm3dsr.multiscale_streak_removal max_bin_iter_horizontal = param.Integer( default=0, doc="The number of total horizontal scales (counting the full scale).", bounds=(0, None), ) bin_vertical = param.Integer( default=0, doc="The factor of vertical binning, e.g. bin_vertical=32 would perform denoising in 1/32th of the original vertical size.", bounds=(0, None), ) filter_strength = param.Number( default=1.0, doc="Strength of BM4D denoising (>0), where 1 is the standard application, >1 is stronger, and <1 is weaker.", bounds=(0, None), ) use_slices = param.Boolean( default=True, doc="If True, the sinograms will be split horizontally across each binning iteration into overlapping.", ) slice_sizes = param.List( default=None, doc="A list of horizontal sizes for use of the slicing if use_slices=True. By default, slice size is either 39 pixels or 1/5th of the total width of the current iteration, whichever is larger.", ) slice_step_sizes = param.List( default=None, doc="List of number of pixels between slices obtained with use_slices=True, one for each binning iteration. By default 1/4th of the corresponding slice size.", ) denoise_indices = param.List( default=None, doc="Indices of sinograms to denoise; by default, denoises the full stack provided.", ) # note: we are skipping the bm3d_profile_obj parameter as bm3d is not explicitly used in iMars3D. def __call__(self, **params): """See class level documentation for help.""" if not bm3dsr: logger.warning("To use method, make sure to install bm3d_streak_removal package via pip.") raise RuntimeError("BM3D suite not installed, please install with pip install bm3d_streak_removal") else: logger.info("Executing Filter: Remove Ring Artifact with BM3D") _ = self.instance(**params) params = param.ParamOverrides(self, params) # mangle parameters if params.max_bin_iter_horizontal == 0: params.max_bin_iter_horizontal = "auto" if params.bin_vertical == 0: params.bin_vertical = "auto" # step 1: extreme streak attenuation logger.debug("Perform extreme streak attenuation") param.arrays = bm3dsr.extreme_streak_attenuation( data=params.arrays, extreme_streak_iterations=params.extreme_streak_iterations, extreme_detect_lambda=params.extreme_detect_lambda, extreme_detect_size=params.extreme_detect_size, extreme_replace_size=params.extreme_replace_size, ) # step 2: multiscale streak removal logger.debug("Perform multiscale streak removal") param.arrays = bm3dsr.multiscale_streak_removal( data=params.arrays, max_bin_iter_horizontal=params.max_bin_iter_horizontal, bin_vertical=params.bin_vertical, filter_strength=params.filter_strength, use_slices=params.use_slices, slice_sizes=params.slice_sizes, slice_step_sizes=params.slice_step_sizes, denoise_indices=params.denoise_indices, ) logger.info("FINISHED Executing Filter: Remove Ring Artifact") return param.arrays
[docs] class remove_ring_artifact(param.ParameterizedFunction): """ Remove ring artifact from radiograph stack using Ketcham method. ref: `10.1117/12.680939 <https://doi.org/10.1117/12.680939>`_ Parameters ---------- arrays: np.ndarray Input radiograph stack. kernel_size: int = 5 The size of the kernel (moving window) during local smoothing with median filter. sub_division: int = 10 Sub-dividing the sinogram into subsections (along rotation angle axis). correction_range: tuple = (0.9, 1.1) Multiplicative correction factor is capped within given range. max_workers: int = 0 Number of cores to use for parallel processing. tqdm_class: panel.widgets.Tqdm Class to be used for rendering tqdm progress Returns ------- Radiograph stack with ring artifact removed. """ arrays = param.Array(doc="Input radiograph stack.", default=None) kernel_size = param.Integer( default=5, doc="The size of the kernel (moving window) during local smoothing with median filter." ) sub_division = param.Integer( default=10, doc="Sub-dividing the sinogram into subsections (along rotation angle axis)." ) correction_range = param.List( default=[0.9, 1.1], doc="Multiplicative correction factor is capped within given range." ) max_workers = param.Integer(default=0, bounds=(0, None), doc="Number of cores to use for parallel processing.") tqdm_class = param.ClassSelector(class_=object, doc="Progress bar to render with") def __call__(self, **params): """See class level documentation for help.""" logger.info("Executing Filter: Remove Ring Artifact") _ = self.instance(**params) params = param.ParamOverrides(self, params) val = self._remove_ring_artifact( params.arrays, params.kernel_size, params.sub_division, params.correction_range, params.max_workers, params.tqdm_class, ) logger.info("FINISHED Executing Filter: Remove Ring Artifact") return val def _remove_ring_artifact( self, arrays: np.ndarray, kernel_size: int, sub_division: int, correction_range: tuple, max_workers: int, tqdm_class, ) -> np.ndarray: # sanity check if arrays.ndim != 3: raise ValueError("This correction can only be used for a stack, i.e. a 3D image.") # NOTE: # additional work is needed to avoid duplicating arrays in memory max_workers = clamp_max_workers(max_workers) # use shared memory to reduce memory footprint with SharedMemoryManager() as smm: # create the shared memory shm = smm.SharedMemory(arrays.nbytes) # create a numpy array point to the shared memory shm_arrays = np.ndarray( arrays.shape, dtype=arrays.dtype, buffer=shm.buf, ) # copy the data to the shared memory np.copyto(shm_arrays, arrays) # invoke mp via tqdm wrapper kwargs = { "max_workers": max_workers, "desc": "Removing ring artifact", } if tqdm_class: kwargs["tqdm_class"] = tqdm_class rst = process_map( partial( _remove_ring_artifact_Ketcham, kernel_size=kernel_size, sub_division=sub_division, correction_range=correction_range, ), [shm_arrays[:, sino_idx, :] for sino_idx in range(shm_arrays.shape[1])], **kwargs, ) rst = np.array(rst) for i in range(arrays.shape[1]): arrays[:, i, :] = rst[i] return arrays
[docs] class remove_ring_artifact_Ketcham(param.ParameterizedFunction): """Ketcham's ring artifact removal method. Use the Ketcham method (doi:`10.1117/12.680939 <https://doi.org/10.1117/12.680939>`_) to remove ring artifact from given sinogram. Parameters ---------- sinogram: np.ndarray Input sinogram. kernel_size: int = 5 The size of the kernel (moving window) during local smoothing via median filter. sub_division: int = 10 Sub-dividing the sinogram into subsections (along rotation angle axis). correction_range: tuple = (0.9, 1.1) Multiplicative correction factor is capped within given range. Returns ------- Sinogram with ring artifact removed. NOTE ---- 0. The ring artifact refers to the halo type artifacts present in the final reconstruction results, which is often caused by local detector/pixel gain error during measurement. 1. This method can only be used on a single sinogram. 2. This method is assuming the ring artifact is of multiplicative nature, i.e. measured = signal * error. """ sinogram = param.Array(doc="Input sinogram.", default=None) kernel_size = param.Integer( default=5, doc="The size of the kernel (moving window) during local smoothing via median filter." ) sub_division = param.Integer( default=10, doc="Sub-dividing the sinogram into subsections (along rotation angle axis)." ) correction_range = param.Tuple( default=(0.9, 1.1), doc="Multiplicative correction factor is capped within given range." ) def __call__(self, **params): """See class level documentation for help.""" logger.info("Executing Filter: Remove Ring Artifact (Ketcham)") _ = self.instance(**params) params = param.ParamOverrides(self, params) val = _remove_ring_artifact_Ketcham( params.sinogram, params.kernel_size, params.sub_division, params.correction_range ) logger.info("FINISHED Executing Filter: Remove Ring Artifact (Ketcham)") return val
def _remove_ring_artifact_Ketcham( sinogram: np.ndarray, kernel_size: int = 5, sub_division: int = 10, correction_range: tuple = (0.9, 1.1), ) -> np.ndarray: # sanity check if sinogram.ndim != 2: raise ValueError("This correction can only be used for a sinogram, i.e. a 2D image.") # sub-divide the sinogram into smaller sections edges = np.linspace(0, sinogram.shape[0], sub_division + 1).astype(int) # corr_ratios = [] for bottom, top in zip(edges[:-1], edges[1:]): sub_sinogram = sinogram[bottom:top] sum_over_angle = sub_sinogram.sum(axis=0) # avoid divide by zero issue when dealing with emission type sinogram sum_over_angle = np.where(sum_over_angle == 0, 1.0, sum_over_angle) sum_over_angle_smoothed = scipy.signal.medfilt(sum_over_angle, kernel_size) # identify regions where large local variation occurs # - skip one window/kernel on both end (mostly air anyway) # - correction ratio is capped within specified range corr_ratio = np.ones_like(sum_over_angle) corr_ratio[kernel_size:-kernel_size] = (sum_over_angle_smoothed / sum_over_angle)[kernel_size:-kernel_size] corr_ratio[corr_ratio < correction_range[0]] = correction_range[0] corr_ratio[corr_ratio > correction_range[1]] = correction_range[1] # corr_ratios.append(corr_ratio) # use median to select the most probable correction ratio from all sub-sinograms corr_raio = np.median(corr_ratios, axis=0) return sinogram * corr_raio[np.newaxis, :]