Source code for imars3d.ui.stages.reconstruction

#!/usr/bin/env python3
"""Reconstruction stage for iMars3D."""

from pathlib import Path

import holoviews as hv
import numpy as np
import panel as pn
import param
from holoviews import opts, streams
from holoviews.operation.datashader import rasterize

from imars3d.backend.dataio.data import save_checkpoint as imars_save_checkpoint
from imars3d.backend.dataio.data import save_data
from imars3d.backend.reconstruction import recon
from imars3d.ui.widgets.rotation import FindRotationCenter


[docs] class Reconstruction(param.Parameterized): """Panel for conduction guided reconstruction with iMars3D.""" # -- data container # ** input data from previous step ct = param.Array( doc="radiograph stack as numpy array", precedence=-1, # hide ) omegas = param.Array( doc="rotation angles in degress derived from tiff filename.", ) recn_root = param.Path( default=Path.home(), doc="reconstruction results root, default should be proj_root/shared/processed_data", ) temp_root = param.Path( default=Path.home() / Path("tmp"), doc="intermedia results save location", ) recn_name = param.String( default="myrecon", doc="reconstruction results folder name", ) # ** output reconstruction results recon = param.Array( doc="reconstruction results as numpy array", precedence=-1, # hide ) # reconstruction control algorithm = param.Selector( default="gridrec", objects=[ "fbp", "gridrec", "art", "bart", "mlem", "osem", "ospml_hybrid", "ospml_quad", "pml_hybrid", "pml_quad", "sirt", "tv", "grad", "tikh", ], doc="algorithm provided by tomopy", ) post_recon_filter = param.Selector( default="hann", objects=["none", "shepp", "cosine", "hann", "hamming", "ramlak", "parzen", "butterworth"], doc="post recon filter", ) # -- viewer idx_active_ct = param.Integer(default=0, doc="index of active ct") # cmap colormap = param.Selector( default="gray", objects=["gray", "viridis", "plasma"], doc="colormap used for images", ) colormap_scale = param.Selector( default="linear", objects=["linear", "log", "eq_hist"], doc="colormap scale for displaying images", ) # image size frame_width = param.Integer(default=500, doc="viewer frame size") # check point ct_checkpoint_action = param.Action(lambda x: x.param.trigger("ct_checkpoint_action"), label="Checkpoint") # rotation center rotation_center_finder = FindRotationCenter() # execute = param.Action(lambda x: x.param.trigger("execute"), label="Execute") status = param.Boolean(default=False, doc="Ring removal completion status") # save recon results recon_save = param.Action(lambda x: x.param.trigger("recon_save"), label="Checkpoint")
[docs] @param.depends("execute", watch=True) def apply(self): """Run reconstruction.""" # sanity check if self.ct is None: pn.state.warning("no CT found!") return if self.omegas is None: pn.state.warning("no omegas provided") return # self.recon = recon( arrays=self.ct, theta=np.radians(self.omegas), center=self.rotation_center_finder.rot_center, algorithm=self.algorithm, filter_name=self.post_recon_filter, ) # self.status = True # pn.state.notifications.success("Reconstruction complete.", duration=3000)
[docs] @param.output( ("recon", param.Array), ) def output(self): """Return reconstruction results to next step.""" return self.recon
[docs] @param.depends("ct_checkpoint_action", watch=True) def save_checkpoint(self): """Save current ct to checkpoint.""" if self.ct is None: pn.state.warning("No CT to save") else: imars_save_checkpoint(data=self.ct, outputbase=self.temp_root, name=self.recn_name, omegas=self.omegas)
[docs] @param.depends("recon_save", watch=True) def save_reconstruction_results(self): """Save reconstruction results to disk.""" if self.recon is None: pn.state.warning("No reconstruction results to save") else: savedirname = save_data(data=self.ct, outputbase=self.recn_root, name=self.recn_name) # save the rotation center with open(f"{savedirname}/rot_center.txt", "w") as f_rotcnt: f_rotcnt.write(f"{self.rotation_center_finder.rot_center}")
[docs] def cross_hair_view(self, x, y): """Return cross hair view of the image.""" return (hv.HLine(y) * hv.VLine(x)).opts( opts.HLine( color="yellow", line_width=0.5, ), opts.VLine( color="yellow", line_width=0.5, ), )
def _sino_view(self, x, y): # sino = hv.Image( ( np.arange(self.ct.shape[2]), np.arange(self.ct.shape[0]), self.ct[:, int(y), :], ), kdims=["x", "ω"], vdims=["count"], ) return (sino * hv.VLine(x) * hv.HLine(self.idx_active_ct)).opts( opts.Image( frame_width=self.frame_width, tools=["hover"], cmap=self.colormap, cnorm=self.colormap_scale, invert_yaxis=True, xaxis=None, yaxis=None, title="sinogram", ), opts.VLine( color="yellow", line_width=0.5, ), opts.HLine( color="yellow", line_width=0.5, ), ) def _ct_active_view(self): ct_active = self.ct[self.idx_active_ct] return hv.Image( ( np.arange(ct_active.shape[1]), np.arange(ct_active.shape[0]), ct_active, ), kdims=["x", "y"], vdims=["count"], ).opts( opts.Image( frame_width=self.frame_width, tools=["hover"], cmap=self.colormap, cnorm=self.colormap_scale, data_aspect=1.0, invert_yaxis=True, xaxis=None, yaxis=None, ) )
[docs] @param.depends( "frame_width", "idx_active_ct", "colormap", "colormap_scale", ) def ct_viewer(self): """Radiograph viewer.""" if self.ct is None: return pn.pane.Markdown("no CT to display") # ct image object img = self._ct_active_view() # cross-hair pointer crosshair = streams.PointerXY(x=0, y=0, source=img) crosshair_dmap = hv.DynamicMap(self.cross_hair_view, streams=[crosshair]) # sinogram view as dynamic map sino_dmap = hv.DynamicMap(self._sino_view, streams=[crosshair]) # viewer = rasterize(img.hist() * crosshair_dmap + sino_dmap).cols(1) # save_ct_button = pn.widgets.Button.from_param( self.param.ct_checkpoint_action, name="Save CT", width=self.frame_width // 4, align="center", ) ct_num_control = pn.widgets.IntSlider.from_param( self.param.idx_active_ct, start=0, end=self.ct.shape[0], name="CT num", width=self.frame_width // 2, ) ct_control = pn.Row(save_ct_button, ct_num_control) return pn.Column(ct_control, viewer)
[docs] @param.depends( "frame_width", "recon", "colormap", "colormap_scale", ) def recon_viewer(self): """Reconstruction viewer.""" if self.recon is None: return pn.pane.Markdown("no recontruction results to display") # ds = hv.Dataset( ( np.arange(self.recon.shape[2]), np.arange(self.recon.shape[1]), np.arange(self.recon.shape[0]), self.recon, ), ["x", "z", "y"], "intensity", ) img = ds.to(hv.Image, ["x", "z"], dynamic=True).opts( opts.Image( frame_width=self.frame_width, tools=["hover"], cmap=self.colormap, cnorm=self.colormap_scale, data_aspect=1.0, invert_yaxis=True, xaxis=None, yaxis=None, ) ) # save_recon_button = pn.widgets.Button.from_param( self.param.recon_save, name="Save Reconstruction", width=self.frame_width // 4, align="center", ) return pn.Column(save_recon_button, rasterize(img.hist()))
[docs] def recon_panel(self, width=200): """Reconstruction panel.""" # methods algorithm_select = pn.widgets.Select.from_param( self.param.algorithm, name="Algorithm", width=int(width / 1.2), ) filter_select = pn.widgets.Select.from_param( self.param.post_recon_filter, name="PostReconFilter", width=int(width / 1.2), ) # action status_indicator = pn.widgets.BooleanStatus.from_param( self.param.status, color="success", ) execute_button = pn.widgets.Button.from_param( self.param.execute, width=width // 2, ) action_pn = pn.Row(status_indicator, execute_button, width=width) # recon_panel = pn.WidgetBox( "**Reconstruction control**", algorithm_select, filter_select, action_pn, width=width, ) return recon_panel
[docs] def plot_control(self, width=80): """Plot control panel.""" # color map cmap = pn.widgets.Select.from_param( self.param.colormap, name="colormap", ) cmapscale = pn.widgets.Select.from_param( self.param.colormap_scale, name="colormap scale", ) framewidth = pn.widgets.LiteralInput.from_param( self.param.frame_width, name="frame width", ) plot_pn = pn.Card( cmap, cmapscale, framewidth, width=width, header="Diaply", collapsible=True, ) return plot_pn
[docs] def panel(self): """App panel view.""" # rotation center finder self.rotation_center_finder.parent = self # -- side panel width = self.frame_width // 2 rotcnt_pn = pn.WidgetBox( "Rotation Center", self.rotation_center_finder.panel(width=width), ) side_pn = pn.Column( self.plot_control(width=width), rotcnt_pn, self.recon_panel(width=width), width=int(width * 1.1), ) # -- viewer viewer = pn.Tabs( ("CT", self.ct_viewer), ("Recon", self.recon_viewer), ) # app = pn.Row(side_pn, viewer) return app