diff --git a/.travis.yml b/.travis.yml index aa2f612e..ec9f9d91 100644 --- a/.travis.yml +++ b/.travis.yml @@ -32,7 +32,7 @@ before_install: # Install packages install: - - conda install --yes cython numpy scipy h5py + - conda install --yes cython numpy scipy h5py pytest - pip install matplotlib - python setup.py install - pip install wget diff --git a/opmd_viewer/addons/pic/lpa_diagnostics.py b/opmd_viewer/addons/pic/lpa_diagnostics.py index 7cfc5d59..d42a9b38 100644 --- a/opmd_viewer/addons/pic/lpa_diagnostics.py +++ b/opmd_viewer/addons/pic/lpa_diagnostics.py @@ -14,12 +14,7 @@ import numpy as np import scipy.constants as const from scipy.optimize import curve_fit -from opmd_viewer.openpmd_timeseries.plotter import check_matplotlib -try: - import matplotlib.pyplot as plt -except ImportError: - # Any error will be caught later by `check_matplotlib` - pass +from opmd_viewer.openpmd_timeseries.plotter import check_matplotlib_and_axis class LpaDiagnostics( OpenPMDTimeSeries ): @@ -97,7 +92,8 @@ def get_mean_gamma( self, t=None, iteration=None, species=None, return( mean_gamma, std_gamma ) def get_sigma_gamma_slice(self, dz, t=None, iteration=None, species=None, - select=None, plot=False, **kw): + select=None, plot=False, ax=None, + figsize=None, **kw): """ Calculate the standard deviation of gamma for particles in z-slices of width dz @@ -128,6 +124,12 @@ def get_sigma_gamma_slice(self, dz, t=None, iteration=None, species=None, plot : bool, optional Whether to plot the requested quantity + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + **kw : dict, otional Additional options to be passed to matplotlib's `plot` method @@ -158,14 +160,14 @@ def get_sigma_gamma_slice(self, dz, t=None, iteration=None, species=None, i += 1 # Plot the result if needed if plot: - check_matplotlib() + ax = check_matplotlib_and_axis(ax, figsize) iteration = self.iterations[ self._current_i ] time_fs = 1.e15 * self.t[ self._current_i ] - plt.plot(z_pos, spreads, **kw) - plt.title("Slice energy spread at %.1f fs (iteration %d)" + ax.plot(z_pos, spreads, **kw) + ax.set_title("Slice energy spread at %.1f fs (iteration %d)" % (time_fs, iteration), fontsize=self.plotter.fontsize) - plt.xlabel('$z \;(\mu m)$', fontsize=self.plotter.fontsize) - plt.ylabel('$\sigma_\gamma (\Delta_z=%s\mu m)$' % dz, + ax.set_xlabel('$z \;(\mu m)$', fontsize=self.plotter.fontsize) + ax.set_ylabel('$\sigma_\gamma (\Delta_z=%s\mu m)$' % dz, fontsize=self.plotter.fontsize) return(spreads, z_pos) @@ -365,7 +367,7 @@ def get_emittance(self, t=None, iteration=None, species=None, return emittance_from_coord(x, y, ux, uy, w) def get_current( self, t=None, iteration=None, species=None, select=None, - bins=100, plot=False, **kw ): + bins=100, plot=False, ax=None, figsize=None, **kw ): """ Calculate the electric current along the z-axis for selected particles. @@ -395,6 +397,12 @@ def get_current( self, t=None, iteration=None, species=None, select=None, plot : bool, optional Whether to plot the requested quantity + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + **kw : dict, otional Additional options to be passed to matplotlib's `plot` method Returns @@ -425,20 +433,21 @@ def get_current( self, t=None, iteration=None, species=None, select=None, global_offset=(np.min(z) + len_z / bins / 2,), position=(0,)) # Plot the result if needed if plot: - check_matplotlib() + ax = check_matplotlib_and_axis(ax, figsize) iteration = self.iterations[ self._current_i ] time_fs = 1.e15 * self.t[ self._current_i ] - plt.plot( info.z, current, **kw) - plt.title("Current at %.1f fs (iteration %d)" + ax.plot( info.z, current, **kw) + ax.set_title("Current at %.1f fs (iteration %d)" % (time_fs, iteration ), fontsize=self.plotter.fontsize) - plt.xlabel('$z \;(\mu m)$', fontsize=self.plotter.fontsize) - plt.ylabel('$I \;(A)$', fontsize=self.plotter.fontsize) + ax.set_xlabel('$z \;(\mu m)$', fontsize=self.plotter.fontsize) + ax.set_ylabel('$I \;(A)$', fontsize=self.plotter.fontsize) # Return the current and bin centers return(current, info) def get_laser_envelope( self, t=None, iteration=None, pol=None, m='all', freq_filter=40, index='center', theta=0, - slicing_dir='y', plot=False, **kw ): + slicing_dir='y', plot=False, ax=None, + figsize=None, **kw ): """ Calculate a laser field by filtering out high frequencies. Can either return the envelope slice-wise or a full 2D envelope. @@ -484,6 +493,12 @@ def get_laser_envelope( self, t=None, iteration=None, pol=None, m='all', plot : bool, optional Whether to plot the requested quantity + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + **kw : dict, otional Additional options to be passed to matplotlib's `plot`(1D) or `imshow` (2D) method @@ -520,22 +535,22 @@ def get_laser_envelope( self, t=None, iteration=None, pol=None, m='all', # Plot the result if needed if plot: - check_matplotlib() + ax = check_matplotlib_and_axis(ax, figsize) iteration = self.iterations[ self._current_i ] time_fs = 1.e15 * self.t[ self._current_i ] if index != 'all': - plt.plot( 1.e6 * info.z, envelope, **kw) - plt.ylabel('$E_%s \;(V/m)$' % pol, - fontsize=self.plotter.fontsize) + ax.plot( 1.e6 * info.z, envelope, **kw) + ax.set_ylabel('$E_%s \;(V/m)$' % pol, + fontsize=self.plotter.fontsize) else: - plt.imshow( envelope, extent=1.e6 * info.imshow_extent, - aspect='auto', **kw) - plt.colorbar() - plt.ylabel('$%s \;(\mu m)$' % pol, + _ = ax.imshow( envelope, extent=1.e6 * info.imshow_extent, + aspect='auto', **kw) + ax.figure.colorbar(mappable=_, ax=ax) + ax.set_ylabel('$%s \;(\mu m)$' % pol, fontsize=self.plotter.fontsize) - plt.title("Laser envelope at %.1f fs (iteration %d)" + ax.set_title("Laser envelope at %.1f fs (iteration %d)" % (time_fs, iteration ), fontsize=self.plotter.fontsize) - plt.xlabel('$z \;(\mu m)$', fontsize=self.plotter.fontsize) + ax.set_xlabel('$z \;(\mu m)$', fontsize=self.plotter.fontsize) # Return the result return( envelope, info ) @@ -655,7 +670,7 @@ def get_main_frequency( self, t=None, iteration=None, pol=None, m='all', raise ValueError('Unknown method: {:s}'.format(method)) def get_spectrum( self, t=None, iteration=None, pol=None, - m='all', plot=False, **kw ): + m='all', plot=False, ax=None, figsize=None, **kw ): """ Return the spectrum of the laser (Absolute value of the Fourier transform of the fields.) @@ -682,6 +697,12 @@ def get_spectrum( self, t=None, iteration=None, pol=None, plot: bool, optional Whether to plot the data + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + **kw : dict, otional Additional options to be passed to matplotlib's `plot` method @@ -720,14 +741,14 @@ def get_spectrum( self, t=None, iteration=None, pol=None, # Plot the field if required if plot: - check_matplotlib() + ax = check_matplotlib_and_axis(ax, figsize) iteration = self.iterations[ self._current_i ] time_fs = 1.e15 * self.t[ self._current_i ] - plt.plot( spect_info.omega, spectrum, **kw ) - plt.xlabel('$\omega \; (rad.s^{-1})$', + ax.plot( spect_info.omega, spectrum, **kw ) + ax.set_xlabel('$\omega \; (rad.s^{-1})$', fontsize=self.plotter.fontsize ) - plt.ylabel('Spectrum', fontsize=self.plotter.fontsize ) - plt.title("Spectrum at %.1f fs (iteration %d)" + ax.set_ylabel('Spectrum', fontsize=self.plotter.fontsize ) + ax.set_title("Spectrum at %.1f fs (iteration %d)" % (time_fs, iteration ), fontsize=self.plotter.fontsize) return( spectrum, spect_info ) @@ -903,7 +924,8 @@ def get_laser_waist( self, t=None, iteration=None, pol=None, theta=0, raise ValueError('Unknown method: {:s}'.format(method)) def get_spectrogram( self, t=None, iteration=None, pol=None, theta=0, - slicing_dir='y', plot=False, **kw ): + slicing_dir='y', plot=False, ax=None, + figsize=None, **kw ): """ Calculates the spectrogram of a laserpulse, by the FROG method. @@ -932,6 +954,12 @@ def get_spectrogram( self, t=None, iteration=None, pol=None, theta=0, plot: bool, optional Whether to plot the spectrogram + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + **kw : dict, otional Additional options to be passed to matplotlib's `imshow` method @@ -984,16 +1012,16 @@ def get_spectrogram( self, t=None, iteration=None, pol=None, theta=0, # Plot the result if needed if plot: - check_matplotlib() + ax = check_matplotlib_and_axis(ax, figsize) iteration = self.iterations[ self._current_i ] time_fs = 1.e15 * self.t[ self._current_i ] - plt.imshow( spectrogram, extent=info.imshow_extent, aspect='auto', + ax.imshow( spectrogram, extent=info.imshow_extent, aspect='auto', **kw) - plt.title("Spectrogram at %.1f fs (iteration %d)" + ax.set_title("Spectrogram at %.1f fs (iteration %d)" % (time_fs, iteration ), fontsize=self.plotter.fontsize) - plt.xlabel('$t \;(s)$', fontsize=self.plotter.fontsize ) - plt.ylabel('$\omega \;(rad.s^{-1})$', - fontsize=self.plotter.fontsize ) + ax.set_xlabel('$t \;(s)$', fontsize=self.plotter.fontsize ) + ax.set_ylabel('$\omega \;(rad.s^{-1})$', + fontsize=self.plotter.fontsize ) return( spectrogram, info ) diff --git a/opmd_viewer/openpmd_timeseries/interactive.py b/opmd_viewer/openpmd_timeseries/interactive.py index e3a7e4c7..9d44401e 100644 --- a/opmd_viewer/openpmd_timeseries/interactive.py +++ b/opmd_viewer/openpmd_timeseries/interactive.py @@ -82,8 +82,9 @@ def refresh_field(change=None, force=False): do_refresh = True # Do the refresh if do_refresh: - plt.figure(fld_figure_button.value, figsize=figsize) - plt.clf() + fig_fld = plt.figure(fld_figure_button.value, figsize=figsize) + fig_fld.clf() + ax_fld = fig_fld.subplots(1, 1) # When working in inline mode, in an ipython notebook, # clear the output (prevents the images from stacking @@ -118,7 +119,7 @@ def refresh_field(change=None, force=False): m=convert_to_int(mode_button.value), slicing=slicing_button.value, theta=theta_button.value, slicing_dir=slicing_dir_button.value, - plot_range=plot_range, **kw_fld ) + plot_range=plot_range, ax=ax_fld, **kw_fld ) def refresh_ptcl(change=None, force=False): """ @@ -142,8 +143,10 @@ def refresh_ptcl(change=None, force=False): do_refresh = True # Do the refresh if do_refresh: - plt.figure(ptcl_figure_button.value, figsize=figsize) - plt.clf() + fig_ptcl = plt.figure(ptcl_figure_button.value, + figsize=figsize) + fig_ptcl.clf() + ax_ptcl = fig_ptcl.subplots(1, 1) # When working in inline mode, in an ipython notebook, # clear the output (prevents the images from stacking @@ -169,7 +172,8 @@ def refresh_ptcl(change=None, force=False): species=ptcl_species_button.value, plot=True, nbins=ptcl_bins_button.value, plot_range=plot_range, - use_field_mesh=ptcl_use_field_button.value, **kw_ptcl ) + use_field_mesh=ptcl_use_field_button.value, + ax=ax_ptcl, **kw_ptcl ) else: # 2D histogram self.get_particle( iteration=self.current_iteration, @@ -179,7 +183,8 @@ def refresh_ptcl(change=None, force=False): species=ptcl_species_button.value, plot=True, nbins=ptcl_bins_button.value, plot_range=plot_range, - use_field_mesh=ptcl_use_field_button.value, **kw_ptcl ) + use_field_mesh=ptcl_use_field_button.value, + ax=ax_ptcl, **kw_ptcl ) def refresh_field_type(change): """ diff --git a/opmd_viewer/openpmd_timeseries/main.py b/opmd_viewer/openpmd_timeseries/main.py index 695f991a..4f6eeafa 100644 --- a/opmd_viewer/openpmd_timeseries/main.py +++ b/opmd_viewer/openpmd_timeseries/main.py @@ -113,7 +113,8 @@ def __init__(self, path_to_dir, check_all_files=True): def get_particle(self, var_list=None, species=None, t=None, iteration=None, select=None, output=True, plot=False, nbins=150, plot_range=[[None, None], [None, None]], - use_field_mesh=True, histogram_deposition='cic', **kw): + use_field_mesh=True, histogram_deposition='cic', + ax=None, figsize=None, **kw): """ Extract a list of particle variables from an HDF5 file in the openPMD format. @@ -193,6 +194,12 @@ def get_particle(self, var_list=None, species=None, t=None, iteration=None, particles affects neighboring bins. `cic` (which is the default) leads to smoother results than `ngp`. + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + **kw : dict, otional Additional options to be passed to matplotlib's hist or hist2d. @@ -331,14 +338,16 @@ def get_particle(self, var_list=None, species=None, t=None, iteration=None, # Do the plotting self.plotter.hist1d(data_list[0], w, var_list[0], species, self._current_i, hist_bins[0], hist_range, - deposition=histogram_deposition, **kw) + deposition=histogram_deposition, ax=ax, + figsize=figsize, **kw) # - In the case of two quantities elif len(data_list) == 2: # Do the plotting self.plotter.hist2d(data_list[0], data_list[1], w, var_list[0], var_list[1], species, self._current_i, hist_bins, hist_range, - deposition=histogram_deposition, **kw) + deposition=histogram_deposition, ax=ax, + figsize=figsize, **kw) # Close the file file_handle.close() @@ -349,7 +358,8 @@ def get_particle(self, var_list=None, species=None, t=None, iteration=None, def get_field(self, field=None, coord=None, t=None, iteration=None, m='all', theta=0., slicing=0., slicing_dir='y', output=True, plot=False, - plot_range=[[None, None], [None, None]], **kw): + plot_range=[[None, None], [None, None]], ax=None, + figsize=None, **kw): """ Extract a given field from an HDF5 file in the openPMD format. @@ -408,6 +418,12 @@ def get_field(self, field=None, coord=None, t=None, iteration=None, along the 1st axis (first list) and 2nd axis (second list) Default: plots the full extent of the simulation box + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + **kw : dict, otional Additional options to be passed to matplotlib's imshow. @@ -500,11 +516,13 @@ def get_field(self, field=None, coord=None, t=None, iteration=None, if plot: if geometry == "1dcartesian": self.plotter.show_field_1d(F, info, field_label, - self._current_i, plot_range=plot_range, **kw) + self._current_i, plot_range=plot_range, ax=ax, + figsize=figsize, **kw) else: self.plotter.show_field_2d(F, info, slicing_dir, m, field_label, geometry, self._current_i, - plot_range=plot_range, **kw) + plot_range=plot_range, ax=ax, + figsize=figsize, **kw) # Return the result return(F, info) diff --git a/opmd_viewer/openpmd_timeseries/plotter.py b/opmd_viewer/openpmd_timeseries/plotter.py index 15af43dd..f16840c4 100644 --- a/opmd_viewer/openpmd_timeseries/plotter.py +++ b/opmd_viewer/openpmd_timeseries/plotter.py @@ -51,7 +51,8 @@ def __init__(self, t, iterations): self.iterations = iterations def hist1d(self, q1, w, quantity1, species, current_i, nbins, hist_range, - cmap='Blues', vmin=None, vmax=None, deposition='cic', **kw): + cmap='Blues', vmin=None, vmax=None, deposition='cic', ax=None, + figsize=None, **kw): """ Plot a 1D histogram of the particle quantity q1 Sets the proper labels @@ -87,11 +88,17 @@ def hist1d(self, q1, w, quantity1, species, current_i, nbins, hist_range, particles affects neighboring bins. `cic` (which is the default) leads to smoother results than `ngp`. + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + **kw : dict, otional - Additional options to be passed to matplotlib's bar function + Additional options to be passed to matplotlib's `bar` function """ - # Check if matplotlib is available - check_matplotlib() + # Check if matplotlib is available and if axis is defined + ax = check_matplotlib_and_axis(ax, figsize) # Find the iteration and time iteration = self.iterations[current_i] @@ -115,16 +122,16 @@ def hist1d(self, q1, w, quantity1, species, current_i, nbins, hist_range, # Do the plot bin_size = (hist_range[0][1] - hist_range[0][0]) / nbins bin_coords = hist_range[0][0] + bin_size * ( 0.5 + np.arange(nbins) ) - plt.bar( bin_coords, binned_data, width=bin_size, **kw ) - plt.xlim( hist_range[0] ) - plt.ylim( hist_range[1] ) - plt.xlabel(quantity1, fontsize=self.fontsize) - plt.title("%s: t = %.0f fs (iteration %d)" + ax.bar( bin_coords, binned_data, width=bin_size, **kw ) + ax.set_xlim( hist_range[0] ) + ax.set_ylim( hist_range[1] ) + ax.set_xlabel(quantity1, fontsize=self.fontsize) + ax.set_title("%s: t = %.0f fs (iteration %d)" % (species, time_fs, iteration), fontsize=self.fontsize) def hist2d(self, q1, q2, w, quantity1, quantity2, species, current_i, nbins, hist_range, cmap='Blues', vmin=None, vmax=None, - deposition='cic', **kw): + deposition='cic', ax=None, figsize=None, **kw): """ Plot a 2D histogram of the particle quantity q1 Sets the proper labels @@ -160,11 +167,17 @@ def hist2d(self, q1, q2, w, quantity1, quantity2, species, current_i, particles affects neighboring bins. `cic` (which is the default) leads to smoother results than `ngp`. + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + **kw : dict, otional - Additional options to be passed to matplotlib's imshow function + Additional options to be passed to matplotlib's `imshow` function """ - # Check if matplotlib is available - check_matplotlib() + # Check if matplotlib is available and if axis is defined + ax = check_matplotlib_and_axis(ax, figsize) # Find the iteration and time iteration = self.iterations[current_i] @@ -188,18 +201,18 @@ def hist2d(self, q1, q2, w, quantity1, quantity2, species, current_i, else: raise ValueError('Unknown deposition method: %s' % deposition) - # Do the plot - plt.imshow( binned_data.T, extent=hist_range[0] + hist_range[1], - origin='lower', interpolation='nearest', aspect='auto', + # Plot the data + _ = ax.imshow( binned_data.T, extent=hist_range[0] + hist_range[1], + origin='lower', aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax, **kw ) - plt.colorbar() - plt.xlabel(quantity1, fontsize=self.fontsize) - plt.ylabel(quantity2, fontsize=self.fontsize) - plt.title("%s: t = %.1f fs (iteration %d)" + ax.figure.colorbar(mappable=_, ax=ax) + ax.set_xlabel(quantity1, fontsize=self.fontsize) + ax.set_ylabel(quantity2, fontsize=self.fontsize) + ax.set_title("%s: t = %.1f fs (iteration %d)" % (species, time_fs, iteration), fontsize=self.fontsize) def show_field_1d( self, F, info, field_label, current_i, plot_range, - vmin=None, vmax=None, **kw ): + vmin=None, vmax=None, ax=None, figsize=None, **kw ): """ Plot the given field in 1D @@ -220,36 +233,45 @@ def show_field_1d( self, F, info, field_label, current_i, plot_range, plot_range : list of lists Indicates the values between which to clip the plot, along the 1st axis (first list) and 2nd axis (second list) + + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + + **kw : dict, otional + Additional options to be passed to matplotlib's `plot` function """ - # Check if matplotlib is available - check_matplotlib() + # Check if matplotlib is available and if axis is defined + ax = check_matplotlib_and_axis(ax, figsize) # Find the iteration and time iteration = self.iterations[current_i] time_fs = 1.e15 * self.t[current_i] # Get the title and labels - plt.title("%s at %.1f fs (iteration %d)" + ax.set_title("%s at %.1f fs (iteration %d)" % (field_label, time_fs, iteration), fontsize=self.fontsize) # Add the name of the axes - plt.xlabel('$%s \;(\mu m)$' % info.axes[0], fontsize=self.fontsize) + ax.set_xlabel('$%s \;(\mu m)$' % info.axes[0], fontsize=self.fontsize) # Get the x axis in microns xaxis = 1.e6 * getattr( info, info.axes[0] ) # Plot the data - plt.plot( xaxis, F ) + ax.plot( xaxis, F, **kw ) # Get the limits of the plot # - Along the first dimension if (plot_range[0][0] is not None) and (plot_range[0][1] is not None): - plt.xlim( plot_range[0][0], plot_range[0][1] ) + ax.set_xlim( plot_range[0][0], plot_range[0][1] ) else: - plt.xlim( xaxis.min(), xaxis.max() ) # Full extent of the box + ax.set_xlim( xaxis.min(), xaxis.max() ) # Full extent of the box # - Along the second dimension if (plot_range[1][0] is not None) and (plot_range[1][1] is not None): - plt.ylim( plot_range[1][0], plot_range[1][1] ) + ax.set_ylim( plot_range[1][0], plot_range[1][1] ) def show_field_2d(self, F, info, slicing_dir, m, field_label, geometry, - current_i, plot_range, **kw): + current_i, plot_range, ax=None, figsize=None, **kw): """ Plot the given field in 2D @@ -278,9 +300,18 @@ def show_field_2d(self, F, info, slicing_dir, m, field_label, geometry, plot_range : list of lists Indicates the values between which to clip the plot, along the 1st axis (first list) and 2nd axis (second list) + + ax : matplotlib axis, optional + Axis to be used for the plot + + figsize : tuple of two integers, optional + Size of the figure for the plot, same as defined in matplotlib + + **kw : dict, otional + Additional options to be passed to matplotlib's `imshow` function """ - # Check if matplotlib is available - check_matplotlib() + # Check if matplotlib is available and if axis is defined + ax = check_matplotlib_and_axis(ax, figsize) # Find the iteration and time iteration = self.iterations[current_i] @@ -290,37 +321,37 @@ def show_field_2d(self, F, info, slicing_dir, m, field_label, geometry, # Cylindrical geometry if geometry == "thetaMode": mode = str(m) - plt.title("%s in the mode %s at %.1f fs (iteration %d)" + ax.set_title("%s in the mode %s at %.1f fs (iteration %d)" % (field_label, mode, time_fs, iteration), fontsize=self.fontsize) # 2D Cartesian geometry elif geometry == "2dcartesian": - plt.title("%s at %.1f fs (iteration %d)" + ax.set_title("%s at %.1f fs (iteration %d)" % (field_label, time_fs, iteration), fontsize=self.fontsize) # 3D Cartesian geometry elif geometry == "3dcartesian": slice_plane = info.axes[0] + '-' + info.axes[1] - plt.title("%s sliced in %s at %.1f fs (iteration %d)" + ax.set_title("%s sliced in %s at %.1f fs (iteration %d)" % (field_label, slice_plane, time_fs, iteration), fontsize=self.fontsize) # Add the name of the axes - plt.xlabel('$%s \;(\mu m)$' % info.axes[1], fontsize=self.fontsize) - plt.ylabel('$%s \;(\mu m)$' % info.axes[0], fontsize=self.fontsize) + ax.set_xlabel('$%s \;(\mu m)$' % info.axes[1], fontsize=self.fontsize) + ax.set_ylabel('$%s \;(\mu m)$' % info.axes[0], fontsize=self.fontsize) # Plot the data - plt.imshow(F, extent=1.e6 * info.imshow_extent, origin='lower', - interpolation='nearest', aspect='auto', **kw) - plt.colorbar() + _ = ax.imshow(F, extent=1.e6 * info.imshow_extent, origin='lower', + aspect='auto', **kw) + ax.figure.colorbar(mappable=_, ax=ax) # Get the limits of the plot # - Along the first dimension if (plot_range[0][0] is not None) and (plot_range[0][1] is not None): - plt.xlim( plot_range[0][0], plot_range[0][1] ) + ax.set_xlim( plot_range[0][0], plot_range[0][1] ) # - Along the second dimension if (plot_range[1][0] is not None) and (plot_range[1][1] is not None): - plt.ylim( plot_range[1][0], plot_range[1][1] ) + ax.set_ylim( plot_range[1][0], plot_range[1][1] ) def print_cic_unavailable(): @@ -332,9 +363,10 @@ def print_cic_unavailable(): " - then reinstall openPMD-viewer") -def check_matplotlib(): +def check_matplotlib_and_axis(ax, figsize): """Raise error messages or warnings when potential issues when - potenial issues with matplotlib are detected.""" + potenial issues with matplotlib are detected. Check if axis is + defined and if it is not, generate a new axis.""" if not matplotlib_installed: raise RuntimeError( "Failed to import the openPMD-viewer plotter.\n" @@ -345,3 +377,12 @@ def check_matplotlib(): "backend. \n(This typically obtained when typing `%matplotlib`.)\n" "With recent version of Jupyter, the plots might not appear.\nIn this " "case, switch to `%matplotlib notebook` and restart the notebook.") + + # check the axis + if ax is None: + if figsize is None: + ax = plt.gca() + else: + fig, ax = plt.subplots(1, 1, figsize=figsize) + + return ax