diff --git a/__init__.py b/__init__.py
new file mode 100644
index 00000000..c560c12d
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,72 @@
+"""
+.. module:: ehtim
+ :platform: Unix
+ :synopsis: EHT Imaging Utilities
+
+.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu)
+
+"""
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import ehtim.observing
+from ehtim.const_def import *
+from ehtim.imaging.imager_utils import imager_func
+from ehtim.modeling.modeling_utils import modeler_func
+import ehtim.imaging
+from ehtim.features import rex
+import ehtim.features
+from ehtim.plotting.summary_plots import *
+from ehtim.plotting.comparisons import *
+from ehtim.plotting.comp_plots import *
+from ehtim.plotting import comparisons
+from ehtim.plotting import comp_plots
+import ehtim.plotting
+from ehtim.calibrating.network_cal import network_cal as netcal
+from ehtim.calibrating.self_cal import self_cal as selfcal
+from ehtim.calibrating.pol_cal import *
+from ehtim.calibrating.pol_cal_new import *
+from ehtim.calibrating import pol_cal
+from ehtim.calibrating import network_cal
+from ehtim.calibrating import self_cal
+import ehtim.calibrating
+import ehtim.parloop
+import ehtim.caltable
+import ehtim.vex
+import ehtim.imager
+import ehtim.obsdata
+import ehtim.array
+import ehtim.movie
+import ehtim.image
+import ehtim.model
+import ehtim.survey
+
+
+import warnings
+warnings.filterwarnings(
+ "ignore", message="numpy.dtype size changed, may indicate binary incompatibility.")
+
+# necessary to prevent hangs from astropy iers bug in astropy v 2.0.8
+#from astropy.utils import iers
+#iers.conf.auto_download = False
+
+try:
+ import pkg_resources
+ version = pkg_resources.get_distribution("ehtim").version
+ print("Welcome to eht-imaging! v", version,'\n')
+except:
+ print("Welcome to eht-imaging!\n")
+
+
+def logo():
+ for line in BHIMAGE:
+ print(line)
+
+
+def eht():
+ for line in EHTIMAGE:
+ print(line)
diff --git a/array.py b/array.py
new file mode 100644
index 00000000..8415f22d
--- /dev/null
+++ b/array.py
@@ -0,0 +1,388 @@
+# array.py
+# a interferometric telescope array class
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import copy
+from astropy.time import Time
+import matplotlib.pyplot as plt
+import matplotlib
+import ehtim.observing.obs_simulate as simobs
+import ehtim.observing.obs_helpers as obsh
+import ehtim.io.save
+import ehtim.io.load
+import ehtim.const_def as ehc
+from ehtim.caltable import plot_tarr_dterms
+
+
+###################################################################################################
+# Array object
+###################################################################################################
+
+
+class Array(object):
+
+ """A VLBI array of telescopes with site locations, SEFDs, and other data.
+
+ Attributes:
+ tarr (numpy.recarray): The array of telescope data with datatype DTARR
+ tkey (dict): A dictionary of rows in the tarr for each site name
+ ephem (dict): A dictionary of TLEs for each space antenna,
+ Space antennas have x=y=z=0 in the tarr
+ """
+
+ def __init__(self, tarr, ephem={}):
+ self.tarr = tarr
+ self.ephem = ephem
+
+ # check to see if ephemeris is correct
+ for line in self.tarr:
+ if np.any(np.isnan([line['x'], line['y'], line['z']])):
+ sitename = str(line['site'])
+ try:
+ elen = len(ephem[sitename])
+ except NameError:
+ raise Exception('no ephemeris for site %s !' % sitename)
+ if elen != 3:
+ raise Exception('wrong ephemeris format for site %s !' % sitename)
+
+ # Dictionary of array indices for site names
+ self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))}
+
+ @property
+ def tarr(self):
+ return self._tarr
+
+ @tarr.setter
+ def tarr(self, tarr):
+ self._tarr = tarr
+ self.tkey = {tarr[i]['site']: i for i in range(len(tarr))}
+
+ def copy(self):
+ """Copy the array object.
+
+ Args:
+
+ Returns:
+ (Array): a copy of the Array object.
+ """
+
+ newarr = copy.deepcopy(self)
+ return newarr
+
+
+ def listbls(self):
+ """List all baselines.
+
+ Args:
+ Returns:
+ numpy.array : array of baselines
+ """
+
+ bls = []
+ for i1 in sorted(self.tarr['site']):
+ for i2 in sorted(self.tarr['site']):
+ if not ([i1, i2] in bls) and not ([i2, i1] in bls) and i1 != i2:
+ bls.append([i1, i2])
+ bls = np.array(bls)
+
+ return bls
+
+ def obsdata(self, ra, dec, rf, bw, tint, tadv, tstart, tstop,
+ mjd=ehc.MJD_DEFAULT, timetype='UTC', polrep='stokes',
+ elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH,
+ no_elevcut_space=False,
+ tau=ehc.TAUDEF, fix_theta_GMST=False, reorder=True):
+ """Generate u,v points and baseline uncertainties.
+
+ Args:
+ ra (float): the source right ascension in fractional hours
+ dec (float): the source declination in fractional degrees
+ tint (float): the scan integration time in seconds
+ tadv (float): the uniform cadence between scans in seconds
+ tstart (float): the start time of the observation in hours
+ tstop (float): the end time of the observation in hours
+
+ mjd (int): the mjd of the observation
+ timetype (str): how to interpret tstart and tstop; either 'GMST' or 'UTC'
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ elevmin (float): station minimum elevation in degrees
+ elevmax (float): station maximum elevation in degrees
+ no_elevcut_space (bool): if True, do not apply elevation cut to orbiters
+ tau (float): the base opacity at all sites, or a dict giving one opacity per site
+ fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v points
+
+ Returns:
+ Obsdata: an observation object with no data
+
+ """
+
+ obsarr = simobs.make_uvpoints(self, ra, dec, rf, bw,
+ tint, tadv, tstart, tstop,
+ mjd=mjd, polrep=polrep, tau=tau,
+ elevmin=elevmin, elevmax=elevmax,
+ no_elevcut_space=no_elevcut_space,
+ timetype=timetype, fix_theta_GMST=fix_theta_GMST)
+
+ uniquetimes = np.sort(np.unique(obsarr['time']))
+ scans = np.array([[time - 0.5 * tadv, time + 0.5 * tadv] for time in uniquetimes])
+ source = str(ra) + ":" + str(dec)
+ obs = ehtim.obsdata.Obsdata(ra, dec, rf, bw, obsarr, self.tarr,
+ source=source, mjd=mjd, timetype=timetype, polrep=polrep,
+ ampcal=True, phasecal=True, opacitycal=True,
+ dcal=True, frcal=True,
+ scantable=scans, reorder=reorder)
+
+ return obs
+
+ def make_subarray(self, sites):
+ """Make a subarray from the Array object array that only includes the sites listed.
+
+ Args:
+ sites (list) : list of sites in the subarray
+ Returns:
+ Array: an Array object with specified sites and metadata
+ """
+ all_sites = [t[0] for t in self.tarr]
+ mask = np.array([t in sites for t in all_sites])
+ subarr = Array(self.tarr[mask], ephem=self.ephem)
+ return subarr
+
+ def save_txt(self, fname):
+ """Save the array data in a text file.
+
+ Args:
+ fname (str) : path to output array file
+ """
+ ehtim.io.save.save_array_txt(self, fname)
+ return
+
+ def plot_dterms(self, sites='all', label=None, legend=True, clist=ehc.SCOLORS,
+ rangex=False, rangey=False, markersize=2 * ehc.MARKERSIZE,
+ show=True, grid=True, export_pdf=""):
+ """Make a plot of the D-terms.
+
+ Args:
+ sites (list) : list of sites to plot
+ label (str) : title for plot
+ legend (bool) : add telescope legend or not
+ clist (list) : list of colors for different stations
+ rangex (list) : lower and upper x-axis limits
+ rangey (list) : lower and upper y-axis limits
+ markersize (float) : marker size
+ show (bool) : display the plot or not
+ grid (bool) : add a grid to the plot or not
+ export_pdf (str) : save a pdf file to this path
+
+ Returns:
+ matplotlib.axes
+ """
+ # sites
+ if sites in ['all' or 'All'] or sites == []:
+ sites = list(self.tkey.keys())
+
+ if not isinstance(sites, list):
+ sites = [sites]
+
+ keys = [self.tkey[site] for site in sites]
+
+ axes = plot_tarr_dterms(self.tarr, keys=keys, label=label, legend=legend, clist=clist,
+ rangex=rangex, rangey=rangey, markersize=markersize,
+ show=show, grid=grid, export_pdf=export_pdf)
+
+ return axes
+
+ def add_site(self, site, coords, sefd=10000,
+ fr_par=0, fr_elev=0, fr_off=0,
+ dr=0.+0.j, dl=0.+0.j):
+
+ """Add a ground station to the array
+
+ """
+ tarr_old = self.tarr.copy()
+ ephem_old = self.ephem.copy()
+
+
+ tarr_newline = np.array((str(site), float(coords[0]), float(coords[1]), float(coords[2]),
+ float(sefd), float(sefd),
+ dr, dl,
+ float(fr_par), float(fr_elev), float(fr_off)), dtype=ehc.DTARR)
+ tarr_new = np.append(tarr_old, tarr_newline)
+
+ arr_out = Array(tarr_new, ephem_old)
+ return arr_out
+
+ def remove_site(self, site):
+ """Remove a site from the array
+
+ """
+ tarr_old = self.tarr.copy()
+ ephem_old = self.ephem.copy()
+ ephem_new = ephem_old.copy()
+
+ try:
+ tarr_new = np.delete(tarr_old.copy(), self.tkey[site])
+ if site in ephem_old.keys():
+ ephem_new.pop(site)
+ except:
+ raise Exception("could not find site %s to delete from Array!"%site)
+
+ arr_out = Array(tarr_new, ephem_new)
+ return arr_out
+
+ def add_satellite_tle(self, tlelist, sefd=10000):
+
+ """Add an earth-orbiting satellite to the array from a TLE
+
+ Args:
+ tlearr (str) : 3 element list with [name, tle line 1, tle line 2] as strings
+ sefd (float) : assumed sefd for the array file (assumes sefdl = sefdr)
+ """
+ satname = tlearr[0]
+ tarr_new = self.tarr.copy()
+ ephem_new = self.ephem.copy()
+
+ tarr_newline = np.array((str(satname), 0., 0., 0.,
+ float(sefd), float(sefd),
+ 0., 0., 0., 0., 0.), dtype=ehc.DTARR)
+ tarr_new = np.append(tarr_new, tarr_newline)
+ ephem_new[satname] = tlearr
+ arr_out = Array(tarr_new, ephem_new)
+
+ return arr_out
+
+ def add_satellite_elements(self, satname,
+ perigee_mjd=Time.now().mjd,
+ period_days=1., eccentricity=0.,
+ inclination=0., arg_perigee=0., long_ascending=0.,
+ sefd=10000):
+ """Add an earth-orbiting satellite to the array from simple keplerian elements
+ perfect keplerian orbit is assumed, no derivatives
+
+ Args:
+ perigee time given in mjd
+ period given in days
+ inclination, arg_perigee, long_ascending given in degrees
+ """
+
+ tarr_new = self.tarr.copy()
+ ephem_new = self.ephem.copy()
+
+ tarr_newline = np.array((str(satname), 0., 0., 0.,
+ float(sefd), float(sefd),
+ 0., 0., 0., 0., 0.), dtype=ehc.DTARR)
+ tarr_new = np.append(tarr_new, tarr_newline)
+
+ ephem_new[satname] = [perigee_mjd, period_days, eccentricity, inclination, arg_perigee, long_ascending]
+ arr_out = Array(tarr_new, ephem_new)
+
+ return arr_out
+
+ def plot_satellite_orbits(self, tstart_mjd=Time.now().mjd, tstop_mjd=Time.now().mjd+1, npoints=1000):
+ earth_radius_polar = 6357. #km
+ earth_radius_eq = 6378.
+
+ fig = plt.figure(figsize=(18,6))
+ gs = matplotlib.gridspec.GridSpec(1,3,width_ratios=[1,1,1])
+
+ satellites = self.ephem.keys()
+ for i,satellite in enumerate(satellites):
+
+ if i==0: color='k'
+ else: color=ehc.SCOLORS[i-1]
+
+ # get skyfield satelllite object
+ if len(self.ephem[satellite])==3: # TLE
+ line1 = self.ephem[satellite][1]
+ line2 = self.ephem[satellite][2]
+ sat = obsh.sat_skyfield_from_tle(satellite, line1, line2)
+ elif len(self.ephem[satellite])==6: #keplerian elements
+ elements = self.ephem[satellite]
+ sat = obsh.sat_skyfield_from_elements(satellite, tstart_mjd,
+ elements[0],elements[1],elements[2],elements[3],elements[4],elements[5])
+ else:
+ raise Exception("ephemeris format not recognized for %s"%satellite)
+
+ # get GCRS positions
+ fracmjds = np.linspace(tstart_mjd, tstop_mjd, npoints)
+ positions = obsh.orbit_skyfield(sat, fracmjds, whichout='gcrs')
+ positions *= 1.e-3 # convert to km
+ distances = np.sqrt(positions[0]**2 + positions[1]**2 + positions[2]**2)
+ maxdist = np.max(distances)
+
+ ax1 = fig.add_subplot(gs[0])
+ ax1.set_aspect(1)
+ plt.plot(positions[0], positions[1], color=color, marker='.',ls='None')
+ circle1 = matplotlib.patches.Circle((0, 0), earth_radius_eq, color='b')
+ plt.gca().add_patch(circle1)
+ plt.xlabel('x (km)')
+ plt.ylabel('y (km)')
+ plt.xlim(-1.1*maxdist, 1.1*maxdist)
+ plt.ylim(-1.1*maxdist, 1.1*maxdist)
+ plt.grid()
+
+ ax2 = fig.add_subplot(gs[1])
+ ax2.set_aspect(1)
+ plt.plot(positions[1], positions[2], color=color, marker='.',ls='None')
+ circle1 = matplotlib.patches.Ellipse((0, 0), 2*earth_radius_eq, 2*earth_radius_polar, color='b')
+ plt.gca().add_patch(circle1)
+ plt.xlabel('y (km)')
+ plt.ylabel('z (km)')
+ plt.xlim(-1.1*maxdist, 1.1*maxdist)
+ plt.ylim(-1.1*maxdist, 1.1*maxdist)
+ plt.grid()
+
+ ax3 = fig.add_subplot(gs[2])
+ ax3.set_aspect(1)
+ plt.plot(positions[0], positions[2], color=color, marker='.',ls='None', label=satellite)
+ circle1 = matplotlib.patches.Ellipse((0, 0), 2*earth_radius_eq, 2*earth_radius_polar, color='b')
+ plt.gca().add_patch(circle1)
+ plt.xlabel('x (km)')
+ plt.ylabel('z (km)')
+ plt.xlim(-1.1*maxdist, 1.1*maxdist)
+ plt.ylim(-1.1*maxdist, 1.1*maxdist)
+ plt.legend(frameon=False,loc='center left', bbox_to_anchor=(1, 0.5))
+ plt.grid()
+
+ plt.subplots_adjust(wspace=1)
+ ehc.show_noblock()
+ return
+
+##########################################################################
+# Array creation functions
+##########################################################################
+
+
+def load_txt(fname, ephemdir='ephemeris'):
+ """Read an array from a text file.
+ Sites with x=y=z=0 are spacecraft, TLE ephemerides read from ephemdir.
+
+ Args:
+ fname (str) : path to input array file
+ ephemdir (str) : path to directory with TLE ephemerides for spacecraft
+ Returns:
+ Array: an Array object loaded from file
+ """
+
+ return ehtim.io.load.load_array_txt(fname, ephemdir=ephemdir)
diff --git a/calibrating/__init__.py b/calibrating/__init__.py
new file mode 100644
index 00000000..ede0bd27
--- /dev/null
+++ b/calibrating/__init__.py
@@ -0,0 +1,15 @@
+"""
+.. module:: ehtim.calibrating
+ :platform: Unix
+ :synopsis: EHT Imaging Utilities: calibration functions
+
+.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu)
+
+"""
+from . import self_cal
+from . import network_cal
+from . import pol_cal
+from . import pol_cal_new
+from . import polgains_cal
+
+from ..const_def import *
diff --git a/calibrating/cal_helpers.py b/calibrating/cal_helpers.py
new file mode 100644
index 00000000..c3976400
--- /dev/null
+++ b/calibrating/cal_helpers.py
@@ -0,0 +1,74 @@
+# cal_helpers.py
+# helper functions for calibration
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import itertools as it
+
+import ehtim.const_def as ehc
+
+ZBLCUTOFF = 1.e7
+
+
+def make_cluster_data(obs, zbl_uvdist_max=ZBLCUTOFF):
+ """Cluster sites in an observation into groups
+ with intra-group basline length not exceeding zbl_uvdist_max
+ """
+
+ clusters = []
+ clustered_sites = []
+ for i1 in range(len(obs.tarr)):
+ t1 = obs.tarr[i1]
+
+ if t1['site'] in clustered_sites:
+ continue
+
+ csites = [t1['site']]
+ clustered_sites.append(t1['site'])
+ for i2 in range(len(obs.tarr))[i1:]:
+ t2 = obs.tarr[i2]
+ if t2['site'] in clustered_sites:
+ continue
+
+ site1coord = np.array([t1['x'], t1['y'], t1['z']])
+ site2coord = np.array([t2['x'], t2['y'], t2['z']])
+ uvdist = np.sqrt(np.sum((site1coord - site2coord)**2)) / (ehc.C / obs.rf)
+
+ if uvdist < zbl_uvdist_max:
+ csites.append(t2['site'])
+ clustered_sites.append(t2['site'])
+ clusters.append(csites)
+
+ clusterdict = {}
+ for site in obs.tarr['site']:
+ for k in range(len(clusters)):
+ if site in clusters[k]:
+ clusterdict[site] = k
+
+ clusterbls = [set(comb) for comb in it.combinations(range(len(clusterdict)), 2)]
+
+ cluster_data = (clusters, clusterdict, clusterbls)
+
+ return cluster_data
diff --git a/calibrating/network_cal.py b/calibrating/network_cal.py
new file mode 100644
index 00000000..8fa1cedc
--- /dev/null
+++ b/calibrating/network_cal.py
@@ -0,0 +1,463 @@
+# network_cal.py
+# functions for network-calibration
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import scipy.optimize as opt
+import time
+import copy
+from multiprocessing import cpu_count, Pool
+
+import ehtim.obsdata
+import ehtim.parloop as parloop
+from . import cal_helpers as calh
+import ehtim.observing.obs_helpers as obsh
+import ehtim.const_def as ehc
+
+ZBLCUTOFF = 1.e7
+MAXIT = 5000
+
+###################################################################################################
+# Network-Calibration
+###################################################################################################
+
+
+def network_cal(obs, zbl, sites=[], zbl_uvdist_max=ZBLCUTOFF, method="amp", minimizer_method='BFGS',
+ pol='I', pad_amp=0., gain_tol=.2, solution_interval=0.0, scan_solutions=False,
+ caltable=False, processes=-1, show_solution=False, debias=True, msgtype='bar'):
+ """Network-calibrate a dataset with zero baseline constraints.
+
+ Args:
+ obs (Obsdata): The observation to be calibrated
+ zbl (float or function): constant zero baseline flux in Jy, or a function of UT hour.
+ sites (list): list of sites to include in the network calibration.
+ empty list calibrates all sites
+ zbl_uvdist_max (float): maximum uv-distance considered a zero baseline
+ method (str): chooses what to calibrate, 'amp', 'phase', or 'both'.
+ minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS')
+ pol (str): which visibility to compute gains for
+
+ pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature
+ gain_tol (float): gains that exceed this value will be disfavored by the prior
+ solution_interval (float): solution interval in seconds;
+ one gain is derived for each interval.
+ If 0.0, a solution is determined for each unique time
+ scan_solutions (bool): If True, determine one gain per site per scan.
+ Supersedes solution_interval
+
+ debias (bool): If True, debias the amplitudes
+ caltable (bool): if True, returns a Caltable instead of an Obsdata
+ processes (int): number of cores to use in multiprocessing
+ show_solution (bool): if True, display the solution as it is calculated
+ msgtype (str): type of progress message to be printed, default is 'bar'
+
+ Returns:
+ (Obsdata): the calibrated observation, if caltable==False
+ (Caltable): the derived calibration table, if caltable==True
+ """
+
+ # Here, RRLL means to use both RR and LL (both as proxies for Stokes I)
+ # to derive a network calibration solution
+ if pol not in ['I', 'Q', 'U', 'V', 'RR', 'LL', 'RRLL']:
+ raise Exception("Can only network-calibrate to I, Q, U, V, RR, LL, or RRLL!")
+ if pol in ['I', 'Q', 'U', 'V']:
+ if obs.polrep != 'stokes':
+ raise Exception("netcal pol is a stokes parameter, but obs.polrep!='stokes'")
+ # obs = obs.switch_polrep('stokes',pol)
+ elif pol in ['RR', 'LL', 'RRLL']:
+ if obs.polrep != 'circ':
+ raise Exception("netcal pol is RR or LL or RRLL, but obs.polrep!='circ'")
+ # obs = obs.switch_polrep('circ',pol)
+
+ # V = model visibility, V' = measured visibility, G_i = site gain
+ # G_i * conj(G_j) * V_ij = V'_ij
+ if len(sites) == 0:
+ print("No stations specified in network cal: defaulting to calibrating all stations!")
+ sites = obs.tarr['site']
+
+ # find colocated sites and put into list allclusters
+ cluster_data = calh.make_cluster_data(obs, zbl_uvdist_max)
+
+ # get scans
+ scans = obs.tlist(t_gather=solution_interval, scan_gather=scan_solutions)
+ scans_cal = copy.copy(scans)
+
+ # Make the pool for parallel processing
+ if processes > 0:
+ counter = parloop.Counter(initval=0, maxval=len(scans))
+ if processes > len(scans):
+ processes = len(scans)
+ print("Using Multiprocessing with %d Processes" % processes)
+ pool = Pool(processes=processes, initializer=init, initargs=(counter,))
+ elif processes == 0:
+ counter = parloop.Counter(initval=0, maxval=len(scans))
+ processes = int(cpu_count())
+ if processes > len(scans):
+ processes = len(scans)
+ print("Using Multiprocessing with %d Processes" % processes)
+ pool = Pool(processes=processes, initializer=init, initargs=(counter,))
+ else:
+ print("Not Using Multiprocessing")
+
+ # loop over scans and calibrate
+ tstart = time.time()
+ if processes > 0: # with multiprocessing
+ scans_cal = np.array(pool.map(get_network_scan_cal, [[i, len(scans), scans[i],
+ zbl, sites, cluster_data, obs.polrep, pol,
+ method, pad_amp, gain_tol,
+ caltable, show_solution, debias, msgtype
+ ] for i in range(len(scans))]),
+ dtype=object)
+ else: # without multiprocessing
+ for i in range(len(scans)):
+ obsh.prog_msg(i, len(scans), msgtype=msgtype, nscan_last=i - 1)
+ scans_cal[i] = network_cal_scan(scans[i], zbl, sites, cluster_data,
+ polrep=obs.polrep, pol=pol,
+ method=method, minimizer_method=minimizer_method,
+ show_solution=show_solution, caltable=caltable,
+ pad_amp=pad_amp, gain_tol=gain_tol, debias=debias)
+
+ tstop = time.time()
+ print("\nnetwork_cal time: %f s" % (tstop - tstart))
+
+ if caltable: # create and return a caltable
+ allsites = obs.tarr['site']
+ caldict = {k: v.reshape(1) for k, v in scans_cal[0].items()}
+ for i in range(1, len(scans_cal)):
+ row = scans_cal[i]
+ if len(row) == 0:
+ continue
+
+ for site in allsites:
+ try:
+ dat = row[site]
+ except KeyError:
+ continue
+
+ try:
+ caldict[site] = np.append(caldict[site], row[site])
+ except KeyError:
+ caldict[site] = [dat]
+
+ caltable = ehtim.caltable.Caltable(obs.ra, obs.dec, obs.rf, obs.bw, caldict, obs.tarr,
+ source=obs.source, mjd=obs.mjd, timetype=obs.timetype)
+ out = caltable
+
+ else: # return the calibrated observation
+ arglist, argdict = obs.obsdata_args()
+ arglist[4] = np.concatenate(scans_cal)
+ out = ehtim.obsdata.Obsdata(*arglist, **argdict)
+
+ # close multiprocessing jobs
+ if processes != -1:
+ pool.close()
+
+ return out
+
+
+def network_cal_scan(scan, zbl, sites, clustered_sites, polrep='stokes', pol='I',
+ zbl_uvidst_max=ZBLCUTOFF, method="both", minimizer_method='BFGS',
+ show_solution=False, pad_amp=0., gain_tol=.2, caltable=False, debias=True):
+ """Network-calibrate a scan with zero baseline constraints.
+
+ Args:
+ obs (Obsdata): The observation to be calibrated
+ zbl (float or function): constant zero baseline flux in Jy, or a function of UT hour.
+ sites (list): list of sites to include in the network calibration.
+ empty list calibrates all sites
+ clustered_sites (tuple): information on clustered sites, returned by make_cluster_data
+
+ polrep (str): 'stokes' or 'circ' to specify the polarization products in scan
+ pol (str): which image polarization to self-calibrate visibilities to
+ zbl_uvdist_max (float): maximum uv-distance considered a zero baseline
+ method (str): chooses what to calibrate, 'amp', 'phase', or 'both'
+ pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature
+ gain_tol (float): gains that exceed this value will be disfavored by the prior
+
+ debias (bool): If True, debias the amplitudes
+ caltable (bool): if True, returns a Caltable instead of an Obsdata
+ show_solution (bool): if True, display the solution as it is calculated
+
+
+ Returns:
+ (Obsdata): the calibrated scan, if caltable==False
+ (Caltable): the derived calibration table, if caltable==True
+ """
+
+ # determine the zero-baseline flux of the scan
+ if callable(zbl):
+ zbl_scan = np.median(zbl(scan['time']))
+ else:
+ zbl_scan = zbl
+
+ # clustered site information
+ allclusters = clustered_sites[0]
+ clusterdict = clustered_sites[1]
+ clusterbls = clustered_sites[2]
+
+ # all the sites in the scan
+ allsites = list(set(np.hstack((scan['t1'], scan['t2']))))
+
+ if len(sites) == 0:
+ print("No stations specified in network cal: defaulting to calibrating all !")
+ sites = allsites
+
+ # only include sites that are present
+ sites = [s for s in sites if s in allsites]
+
+ # create a dictionary to keep track of gains;
+ # sites that aren't network calibrated (no co-located partners) get a value of -1
+ # so that they won't be network calibrated; other sites get a unique number
+ tkey = {b: a for a, b in enumerate(sites)}
+ for cluster in allclusters:
+ if len(cluster) == 1:
+ tkey[cluster[0]] = -1
+
+ clusterkey = clusterdict
+
+ # restrict solved cluster visibilities to ones present in the scan
+ # (this is much faster than allowing many unconstrained variables)
+ clusterbls_scan = [set([clusterkey[row['t1']], clusterkey[row['t2']]])
+ for row in scan
+ if len(set([clusterkey[row['t1']], clusterkey[row['t2']]])) == 2]
+
+ # now delete duplicates
+ clusterbls = [cluster for cluster in clusterbls if cluster in clusterbls_scan]
+
+ # make two lists of gain keys that relates scan bl gains to solved site ones
+ # (-1 means that this station does not have a gain that is being solved for)
+ # and make one list of scan keys that relates scan bl visibilities to solved cluster ones
+ # (-1 means it's a zero baseline!)
+
+ g1_keys = []
+ g2_keys = []
+ scan_keys = []
+ for row in scan:
+ try:
+ g1_keys.append(tkey[row['t1']])
+ except KeyError:
+ g1_keys.append(-1)
+ try:
+ g2_keys.append(tkey[row['t2']])
+ except KeyError:
+ g2_keys.append(-1)
+
+ clusternum1 = clusterkey[row['t1']]
+ clusternum2 = clusterkey[row['t2']]
+
+ if clusternum1 == clusternum2: # sites are in the same cluster
+ scan_keys.append(-1)
+ else: # sites are not in the same cluster
+ bl_index = clusterbls.index(set((clusternum1, clusternum2)))
+ scan_keys.append(bl_index)
+
+ # no sites to calibrate on this scan!
+ # if np.all(g1_keys == -1):
+ # return scan #Doesn't work with the caldict options
+
+ # Start by restricting to visibilities that include baselines to a site with a zero-baseline
+ vis_mask = [((row['t1'] in tkey.keys() and tkey[row['t1']] != -1)
+ or (row['t2'] in tkey.keys() and tkey[row['t2']] != -1)) for row in scan]
+
+ # get scan visibilities of the specified polarization
+ if pol != 'RRLL':
+ vis = scan[ehc.vis_poldict[pol]]
+ sigma = scan[ehc.sig_poldict[pol]]
+ else:
+ vis = np.concatenate([scan[ehc.vis_poldict['RR']], scan[ehc.vis_poldict['LL']]])
+ sigma = np.concatenate([scan[ehc.sig_poldict['RR']], scan[ehc.sig_poldict['LL']]])
+ vis_mask = np.concatenate([vis_mask, vis_mask])
+
+ if method == 'amp':
+ if debias:
+ vis = obsh.amp_debias(np.abs(vis), np.abs(sigma))
+ else:
+ vis = np.abs(vis)
+
+ sigma_inv = 1.0 / np.sqrt(sigma**2 + (pad_amp * np.abs(vis))**2)
+
+ # initial guesses for parameters
+ n_gains = len(sites)
+ n_clusterbls = len(clusterbls)
+ if show_solution:
+ print('%d Gains; %d Clusters' % (n_gains, n_clusterbls))
+
+ gpar_guess = np.ones(n_gains, dtype=np.complex128).view(dtype=np.float64)
+ vpar_guess = np.ones(n_clusterbls, dtype=np.complex128)
+ for i in range(len(scan_keys)):
+ if scan_keys[i] < 0:
+ continue
+ if np.isnan(vis[i]):
+ continue
+ vpar_guess[scan_keys[i]] = vis[i]
+
+ vpar_guess = vpar_guess.view(dtype=np.float64)
+ gvpar_guess = np.hstack((gpar_guess, vpar_guess))
+
+ # error function
+ def errfunc(gvpar):
+
+ # all the forward site gains (complex)
+ g = gvpar[0:2 * n_gains].astype(np.float64).view(dtype=np.complex128)
+
+ # all the intercluster visibilities (complex)
+ v = gvpar[2 * n_gains:].astype(np.float64).view(dtype=np.complex128)
+
+ # choose to only scale ampliltudes or phases
+ if method == "phase":
+ g = g / np.abs(g)
+ elif method == "amp":
+ g = np.abs(np.real(g))
+
+ # append the default values to g for missing points
+ # and to v for the zero baseline points
+ g = np.append(g, 1.)
+ v = np.append(v, zbl_scan)
+
+ # scan visibilities are either an intercluster visibility or the fixed zbl
+ v_scan = v[scan_keys]
+ g1 = g[g1_keys]
+ g2 = g[g2_keys]
+ if pol == 'RRLL':
+ v_scan = np.concatenate([v_scan, v_scan])
+ g1 = np.concatenate([g1, g1])
+ g2 = np.concatenate([g2, g2])
+
+ if method == 'amp':
+ verr = np.abs(vis) - g1 * g2.conj() * np.abs(v_scan)
+ else:
+ verr = vis - g1 * g2.conj() * v_scan
+
+ chi = np.abs(verr) * sigma_inv
+ chisq = np.sum((chi * chi)[np.isfinite(chi) * vis_mask])
+
+ # prior on the gains
+ g_fracerr = gain_tol
+ if method == "phase":
+ chisq_g = 0 # because |g| == 1 so log(|g|) = 0
+ elif method == "amp":
+ logg = np.log(g)
+ chisq_g = np.sum(logg * logg) / (g_fracerr * g_fracerr)
+ else:
+ logabsg = np.log(np.abs(g))
+ chisq_g = np.sum(logabsg * logabsg) / (g_fracerr * g_fracerr)
+
+ absv = np.abs(v)
+ vv = absv * absv
+ chisq_v = np.sum(vv * vv) / zbl_scan**4
+ return chisq + chisq_g + chisq_v
+
+ if np.max(g1_keys) > -1 or np.max(g2_keys) > -1:
+ # run the minimizer to get a solution (but only run if there's at least one gain to fit)
+ optdict = {'maxiter': MAXIT} # minimizer params
+ res = opt.minimize(errfunc, gvpar_guess, method=minimizer_method, options=optdict)
+
+ # get solution
+ g_fit = res.x[0:2 * n_gains].view(np.complex128)
+ v_fit = res.x[2 * n_gains:].view(np.complex128)
+
+ if method == "phase":
+ g_fit = g_fit / np.abs(g_fit)
+ if method == "amp":
+ g_fit = np.abs(np.real(g_fit))
+
+ if show_solution:
+ print(np.abs(g_fit))
+ print(np.abs(v_fit))
+ else:
+ g_fit = []
+ v_fit = []
+
+ g_fit = np.append(g_fit, 1.)
+ v_fit = np.append(v_fit, zbl_scan)
+
+ # Derive a calibration table or apply the solution to the scan
+ if caltable:
+ allsites = list(set(scan['t1']).union(set(scan['t2'])))
+
+ caldict = {}
+ for site in allsites:
+ if site in sites:
+ site_key = tkey[site]
+ else:
+ site_key = -1
+
+ # We will *always* set the R and L gain corrections to be equal in network calibration,
+ # to avoid breaking polarization consistency relationships
+ rscale = g_fit[site_key]**-1
+ lscale = g_fit[site_key]**-1
+
+ # Note: we may want to give two entries for the start/stop times
+ # when a non-zero solution interval is used
+ caldict[site] = np.array((scan['time'][0], rscale, lscale), dtype=ehc.DTCAL)
+
+ out = caldict
+
+ else:
+ g1_fit = g_fit[g1_keys]
+ g2_fit = g_fit[g2_keys]
+
+ gij_inv = (g1_fit * g2_fit.conj())**(-1)
+
+ if polrep == 'stokes':
+ # scale visibilities
+ for vistype in ['vis', 'qvis', 'uvis', 'vvis']:
+ scan[vistype] *= gij_inv
+ # scale sigmas
+ for sigtype in ['sigma', 'qsigma', 'usigma', 'vsigma']:
+ scan[sigtype] *= np.abs(gij_inv)
+ elif polrep == 'circ':
+ # scale visibilities
+ for vistype in ['rrvis', 'llvis', 'rlvis', 'lrvis']:
+ scan[vistype] *= gij_inv
+ # scale sigmas
+ for sigtype in ['rrsigma', 'llsigma', 'rlsigma', 'lrsigma']:
+ scan[sigtype] *= np.abs(gij_inv)
+
+ out = scan
+
+ return out
+
+
+def init(x):
+ global counter
+ counter = x
+
+
+def get_network_scan_cal(args):
+ return get_network_scan_cal2(*args)
+
+
+def get_network_scan_cal2(i, n, scan, zbl, sites, cluster_data, polrep, pol,
+ method, pad_amp, gain_tol, caltable, show_solution, debias, msgtype):
+ if n > 1:
+ global counter
+ counter.increment()
+ obsh.prog_msg(counter.value(), counter.maxval, msgtype, counter.value() - 1)
+
+ return network_cal_scan(scan, zbl, sites, cluster_data, polrep=polrep, pol=pol,
+ zbl_uvidst_max=ZBLCUTOFF, method=method, caltable=caltable,
+ show_solution=show_solution,
+ pad_amp=pad_amp, gain_tol=gain_tol, debias=debias)
diff --git a/calibrating/pol_cal.py b/calibrating/pol_cal.py
new file mode 100644
index 00000000..d317ba35
--- /dev/null
+++ b/calibrating/pol_cal.py
@@ -0,0 +1,400 @@
+# pol_cal.py
+# functions for polarimetric-calibration
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import scipy.optimize as opt
+import matplotlib.pyplot as plt
+import time
+
+import ehtim.imaging.imager_utils as iu
+import ehtim.observing.obs_simulate as simobs
+import ehtim.const_def as ehc
+
+MAXIT = 1000 # maximum number of iterations in pol-cal minimizer
+
+###################################################################################################
+# Polarimetric Calibration
+###################################################################################################
+
+
+def leakage_cal(obs, im=None, sites=[], leakage_tol=.1, pol_fit=['RL', 'LR'], dtype='vis',
+ const_fpol=False, inverse=False, minimizer_method='L-BFGS-B',
+ ttype='direct', fft_pad_factor=2, show_solution=True, obs_apply=False):
+ """Polarimetric calibration (detects and removes polarimetric leakage,
+ based on consistency with a given image)
+
+ Args:
+ obs (Obsdata): The observation to be calibrated
+ im (Image): the reference image used for calibration
+ (not needed if using const_fpol = True)
+ sites (list): list of sites to include in the polarimetric calibration.
+ empty list calibrates all sites
+
+ leakage_tol (float): leakage values exceeding this value will be disfavored by the prior
+ pol_fit (list): list of visibilities to use; e.g., ['RL','LR'] or ['RR','LL','RL','LR']
+ dtype (str): Type of data to fit ('vis' for complex visibilities;
+ 'amp' for just the amplitudes)
+ const_fpol (bool): If true, solve for a single fractional polarization
+ across all baselines in addition to leakage.
+ For this option, the passed image is not used.
+ minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS')
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+
+ show_solution (bool): if True, display the solution as it is calculated
+ obs_apply (Obsdata): apply the solution to another observation
+ Returns:
+ (Obsdata): the calibrated observation, with computed leakage values added to the obs.tarr
+ """
+ tstart = time.time()
+
+ mask = []
+
+ # Do everything in a circular basis
+ if not const_fpol:
+ im_circ = im.switch_polrep('circ')
+ else:
+ im_circ = None
+
+ if dtype not in ['vis', 'amp']:
+ raise Exception('dtype must be vis or amp')
+
+ # Create the obsdata object for searching
+ obs_test = obs.copy()
+ obs_test = obs_test.switch_polrep('circ')
+
+ # Check to see if the field rotation is corrected
+ if obs_test.frcal is False:
+ print("Field rotation angles have not been corrected. Correcting now...")
+ obs_test.data = simobs.apply_jones_inverse(obs_test, frcal=False, dcal=True, verbose=False)
+ obs_test.frcal = True
+
+ # List of all sites present in the observation
+ allsites = list(set(np.hstack((obs.data['t1'], obs.data['t2']))))
+
+ if len(sites) == 0:
+ print("No stations specified for leakage calibration: defaulting to calibrating all !")
+ sites = allsites
+
+ # Set all leakage terms in obs_test to zero
+ # (we will only correct leakage for those sites with new solutions)
+ for j in range(len(obs_test.tarr)):
+ if obs_test.tarr[j]['site'] in sites:
+ continue
+ obs_test.tarr[j]['dr'] = obs_test.tarr[j]['dl'] = 0.0j
+
+ # only include sites that are present
+ sites = [s for s in sites if s in allsites]
+ site_index = [list(obs.tarr['site']).index(s) for s in sites]
+
+ if not const_fpol:
+ (dataRR, sigmaRR, ARR) = iu.chisqdata(obs, im_circ, mask=mask, dtype=dtype, pol='RR',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+ (dataLL, sigmaLL, ALL) = iu.chisqdata(obs, im_circ, mask=mask, dtype=dtype, pol='LL',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+ (dataRL, sigmaRL, ARL) = iu.chisqdata(obs, im_circ, mask=mask, dtype=dtype, pol='RL',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+ (dataLR, sigmaLR, ALR) = iu.chisqdata(obs, im_circ, mask=mask, dtype=dtype, pol='LR',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+
+ # If inverse modeling, pre-compute field rotation angles
+ if inverse:
+ el1 = obs.unpack(['el1'], ang_unit='rad')['el1']
+ el2 = obs.unpack(['el2'], ang_unit='rad')['el2']
+ par1 = obs.unpack(['par_ang1'], ang_unit='rad')['par_ang1']
+ par2 = obs.unpack(['par_ang2'], ang_unit='rad')['par_ang2']
+
+ fr_elev1 = np.array([obs.tarr[obs.tkey[o['t1']]]['fr_elev'] for o in obs.data])
+ fr_elev2 = np.array([obs.tarr[obs.tkey[o['t2']]]['fr_elev'] for o in obs.data])
+ fr_par1 = np.array([obs.tarr[obs.tkey[o['t1']]]['fr_par'] for o in obs.data])
+ fr_par2 = np.array([obs.tarr[obs.tkey[o['t2']]]['fr_par'] for o in obs.data])
+ fr_off1 = np.array([obs.tarr[obs.tkey[o['t1']]]['fr_off'] for o in obs.data])
+ fr_off2 = np.array([obs.tarr[obs.tkey[o['t2']]]['fr_off'] for o in obs.data])
+
+ fr1 = fr_elev1*el1 + fr_par1*par1 + fr_off1*np.pi/180.
+ fr2 = fr_elev2*el2 + fr_par2*par2 + fr_off2*np.pi/180.
+
+ def chisq_total(data, im, D, inverse=False):
+ if const_fpol:
+ if inverse:
+ # Note: These are linearized approximations (in leakage and source polarization)
+ fpol_model = D[-2]
+ cpol_model = np.real(D[-1])
+
+ D1R = [D[2*sites.index(o['t1'])] for o in data]
+ D1L = [D[2*sites.index(o['t1'])+1] for o in data]
+ D2R = [D[2*sites.index(o['t2'])] for o in data]
+ D2L = [D[2*sites.index(o['t2'])+1] for o in data]
+
+ lrll = data['lrvis']/data['llvis']
+ lrrr = data['lrvis']/data['rrvis']
+ rlll = data['rlvis']/data['llvis']
+ rlrr = data['rlvis']/data['rrvis']
+
+ lrll_sigma = data['lrsigma']/np.abs(data['llvis'])
+ lrrr_sigma = data['lrsigma']/np.abs(data['rrvis'])
+ rlll_sigma = data['rlsigma']/np.abs(data['llvis'])
+ rlrr_sigma = data['rlsigma']/np.abs(data['rrvis'])
+
+ lrll_model = (np.conjugate(fpol_model) * (1.0 + cpol_model) +
+ D1L * np.exp(-2j*fr1) * (1.0 + 2.0*cpol_model) +
+ np.conjugate(D2R) * np.exp(-2j*fr2))
+ lrrr_model = (np.conjugate(fpol_model) * (1.0 - cpol_model) +
+ D1L * np.exp(-2j*fr1) +
+ np.conjugate(D2R) * np.exp(-2j*fr2) * (1.0 - 2.0*cpol_model))
+ rlll_model = (fpol_model * (1.0 + cpol_model) +
+ D1R * np.exp(2j*fr1) +
+ np.conjugate(D2L) * np.exp(2j*fr2) * (1.0 + 2.0*cpol_model))
+ rlrr_model = (fpol_model * (1.0 - cpol_model) +
+ D1R * np.exp(2j*fr1) * (1.0 - 2.0*cpol_model) +
+ np.conjugate(D2L) * np.exp(2j*fr2))
+
+ chisq = np.concatenate([np.abs((lrll - lrll_model)/lrll_sigma)**2,
+ np.abs((lrrr - lrrr_model)/lrrr_sigma)**2,
+ np.abs((rlll - rlll_model)/rlll_sigma)**2,
+ np.abs((rlrr - rlrr_model)/rlrr_sigma)**2])
+ chisq = chisq[~np.isnan(chisq)]
+ return np.mean(chisq)
+ else:
+ fpol_model = D[-2]
+ cpol_model = np.real(D[-1])
+
+ fpol_data_1 = 2.0 * data['rlvis']/(data['rrvis'] + data['llvis'])
+ fpol_data_2 = 2.0 * np.conj(data['lrvis']/(data['rrvis'] + data['llvis']))
+ fpol_sigma_1 = 2.0/np.abs(data['rrvis'] + data['llvis']) * data['rlsigma']
+ fpol_sigma_2 = 2.0/np.abs(data['rrvis'] + data['llvis']) * data['lrsigma']
+ return 0.5*np.mean(np.abs((fpol_model - fpol_data_1)/fpol_sigma_1)**2
+ + np.abs((fpol_model - fpol_data_2)/fpol_sigma_2)**2)
+ else:
+ chisq_RR = chisq_LL = chisq_RL = chisq_LR = 0.0
+ if 'RR' in pol_fit:
+ chisq_RR = iu.chisq(im.rrvec, ARR,
+ obs_test.unpack_dat(data, ['rr' + dtype])['rr' + dtype],
+ data['rrsigma'], dtype=dtype, ttype=ttype, mask=mask)
+ if 'LL' in pol_fit:
+ chisq_LL = iu.chisq(im.llvec, ALL,
+ obs_test.unpack_dat(data, ['ll' + dtype])['ll' + dtype],
+ data['llsigma'], dtype=dtype, ttype=ttype, mask=mask)
+ if 'RL' in pol_fit:
+ chisq_RL = iu.chisq(im.rlvec, ARL,
+ obs_test.unpack_dat(data, ['rl' + dtype])['rl' + dtype],
+ data['rlsigma'], dtype=dtype, ttype=ttype, mask=mask)
+ if 'LR' in pol_fit:
+ chisq_LR = iu.chisq(im.lrvec, ALR,
+ obs_test.unpack_dat(data, ['lr' + dtype])['lr' + dtype],
+ data['lrsigma'], dtype=dtype, ttype=ttype, mask=mask)
+ return (chisq_RR + chisq_LL + chisq_RL + chisq_LR)/len(pol_fit)
+
+ print("Finding leakage for sites:", sites)
+
+ def errfunc(Dpar):
+ # all the D-terms (complex). If const_fpol, fpol is the last parameter.
+ D = Dpar.astype(np.float64).view(dtype=np.complex128)
+
+ if not inverse:
+ for isite in range(len(sites)):
+ obs_test.tarr['dr'][site_index[isite]] = D[2*isite]
+ obs_test.tarr['dl'][site_index[isite]] = D[2*isite+1]
+ data = simobs.apply_jones_inverse(obs_test, dcal=False, verbose=False)
+ else:
+ data = obs.data
+
+ # goodness-of-fit for the leakage-corrected data
+ chisq = chisq_total(data, im_circ, D, inverse=inverse)
+
+ # prior on the D terms
+ chisq_D = np.sum(np.abs(D/leakage_tol)**2)
+
+ return chisq + chisq_D
+
+ # Now, we will minimize the total chi-squared. We need two complex leakage terms for each site
+ optdict = {'maxiter': MAXIT} # minimizer params
+ Dpar_guess = np.zeros((len(sites) + const_fpol*2)*2, dtype=np.complex128).view(dtype=np.float64)
+ print("Minimizing...")
+ res = opt.minimize(errfunc, Dpar_guess, method=minimizer_method, options=optdict)
+ print(errfunc(Dpar_guess),errfunc(res.x))
+
+ # get solution
+ D_fit = res.x.astype(np.float64).view(dtype=np.complex128) # all the D-terms (complex)
+
+ # Apply the solution
+ for isite in range(len(sites)):
+ obs_test.tarr['dr'][site_index[isite]] = D_fit[2*isite]
+ obs_test.tarr['dl'][site_index[isite]] = D_fit[2*isite+1]
+ obs_test.data = simobs.apply_jones_inverse(obs_test, dcal=False, verbose=False)
+ obs_test.dcal = True
+
+ # Re-populate any additional leakage terms that were present
+ for j in range(len(obs_test.tarr)):
+ if obs_test.tarr[j]['site'] in sites:
+ continue
+ obs_test.tarr[j]['dr'] = obs.tarr[j]['dr']
+ obs_test.tarr[j]['dl'] = obs.tarr[j]['dl']
+
+ if show_solution:
+ if inverse is False:
+ chisq_orig = chisq_total(obs.switch_polrep('circ').data,
+ im_circ, D_fit, inverse=inverse)
+ chisq_new = chisq_total(obs_test.data, im_circ, D_fit, inverse=inverse)
+ else:
+ chisq_orig = chisq_total(obs.switch_polrep('circ').data,
+ im_circ, D_fit*0, inverse=inverse)
+ chisq_new = chisq_total(obs.switch_polrep('circ').data, im_circ, D_fit, inverse=inverse)
+
+ print("Original chi-squared: {:.4f}".format(chisq_orig))
+ print("New chi-squared: {:.4f}\n".format(chisq_new))
+ for isite in range(len(sites)):
+ print(sites[isite])
+ print(' D_R: {:.4f}'.format(D_fit[2*isite]))
+ print(' D_L: {:.4f}\n'.format(D_fit[2*isite+1]))
+ if const_fpol:
+ print('Source Fractional Polarization Magnitude: {:.4f}'.format(np.abs(D_fit[-2])))
+ print('Source Fractional Polarization EVPA [deg]: {:.4f}\n'.format(
+ 90./np.pi*np.angle(D_fit[-2])))
+ if inverse:
+ print('Source Fractional Circular Polarization: {:.4f}'.format(np.real(D_fit[-1])))
+
+ tstop = time.time()
+ print("\nleakage_cal time: %f s" % (tstop - tstart))
+
+ if obs_apply is not False:
+ # Apply the solution to another observation
+ obs_test = obs_apply.copy()
+ obs_test.tarr['dr'] *= 0.0
+ obs_test.tarr['dl'] *= 0.0
+
+ # Copy the solved D-terms
+ for isite in range(len(sites)):
+ if sites[isite] in list(obs_test.tarr['site']):
+ i_site = list(obs_test.tarr['site']).index(sites[isite])
+ obs_test.tarr['dr'][i_site] = D_fit[2*isite]
+ obs_test.tarr['dl'][i_site] = D_fit[2*isite+1]
+
+ obs_test.data = simobs.apply_jones_inverse(obs_test, dcal=False, verbose=False)
+ obs_test.dcal = True
+
+ # Copy in the remaining D-terms that were there before
+ for j in range(len(obs_test.tarr)):
+ if obs_test.tarr[j]['site'] in sites:
+ continue
+ obs_test.tarr[j]['dr'] = obs_apply.tarr[j]['dr']
+ obs_test.tarr[j]['dl'] = obs_apply.tarr[j]['dl']
+ else:
+ obs_test = obs_test.switch_polrep(obs.polrep)
+
+ if not const_fpol:
+ return obs_test
+ else:
+ if inverse:
+ return [obs_test, D_fit[-2], D_fit[-1]]
+ else:
+ return [obs_test, D_fit[-2]]
+
+
+def plot_leakage(obs, sites=[], axis=False, rangex=False, rangey=False,
+ markers=['o', 's'], markersize=6,
+ export_pdf="", axislabels=True, legend=True, sort_tarr=True, show=True):
+ """Plot polarimetric leakage terms in an observation
+
+ Args:
+ obs (Obsdata): observation (or Array) containing the tarr
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ markers (str): pair of matplotlib plot markers (for RCP and LCP)
+ markersize (int): size of plot markers
+ label (str): plot legend label
+
+ export_pdf (str): path to pdf file to save figure
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+ """
+
+ tarr = obs.tarr.copy()
+ if sort_tarr:
+ tarr.sort(axis=0)
+
+ if len(sites):
+ mask = [t in sites for t in tarr['site']]
+ tarr = tarr[mask]
+
+ clist = ehc.SCOLORS
+
+ # make plot(s)
+ if axis:
+ fig = axis.figure
+ x = axis
+ else:
+ fig = plt.figure()
+ x = fig.add_subplot(1, 1, 1)
+
+ plt.axhline(0, color='black')
+ plt.axvline(0, color='black')
+
+ xymax = np.max([np.abs(tarr['dr']), np.abs(tarr['dl'])])*100.0
+
+ plot_points = []
+ for i in range(len(tarr)):
+ color = clist[i % len(clist)]
+ label = tarr['site'][i]
+ dre, = x.plot(np.real(tarr['dr'][i])*100.0, np.imag(tarr['dr'][i])*100.0, markers[0],
+ markersize=markersize, color=color, label=label)
+ dim, = x.plot(np.real(tarr['dl'][i])*100.0, np.imag(tarr['dl'][i])*100.0, markers[1],
+ markersize=markersize, color=color, label=label)
+ plot_points.append([dre, dim])
+
+ # Data ranges
+ if not rangex:
+ rangex = [-xymax*1.1-0.01, xymax*1.1+0.01]
+
+ if not rangey:
+ rangey = [-xymax*1.1-0.01, xymax*1.1+0.01]
+
+# if not rangex and not rangey:
+# plt.axes().set_aspect('equal', 'datalim')
+
+ x.set_xlim(rangex)
+ x.set_ylim(rangey)
+
+ # label and save
+ if axislabels:
+ x.set_xlabel('Re[$D$] (\%)')
+ x.set_ylabel('Im[$D$] (\%)')
+ if legend:
+ legend1 = plt.legend([l[0] for l in plot_points], tarr['site'], ncol=1, loc=1)
+ plt.legend(plot_points[0], ['$D_R$ (\%)', '$D_L$ (\%)'], loc=4)
+ plt.gca().add_artist(legend1)
+ if export_pdf != "": # and not axis:
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ return x
diff --git a/calibrating/pol_cal_new.py b/calibrating/pol_cal_new.py
new file mode 100644
index 00000000..eac9841d
--- /dev/null
+++ b/calibrating/pol_cal_new.py
@@ -0,0 +1,488 @@
+# pol_cal.py
+# functions for D-term calibration
+# new version that should be faster (2024)
+#
+# Copyright (C) 2024 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import scipy.optimize as opt
+import matplotlib.pyplot as plt
+import time
+
+import ehtim.imaging.imager_utils as iu
+import ehtim.observing.obs_simulate as simobs
+import ehtim.const_def as ehc
+
+MAXIT = 10000 # maximum number of iterations in self-cal minimizer
+NHIST = 50 # number of steps to store for hessian approx
+MAXLS = 40 # maximum number of line search steps in BFGS-B
+STOP = 1e-6 # convergence criterion
+
+###################################################################################################
+# Polarimetric Calibration
+###################################################################################################
+
+# TODO - other chi^2 terms, not just 'vis'?
+# TODO - do we want to start with some nonzero D-term initial guess?
+# TODO - option to not frcal?
+# TODO - pass other kwargs to the chisq?
+# TODO - handle gain cal == False, read in gains from a caltable
+
+def leakage_cal_new(obs, im, sites=[], leakage_tol=.1, rescale_leakage_tol=False,
+ pol_fit=['RL', 'LR'], dtype='vis',
+ minimizer_method='L-BFGS-B',
+ ttype='direct', fft_pad_factor=2,
+ use_grad=True,
+ show_solution=True, apply_solution=True):
+ """Polarimetric calibration (detects and removes polarimetric leakage,
+ based on consistency with a given image)
+
+ Args:
+ obs (Obsdata): The observation to be calibrated
+ im (Image): the reference image used for calibration
+
+ sites (list): list of sites to include in the polarimetric calibration.
+ empty list calibrates all sites
+
+ leakage_tol (float): leakage values exceeding this value will be disfavored by the prior
+ rescale_leakage_tol (bool): if True, properly scale leakage tol for number of sites
+ (not done correctly in old version)
+
+ pol_fit (list): list of visibilities to use; e.g., ['RL','LR'] or ['RR','LL','RL','LR']
+
+ minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS')
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+
+ use_grad (bool): if True, use gradients in minimizer
+
+ show_solution (bool): if True, display the solution as it is calculated
+ apply_solution (bool): if True, apply the solution to the Obsdata, otherwise just return in tarr
+
+ Returns:
+ (Obsdata): the calibrated observation, with computed leakage values added to the obs.tarr
+ """
+
+ if not(obs.ampcal and obs.phasecal):
+ raise Exception("obs must be amplitude and phase calibrated before leakage_cal! (TODO: generalize)")
+
+
+ tstart = time.time()
+
+ mask = [] # TODO: add image masks?
+ dtype = 'vis' # TODO: add other data terms?
+
+ # Do everything in a circular basis
+ im_circ = im.switch_polrep('circ')
+ obs_circ = obs.copy().switch_polrep('circ')
+
+ # Check to see if the field rotation is corrected
+ if obs_circ.frcal is False:
+ print("Field rotation angles have not been corrected. Correcting now...")
+ obs_circ.data = simobs.apply_jones_inverse(obs_circ, frcal=False, dcal=True, opacitycal=True, verbose=False)
+ obs_circ.frcal = True
+
+ # List of all sites present in the observation. Make sure they are all in the tarr
+ allsites = list(set(np.hstack((obs_circ.data['t1'], obs_circ.data['t2']))))
+ for site in allsites:
+ if not (site in obs_circ.tarr['site']):
+ raise Exception("site %s not in obs.tarr!"%site)
+
+ if len(sites) == 0:
+ print("No stations specified for leakage calibration: defaulting to calibrating all sites !")
+ sites = allsites
+ # only include sites that are present in obs.tarr
+ sites = [s for s in sites if s in allsites]
+ site_index = [list(obs_circ.tarr['site']).index(s) for s in sites]
+
+ # TODO do we want to start with some nonzero D-terms?
+ # Set all leakage terms in obs_circ to zero
+ # (we will only correct leakage for those sites with new solutions)
+ for j in range(len(obs_circ.tarr)):
+ if obs_circ.tarr[j]['site'] in sites:
+ continue
+ obs_circ.tarr[j]['dr'] = obs_circ.tarr[j]['dl'] = 0.0j
+
+ print("Finding leakage for sites:", sites)
+
+ print("Precomputing visibilities...")
+ # get stations
+ t1 = obs_circ.unpack('t1')['t1']
+ t2 = obs_circ.unpack('t2')['t2']
+
+ # index sites in t1, t2 position. If no calibrated site is used in a baseline, -1
+ idx1 = np.array([sites.index(t) if (t in sites) else -1 for t in t1])
+ idx2 = np.array([sites.index(t) if (t in sites) else -1 for t in t2])
+
+ # get real data and sigmas
+ # TODO add other chisqdata parameters?
+ # TODO modify chisqdata function to have the option to return samples?
+
+ (vis_RR, sigma_RR, _) = iu.chisqdata(obs_circ, im_circ, mask=mask, dtype=dtype, pol='RR',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+ (vis_LL, sigma_LL, _) = iu.chisqdata(obs_circ, im_circ, mask=mask, dtype=dtype, pol='LL',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+ (vis_RL, sigma_RL, _) = iu.chisqdata(obs_circ, im_circ, mask=mask, dtype=dtype, pol='RL',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+ (vis_LR, sigma_LR, _) = iu.chisqdata(obs_circ, im_circ, mask=mask, dtype=dtype, pol='LR',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+
+ # get simulated data (from simple Fourier transform)
+ obs_sim = im_circ.observe_same_nonoise(obs_circ,
+ ttype=ttype, fft_pad_factor=fft_pad_factor,
+ zero_empty_pol=True,verbose=False)
+
+ (ft_RR, _, _) = iu.chisqdata(obs_sim, im_circ, mask=mask, dtype=dtype, pol='RR',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+ (ft_LL, _, _) = iu.chisqdata(obs_sim, im_circ, mask=mask, dtype=dtype, pol='LL',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+ (ft_RL, _, _) = iu.chisqdata(obs_sim, im_circ, mask=mask, dtype=dtype, pol='RL',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+ (ft_LR, _, _) = iu.chisqdata(obs_sim, im_circ, mask=mask, dtype=dtype, pol='LR',
+ ttype=ttype, fft_pad_factor=fft_pad_factor)
+
+ # field rotation angles
+ el1 = obs_circ.unpack(['el1'], ang_unit='rad')['el1']
+ el2 = obs_circ.unpack(['el2'], ang_unit='rad')['el2']
+ par1 = obs_circ.unpack(['par_ang1'], ang_unit='rad')['par_ang1']
+ par2 = obs_circ.unpack(['par_ang2'], ang_unit='rad')['par_ang2']
+
+ fr_elev1 = np.array([obs_circ.tarr[obs_circ.tkey[o['t1']]]['fr_elev'] for o in obs.data])
+ fr_elev2 = np.array([obs_circ.tarr[obs_circ.tkey[o['t2']]]['fr_elev'] for o in obs.data])
+ fr_par1 = np.array([obs_circ.tarr[obs_circ.tkey[o['t1']]]['fr_par'] for o in obs.data])
+ fr_par2 = np.array([obs_circ.tarr[obs_circ.tkey[o['t2']]]['fr_par'] for o in obs.data])
+ fr_off1 = np.array([obs_circ.tarr[obs_circ.tkey[o['t1']]]['fr_off'] for o in obs.data])
+ fr_off2 = np.array([obs_circ.tarr[obs_circ.tkey[o['t2']]]['fr_off'] for o in obs.data])
+
+ fr1 = fr_elev1*el1 + fr_par1*par1 + fr_off1*np.pi/180.
+ fr2 = fr_elev2*el2 + fr_par2*par2 + fr_off2*np.pi/180.
+
+ Delta = fr1 - fr2
+ Phi = fr1 + fr2
+
+ # TODO: read in gains from caltable?
+ # gains
+ GR1 = np.ones(fr1.shape)
+ GL1 = np.ones(fr1.shape)
+ GR2 = np.ones(fr2.shape)
+ GL2 = np.ones(fr2.shape)
+
+ if not(len(Delta)==len(vis_RR)==len(sigma_LL)==len(ft_RR)==len(t1)):
+ raise Exception("not all data columns the right length in pol_cal!")
+ Nvis = len(vis_RR)
+
+ # define the error function
+ def chisq_total(Dpar):
+ # all the D-terms as complex numbers. If const_fpol, fpol is the last parameter.
+ D = Dpar.astype(np.float64).view(dtype=np.complex128)
+
+ # current D-terms for each baseline, zero for stations not calibrated (TODO faster?)
+ DR1 = np.asarray([D[2*sites.index(s)] if s in sites else 0. for s in t1])
+ DL1 = np.asarray([D[2*sites.index(s)+1] if s in sites else 0. for s in t1])
+
+ DR2 = np.asarray([D[2*sites.index(s)] if s in sites else 0. for s in t2])
+ DL2 = np.asarray([D[2*sites.index(s)+1] if s in sites else 0. for s in t2])
+
+ # simulated visibilities and chisqs with leakage
+ chisq_RR = chisq_LL = chisq_RL = chisq_LR = 0.0
+ if 'RR' in pol_fit:
+ vis_RR_leak = ft_RR + DR1*DR2.conj()*np.exp(2j*Delta)*ft_LL + DR1*np.exp(2j*fr1)*ft_LR + DR2.conj()*np.exp(-2j*fr2)*ft_RL
+ vis_RR_leak *= GR1*GR2.conj()
+
+ chisq_RR = np.sum(np.abs(vis_RR - vis_RR_leak)**2 / (sigma_RR**2))
+ chisq_RR = chisq_RR / (2.*Nvis)
+ if 'LL' in pol_fit:
+ vis_LL_leak = ft_LL + DL1*DL2.conj()*np.exp(-2j*Delta)*ft_RR + DL1*np.exp(-2j*fr1)*ft_RL + DL2.conj()*np.exp(2j*fr2)*ft_LR
+ vis_LL_leak *= GL1*GL2.conj()
+
+ chisq_LL = np.sum(np.abs(vis_LL - vis_LL_leak)**2 / (sigma_LL**2))
+ chisq_LL = chisq_LL / (2.*Nvis)
+ if 'RL' in pol_fit:
+ vis_RL_leak = ft_RL + DR1*DL2.conj()*np.exp(2j*Phi)*ft_LR + DR1*np.exp(2j*fr1)*ft_LL + DL2.conj()*np.exp(2j*fr2)*ft_RR
+ vis_RL_leak *= GR1*GL2.conj()
+
+ chisq_RL = np.sum(np.abs(vis_RL - vis_RL_leak)**2 / (sigma_RL**2))
+ chisq_RL = chisq_RL / (2.*Nvis)
+ if 'LR' in pol_fit:
+ vis_LR_leak = ft_LR + DL1*DR2.conj()*np.exp(-2j*Phi)*ft_RL + DL1*np.exp(-2j*fr1)*ft_RR + DR2.conj()*np.exp(-2j*fr2)*ft_LL
+ vis_LR_leak *= GL1*GR2.conj()
+
+ chisq_LR = np.sum(np.abs(vis_LR - vis_LR_leak)**2 / (sigma_LR**2))
+ chisq_LR = chisq_LR / (2.*Nvis)
+
+ chisq_tot = (chisq_RR + chisq_LL + chisq_RL + chisq_LR)/len(pol_fit)
+ return chisq_tot
+
+ def errfunc(Dpar):
+ # chi-squared
+ chisq_tot = chisq_total(Dpar)
+
+ # prior on the D terms
+ # TODO
+ prior = np.sum((np.abs(Dpar)**2)/(leakage_tol**2))
+ if rescale_leakage_tol:
+ prior = prior / (len(Dpar))
+
+ return chisq_tot + prior
+
+ # define the error function gradient
+ def chisq_total_grad(Dpar):
+ chisqgrad = np.zeros(len(Dpar))
+
+ # all the D-terms as complex numbers.
+ # stored in groups of 4 per site [Re(DR), Im(DR), Re(DL), Im(DL)]
+ D = Dpar.astype(np.float64).view(dtype=np.complex128)
+
+ # current D-terms for each baseline, zero for stations not calibrated (TODO faster?)
+ DR1 = np.asarray([D[2*sites.index(s)] if s in sites else 0. for s in t1])
+ DL1 = np.asarray([D[2*sites.index(s)+1] if s in sites else 0. for s in t1])
+
+ DR2 = np.asarray([D[2*sites.index(s)] if s in sites else 0. for s in t2])
+ DL2 = np.asarray([D[2*sites.index(s)+1] if s in sites else 0. for s in t2])
+
+ # residual and dV/dD terms
+ if 'RR' in pol_fit:
+ vis_RR_leak = ft_RR + DR1*DR2.conj()*np.exp(2j*Delta)*ft_LL + DR1*np.exp(2j*fr1)*ft_LR + DR2.conj()*np.exp(-2j*fr2)*ft_RL
+ vis_RR_leak *= GR1*GR2.conj()
+
+ resid_RR = (vis_RR - vis_RR_leak).conj()
+
+ dRR_dReDR1 = DR2.conj()*np.exp(2j*Delta)*ft_LL + np.exp(2j*fr1)*ft_LR
+ dRR_dReDR1 *= GR1*GR2.conj()
+
+ dRR_dReDR2 = DR1*np.exp(2j*Delta)*ft_LL + np.exp(-2j*fr2)*ft_RL
+ dRR_dReDR2 *= GR1*GR2.conj()
+
+ if 'LL' in pol_fit:
+ vis_LL_leak = ft_LL + DL1*DL2.conj()*np.exp(-2j*Delta)*ft_RR + DL1*np.exp(-2j*fr1)*ft_RL + DL2.conj()*np.exp(2j*fr2)*ft_LR
+ vis_LL_leak *= GL1*GL2.conj()
+
+ resid_LL = (vis_LL - vis_LL_leak).conj()
+
+ dLL_dReDL1 = DL2.conj()*np.exp(-2j*Delta)*ft_RR + np.exp(-2j*fr1)*ft_RL
+ dLL_dReDL1 *= GL1*GL2.conj()
+
+ dLL_dReDL2 = DL1*np.exp(-2j*Delta)*ft_RR + np.exp(2j*fr2)*ft_LR
+ dLL_dReDL2 *= GL1*GL2.conj()
+
+
+ if 'RL' in pol_fit:
+ vis_RL_leak = ft_RL + DR1*DL2.conj()*np.exp(2j*Phi)*ft_LR + DR1*np.exp(2j*fr1)*ft_LL + DL2.conj()*np.exp(2j*fr2)*ft_RR
+ vis_RL_leak *= GR1*GL2.conj()
+
+ resid_RL = (vis_RL - vis_RL_leak).conj()
+
+ dRL_dReDR1 = DL2.conj()*np.exp(2j*Phi)*ft_LR + np.exp(2j*fr1)*ft_LL
+ dRL_dReDR1 *= GR1*GL2.conj()
+
+ dRL_dReDL2 = DR1*np.exp(2j*Phi)*ft_LR + np.exp(2j*fr2)*ft_RR
+ dRL_dReDL2 *= GR1*GL2.conj()
+
+
+ if 'LR' in pol_fit:
+ vis_LR_leak = ft_LR + DL1*DR2.conj()*np.exp(-2j*Phi)*ft_RL + DL1*np.exp(-2j*fr1)*ft_RR + DR2.conj()*np.exp(-2j*fr2)*ft_LL
+ vis_LR_leak *= GL1*GR2.conj()
+
+ resid_LR = (vis_LR - vis_LR_leak).conj()
+
+ dLR_dReDL1 = DR2.conj()*np.exp(-2j*Phi)*ft_RL + np.exp(-2j*fr1)*ft_RR
+ dLR_dReDL1 *= GL1*GR2.conj()
+
+ dLR_dReDR2 = DL1*np.exp(-2j*Phi)*ft_RL + np.exp(-2j*fr2)*ft_LL
+ dLR_dReDR2 *= GL1*GR2.conj()
+
+
+ # to get gradients, sum over baselines
+ # TODO remove for loop with some fancy vectorization?
+ for isite in range(len(sites)):
+ mask1 = (idx1 == isite)
+ mask2 = (idx2 == isite)
+
+ # DR
+ regrad = 0
+ imgrad = 0
+ if 'RR' in pol_fit:
+ terms = resid_RR[mask1] * dRR_dReDR1[mask1] / (sigma_RR[mask1]**2)
+ regrad += -1*np.sum(np.real(terms))
+ imgrad += np.sum(np.imag(terms))
+
+ terms = resid_RR[mask2] * dRR_dReDR2[mask2] / (sigma_RR[mask2]**2)
+ regrad += -1*np.sum(np.real(terms))
+ imgrad += -1*np.sum(np.imag(terms))
+
+ if 'RL' in pol_fit:
+ terms = resid_RL[mask1] * dRL_dReDR1[mask1] / (sigma_RL[mask1]**2)
+ regrad += -1*np.sum(np.real(terms))
+ imgrad += np.sum(np.imag(terms))
+
+ if 'LR' in pol_fit:
+ terms = resid_LR[mask2] * dLR_dReDR2[mask2] / (sigma_LR[mask2]**2)
+ regrad += -1*np.sum(np.real(terms))
+ imgrad += -1*np.sum(np.imag(terms))
+
+ chisqgrad[4*isite] += regrad # Re(DR)
+ chisqgrad[4*isite+1] += imgrad # Im(DR)
+
+ # DL
+ regrad = 0
+ imgrad = 0
+ if 'LL' in pol_fit:
+ terms = resid_LL[mask1] * dLL_dReDL1[mask1] / (sigma_LL[mask1]**2)
+ regrad += -1*np.sum(np.real(terms))
+ imgrad += np.sum(np.imag(terms))
+
+ terms = resid_LL[mask2] * dLL_dReDL2[mask2] / (sigma_LL[mask2]**2)
+ regrad += -1*np.sum(np.real(terms))
+ imgrad += -1*np.sum(np.imag(terms))
+
+ if 'RL' in pol_fit:
+ terms = resid_RL[mask2] * dRL_dReDL2[mask2] / (sigma_RL[mask2]**2)
+ regrad += -1*np.sum(np.real(terms))
+ imgrad += -1*np.sum(np.imag(terms))
+
+ if 'LR' in pol_fit:
+ terms = resid_LR[mask1] * dLR_dReDL1[mask1] / (sigma_LR[mask1]**2)
+ regrad += -1*np.sum(np.real(terms))
+ imgrad += np.sum(np.imag(terms))
+
+ chisqgrad[4*isite+2] += regrad # Re(DL)
+ chisqgrad[4*isite+3] += imgrad # Im(DL)
+
+ chisqgrad /= (Nvis*len(pol_fit))
+
+ return chisqgrad
+
+ def errfunc_grad(Dpar):
+ # gradient of the chi^2
+ chisqgrad = chisq_total_grad(Dpar)
+
+ # gradient of the prior
+ priorgrad = 2*Dpar / (leakage_tol**2)
+ if rescale_leakage_tol:
+ priorgrad = priorgrad / (len(Dpar))
+
+ return chisqgrad + priorgrad
+
+ # Gradient test - remove!
+# def test_grad(Dpar):
+# grad_ana = errfunc_grad(Dpar)
+# grad_num1 = np.zeros(len(Dpar))
+# for i in range(len(Dpar)):
+# dd = 1.e-8
+# Dpar_dd = Dpar.copy()
+# Dpar_dd[i] += dd
+# grad_num1[i] = (errfunc(Dpar_dd) - errfunc(Dpar))/dd
+# grad_num2 = np.zeros(len(Dpar))
+# for i in range(len(Dpar)):
+# dd = -1.e-8
+# Dpar_dd = Dpar.copy()
+# Dpar_dd[i] += dd
+# grad_num2[i] = (errfunc(Dpar_dd) - errfunc(Dpar))/dd
+#
+# plt.close('all')
+# plt.ion()
+# plt.figure()
+# plt.plot(np.arange(len(Dpar)), grad_ana, 'ro')
+# plt.plot(np.arange(len(Dpar)), grad_num1, 'b.')
+# plt.plot(np.arange(len(Dpar)), grad_num2, 'bx')
+# plt.xticks(np.arange(0,len(Dpar),4), sites)
+#
+# plt.figure()
+# zscal = 1.e-32*np.min(np.abs(grad_ana)[grad_ana!=0])
+# plt.plot(np.arange(len(Dpar)), 100-100*(grad_num1+zscal)/(grad_ana+zscal),'b.')
+# plt.plot(np.arange(len(Dpar)), 100-100*(grad_num2+zscal)/(grad_ana+zscal),'bx')
+# plt.xticks(np.arange(0,len(Dpar),4), sites)
+# plt.ylim(-1,1)
+# plt.show()
+# return
+
+# Dpar_guess = .1*np.random.randn(len(sites)*4)
+# test_grad(Dpar_guess)
+
+ print("Calibrating D-terms...")
+ # Now, we will finally minimize the total error term. We need two complex leakage terms for each site
+ if minimizer_method=='L-BFGS-B':
+ optdict = {'maxiter': MAXIT,
+ 'ftol': STOP, 'gtol': STOP,
+ 'maxcor': NHIST, 'maxls': MAXLS}
+ else:
+ optdict = {'maxiter': MAXIT}
+
+ Dpar_guess = np.zeros(len(sites)*2, dtype=np.complex128).view(dtype=np.float64)
+ if use_grad:
+ res = opt.minimize(errfunc, Dpar_guess, method=minimizer_method, options=optdict, jac=errfunc_grad)
+ else:
+ res = opt.minimize(errfunc, Dpar_guess, method=minimizer_method, options=optdict)
+
+ print(errfunc(Dpar_guess),errfunc(res.x))
+
+ # get solution
+ Dpar_fit = res.x.astype(np.float64)
+ D_fit = Dpar_fit.view(dtype=np.complex128) # all the D-terms (complex)
+
+ # fill in the new D-terms to the tarr
+ obs_out = obs_circ.copy() # TODO or overwrite directly?
+ for isite in range(len(sites)):
+ obs_out.tarr['dr'][site_index[isite]] = D_fit[2*isite]
+ obs_out.tarr['dl'][site_index[isite]] = D_fit[2*isite+1]
+
+ # Apply the solution
+ if apply_solution:
+ obs_out.data = simobs.apply_jones_inverse(obs_out, dcal=False, frcal=True, opacitycal=True, verbose=False)
+ obs_out.dcal = True
+ else:
+ obs_out.dcal = False
+
+ # Re-populate any additional leakage terms that were present
+ # NOTE we don't want to do this above, in case we want to ignore these terms in apply_jones inverse
+ # TODO can we do this better?
+ for j in range(len(obs_out.tarr)):
+ if obs_out.tarr[j]['site'] in sites:
+ continue
+ obs_out.tarr[j]['dr'] = obs.tarr[j]['dr']
+ obs_out.tarr[j]['dl'] = obs.tarr[j]['dl']
+
+ # TODO are these diagnostics correct?
+ if show_solution:
+ chisq_orig = chisq_total(Dpar_fit*0)
+ chisq_new = chisq_total(Dpar_fit)
+
+ print("Original chi-squared: {:.4f}".format(chisq_orig))
+ print("New chi-squared: {:.4f}\n".format(chisq_new))
+ for isite in range(len(sites)):
+ print(sites[isite])
+ print(' D_R: {:.4f}'.format(D_fit[2*isite]))
+ print(' D_L: {:.4f}\n'.format(D_fit[2*isite+1]))
+
+ tstop = time.time()
+ print("\nleakage_cal time: %f s" % (tstop - tstart))
+
+
+ obs_out = obs_out.switch_polrep(obs.polrep)
+
+ return obs_out
+
+
+
+
diff --git a/calibrating/polgains_cal.py b/calibrating/polgains_cal.py
new file mode 100644
index 00000000..38c36eaa
--- /dev/null
+++ b/calibrating/polgains_cal.py
@@ -0,0 +1,312 @@
+# polgains_cal.py
+# functions for calibrating RCP-LCP phase offsets
+#
+# Copyright (C) 2019 Maciek Wielgus (maciek.wielgus(at)gmail.com)
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import scipy.optimize as opt
+import time
+import copy
+from multiprocessing import cpu_count, Pool
+
+import ehtim.obsdata
+import ehtim.parloop as parloop
+import ehtim.observing.obs_helpers as obsh
+import ehtim.const_def as ehc
+
+MAXIT = 5000
+
+###################################################################################################
+# Polarimetric-Phase-Gains-Calibration
+###################################################################################################
+
+
+def polgains_cal(obs, reference='AA', sites=[], method='phase', minimizer_method='BFGS', pad_amp=0.,
+ solution_interval=0.0, scan_solutions=False,
+ caltable=False, processes=-1, show_solution=False, msgtype='bar'):
+
+ # TODO: function to globalize the polarimetric solution in time
+ # given provided absolute calibration of the reference station
+ # so that we have a meaningful EVPA
+ """Polarimeteric-phase-gains-calibrate a dataset.
+ Numerically solves for polarimetric gains to align RCP and LCP feeds.
+ Uses all baselines to find the solution. Effectively assumes phase of Stokes V to be zero.
+ Because fits are local, it's not providing absolute phase calibration.
+
+ Args:
+ obs (Obsdata): The observation to be calibrated
+ reference (str): station used as reference to break the degeneracy
+ (LCP on baselines to the reference station remains unchanged)
+ sites (list): list of sites to include in the polarimetric calibration.
+ Empty list calibrates all sites
+ method (str): chooses what to calibrate, 'phase' or 'both'
+ 'phase' is default, most useful (instrumental offsets),
+ 'both' will align RCP/LCP amplitudes as well
+ minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS')
+ pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature
+ solution_interval (float): solution interval in seconds; one gain is derived per interval.
+ If 0.0, a solution is determined for each unique time.
+ scan_solutions (bool): If True, determine one gain per site per scan
+ Supersedes solution_interval.
+ caltable (bool): if True, returns a Caltable instead of an Obsdata
+ processes (int): number of cores to use in multiprocessing
+ show_solution (bool): if True, display the solution as it is calculated
+ msgtype (str): type of progress message to be printed, default is 'bar'
+
+ Returns:
+ (Obsdata): the calibrated observation, if caltable==False
+ (Caltable): the derived calibration table, if caltable==True
+ """
+ # circular representation is needed
+ if obs.polrep != 'circ':
+ obs = obs.switch_polrep('circ')
+
+ if len(sites) == 0:
+ print("No stations specified in polgain cal!")
+ print('Defaulting to calibrating all stations with reference station as: ' + reference)
+ sites = np.array([x for x in obs.tarr['site'] if x != reference], dtype=' 0:
+ counter = parloop.Counter(initval=0, maxval=len(scans))
+ if processes > len(scans):
+ processes = len(scans)
+ print("Using Multiprocessing with %d Processes" % processes)
+ pool = Pool(processes=processes, initializer=init, initargs=(counter,))
+ elif processes == 0:
+ counter = parloop.Counter(initval=0, maxval=len(scans))
+ processes = int(cpu_count())
+ if processes > len(scans):
+ processes = len(scans)
+ print("Using Multiprocessing with %d Processes" % processes)
+ pool = Pool(processes=processes, initializer=init, initargs=(counter,))
+ else:
+ print("Not Using Multiprocessing")
+
+ # loop over scans and calibrate
+ tstart = time.time()
+ if processes > 0: # with multiprocessing
+ scans_cal = pool.map(get_polgains_scan_cal,
+ [[i, len(scans), scans[i], reference,
+ sites, method, pad_amp, caltable, show_solution, msgtype]
+ for i in range(len(scans))
+ ])
+
+ else: # without multiprocessing
+ for i in range(len(scans)):
+ obsh.prog_msg(i, len(scans), msgtype=msgtype, nscan_last=i - 1)
+ scans_cal[i] = polgains_cal_scan(scans[i], reference, sites,
+ method=method, minimizer_method=minimizer_method,
+ show_solution=show_solution, caltable=caltable,
+ pad_amp=pad_amp)
+
+ tstop = time.time()
+ print("\npolgain_cal time: %f s" % (tstop - tstart))
+
+ if caltable: # create and return a caltable
+ allsites = obs.tarr['site']
+ caldict = {k: v.reshape(1) for k, v in scans_cal[0].items()}
+ for i in range(1, len(scans_cal)):
+ row = scans_cal[i]
+ if len(row) == 0:
+ continue
+
+ for site in allsites:
+ try:
+ dat = row[site]
+ except KeyError:
+ continue
+
+ try:
+ caldict[site] = np.append(caldict[site], row[site])
+ except KeyError:
+ caldict[site] = [dat]
+
+ caltable = ehtim.caltable.Caltable(obs.ra, obs.dec, obs.rf, obs.bw, caldict, obs.tarr,
+ source=obs.source, mjd=obs.mjd, timetype=obs.timetype)
+ out = caltable
+
+ else: # return the calibrated observation
+ arglist, argdict = obs.obsdata_args()
+ arglist[4] = np.concatenate(scans_cal)
+ out = ehtim.obsdata.Obsdata(*arglist, **argdict)
+
+ # close multiprocessing jobs
+ if processes != -1:
+ pool.close()
+
+ return out
+
+
+def polgains_cal_scan(scan, reference='AA', sites=[], method='phase', minimizer_method='BFGS',
+ show_solution=False, pad_amp=0., caltable=False, msgtype='bar'):
+ """Polarimeteric-phase-gains-calibrate a dataset.
+ Numerically solves for polarimetric gains to align RCP and LCP feeds.
+ Uses all baselines to find the solution. Effectively assumes phase of Stokes V to be zero.
+ Because fits are local, it's not providing absolute phase calibration.
+
+ Args:
+ obs (Obsdata): The observation to be calibrated
+ reference (str): station used as reference to break the degeneracy
+ (LCP on baselines to the reference station remains unchanged)
+ sites (list): list of sites to include in the polarimetric calibration.
+ Empty list calibrates all sites
+ method (str): chooses what to calibrate, 'phase' or 'both'
+ 'phase' is default, most useful (instrumental offsets),
+ 'both' will align RCP/LCP amplitudes as well
+ minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS')
+ pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature
+ caltable (bool): if True, returns a Caltable instead of an Obsdata
+ show_solution (bool): if True, display the solution as it is calculated
+ msgtype (str): type of progress message to be printed, default is 'bar'
+
+ Returns:
+ (Obsdata): the calibrated observation, if caltable==False
+ (Caltable): the derived calibration table, if caltable==True
+ """
+ indices_no_vis_nans = (~np.isnan(scan[ehc.vis_poldict['RR']])) & (
+ ~np.isnan(scan[ehc.vis_poldict['LL']]))
+ scan_no_vis_nans = scan[indices_no_vis_nans]
+ allsites_no_vis_nans = list(set(np.hstack((scan_no_vis_nans['t1'], scan_no_vis_nans['t2']))))
+
+ # all the sites in the scan
+ allsites = list(set(np.hstack((scan['t1'], scan['t2']))))
+ if len(sites) == 0:
+ print("No stations specified in polgain cal")
+ print("defaulting to calibrating all non-reference stations!")
+ sites = allsites
+
+ # only include sites that are present, not singlepol and not reference
+ sites = [s for s in sites if (s in allsites_no_vis_nans) & (s != reference)]
+
+ # create a dictionary to keep track of gains; reference site and singlepol sites get key -1
+ tkey = {b: a for a, b in enumerate(sites)}
+ # make two lists of gain keys that relates scan bl gains to solved site ones
+ # -1 means that this station does not have a gain that is being solved for
+ g1_keys = []
+ g2_keys = []
+ for row in scan:
+ try:
+ g1_keys.append(tkey[row['t1']])
+ except KeyError:
+ g1_keys.append(-1)
+ try:
+ g2_keys.append(tkey[row['t2']])
+ except KeyError:
+ g2_keys.append(-1)
+
+ # get scan visibilities of the specified polarization
+ visRR = scan_no_vis_nans[ehc.vis_poldict['RR']]
+ visLL = scan_no_vis_nans[ehc.vis_poldict['LL']]
+ sigmaRR = scan_no_vis_nans[ehc.sig_poldict['RR']]
+ sigmaLL = scan_no_vis_nans[ehc.sig_poldict['LL']]
+ sigma = np.sqrt(sigmaRR**2 + sigmaLL**2)
+ # sigma_inv = 1.0/np.sqrt(sigma**2 + (pad_amp*0.5*(np.abs(visRR)+np.abs(visLL)))**2)
+ # initial guesses for parameters
+ n_gains = len(sites)
+ gpar_guess = np.ones(n_gains, dtype=np.complex128).view(dtype=np.float64)
+
+ # error function
+ def errfunc(g):
+ g = g.view(dtype=np.complex128)
+ g = np.append(g, 1. + 0.j)
+ if method == "phase":
+ g = g / np.abs(g)
+ g1 = g[g1_keys]
+ g2 = g[g2_keys]
+ chisq = np.sum(np.abs((visRR - g1[indices_no_vis_nans] *
+ g2[indices_no_vis_nans].conj() * visLL) / sigma)**2)
+ return chisq
+
+ if np.max(g1_keys) > -1 or np.max(g2_keys) > -1:
+ # run the minimizer to get a solution (but only run if there's at least one gain to fit)
+ optdict = {'maxiter': MAXIT} # minimizer params
+ res = opt.minimize(errfunc, gpar_guess, method=minimizer_method, options=optdict)
+
+ # get solution
+ g_fit = res.x.view(np.complex128)
+ if method == "phase":
+ g_fit = g_fit / np.abs(g_fit)
+
+ if show_solution:
+ print(g_fit)
+ else:
+ g_fit = []
+ g_fit = np.append(g_fit, 1.)
+ # Derive a calibration table or apply the solution to the scan
+ if caltable:
+ allsites = list(set(scan['t1']).union(set(scan['t2'])))
+
+ caldict = {}
+ for site in allsites:
+ if site in sites:
+ site_key = tkey[site]
+ else:
+ site_key = -1
+
+ # Convention is that we calibrate RCP phase to align with the LCP phase
+ rscale = g_fit[site_key]**-1
+ lscale = 1. + 0.j
+
+ caldict[site] = np.array((scan['time'][0], rscale, lscale), dtype=ehc.DTCAL)
+ out = caldict
+ else:
+ g1_fit = g_fit[g1_keys]
+ g2_fit = g_fit[g2_keys]
+ g1_inv = g1_fit**(-1)
+ g2_inv = g2_fit**(-1)
+ # scale visibilities
+ scan['rrvis'] *= g1_inv * g2_inv.conj()
+ scan['llvis'] *= 1. + 0.j
+ scan['rlvis'] *= g1_inv
+ scan['lrvis'] *= g2_inv.conj()
+ # don't scale sigmas
+ out = scan
+ return out
+
+
+def init(x):
+ global counter
+ counter = x
+
+
+def get_polgains_scan_cal(args):
+ return get_polgains_scan_cal2(*args)
+
+
+def get_polgains_scan_cal2(i, n, scan, reference, sites, method, pad_amp, caltable,
+ show_solution, msgtype):
+
+ if n > 1:
+ global counter
+ counter.increment()
+ obsh.prog_msg(counter.value(), counter.maxval, msgtype, counter.value() - 1)
+
+ return polgains_cal_scan(scan, reference, sites,
+ method=method, caltable=caltable, show_solution=show_solution,
+ pad_amp=pad_amp, msgtype=msgtype)
diff --git a/calibrating/self_cal.py b/calibrating/self_cal.py
new file mode 100644
index 00000000..90d8e9d9
--- /dev/null
+++ b/calibrating/self_cal.py
@@ -0,0 +1,594 @@
+# selfcal.py
+# functions for self-calibration
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import scipy.optimize as opt
+import time
+import copy
+from multiprocessing import cpu_count, Pool
+
+import ehtim.obsdata
+import ehtim.parloop as parloop
+from . import cal_helpers as calh
+from ehtim.observing.obs_simulate import add_jones_and_noise
+import ehtim.observing.obs_helpers as obsh
+import ehtim.const_def as ehc
+
+import warnings
+warnings.filterwarnings("ignore", message="divide by zero encountered in log")
+
+MAXIT = 10000 # maximum number of iterations in self-cal minimizer
+NHIST = 50 # number of steps to store for hessian approx
+MAXLS = 40 # maximum number of line search steps in BFGS-B
+STOP = 1e-6 # convergence criterion
+
+###################################################################################################
+# Self-Calibration
+###################################################################################################
+
+
+def self_cal(obs, im, sites=[], pol='I', apply_singlepol=False, method="both",
+ minimizer_method='BFGS',
+ pad_amp=0., gain_tol=.2, solution_interval=0.0, scan_solutions=False,
+ ttype='direct', fft_pad_factor=2, caltable=False,
+ debias=True, apply_dterms=False,
+ copy_closure_tables=False,
+ processes=-1, show_solution=False, msgtype='bar',
+ use_grad=False):
+ """Self-calibrate a dataset to an image.
+
+ Args:
+ obs (Obsdata): The observation to be calibrated
+ im (Image): the image to be calibrated to
+ sites (list): list of sites to include in the self calibration.
+ empty list calibrates all sites
+
+ pol (str): which image polarization to self-calibrate visibilities to
+ apply_singlepol (str): if calibrating to pol='RR' or pol='LL',
+ apply solution only to the single polarization
+
+ method (str): chooses what to calibrate, 'amp', 'phase', or 'both'
+ minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS')
+
+ pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature
+ gain_tol (float or list): gains that exceed this value will be disfavored by the prior
+ for asymmetric gain_tol for corrections below/above unity,
+ pass a 2-element list
+ solution_interval (float): solution interval in seconds;
+ If 0., determine solution for each unique time
+ scan_solutions (bool): If True, determine one gain per site per scan
+ (supersedes solution_interval)
+
+ caltable (bool): if True, returns a Caltable instead of an Obsdata
+ processes (int): number of cores to use in multiprocessing
+
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+
+ debias (bool): If True, debias the amplitudes
+ apply_dterms (bool): if True, apply dterms (in obs.tarr) to clean data before calibrating
+ show_solution (bool): if True, display the solution as it is calculated
+ msgtype (str): type of progress message to be printed, default is 'bar'
+ use_grad (bool): if True, use gradients in minimizer
+
+ Returns:
+ (Obsdata): the calibrated observation, if caltable==False
+ (Caltable): the derived calibration table, if caltable==True
+ """
+ if use_grad and (method=='phase' or method=='amp'):
+ raise Exception("errfunc_grad in self_cal only works with method=='both'!")
+
+ if pol in ['I', 'Q', 'U', 'V']:
+ if obs.polrep != 'stokes':
+ raise Exception("selfcal pol is a stokes parameter, but obs.polrep!='stokes'")
+ im = im.switch_polrep('stokes', pol)
+ elif pol in ['RR', 'LL']:
+ if obs.polrep != 'circ':
+ raise Exception("selfcal pol is RR or LL, but obs.polrep!='circ'")
+ im = im.switch_polrep('circ', pol)
+ else:
+ raise Exception("Can only self-calibrate to I, Q, U, V, RR, or LL images!")
+
+ if apply_singlepol and obs.polrep!='circ':
+ raise Exception("apply_singlepol must be False unless self-calibrating to 'RR' or 'LL'")
+
+ # V = model visibility, V' = measured visibility, G_i = site gain
+ # G_i * conj(G_j) * V_ij = V'_ij
+ if len(sites) == 0:
+ print("No stations specified in self cal: defaulting to calibrating all stations!")
+ sites = obs.tarr['site']
+
+ # First, sample the model visibilities of the specified polarization
+ print("Computing the Model Visibilities with " + ttype + " Fourier Transform...")
+ obs_clean = im.observe_same_nonoise(obs, ttype=ttype, fft_pad_factor=fft_pad_factor)
+
+ # apply dterms
+ # TODO check!
+ if apply_dterms:
+ print("Applying dterms in obs.tarr to clean visibilities before selfcal!")
+ obsdata_dterms = add_jones_and_noise(obs_clean,
+ add_th_noise=False, ampcal=True, phasecal=True, opacitycal=True,
+ dcal=False, frcal=True, dterm_offset=0.0)
+ obs_clean.data = obsdata_dterms
+
+ # Partition the list of observed visibilities into scans
+ scans = obs.tlist(t_gather=solution_interval, scan_gather=scan_solutions)
+ scans_cal = copy.copy(scans)
+
+ # Partition the list of model visibilities into scans
+ V_scans = [o[ehc.vis_poldict[pol]] for o in obs_clean.tlist(
+ t_gather=solution_interval, scan_gather=scan_solutions)]
+
+ # Make the pool for parallel processing
+ if processes > 0:
+ counter = parloop.Counter(initval=0, maxval=len(scans))
+ print("Using Multiprocessing with %d Processes" % processes)
+ pool = Pool(processes=processes, initializer=init, initargs=(counter,))
+ elif processes == 0:
+ counter = parloop.Counter(initval=0, maxval=len(scans))
+ processes = int(cpu_count())
+ print("Using Multiprocessing with %d Processes" % processes)
+ pool = Pool(processes=processes, initializer=init, initargs=(counter,))
+ else:
+ print("Not Using Multiprocessing")
+
+ # loop over scans and calibrate
+ tstart = time.time()
+ if processes > 0: # run on multiple cores with multiprocessing
+ scans_cal = np.array(pool.map(get_selfcal_scan_cal, [[i, len(scans), scans[i],
+ im, V_scans[i], sites,
+ obs.polrep, pol, apply_singlepol,
+ method, minimizer_method,
+ show_solution, pad_amp, gain_tol,
+ debias, caltable, msgtype,
+ use_grad
+ ] for i in range(len(scans))]),
+ dtype=object)
+
+ else: # run on a single core
+ for i in range(len(scans)):
+ obsh.prog_msg(i, len(scans), msgtype=msgtype, nscan_last=i - 1)
+ scans_cal[i] = self_cal_scan(scans[i], im, V_scan=V_scans[i], sites=sites,
+ polrep=obs.polrep, pol=pol, apply_singlepol=apply_singlepol,
+ method=method, minimizer_method=minimizer_method,
+ show_solution=show_solution,
+ pad_amp=pad_amp, gain_tol=gain_tol,
+ debias=debias, caltable=caltable,
+ use_grad=use_grad)
+
+ tstop = time.time()
+ print("\nself_cal time: %f s" % (tstop - tstart))
+
+ if caltable: # assemble the caltable to return
+ allsites = obs.tarr['site']
+ caldict = scans_cal[0]
+ for i in range(1, len(scans_cal)):
+ row = scans_cal[i]
+ for site in allsites:
+ try:
+ dat = row[site]
+ except KeyError:
+ continue
+
+ try:
+ caldict[site] = np.append(caldict[site], row[site])
+ except KeyError:
+ caldict[site] = dat
+
+ caltable = ehtim.caltable.Caltable(obs.ra, obs.dec, obs.rf, obs.bw, caldict, obs.tarr,
+ source=obs.source, mjd=obs.mjd, timetype=obs.timetype)
+ out = caltable
+
+ else: # return a calibrated observation
+ arglist, argdict = obs.obsdata_args()
+ arglist[4] = np.concatenate(scans_cal)
+ out = ehtim.obsdata.Obsdata(*arglist, **argdict)
+ if copy_closure_tables:
+ out.camp = obs.camp
+ out.logcamp = obs.logcamp
+ out.cphase = obs.cphase
+
+ # close multiprocessing jobs
+ if processes >= 0:
+ pool.close()
+
+ return out
+
+
+def self_cal_scan(scan, im, V_scan=[], sites=[], polrep='stokes', pol='I', apply_singlepol=False,
+ method="both",
+ minimizer_method='BFGS', show_solution=False,
+ pad_amp=0., gain_tol=.2, debias=True, caltable=False,
+ use_grad=False):
+ """Self-calibrate a scan to an image.
+
+ Args:
+ scan (np.recarray): data array of type DTPOL_STOKES or DTPOL_CIRC
+ im (Image): the image to be calibrated to
+ sites (list): list of sites to include in the self calibration.
+ empty list calibrates all sites
+ V_scan (list) : precomputed scan visibilities
+
+ polrep (str): 'stokes' or 'circ' to specify the polarization products in scan
+ pol (str): which image polarization to self-calibrate visibilities to
+ apply_singlepol (str): if calibrating to pol='RR' or pol='LL',
+ apply solution only to the single polarization
+
+ method (str): chooses what to calibrate, 'amp', 'phase', or 'both'
+ minimizer_method (str): Method for scipy.optimize.minimize
+ (e.g., 'CG', 'BFGS', 'Nelder-Mead', etc.)
+ pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature
+ gain_tol (float or list): gains that exceed this value will be disfavored by the prior
+ for asymmetric gain_tol for corrections below/above unity,
+ pass a 2-element list
+
+ debias (bool): If True, debias the amplitudes
+ caltable (bool): if True, returns a Caltable instead of an Obsdata
+ show_solution (bool): if True, display the solution as it is calculated
+ use_grad (bool): if True, use gradients in minimizer
+
+ Returns:
+ (Obsdata): the calibrated observation, if caltable==False
+ (Caltable): the derived calibration table, if caltable==True
+ """
+
+ if use_grad and (method=='phase' or method=='amp'):
+ raise Exception("errfunc_grad in self_cal only works with method=='both'!")
+
+ if len(sites) == 0:
+ print("No stations specified in self cal: defaulting to calibrating all !")
+ sites = list(set(scan['t1']).union(set(scan['t2'])))
+
+ if len(V_scan) < 1:
+ # TODO This is not correct. Need to update to use polarization dictionary
+ uv = np.hstack((scan['u'].reshape(-1, 1), scan['v'].reshape(-1, 1)))
+ A = obsh.ftmatrix(im.psize, im.xdim, im.ydim, uv, pulse=im.pulse)
+ V_scan = np.dot(A, im.imvec)
+
+ # convert gain tolerance to lookup table if needed
+ if type(gain_tol) is not dict:
+ gain_tol = {'default':gain_tol}
+ # convert any 1-sided tolerance to 2-sided tolerance parameterization
+ for (key, val) in gain_tol.items():
+ if type(val) == float or type(val) == int:
+ gain_tol[key] = [val, val]
+
+ # create a dictionary to keep track of gains
+ tkey = {b: a for a, b in enumerate(sites)}
+
+ # make a list of gain keys that relates scan bl gains to solved site ones
+ # -1 means that this station does not have a gain that is being solved for
+ g1_keys = []
+ g2_keys = []
+ for row in scan:
+ try:
+ g1_keys.append(tkey[row['t1']])
+ except KeyError:
+ g1_keys.append(-1)
+ try:
+ g2_keys.append(tkey[row['t2']])
+ except KeyError:
+ g2_keys.append(-1)
+
+ # no sites to calibrate on this scan!
+ if np.all(g1_keys == -1) and np.all(g2_keys == -1):
+ return scan
+
+ # get scan visibilities of the specified polarization
+ vis = scan[ehc.vis_poldict[pol]]
+ sigma = scan[ehc.sig_poldict[pol]]
+
+ if method == 'amp':
+ if debias:
+ vis = obsh.amp_debias(np.abs(vis), np.abs(sigma))
+ else:
+ vis = np.abs(vis)
+
+ sigma_inv = 1.0 / np.sqrt(sigma**2 + (pad_amp * np.abs(vis))**2)
+
+ # initial guess for gains
+ gpar_guess = np.ones(len(sites), dtype=np.complex128).view(dtype=np.float64)
+
+ # error function
+ def errfunc(gpar):
+ return errfunc_full(gpar, vis, V_scan, sigma_inv, gain_tol, sites, g1_keys, g2_keys, method)
+
+ def errfunc_grad(gpar):
+ return errfunc_grad_full(gpar, vis, V_scan, sigma_inv, gain_tol, sites, g1_keys, g2_keys, method)
+
+ # use gradient descent to find the gains
+ # minimizer params
+ if minimizer_method=='L-BFGS-B':
+ optdict = {'maxiter': MAXIT,
+ 'ftol': STOP, 'gtol': STOP,
+ 'maxcor': NHIST, 'maxls': MAXLS}
+ else:
+ optdict = {'maxiter': MAXIT}
+
+ if use_grad:
+ res = opt.minimize(errfunc, gpar_guess, method=minimizer_method, options=optdict, jac=errfunc_grad)
+ else:
+ res = opt.minimize(errfunc, gpar_guess, method=minimizer_method, options=optdict)
+
+ # save the solution
+ g_fit = res.x.view(np.complex128)
+
+ if show_solution:
+ print(np.abs(g_fit))
+
+ if method == "phase":
+ g_fit = g_fit / np.abs(g_fit)
+ if method == "amp":
+ g_fit = np.abs(np.real(g_fit))
+
+ g_fit = np.append(g_fit, 1.)
+
+ # Derive a calibration table or apply the solution to the scan
+ if caltable:
+ allsites = list(set(scan['t1']).union(set(scan['t2'])))
+
+ caldict = {}
+ for site in allsites:
+ if site in sites:
+ site_key = tkey[site]
+ else:
+ site_key = -1
+
+ # TODO: ANDREW - this has been changed
+ # We will *always* set the R and L gain corrections to be equal in self calibration,
+ # to avoid breaking polarization consistency relationships
+ if apply_singlepol:
+ if pol=='RR':
+ rscale = g_fit[site_key]**-1
+ lscale = np.ones(g_fit[site_key].shape)
+ elif pol=='LL':
+ rscale = np.ones(g_fit[site_key].shape)
+ lscale = g_fit[site_key]**-1
+ else:
+ rscale = g_fit[site_key]**-1
+ lscale = g_fit[site_key]**-1
+
+ # TODO: we may want to give two entries for the start/stop times
+ # when a non-zero interval is used
+ caldict[site] = np.array((scan['time'][0], rscale, lscale), dtype=ehc.DTCAL)
+
+ out = caldict
+
+ else:
+ g1_fit = g_fit[g1_keys]
+ g2_fit = g_fit[g2_keys]
+ gij_inv = (g1_fit * g2_fit.conj())**(-1)
+
+ if polrep == 'stokes':
+ # gain factors
+ g1_fit = g_fit[g1_keys]
+ g2_fit = g_fit[g2_keys]
+ gij_inv = (g1_fit * g2_fit.conj())**(-1)
+
+ # scale visibilities
+ for vistype in ['vis', 'qvis', 'uvis', 'vvis']:
+ scan[vistype] *= gij_inv
+ # scale sigmas
+ for sigtype in ['sigma', 'qsigma', 'usigma', 'vsigma']:
+ scan[sigtype] *= np.abs(gij_inv)
+
+ elif polrep == 'circ':
+ if apply_singlepol: #scale only solved polarization
+ if pol=='RR':
+ grr_inv = (g1_fit * g2_fit.conj())**(-1)
+ gll_inv = np.ones(g1_fit.shape)
+ grl_inv = (g1_fit)**(-1)
+ glr_inv = (g2_fit.conj())**(-1)
+
+ elif pol=='LL':
+ grr_inv = np.ones(g1_fit.shape)
+ gll_inv = (g1_fit * g2_fit.conj())**(-1)
+ grl_inv = (g2_fit.conj())**(-1)
+ glr_inv = (g1_fit)**(-1)
+
+ # scale visibilities
+ scan['rrvis'] *= grr_inv
+ scan['llvis'] *= gll_inv
+ scan['rlvis'] *= grl_inv
+ scan['lrvis'] *= glr_inv
+
+ # scale sigmas
+ scan['rrsigma'] *= np.abs(grr_inv)
+ scan['llsigma'] *= np.abs(gll_inv)
+ scan['rlsigma'] *= np.abs(grl_inv)
+ scan['lrsigma'] *= np.abs(glr_inv)
+
+ else: #scale both polarizations
+ gij_inv = (g1_fit * g2_fit.conj())**(-1)
+
+ # scale visibilities
+ for vistype in ['rrvis', 'llvis', 'rlvis', 'lrvis']:
+ scan[vistype] *= gij_inv
+ # scale sigmas
+ for sigtype in ['rrsigma', 'llsigma', 'rlsigma', 'lrsigma']:
+ scan[sigtype] *= np.abs(gij_inv)
+
+ out = scan
+
+ return out
+
+
+def init(x):
+ global counter
+ counter = x
+
+
+def get_selfcal_scan_cal(args):
+ return get_selfcal_scan_cal2(*args)
+
+
+def get_selfcal_scan_cal2(i, n, scan, im, V_scan, sites, polrep, pol, apply_singlepol, method, minimizer_method,
+ show_solution, pad_amp, gain_tol, debias, caltable, msgtype, use_grad):
+ if n > 1:
+ global counter
+ counter.increment()
+ obsh.prog_msg(counter.value(), counter.maxval, msgtype, counter.value() - 1)
+
+ return self_cal_scan(scan, im, V_scan=V_scan, sites=sites, polrep=polrep, pol=pol, apply_singlepol=apply_singlepol,
+ method=method, minimizer_method=minimizer_method,
+ show_solution=show_solution,
+ pad_amp=pad_amp, gain_tol=gain_tol, debias=debias, caltable=caltable,
+ use_grad=use_grad)
+
+# error function
+def errfunc_full(gpar, vis, v_scan, sigma_inv, gain_tol, sites, g1_keys, g2_keys, method):
+ # all the forward site gains (complex)
+ g = gpar.astype(np.float64).view(dtype=np.complex128)
+
+ if method == "phase":
+ g = g / np.abs(g)
+ if method == "amp":
+ g = np.abs(np.real(g))
+
+ # append the default values to g for missing gains
+ g = np.append(g, 1.)
+ g1 = g[g1_keys]
+ g2 = g[g2_keys]
+
+ # build site specific tolerance parameters
+ tol0 = np.array([gain_tol.get(s, gain_tol['default'])[0] for s in sites])
+ tol1 = np.array([gain_tol.get(s, gain_tol['default'])[1] for s in sites])
+
+ if method == 'amp':
+ verr = np.abs(vis) - g1 * g2.conj() * np.abs(v_scan)
+ else:
+ verr = vis - g1 * g2.conj() * v_scan
+
+ nan_mask = [not np.isnan(v) for v in verr]
+ verr = verr[nan_mask]
+
+ # goodness-of-fit for gains
+ chisq = np.sum((verr.real * sigma_inv[nan_mask])**2) + \
+ np.sum((verr.imag * sigma_inv[nan_mask])**2)
+
+ # prior on the gains
+ # don't count the last (default missing site) gain dummy value
+ tolsq = ((np.abs(g[:-1]) > 1) * tol0 + (np.abs(g[:-1]) <= 1) * tol1)**2
+ chisq_g = np.sum(np.log(np.abs(g[:-1]))**2 / tolsq)
+
+ # total chi^2
+ chisqtot = chisq + chisq_g
+ return chisqtot
+
+def errfunc_grad_full(gpar, vis, v_scan, sigma_inv, gain_tol, sites, g1_keys, g2_keys, method):
+ # does not work for method=='phase' or method=='amp'
+ if method=='phase' or method=='amp':
+ raise Exception("errfunc_grad in self_cal only works with method=='both'!")
+
+ # all the forward site gains (complex)
+ g = gpar.astype(np.float64).view(dtype=np.complex128)
+ gr = np.real(g)
+ gi = np.imag(g)
+
+ # build site specific tolerance parameters
+ tol0 = np.array([gain_tol.get(s, gain_tol['default'])[0] for s in sites])
+ tol1 = np.array([gain_tol.get(s, gain_tol['default'])[1] for s in sites])
+
+ # append the default values to g for missing gains
+ g = np.append(g, 1.)
+ g1 = g[g1_keys]
+ g2 = g[g2_keys]
+
+ g1r = np.real(g1)
+ g1i = np.imag(g1)
+ g2r = np.real(g2)
+ g2i = np.imag(g2)
+
+ v_scan_sq = v_scan*v_scan.conj()
+ g1sq = g1*(g1.conj())
+ g2sq = g2*(g2.conj())
+
+ ###################################
+ # data term chi^2 derivitive
+ ###################################
+
+ # chi^2 term gradients
+ dchisq_dg1r = (-g2.conj()*vis.conj()*v_scan - g2*vis*v_scan.conj() + 2*g1r*g2sq*v_scan_sq)
+ dchisq_dg1i = (-1j*g2.conj()*vis.conj()*v_scan + 1j*g2*vis*v_scan.conj() + 2*g1i*g2sq*v_scan_sq)
+
+ dchisq_dg2r = (-g1*vis.conj()*v_scan - g1.conj()*vis*v_scan.conj() + 2*g2r*g1sq*v_scan_sq)
+ dchisq_dg2i = (1j*g1*vis.conj()*v_scan - 1j*g1.conj()*vis*v_scan.conj() + 2*g2i*g1sq*v_scan_sq)
+
+
+ dchisq_dg1r *= ((sigma_inv)**2)
+ dchisq_dg1i *= ((sigma_inv)**2)
+ dchisq_dg2r *= ((sigma_inv)**2)
+ dchisq_dg2i *= ((sigma_inv)**2)
+
+ # same masking function as in errfunc
+ # preserve length of dchisq arrays
+ verr = vis - g1 * g2.conj() * v_scan
+ nan_mask = np.isnan(verr)
+
+ dchisq_dg1r[nan_mask] = 0
+ dchisq_dg1i[nan_mask] = 0
+ dchisq_dg2r[nan_mask] = 0
+ dchisq_dg2i[nan_mask] = 0
+
+ # derivitives of real and imaginary gains
+ dchisq_dgr = np.zeros(len(gpar)//2) #len(gpar) must be even
+ dchisq_dgi = np.zeros(len(gpar)//2)
+
+ # TODO faster than a for loop?
+ for i in range(len(gpar)//2):
+ g1idx = np.argwhere(np.array(g1_keys)==i)
+ g2idx = np.argwhere(np.array(g2_keys)==i)
+
+ dchisq_dgr[i] = np.sum(dchisq_dg1r[g1idx]) + np.sum(dchisq_dg2r[g2idx])
+ dchisq_dgi[i] = np.sum(dchisq_dg1i[g1idx]) + np.sum(dchisq_dg2i[g2idx])
+
+ ###################################
+ # prior term chi^2 derivitive
+ ###################################
+
+ # NOTE this derivitive doesn't account for possible sharp change in tol at g=1
+ gsq = np.abs(g[:-1])**2 # don't count default missing site dummy value
+ tolsq = ((np.abs(g[:-1]) > 1) * tol0 + (np.abs(g[:-1]) <= 1) * tol1)**2
+
+ dchisqg_dgr = gr*np.log(gsq)/gsq/tolsq
+ dchisqg_dgi = gi*np.log(gsq)/gsq/tolsq
+
+ # total derivative
+ dchisqtot_dgr = dchisq_dgr + dchisqg_dgr
+ dchisqtot_dgi = dchisq_dgi + dchisqg_dgi
+
+ # interleave final derivs
+ dchisqtot_dgpar = np.zeros(len(gpar))
+ dchisqtot_dgpar[0::2] = dchisqtot_dgr
+ dchisqtot_dgpar[1::2] = dchisqtot_dgi
+
+ # any imaginary parts??? should all be real
+ dchisqtot_dgpar = np.real(dchisqtot_dgpar)
+
+ return dchisqtot_dgpar
+
+
+
diff --git a/caltable.py b/caltable.py
new file mode 100644
index 00000000..ed47e617
--- /dev/null
+++ b/caltable.py
@@ -0,0 +1,1031 @@
+# caltable.py
+# a calibration table class
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import matplotlib.pyplot as plt
+import os
+import copy
+import scipy.interpolate
+
+import ehtim.io.save
+import ehtim.io.load
+
+import ehtim.const_def as ehc
+import ehtim.observing.obs_helpers as obsh
+
+
+##################################################################################################
+# Caltable object
+##################################################################################################
+
+
+class Caltable(object):
+ """
+ Attributes:
+ source (str): The source name
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ mjd (int): The integer MJD of the observation
+ rf (float): The observation frequency in Hz
+ bw (float): The observation bandwidth in Hz
+ timetype (str): How to interpret tstart and tstop; either 'GMST' or 'UTC'
+
+ tarr (numpy.recarray): The array of telescope data with datatype DTARR
+ tkey (dict): A dictionary of rows in the tarr for each site name
+
+ data (dict): keys are sites in tarr, entries are calibration data tables of type DTCAL
+
+ """
+
+ def __init__(self, ra, dec, rf, bw, datadict, tarr,
+ source=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, timetype='UTC'):
+ """A Calibration Table.
+
+ Args:
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ rf (float): The observation frequency in Hz
+ mjd (int): The integer MJD of the observation
+ bw (float): The observation bandwidth in Hz
+
+ datadict (dict): keys are sites in tarr, entries are data tables of type DTCAL
+ tarr (numpy.recarray): The array of telescope data with datatype DTARR
+
+ source (str): The source name
+ mjd (int): The integer MJD of the observation
+ timetype (str): How to interpret tstart and tstop; either 'GMST' or 'UTC'
+
+ Returns:
+ (Caltable): an Caltable object
+ """
+
+ # Set the various parameters
+ self.source = str(source)
+ self.ra = float(ra)
+ self.dec = float(dec)
+ self.rf = float(rf)
+ self.bw = float(bw)
+ self.mjd = int(mjd)
+
+ if timetype not in ['GMST', 'UTC']:
+ raise Exception("timetype must by 'GMST' or 'UTC'")
+ self.timetype = timetype
+
+ # Dictionary of array indices for site names
+ self.tarr = tarr
+ self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))}
+
+ # Save the data
+ self.data = datadict
+
+ def copy(self):
+ """Copy the observation object.
+
+ Args:
+
+ Returns:
+ (Caltable): a copy of the Caltable object.
+ """
+ new_caltable = Caltable(self.ra, self.dec, self.rf, self.bw, self.data, self.tarr,
+ source=self.source, mjd=self.mjd, timetype=self.timetype)
+ return new_caltable
+
+ def plot_dterms(self, sites='all', label=None, legend=True, clist=ehc.SCOLORS,
+ rangex=False, rangey=False, markersize=2 * ehc.MARKERSIZE,
+ show=True, grid=True, export_pdf=""):
+ """Make a plot of the D-terms.
+
+ Args:
+ sites (list) : list of sites to plot
+ label (str) : title for plot
+ legend (bool) : add telescope legend or not
+ clist (list) : list of colors for different stations
+ rangex (list) : lower and upper x-axis limits
+ rangey (list) : lower and upper y-axis limits
+ markersize (float) : marker size
+ show (bool) : display the plot or not
+ grid (bool) : add a grid to the plot or not
+ export_pdf (str) : save a pdf file to this path
+
+ Returns:
+ matplotlib.axes
+ """
+ # sites
+ if sites in ['all' or 'All'] or sites == []:
+ sites = list(self.data.keys())
+
+ if not isinstance(sites, list):
+ sites = [sites]
+
+ keys = [self.tkey[site] for site in sites]
+
+ axes = plot_tarr_dterms(self.tarr, keys=keys, label=label, legend=legend, clist=clist,
+ rangex=rangex, rangey=rangey, markersize=markersize,
+ show=show, grid=grid, export_pdf=export_pdf)
+
+ return axes
+
+ def plot_gains(self, sites, gain_type='amp', pol='R', label=None,
+ ang_unit='deg', timetype=False, yscale='log', legend=True,
+ clist=ehc.SCOLORS, rangex=False, rangey=False, markersize=[ehc.MARKERSIZE],
+ show=True, grid=False, axislabels=True, axis=False, export_pdf=""):
+ """Plot gains on multiple sites vs time.
+ Args:
+ sites (list): a list of site names for which to plot gains. Empty list is all sites.
+ gain_type (str): 'amp' or 'phase'
+ pol str(str): 'R' or 'L'
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+ yscale (str): 'log' or 'lin',
+ clist (list): list of colors for the plot
+ label (str): base label for legend
+
+ rangex (list): [xmin, xmax] x-axis (time) limits
+ rangey (list): [ymin, ymax] y-axis (gain) limits
+
+ legend (bool): Plot legend if True
+ grid (bool): Plot gridlines if True
+ axislabels (bool): Show axis labels if True
+ show (bool): Display the plot if true
+ axis (matplotlib.axes.Axes): add plot to this axis
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with the plot
+ """
+
+ colors = iter(clist)
+
+ if timetype is False:
+ timetype = self.timetype
+ if timetype not in ['GMST', 'UTC', 'utc', 'gmst']:
+ raise Exception("timetype should be 'GMST' or 'UTC'!")
+ if gain_type not in ['amp', 'phase']:
+ raise Exception("gain_type must be 'amp' or 'phase' ")
+ if pol not in ['R', 'L', 'both']:
+ raise Exception("pol must be 'R' or 'L'")
+
+ if ang_unit == 'deg':
+ angle = ehc.DEGREE
+ else:
+ angle = 1.0
+
+ # axis
+ if axis:
+ x = axis
+ else:
+ fig = plt.figure()
+ x = fig.add_subplot(1, 1, 1)
+
+ # sites
+ if sites in ['all' or 'All'] or sites == []:
+ sites = sorted(list(self.data.keys()))
+
+ if not isinstance(sites, list):
+ sites = [sites]
+
+ if len(markersize) == 1:
+ markersize = markersize * np.ones(len(sites))
+
+ # plot gain on each site
+ tmins = tmaxes = gmins = gmaxes = []
+ for s in range(len(sites)):
+ site = sites[s]
+ times = self.data[site]['time']
+ if timetype in ['UTC', 'utc'] and self.timetype == 'GMST':
+ times = obsh.gmst_to_utc(times, self.mjd)
+ elif timetype in ['GMST', 'gmst'] and self.timetype == 'UTC':
+ times = obsh.utc_to_gmst(times, self.mjd)
+ if pol == 'R':
+ gains = self.data[site]['rscale']
+ elif pol == 'L':
+ gains = self.data[site]['lscale']
+
+ if gain_type == 'amp':
+ gains = np.abs(gains)
+ ylabel = r'$|G|$'
+
+ if gain_type == 'phase':
+ gains = np.angle(gains) / angle
+ if ang_unit == 'deg':
+ ylabel = r'arg($|G|$) ($^\circ$)'
+ else:
+ ylabel = r'arg($|G|$) (radian)'
+
+ tmins.append(np.min(times))
+ tmaxes.append(np.max(times))
+ gmins.append(np.min(gains))
+ gmaxes.append(np.max(gains))
+
+ # Plot the data
+ if label is None:
+ bllabel = str(site)
+ else:
+ bllabel = label + ' ' + str(site)
+ plt.plot(times, gains, color=next(colors), marker='o', markersize=markersize[s],
+ label=bllabel, linestyle='none')
+
+ if not rangex:
+ rangex = [np.min(tmins) - 0.2 * np.abs(np.min(tmins)),
+ np.max(tmaxes) + 0.2 * np.abs(np.max(tmaxes))]
+ if np.any(np.isnan(np.array(rangex))):
+ print("Warning: NaN in data x range: specifying rangex to default")
+ rangex = [0, 24]
+ if not rangey:
+ rangey = [np.min(gmins) - 0.2 * np.abs(np.min(gmins)),
+ np.max(gmaxes) + 0.2 * np.abs(np.max(gmaxes))]
+ if np.any(np.isnan(np.array(rangey))):
+ print("Warning: NaN in data x range: specifying rangey to default")
+ rangey = [1.e-2, 1.e2]
+
+ plt.plot(np.linspace(rangex[0], rangex[1], 5), np.ones(5), 'k--')
+ x.set_xlim(rangex)
+ x.set_ylim(rangey)
+
+ # labels
+ if axislabels:
+ x.set_xlabel(self.timetype + ' (hr)')
+ x.set_ylabel(ylabel)
+ plt.title('Caltable gains for %s on day %s' % (self.source, self.mjd))
+
+ if legend:
+ plt.legend()
+
+ if yscale == 'log':
+ x.set_yscale('log')
+ if grid:
+ x.grid()
+ if export_pdf != "" and not axis:
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ return x
+
+ def enforce_positive(self, method='median', min_gain=0.9, sites=[], verbose=True):
+ """Enforce that caltable gains are not low
+ (e.g., that sites are not significantly more sensitive than estimated).
+ By rescaling the entire gain curve to enforce a specified minimum site gain.
+
+ Args:
+ caltab (Caltable): Input Caltable with station gains
+ method (str): 'median', 'mean', or 'min'
+ min_gain (float): Site gains above this value are not modified.
+ sites (list): List of sites to check and adjust. For sites=[], all sites are fixed.
+ verbose (bool): If True, print corrections.
+
+ Returns:
+ (Caltable): Axes object with the plot
+ """
+
+ if len(sites) == 0:
+ sites = self.data.keys()
+
+ caltab_pos = self.copy()
+ for site in self.data.keys():
+ if site not in sites:
+ continue
+ if len(self.data[site]['rscale']) == 0:
+ continue
+
+ if method == 'min':
+ sitemin = np.min([np.abs(self.data[site]['rscale']),
+ np.abs(self.data[site]['lscale'])])
+ elif method == 'mean':
+ sitemin = np.mean([np.abs(self.data[site]['rscale']),
+ np.abs(self.data[site]['lscale'])])
+ elif method == 'median':
+ sitemin = np.median([np.abs(self.data[site]['rscale']),
+ np.abs(self.data[site]['lscale'])])
+ else:
+ print('Method ' + method + ' not recognized!')
+ return caltab_pos
+
+ if sitemin < min_gain:
+ if verbose:
+ print(method + ' gain for ' + site + ' is ' + str(sitemin) + '. Rescaling.')
+ caltab_pos.data[site]['rscale'] /= sitemin
+ caltab_pos.data[site]['lscale'] /= sitemin
+ else:
+ if verbose:
+ print(method + ' gain for ' + site + ' is ' + str(sitemin) + '. Not adjusting.')
+
+ return caltab_pos
+
+ # TODO default extrapolation?
+ def pad_scans(self, maxdiff=60, padtype='median'):
+ """Pad data points around scans.
+
+ Args:
+ maxdiff (float): "scan" separation length (seconds)
+ padtype (str): padding type, 'endval' or 'median'
+
+ Returns:
+ (Caltable): a padded caltable object
+ """
+
+ outdict = {}
+ scopes = list(self.data.keys())
+ for scope in scopes:
+ if np.any(self.data[scope] is None) or len(self.data[scope]) == 0:
+ continue
+
+ caldata = copy.deepcopy(self.data[scope])
+
+ # Gather data into "scans"
+ # TODO we could use a scan table for this as well!
+ gathered_data = []
+ scandata = [caldata[0]]
+ for i in range(1, len(caldata)):
+ if (caldata[i]['time'] - caldata[i - 1]['time']) * 3600 > maxdiff:
+ scandata = np.array(scandata, dtype=ehc.DTCAL)
+ gathered_data.append(scandata)
+ scandata = [caldata[i]]
+ else:
+ scandata.append(caldata[i])
+
+ # This adds the last scan
+ scandata = np.array(scandata)
+ gathered_data.append(scandata)
+
+ # Compute padding values and pad scans
+ for i in range(len(gathered_data)):
+ gg = gathered_data[i]
+
+ medR = np.median(gg['rscale'])
+ medL = np.median(gg['lscale'])
+
+ timepre = gg['time'][0] - maxdiff / 2. / 3600.
+ timepost = gg['time'][-1] + maxdiff / 2. / 3600.
+
+ if padtype == 'median': # pad with median scan value
+ medR = np.median(gg['rscale'])
+ medL = np.median(gg['lscale'])
+ preR = medR
+ postR = medR
+ preL = medL
+ postL = medL
+ elif padtype == 'endval': # pad with endpoints
+ preR = gg['rscale'][0]
+ postR = gg['rscale'][-1]
+ preL = gg['lscale'][0]
+ postL = gg['lscale'][-1]
+ else: # pad with ones
+ preR = 1.
+ postR = 1.
+ preL = 1.
+ postL = 1.
+
+ valspre = np.array([(timepre, preR, preL)], dtype=ehc.DTCAL)
+ valspost = np.array([(timepost, postR, postL)], dtype=ehc.DTCAL)
+
+ gg = np.insert(gg, 0, valspre)
+ gg = np.append(gg, valspost)
+
+ # output data table
+ if i == 0:
+ caldata_out = gg
+ else:
+ caldata_out = np.append(caldata_out, gg)
+
+ try:
+ caldata_out # TODO: refractor to avoid using exception
+ except NameError:
+ print("No gathered_data")
+ else:
+ outdict[scope] = caldata_out
+
+ return Caltable(self.ra, self.dec, self.rf, self.bw, outdict, self.tarr,
+ source=self.source, mjd=self.mjd, timetype=self.timetype)
+
+ def applycal(self, obs, interp='linear', extrapolate=None,
+ force_singlepol=False, copy_closure_tables=True):
+ """Apply the calibration table to an observation.
+
+ Args:
+ obs (Obsdata): The observation with data to be calibrated
+ interp (str): Interpolation method ('linear','nearest','cubic')
+ extrapolate (bool): If True, points outside interpolation range will be extrapolated.
+ force_singlepol (str): If 'L' or 'R', will set opposite polarization gains
+ equal to chosen polarization
+
+ Returns:
+ (Obsdata): the calibrated Obsdata object
+ """
+ if not (self.tarr == obs.tarr).all():
+ raise Exception("The telescope array in the Caltable is not the same as in the Obsdata")
+
+ if extrapolate is True: # extrapolate can be a tuple or numpy array
+ fill_value = "extrapolate"
+ else:
+ fill_value = extrapolate
+
+ obs_orig = obs.copy() # Need to do this before switch_polrep to keep tables
+ orig_polrep = obs.polrep
+ obs = obs.switch_polrep('circ')
+
+ rinterp = {}
+ linterp = {}
+ skipsites = []
+ for s in range(0, len(self.tarr)):
+ site = self.tarr[s]['site']
+
+ try:
+ self.data[site]
+ except KeyError:
+ skipsites.append(site)
+ print("No Calibration Data for %s !" % site)
+ continue
+
+ time_mjd = self.data[site]['time'] / 24.0 + self.mjd
+ rinterp[site] = relaxed_interp1d(time_mjd, self.data[site]['rscale'],
+ kind=interp, fill_value=fill_value, bounds_error=False)
+ linterp[site] = relaxed_interp1d(time_mjd, self.data[site]['lscale'],
+ kind=interp, fill_value=fill_value, bounds_error=False)
+
+ bllist = obs.bllist()
+ datatable = []
+ for bl_obs in bllist:
+ t1 = bl_obs['t1'][0]
+ t2 = bl_obs['t2'][0]
+ time_mjd = bl_obs['time'] / 24.0 + obs.mjd
+
+ if t1 in skipsites:
+ rscale1 = lscale1 = np.array(1.)
+ else:
+ rscale1 = rinterp[t1](time_mjd)
+ lscale1 = linterp[t1](time_mjd)
+ if t2 in skipsites:
+ rscale2 = lscale2 = np.array(1.)
+ else:
+ rscale2 = rinterp[t2](time_mjd)
+ lscale2 = linterp[t2](time_mjd)
+
+ if force_singlepol == 'R':
+ lscale1 = rscale1
+ lscale2 = rscale2
+
+ if force_singlepol == 'L':
+ rscale1 = lscale1
+ rscale2 = lscale2
+
+ rrscale = rscale1 * rscale2.conj()
+ llscale = lscale1 * lscale2.conj()
+ rlscale = rscale1 * lscale2.conj()
+ lrscale = lscale1 * rscale2.conj()
+
+ bl_obs['rrvis'] = (bl_obs['rrvis']) * rrscale
+ bl_obs['llvis'] = (bl_obs['llvis']) * llscale
+ bl_obs['rlvis'] = (bl_obs['rlvis']) * rlscale
+ bl_obs['lrvis'] = (bl_obs['lrvis']) * lrscale
+
+ bl_obs['rrsigma'] = bl_obs['rrsigma'] * np.abs(rrscale)
+ bl_obs['llsigma'] = bl_obs['llsigma'] * np.abs(llscale)
+ bl_obs['rlsigma'] = bl_obs['rlsigma'] * np.abs(rlscale)
+ bl_obs['lrsigma'] = bl_obs['lrsigma'] * np.abs(lrscale)
+
+ if len(datatable):
+ datatable = np.hstack((datatable, bl_obs))
+ else:
+ datatable = bl_obs
+
+ calobs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw,
+ np.array(datatable), obs.tarr,
+ polrep=obs.polrep, scantable=obs.scans, source=obs.source,
+ mjd=obs.mjd, ampcal=obs.ampcal, phasecal=obs.phasecal,
+ opacitycal=obs.opacitycal, dcal=obs.dcal,
+ frcal=obs.frcal, timetype=obs.timetype)
+ calobs = calobs.switch_polrep(orig_polrep)
+
+ if copy_closure_tables:
+ calobs.camp = obs_orig.camp
+ calobs.logcamp = obs_orig.logcamp
+ calobs.cphase = obs_orig.cphase
+
+ return calobs
+
+ def merge(self, caltablelist, interp='linear', extrapolate=1):
+ """Merge the calibration table with a list of other calibration tables
+
+ Args:
+ caltablelist (list): The list of caltables to be merged
+ interp (str): Interpolation method ('linear','nearest','cubic')
+ extrapolate (bool): If True, points outside interpolation range will be extrapolated.
+
+ Returns:
+ (Caltable): the merged Caltable object
+ """
+
+ if extrapolate is True: # extrapolate can be a tuple or numpy array
+ fill_value = "extrapolate"
+ else:
+ fill_value = extrapolate
+
+ if not hasattr(caltablelist, '__iter__'):
+ caltablelist = [caltablelist]
+
+ tarr1 = self.tarr.copy()
+ tkey1 = self.tkey.copy()
+ data1 = self.data.copy()
+ for caltable in caltablelist:
+
+ # TODO check metadata!
+
+ # TODO CHECK ARE THEY ALL REFERENCED TO SAME MJD???
+ tarr2 = caltable.tarr.copy()
+ tkey2 = caltable.tkey.copy()
+ data2 = caltable.data.copy()
+ sites2 = list(data2.keys())
+ sites1 = list(data1.keys())
+ for site in sites2:
+ if site in sites1: # if site in both tables
+
+ # merge the data by interpolating
+ time1 = data1[site]['time']
+ time2 = data2[site]['time']
+
+ rinterp1 = relaxed_interp1d(time1, data1[site]['rscale'],
+ kind=interp, fill_value=fill_value,
+ bounds_error=False)
+ linterp1 = relaxed_interp1d(time1, data1[site]['lscale'],
+ kind=interp, fill_value=fill_value,
+ bounds_error=False)
+ rinterp2 = relaxed_interp1d(time2, data2[site]['rscale'],
+ kind=interp, fill_value=fill_value,
+ bounds_error=False)
+ linterp2 = relaxed_interp1d(time2, data2[site]['lscale'],
+ kind=interp, fill_value=fill_value,
+ bounds_error=False)
+
+ times_merge = np.unique(np.hstack((time1, time2)))
+
+ rscale_merge = rinterp1(times_merge) * rinterp2(times_merge)
+ lscale_merge = linterp1(times_merge) * linterp2(times_merge)
+
+ # put the merged data back in data1
+ # TODO can we do this faster?
+ datatable = []
+ for i in range(len(times_merge)):
+ datatable.append(
+ np.array((times_merge[i], rscale_merge[i], lscale_merge[i]),
+ dtype=ehc.DTCAL))
+ data1[site] = np.array(datatable)
+
+ # sites not in both caltables
+ else:
+ if site not in tkey1.keys():
+ tarr1 = np.append(tarr1, tarr2[tkey2[site]])
+ data1[site] = data2[site]
+
+ # update tkeys every time
+ tkey1 = {tarr1[i]['site']: i for i in range(len(tarr1))}
+
+ new_caltable = Caltable(self.ra, self.dec, self.rf, self.bw, data1, tarr1,
+ source=self.source, mjd=self.mjd, timetype=self.timetype)
+
+ return new_caltable
+
+ def save_txt(self, obs, datadir='.', sqrt_gains=False):
+ """Saves a Caltable object to text files in the given directory
+ Args:
+ obs (Obsdata): The observation object associated with the Caltable
+ datadir (str): directory to save caltable in
+ sqrt_gains (bool): If True, we square gains before saving.
+
+ Returns:
+ """
+
+ return save_caltable(self, obs, datadir=datadir, sqrt_gains=sqrt_gains)
+
+ def scan_avg(self, obs, incoherent=True):
+ """average the gains across scans.
+
+ Args:
+ obs (ehtim.Obsdata) : input observation
+ incoherent (bool) : True to average gain amps, False to average amps+phase
+
+ Returns:
+ (Caltable): the averaged Caltable object
+ """
+ sites = self.data.keys()
+ ntele = len(sites)
+
+ datatables = {}
+
+ # iterate over each site
+ for s in range(0, ntele):
+ site = sites[s]
+
+ # make a list of times that is the same value for all points in the same scan
+ times = self.data[site]['time']
+ times_stable = times.copy()
+ obs.add_scans()
+ scans = obs.scans
+ for j in range(len(times_stable)):
+ for scan in scans:
+ if scan[0] <= times_stable[j] and scan[1] >= times_stable[j]:
+ times_stable[j] = scan[0]
+ break
+
+ datatable = []
+ for scan in scans:
+ gains_l = self.data[site]['lscale']
+ gains_r = self.data[site]['rscale']
+
+ # if incoherent average then average the magnitude of gains
+ if incoherent:
+ gains_l = np.abs(gains_l)
+ gains_r = np.abs(gains_r)
+
+ # average the gains
+ gains_l_avg = np.mean(gains_l[np.array(times_stable == scan[0])])
+ gains_r_avg = np.mean(gains_r[np.array(times_stable == scan[0])])
+
+ # add them to a new datatable
+ datatable.append(np.array((scan[0], gains_r_avg, gains_l_avg), dtype=ehc.DTCAL))
+
+ datatables[site] = np.array(datatable)
+
+ if len(datatables) > 0:
+ caltable = Caltable(obs.ra, obs.dec, obs.rf,
+ obs.bw, datatables, obs.tarr, source=obs.source,
+ mjd=obs.mjd, timetype=obs.timetype)
+ else:
+ caltable = False
+
+ return caltable
+
+ def invert_gains(self):
+
+ sites = self.data.keys()
+
+ for site in sites:
+ self.data[site]['rscale'] = 1 / self.data[site]['rscale']
+ self.data[site]['lscale'] = 1 / self.data[site]['lscale']
+
+ return self
+
+
+def load_caltable(obs, datadir, sqrt_gains=False):
+ """Load apriori Caltable object from text files in the given directory
+ Args:
+ obs (Obsdata): The observation object associated with the Caltable
+ datadir (str): directory to save caltable in
+ sqrt_gains (bool): If True, we take the sqrt of table gains before loading.
+
+ Returns:
+ (Caltable): a caltable object
+ """
+
+ tarr = obs.tarr
+ array_filename = datadir + '/array.txt'
+ if os.path.exists(array_filename):
+ tarr = ehtim.io.load.load_array_txt(array_filename).tarr
+
+ datatables = {}
+ for s in range(0, len(tarr)):
+
+ site = tarr[s]['site']
+ filename = os.path.join(datadir, obs.source + '_' + site + '.txt')
+ try:
+ data = np.loadtxt(filename, dtype=bytes).astype(str)
+ except IOError:
+ try:
+ filename = datadir + site + '.txt'
+ data = np.loadtxt(filename, dtype=bytes).astype(str)
+ except IOError:
+ continue
+
+
+ datatable = []
+
+ for row in data:
+
+ time = (float(row[0]) - obs.mjd) * 24.0 # time is given in mjd
+
+ if len(row) == 3:
+ rscale = float(row[1])
+ lscale = float(row[2])
+ elif len(row) == 5:
+ rscale = float(row[1]) + 1j * float(row[2])
+ lscale = float(row[3]) + 1j * float(row[4])
+ else:
+ raise Exception("cannot load caltable -- format unknown!")
+ if sqrt_gains:
+ rscale = rscale**.5
+ lscale = lscale**.5
+ datatable.append(np.array((time, rscale, lscale), dtype=ehc.DTCAL))
+
+ datatables[site] = np.array(datatable)
+ if len(datatables) > 0:
+ caltable = Caltable(obs.ra, obs.dec, obs.rf, obs.bw, datatables, tarr,
+ source=obs.source, mjd=obs.mjd, timetype=obs.timetype)
+ else:
+ print("COULD NOT FIND CALTABLE IN DIRECTORY %s" % datadir)
+ caltable = False
+ return caltable
+
+
+def save_caltable(caltable, obs, datadir='.', sqrt_gains=False):
+ """Saves a Caltable object to text files in the given directory
+ Args:
+ obs (Obsdata): The observation object associated with the Caltable
+ datadir (str): directory to save caltable in
+ sqrt_gains (bool): If True, we square gains before saving.
+
+ Returns:
+ """
+
+ if not os.path.exists(datadir):
+ os.makedirs(datadir)
+
+ ehtim.io.save.save_array_txt(obs.tarr, datadir + '/array.txt')
+
+ datatables = caltable.data
+ src = caltable.source
+ for site_info in caltable.tarr:
+ site = site_info['site']
+
+ if len(datatables.get(site, [])) == 0:
+ continue
+
+ filename = datadir + '/' + src + '_' + site + '.txt'
+ outfile = open(filename, 'w')
+ site_data = datatables[site]
+ for entry in site_data:
+ time = entry['time'] / 24.0 + obs.mjd
+
+ if sqrt_gains:
+ rscale = np.square(entry['rscale'])
+ lscale = np.square(entry['lscale'])
+ else:
+ rscale = entry['rscale']
+ lscale = entry['lscale']
+
+ rreal = float(np.real(rscale))
+ rimag = float(np.imag(rscale))
+ lreal = float(np.real(lscale))
+ limag = float(np.imag(lscale))
+ outline = (str(float(time)) + ' ' +
+ str(float(rreal)) + ' ' + str(float(rimag)) + ' ' +
+ str(float(lreal)) + ' ' + str(float(limag)) + '\n')
+ outfile.write(outline)
+ outfile.close()
+
+ return
+
+
+def make_caltable(obs, gains, sites, times):
+ """Create a Caltable object for an observation
+ Args:
+ obs (Obsdata): The observation object associated with the Caltable
+ gains (list): list of gains (?? format ??)
+ sites (list): list of sites
+ times (list): list of times
+
+ Returns:
+ (Caltable): a caltable object
+ """
+ ntele = len(sites)
+ ntimes = len(times)
+
+ datatables = {}
+ for s in range(0, ntele):
+ datatable = []
+ for t in range(0, ntimes):
+ gain = gains[s * ntele + t]
+ datatable.append(np.array((times[t], gain, gain), dtype=ehc.DTCAL))
+ datatables[sites[s]] = np.array(datatable)
+ if len(datatables) > 0:
+ caltable = Caltable(obs.ra, obs.dec, obs.rf,
+ obs.bw, datatables, obs.tarr, source=obs.source,
+ mjd=obs.mjd, timetype=obs.timetype)
+ else:
+ caltable = False
+
+ return caltable
+
+
+def relaxed_interp1d(x, y, **kwargs):
+ try:
+ len(x)
+ except TypeError:
+ x = np.asarray([x])
+ y = np.asarray([y]) # allows to run on a single float number
+ if len(x) == 1:
+ x = np.array([-0.5, 0.5]) + x[0]
+ y = np.array([1.0, 1.0]) * y[0]
+ return scipy.interpolate.interp1d(x, y, **kwargs)
+
+
+def plot_tarr_dterms(tarr, keys=None, label=None, legend=True, clist=ehc.SCOLORS,
+ rangex=False, rangey=False, markersize=2 * ehc.MARKERSIZE,
+ show=True, grid=True, export_pdf="", auto_order=True):
+
+ if auto_order:
+ # Ensure that the plot will put the stations in alphabetical order
+ keys = np.argsort(tarr['site']) # range(len(tarr))
+ else:
+ keys = range(len(tarr))
+
+ colors = iter(clist)
+
+ if export_pdf != "":
+ fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True, sharex=True, figsize=(16, 8))
+ else:
+ fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True, sharex=True)
+
+ for key in keys:
+
+ # get the label
+ site = str(tarr[key]['site'])
+ if label is None:
+ bllabel = str(site)
+ else:
+ bllabel = label + ' ' + str(site)
+ color = next(colors)
+
+ axes[0].plot(np.real(tarr[key]['dr']), np.imag(tarr[key]['dr']),
+ color=color, marker='o', markersize=markersize,
+ label=bllabel, linestyle='none')
+ axes[0].set_title("Right D-terms")
+ axes[0].set_xlabel("Real")
+ axes[0].set_ylabel("Imaginary")
+
+ axes[1].plot(np.real(tarr[key]['dl']), np.imag(tarr[key]['dl']),
+ color=color, marker='o', markersize=markersize,
+ label=bllabel, linestyle='none')
+ axes[1].set_title("Left D-terms")
+ axes[1].set_xlabel("Real")
+ axes[1].set_ylabel("Imaginary")
+
+ axes[0].axhline(y=0, color='k')
+ axes[0].axvline(x=0, color='k')
+ axes[1].axhline(y=0, color='k')
+ axes[1].axvline(x=0, color='k')
+
+ if grid:
+ axes[0].grid()
+ axes[1].grid()
+
+ if rangex:
+ axes[0].set_xlim(rangex)
+ axes[1].set_xlim(rangex)
+
+ if rangey:
+ axes[0].set_ylim(rangey)
+ axes[1].set_ylim(rangey)
+
+ if legend:
+ axes[1].legend(loc='center left', bbox_to_anchor=(1, 0.5))
+
+ if export_pdf != "":
+ fig.savefig(export_pdf, bbox_inches='tight')
+
+ return axes
+
+
+def plot_compare_gains(caltab1, caltab2, obs, sites='all', pol='R', gain_type='amp', ang_unit='deg',
+ scan_avg=True, site_name_dict=None, fontsize=13, legend_fontsize=13,
+ yscale='log', legend=True, clist=ehc.SCOLORS,
+ rangex=False, rangey=False, scalefac=[0.9, 1.1],
+ markersize=[2 * ehc.MARKERSIZE], show=True, grid=False,
+ axislabels=True, remove_ticks=False, axis=False,
+ export_pdf=""):
+
+ colors = iter(clist)
+
+ if ang_unit == 'deg':
+ angle = ehc.DEGREE
+ else:
+ angle = 1.0
+
+ # axis
+ if axis:
+ x = axis
+ else:
+ fig = plt.figure()
+ x = fig.add_subplot(1, 1, 1)
+
+ if scan_avg:
+ caltab1 = caltab1.scan_avg(obs, incoherent=True)
+ caltab2 = caltab2.scan_avg(obs, incoherent=True)
+
+ # sites
+ if sites in ['all' or 'All'] or sites == []:
+ sites = list(set(caltab1.data.keys()).intersection(caltab2.data.keys()))
+
+ if not isinstance(sites, list):
+ sites = [sites]
+
+ if site_name_dict is None:
+ print('hi')
+ site_name_dict = {}
+ for site in sites:
+ site_name_dict[site] = site
+
+ if len(markersize) == 1:
+ markersize = markersize * np.ones(len(sites))
+
+ maxgain = 0.0
+ mingain = 10000
+
+ for s in range(len(sites)):
+
+ site = sites[s]
+ if pol == 'R':
+ gains1 = caltab1.data[site]['rscale']
+ gains2 = caltab2.data[site]['rscale']
+ elif pol == 'L':
+ gains1 = caltab1.data[site]['lscale']
+ gains2 = caltab2.data[site]['lscale']
+
+ if gain_type == 'amp':
+ gains1 = np.abs(gains1)
+ gains2 = np.abs(gains2)
+ ylabel = 'Amplitudes' # r'$|G|$'
+
+ if gain_type == 'phase':
+ gains1 = np.angle(gains1) / angle
+ gains2 = np.angle(gains2) / angle
+ if ang_unit == 'deg':
+ ylabel = r'arg($|G|$) ($^\circ$)'
+ else:
+ ylabel = 'Phases (radian)' # r'arg($|G|$) (radian)'
+
+ # print a line
+ maxgain = np.nanmax([maxgain, np.nanmax(gains1), np.nanmax(gains2)])
+ mingain = np.nanmin([mingain, np.nanmin(gains1), np.nanmin(gains2)])
+
+ # mark the gains on the plot
+ plt.plot(gains1, gains2, marker='.', linestyle='None', color=next(
+ colors), markersize=markersize[s], label=site_name_dict[site])
+
+ plt.xticks(fontsize=fontsize)
+ plt.yticks(fontsize=fontsize)
+
+ plt.axes().set_aspect('equal')
+
+ if rangex:
+ x.set_xlim(rangex)
+ else:
+ x.set_xlim([mingain * scalefac[0], maxgain * scalefac[1]])
+ if rangey:
+ x.set_ylim(rangey)
+ else:
+ x.set_ylim([mingain * scalefac[0], maxgain * scalefac[1]])
+
+ plt.plot([mingain * scalefac[0], maxgain * scalefac[1]],
+ [mingain * scalefac[0], maxgain * scalefac[1]], 'grey', linewidth=1)
+
+ # labels
+ if axislabels:
+ x.set_xlabel('Ground Truth Gain ' + ylabel, fontsize=fontsize)
+ x.set_ylabel('Recovered Gain ' + ylabel, fontsize=fontsize)
+ else:
+ x.tick_params(axis="y", direction="in", pad=-30)
+ x.tick_params(axis="x", direction="in", pad=-18)
+
+ if remove_ticks:
+ plt.setp(x.get_xticklabels(), visible=False)
+ plt.setp(x.get_yticklabels(), visible=False)
+
+ if legend:
+ plt.legend(frameon=False, fontsize=legend_fontsize)
+
+ if yscale == 'log':
+ x.set_yscale('log')
+ x.set_xscale('log')
+ if grid:
+ x.grid()
+ if export_pdf != "" and not axis:
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ return x
diff --git a/const_def.py b/const_def.py
new file mode 100644
index 00000000..67ed1f28
--- /dev/null
+++ b/const_def.py
@@ -0,0 +1,377 @@
+# const_def.py
+# useful constants and definitions
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+from packaging import version
+
+from ehtim.observing.pulses import trianglePulse2D
+
+mpl.rc('font', **{'family': 'serif', 'size': 12})
+
+EP = 1.0e-10
+C = 299792458.0
+DEGREE = 3.141592653589/180.0
+HOUR = 15.0*DEGREE
+RADPERAS = DEGREE/3600.0
+RADPERUAS = RADPERAS*1.e-6
+
+# Default Parameters
+SOURCE_DEFAULT = "SgrA"
+RA_DEFAULT = 17.761122472222223
+DEC_DEFAULT = -28.992189444444445
+RF_DEFAULT = 230e9
+MJD_DEFAULT = 51544
+PULSE_DEFAULT = trianglePulse2D
+
+# movie parameters
+INTERP_DEFAULT = 'linear'
+BOUNDS_ERROR = True # When False, movie will return NEAREST NEIGHBOR frames
+# for times beyond [movie.start_hr, movie.stop_hr]
+
+# Telescope elevation cuts (degrees)
+ELEV_LOW = 10.0
+ELEV_HIGH = 85.0
+
+TAUDEF = 0.1 # Default Optical Depth
+GAINPDEF = 0.1 # Default rms of gain errors
+DTERMPDEF = 0.05 # Default rms of D-term errors
+
+# Sgr A* Kernel Values (Bower et al., in uas/cm^2)
+FWHM_MAJ = 1.309 * 1000 # in uas
+FWHM_MIN = 0.64 * 1000
+POS_ANG = 78 # in degree, E of N
+
+# FFT & NFFT options
+NFFT_KERSIZE_DEFAULT = 20
+GRIDDER_P_RAD_DEFAULT = 2
+GRIDDER_CONV_FUNC_DEFAULT = 'gaussian'
+FFT_PAD_DEFAULT = 2
+FFT_INTERP_DEFAULT = 3
+
+# Observation recarray datatypes
+DTARR = [('site', 'U32'), ('x', 'f8'), ('y', 'f8'), ('z', 'f8'),
+ ('sefdr', 'f8'), ('sefdl', 'f8'), ('dr', 'c16'), ('dl', 'c16'),
+ ('fr_par', 'f8'), ('fr_elev', 'f8'), ('fr_off', 'f8')]
+
+DTPOL_STOKES = [('time', 'f8'), ('tint', 'f8'),
+ ('t1', 'U32'), ('t2', 'U32'),
+ ('tau1', 'f8'), ('tau2', 'f8'),
+ ('u', 'f8'), ('v', 'f8'),
+ ('vis', 'c16'), ('qvis', 'c16'), ('uvis', 'c16'), ('vvis', 'c16'),
+ ('sigma', 'f8'), ('qsigma', 'f8'), ('usigma', 'f8'), ('vsigma', 'f8')]
+
+DTPOL_CIRC = [('time', 'f8'), ('tint', 'f8'),
+ ('t1', 'U32'), ('t2', 'U32'),
+ ('tau1', 'f8'), ('tau2', 'f8'),
+ ('u', 'f8'), ('v', 'f8'),
+ ('rrvis', 'c16'), ('llvis', 'c16'), ('rlvis', 'c16'), ('lrvis', 'c16'),
+ ('rrsigma', 'f8'), ('llsigma', 'f8'), ('rlsigma', 'f8'), ('lrsigma', 'f8')]
+
+DTAMP = [('time', 'f8'), ('tint', 'f8'),
+ ('t1', 'U32'), ('t2', 'U32'),
+ ('u', 'f8'), ('v', 'f8'),
+ ('amp', 'f8'), ('sigma', 'f8')]
+
+DTBIS = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'),
+ ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), ('u3', 'f8'), ('v3', 'f8'),
+ ('bispec', 'c16'), ('sigmab', 'f8')]
+
+DTCPHASE = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'),
+ ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), ('u3', 'f8'), ('v3', 'f8'),
+ ('cphase', 'f8'), ('sigmacp', 'f8')]
+
+DTCAMP = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'), ('t4', 'U32'),
+ ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'),
+ ('u3', 'f8'), ('v3', 'f8'), ('u4', 'f8'), ('v4', 'f8'),
+ ('camp', 'f8'), ('sigmaca', 'f8')]
+
+DTCPHASEDIAG = [('time', 'f8'), ('cphase', 'f8'), ('sigmacp', 'f8'),
+ ('triangles', 'O'), ('u', 'O'), ('v', 'O'), ('tform_matrix', 'O')]
+
+DTLOGCAMPDIAG = [('time', 'f8'), ('camp', 'f8'), ('sigmaca', 'f8'),
+ ('quadrangles', 'O'), ('u', 'O'), ('v', 'O'), ('tform_matrix', 'O')]
+
+DTCAL = [('time', 'f8'), ('rscale', 'c16'), ('lscale', 'c16')]
+
+DTSCANS = [('time', 'f8'), ('interval', 'f8'), ('startvis', 'f8'), ('endvis', 'f8')]
+
+# Dictionaries for keeping track of polarization fields
+POLDICT_STOKES = {'vis1': 'vis', 'vis2': 'qvis', 'vis3': 'uvis', 'vis4': 'vvis',
+ 'sigma1': 'sigma', 'sigma2': 'qsigma', 'sigma3': 'usigma', 'sigma4': 'vsigma'}
+POLDICT_CIRC = {'vis1': 'rrvis', 'vis2': 'llvis', 'vis3': 'rlvis', 'vis4': 'lrvis',
+ 'sigma1': 'rrsigma', 'sigma2': 'llsigma', 'sigma3': 'rlsigma', 'sigma4': 'lrsigma'}
+vis_poldict = {'I': 'vis', 'Q': 'qvis', 'U': 'uvis', 'V': 'vvis',
+ 'RR': 'rrvis', 'LL': 'llvis', 'RL': 'rlvis', 'LR': 'lrvis'}
+amp_poldict = {'I': 'amp', 'Q': 'qamp', 'U': 'uamp', 'V': 'vamp',
+ 'RR': 'rramp', 'LL': 'llamp', 'RL': 'rlamp', 'LR': 'lramp'}
+sig_poldict = {'I': 'sigma', 'Q': 'qsigma', 'U': 'usigma', 'V': 'vsigma',
+ 'RR': 'rrsigma', 'LL': 'llsigma', 'RL': 'rlsigma', 'LR': 'lrsigma'}
+
+# Observation fields for plotting and retrieving data
+FIELDS = ['time', 'time_utc', 'time_gmst',
+ 'tint', 'u', 'v', 'uvdist',
+ 't1', 't2', 'tau1', 'tau2',
+ 'el1', 'el2', 'hr_ang1', 'hr_ang2', 'par_ang1', 'par_ang2',
+ 'vis', 'amp', 'phase', 'snr',
+ 'qvis', 'qamp', 'qphase', 'qsnr',
+ 'uvis', 'uamp', 'uphase', 'usnr',
+ 'vvis', 'vamp', 'vphase', 'vsnr',
+ 'sigma', 'qsigma', 'usigma', 'vsigma',
+ 'sigma_phase', 'qsigma_phase', 'usigma_phase', 'vsigma_phase',
+ 'psigma_phase', 'msigma_phase',
+ 'pvis', 'pamp', 'pphase', 'psnr',
+ 'evis', 'eamp', 'ephase', 'esnr',
+ 'bvis', 'bamp', 'bphase', 'bsnr',
+ 'm', 'mamp', 'mphase', 'msnr',
+ 'rrvis', 'rramp', 'rrphase', 'rrsnr', 'rrsigma', 'rrsigma_phase',
+ 'llvis', 'llamp', 'llphase', 'llsnr', 'llsigma', 'llsigma_phase',
+ 'rlvis', 'rlamp', 'rlphase', 'rlsnr', 'rlsigma', 'rlsigma_phase',
+ 'lrvis', 'lramp', 'lrphase', 'lrsnr', 'lrsigma', 'lrsigma_phase',
+ 'rrllvis', 'rrllamp', 'rrllphase', 'rrllsnr', 'rrllsigma', 'rrllsigma_phase']
+
+FIELDS_AMPS = ["amp", "qamp", "uamp", "vamp",
+ "pamp", "mamp", "rramp", "llamp", "rlamp", "lramp", "rrllamp"]
+FIELDS_SIGS = ["sigma", "qsigma", "usigma", "vsigma",
+ "psigma", "msigma", "rrsigma", "llsigma", "rlsigma", "lrsigma", "rrllsigma"]
+FIELDS_PHASE = ["phase", "qphase", "uphase", "vphase", "pphase", "mphase",
+ "rrphase", "llphase", "rlphase", "lrphase", "rrllphase"]
+FIELDS_SIGPHASE = ["sigma_phase", "qsigma_phase", "usigma_phase", "vsigma_phase",
+ "psigma_phase", "msigma_phase", "rrsigma_phase", "llsigma_phase",
+ "rlsigma_phase", "lrsigma_phase", "rrllsigma_phase"]
+FIELDS_SNRS = ["snr", "qsnr", "usnr", "vsnr", "psnr", "msnr",
+ "rrsnr", "llsnr", "rlsnr", "lrsnr", "rrllsnr"]
+
+# Plotting
+MARKERSIZE = 3
+
+def show_noblock(pause=0.001):
+ """helper function for image display with different matplotlib versions"""
+
+# this seems to be required for matplotlib version 3.5
+ if version.parse(mpl.__version__) <= version.parse('3.2.2') or version.parse(mpl.__version__) > version.parse('3.6'):
+ plt.show(block=False)
+ else:
+ plt.ion()
+ plt.show()
+ plt.pause(pause)
+ plt.draw()
+ plt.pause(pause)
+
+FIELD_LABELS = {'time': 'Time',
+ 'time_utc': 'Time (UTC)',
+ 'time_gmst': 'Time (GMST)',
+ 'tint': 'Integration Time',
+ 'u': r'$u$',
+ 'v': r'$v$',
+ 'uvdist': r'$u-v$ Distance',
+ 't1': 'Site 1',
+ 't2': 'Site 2',
+ 'tau1': r'$\tau_1$',
+ 'tau2': r'$\tau_2$',
+ 'el1': r'Elevation Angle$_1$',
+ 'el2': r'Elevation Angle$_2$',
+ 'hr_ang1': r'Hour Angle$_1$',
+ 'hr_ang2': r'Hour Angle$_2$',
+ 'par_ang1': r'Parallactic Angle$_1$',
+ 'par_ang2': r'Parallactic Angle$_2$',
+ 'vis': 'Visibility',
+ 'amp': 'Amplitude',
+ 'phase': 'Phase',
+ 'snr': 'SNR',
+ 'qvis': 'Q-Visibility',
+ 'qamp': 'Q-Amplitude',
+ 'qphase': 'Q-Phase',
+ 'qsnr': 'Q-SNR',
+ 'uvis': 'U-Visibility',
+ 'uamp': 'U-Amplitude',
+ 'uphase': 'U-Phase',
+ 'usnr': 'U-SNR',
+ 'vvis': 'V-Visibility',
+ 'vamp': 'V-Amplitude',
+ 'vphase': 'V-Phase',
+ 'vsnr': 'V-SNR',
+ 'sigma': r'$\sigma$',
+ 'qsigma': r'$\sigma_{Q}$',
+ 'usigma': r'$\sigma_{U}$',
+ 'vsigma': r'$\sigma_{V}$',
+ 'sigma_phase': r'$\sigma_{phase}$',
+ 'qsigma_phase': r'$\sigma_{Q phase}$',
+ 'usigma_phase': r'$\sigma_{U phase}$',
+ 'vsigma_phase': r'$\sigma_{V phase}$',
+ 'psigma_phase': r'$\sigma_{P phase}$',
+ 'msigma_phase': r'$\sigma_{m phase}$',
+ 'pvis': r'P-Visibility',
+ 'pamp': r'P-Amplitude',
+ 'pphase': 'P-Phase',
+ 'psnr': 'P-SNR',
+ 'mvis': r'm-Visibility',
+ 'mamp': r'm-Amplitude',
+ 'mphase': 'm-Phase',
+ 'msnr': 'm-SNR',
+ 'evis': r'E-Visibility',
+ 'eamp': r'E-Amplitude',
+ 'ephase': 'E-Phase',
+ 'esnr': 'E-SNR',
+ 'bvis': r'B-Visibility',
+ 'bamp': r'B-Amplitude',
+ 'bphase': 'B-Phase',
+ 'bsnr': 'B-SNR',
+ 'rrvis': r'RR-Visibility',
+ 'rramp': r'RR-Amplitude',
+ 'rrphase': 'RR-Phase',
+ 'rrsnr': 'RR-SNR',
+ 'rrsigma': r'$\sigma_{RR}$',
+ 'rrsigma_phase': r'$\sigma_{RR phase}$',
+ 'llvis': r'LL-Visibility',
+ 'llamp': r'LL-Amplitude',
+ 'llphase': 'LL-Phase',
+ 'llsnr': 'LL-SNR',
+ 'llsigma': r'$\sigma_{LL}$',
+ 'llsigma_phase': r'$\sigma_{LL phase}$',
+ 'rlvis': r'RL-Visibility',
+ 'rlamp': r'RL-Amplitude',
+ 'rlphase': 'RL-Phase',
+ 'rlsnr': 'RL-SNR',
+ 'rlsigma': r'$\sigma_{RL}$',
+ 'rlsigma_phase': r'$\sigma_{RL phase}$',
+ 'lrvis': r'LR-Visibility',
+ 'lramp': r'LR-Amplitude',
+ 'lrphase': 'LR-Phase',
+ 'lrsnr': 'LR-SNR',
+ 'lrsigma': r'$\sigma_{LR}$',
+ 'lrsigma_phase': r'$\sigma_{LR phase}$',
+ 'rrllvis': r'RR/LL-Visibility',
+ 'rrllamp': r'RR/LL-Amplitude',
+ 'rrllphase': 'RR/LL-Phase',
+ 'rrllsnr': 'RR/LL-SNR',
+ 'rrllsigma': r'$\sigma_{RR/LL}$',
+ 'rrllsigma_phase': r'$\sigma_{RR/LL phase}$'}
+
+# Seaborn Colors from Maciek
+SCOLORS = [(0.11764705882352941, 0.5647058823529412, 1.0),
+ (1.0, 0.38823529411764707, 0.2784313725490196),
+ (0.5411764705882353, 0.16862745098039217, 0.8862745098039215),
+ (0.4196078431372549, 0.5568627450980392, 0.13725490196078433),
+ (1.0, 0.6470588235294118, 0.0),
+ (0.5450980392156862, 0.27058823529411763, 0.07450980392156863),
+ (0.0, 0.0, 0.803921568627451),
+ (1.0, 0.0, 0.0),
+ (0.0, 1.0, 1.0),
+ (1.0, 0.0, 1.0),
+ (0.0, 0.39215686274509803, 0.0),
+ (0.8235294117647058, 0.7058823529411765, 0.5490196078431373),
+ (0.0, 0.0, 0.0)]
+
+
+BHIMAGE = [
+ ' .. ',
+ ' ..... . ',
+ ' ........... .... ',
+ ' ............ ........................ ',
+ ' .........................,,******,,..... ',
+ ' .......,,,,..........,,**/(((/*,,... ',
+ ' .....,,,,,**,.......,**/(#%%#/,... ',
+ ' .....,,,,,,**,.....,*/(#%&&%/,. ',
+ ' ......,**,***/*,,,,*/(#%@@*. ',
+ ' ..,....,****////***/((%&@@@#, ',
+ ' ..,,,...****///////((#%&@@@%*. ',
+ ' .,**,,.,*////((((((##%&@@@&/. ',
+ ' .,//*,,,*/(((((((##%%&@@@@%, ,,,,,. .,.. ',
+ ' ,/(//***/(###((##%%%&@@@@#, . ,*/,. ... ',
+ ' .*(((////(#%%###%%&&@@@@@#, ..*(*. .. ',
+ ' ./((((//(#%&%%%%&&@@@@@@%*. .*#/. . ',
+ ' ,(#((#(((#&&&&&&@@@@@@@&(.. .(%* .. ',
+ ' ... ,(#####((#&&@@@@@@@@@@@%/..,##, .. ',
+ ' .... ,(###%%#(#%&@@@@@@@@@@@%/,#&* .. ',
+ ' ..... .(###%%%##%&@@@@@@@@@@@&%, . ',
+ ' ...... ./#%%%%%%##&@@@@@@@@@@@@@%, . ',
+ ' ........ .*%%%%%&%##%@@@@@@@@@@@@@(,. . ',
+ ' ..,,,... .. .,#%%%%%&%%%&@@@@@@@@@@@&%/. . ',
+ ' ..,,,,... ... ../#%%%%&%%%&@@@@@@@@@@@@*. @@ @ @ # ',
+ ' ..,,,,... ......*(##%%&&%%&&@@@@@@@@@@@&%(,. @ @ @ @ # ### ## . ',
+ ' ..,**,,.......,.,/###%%&%#%&@@@@@@@@@@@&%#/,.. @@@@ @@@ @@@ %%% # # # # .. ',
+ ' ..,**,,.......,,*(((#%&&%#%&@@@@@@@@&&&&%(*... @ @ @ @ # # # # .. ',
+ ' .,,*,,,,.......,/(((#%&&%##&@@&@@@&&%%%##(*.. @@ @ @ @@ ## # # # . ',
+ ' ..,,,,,,,.......,/((((%&&%#(%&&&@@%%%###((/*.. . ',
+ ' ...,,.,,,....,,,*/((/(#%&(#%&@@%%%#(///**,.. . ',
+ ' ........,.....,,*////((%&&%###&@&%###(/*,,,,... .. ',
+ ' .............,,**///((#%&&%%#@&%%#((((/*....... ,. ',
+ ' ............,,,**////((#%%&%@@#####((/***,.... .,. ',
+ ' ....,,*****////((#%%&@(((((((/**,,,..... ,, ',
+ ' . ...,**/****/////(#%&@&%##((((((//**,....... ..... .*, ',
+ ' ... ..,*/////***////(#%&%%%###((((///****,,,,.......,,....... .,,. ',
+ ' .... .,*/((//****///((#%#######((((///*****,,,,,,........... ,*. ',
+ ' ..... ..,/(((//****///((#(///((((((///******,,,,,............... .*, ',
+ ' ... ...,*/(#((///////((#(/*****///(((((////***,,,......... .*, ',
+ ' .....,*(#####(((((##%%##((((((////****,,,......... ,*, ',
+ ' .....,,*/(##%%#########(///*******,,,,........... .,*. ',
+ ' .......,,*/(((#########(//**,,,,,,,**,,,... .,*,. ',
+ ' ..........,,,*****/********,,,.... .,*,. ',
+ ' . .... ..**. ',
+ ' . .. ..,,,. ',
+ ' . ........ ',
+ ' ................... ',
+ ' ']
+
+EHTIMAGE = [
+ ' `..----..` ',
+ ' `-/oyhmNNNNMMMMNNNNmhs+:-` ',
+ ' `.+ymNMMMMMMMMMMMMMMMMMMMMMMMMNds/. ',
+ ' `:ymNMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMNdo-` ',
+ ' `-yNMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMNmo. ',
+ ' `/dNMMMMMMMMMMMMMMmdysoo+++dMMMMMMMMMMMMMMMMMMMMNh: :////// ',
+ ' `+mMMMMMMMMMMMNmy+-``` sMMMMMNshNMMMMMMMMMMMMMNh- mNyoooo``` `` ``` ` ``` `ds`` ',
+ ' :dMMMMMMMMMMMho-` .:- :sd- sMMMMNs.-yMMMMMMMMMMMMMMMNy. mN: sd: /ds -ydyhh+` .dyyhhdh: :dMmyy. ',
+ ' `yNMMMMMMMMMmo-` `/ymh:`-hNMM- sMMMMMMmNMMMMMMNmMMMMMMMMMMm/ mMmdddh -mm. .Nm.-Nm-..oMs -Mm:``:Mh``oMy`` ',
+ ' .hMMMMMMMMMNo. :ymMNo` oNMMMM- sMMMMMMMMMMMMMm/`oNMMMMMMMMMNs` mM/...` /My`yM/ +Mmssssh+ -Md .Md` +Ms ',
+ ' -mMMMMMMMMMs. /hNMMN: sMMMMMM- sMMMMMMMMMNMMMNh+mMMMMMMMMMMMMy` mM/.... sNsMo -mm/```-` -Md .Md` +Mh`` ',
+ ' -mMMMMMMMMN/ -hNMMMN: oMMMMMMM: sMMMMMMMMy-sNMMMMMMMMMMMMMMMMMMy ydhyyyy` `yhy .oyhhyy- .ds .hy` `sdhy- ',
+ ' `mMMMMMMMMm- -:::::- -:::::::` sMMMMMMMMmodMMMMMMMMMdMMMMMMMMMMo `````` ` ``` ` ` ``` ',
+ ' oMMMMMMMMm- -::::::. `-::::::::` sMNhNMMMMMMMMMMMMMMmo`/mMMMMMMMMM: ',
+ ' -NMMMMMMMM: +NMMMMMM: /NMMMMMMMM- sMd -+hNMMMMMMMMMMMMm+dNMMMMMMMMMh` ',
+ ' yMMMMMMMMy .NMMMMMMd` hMMMMMMMMM- sMNo-``.+hNMMMMMMMMMMMMMMMMMMMMMMN: .. `.. `:: ',
+ ' `mMMMMMMMN- yMMMMMMMs NMMMMMMMMM- sMMMNho-` .+hNMMMMMMMMMMMMMMMMMMMMy dm: :mh `sy` ',
+ ' -NMMMMMMMm .mMMMMMMM+ `NMMMMMMMMM: sMMMMMMNds:` .+hNMMMNhsshMMMMMMMMMm mN: /Mm -osss+- .o::os:`os``ooooss/ `:osss/` :s:+sss/` ',
+ ' /MMMMMMMMh `////////. //////////` sMm////////. .+Mm: `+MMMMMMMMN` mMhsssshMm +Nh/-/dm: -MNy+/..mM` --:oMd- hNs:-+Nd. oMm+-:hMo ',
+ ' /MMMMMMMMh --------` ----------` sMm--------. `:Mm. /NMMMMMMMN` mMo////sMm `dM. /Md -Mm` .mM` `oms. .Md yM/ oMs /My ',
+ ' -NMMMMMMMm .mNNNNNNN+ `mNNNNNNNNN- sMMNNNNNdy/. `:sdMMMms++sNMMMMMMMMN mN: /Mm yM/` `oMy -Mm` .mM` -hm/ `Nm- .hM: oMo /My ',
+ ' `mMMMMMMMN- `hMMMMMMMs NMMMMMMMMM- sMMMMmy/. `:sdMMMMMMMMMMMMMMMMMMMMh dm: /md .sdyyyds` -md` .dm`.mNhsss+ :hdyyhd+ +N+ /Ns ',
+ ' hMMMMMMMMs :MMMMMMMh` dMMMMMMMMM- sMNy/. `:smMMMMMMMMMMMMMMMMMMMMMMM/ ..` `.. `---` .. .. ....... .--.` `.` `.` ',
+ ' :NMMMMMMMM- oMMMMMMM: /MMMMMMMMM: sMd `:smMMMMMMMMMMMMNymMMMMMMMMMMd` ',
+ ' `yMMMMMMMMd. `///////- `/////////` sMmsmMMMMMMMMMMMMMMm+ /mMMMMMMMMM/ ',
+ ' .NMMMMMMMMd. ....... ........` sMMMMMMMMNyNMMMMMMMMNyNMMMMMMMMMy` `-- ',
+ ' :NMMMMMMMMm- :dNNNNm- sNNNNNNN- sMMMMMMMMs`+NMMMMMMMMMMMMMMMMMMd` `yyyhhhyys` .Nd ',
+ ' /NMMMMMMMMNo` `omMMMm- `hMMMMMM- sMMMMMMMMMmMMMMmsNMMMMMMMMMMMMd. ...hMo... .-/:-` .Nm `-:/-. .-//:- `-:/:. `-:/:.` -..:/:. .-//-. ',
+ ' :mMMMMMMMMMm/` `+dNMm/``yMMMMM- sMMMMMMMMMMMMMd:`+NMMMMMMMMMMy` yM+ `omy++dm: .Nm :dd++yNo -mN++oy` /mds+s+ :mms+ymh. .MmyoodN+ /mh hm/ ',
+ ' .hMMMMMMMMMMd+. `-odms.`/mMMM- sMMMMMMNNMMMMMMNhMMMMMMMMMMNo` yM+ /MN+++sNh .Nm `mMo+++Nm..dmy+-` -mM` `mM- +Ms .Md` .mM`.NNo++oNN ',
+ ' `+mMMMMMMMMMMNy:. `:+: `+hm- sMMMMMd-/mMMMMMMMMMMMMMMMMd: yM+ /Mm:::::- .Nm `mN/:::::` ./sNm--NM` `mM- /Ms .Md` .dM..NN/ ',
+ ' .sNMMMMMMMMMMMNho:``` ` sMMMMMm/omMMMMMMMMMMMMMMN/` yM+ `sNy+//o: .Nd /md+//++ -o///dN: oNd+/++ :mdo/omd. .MNy/+hNo +Nh+//o: ',
+ ' .smMMMMMMMMMMMMMMmhyo+/:---hMMMMMMMMMMMMMMMMMMMMMd+` -/. .:+++:` `/: `:+++:. `:+++:. `:++/. `-++/- .Mh-/+/. `:+++:` ',
+ ' `+dNMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMNh:` .Mh ',
+ ' -odNMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMNh/. `+: ',
+ ' ./sdNMMMMMMMMMMMMMMMMMMMMMMMMMMNho:` ',
+ ' `-/symNNMMMMMMMMMMMNNNdyo/.` ',
+ ' .----::::---.` ']
diff --git a/diagnostics.py b/diagnostics.py
new file mode 100644
index 00000000..f627ab9c
--- /dev/null
+++ b/diagnostics.py
@@ -0,0 +1,89 @@
+# diagnostics.py
+# useful diagnostic tests on images
+#
+# Copyright (C) 2018 Katie Bouman
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+
+
+def sumdown_lin(y, n=16):
+ """Sum segments of a line together to reduce its size
+ """
+ nold = y.shape[0]
+ nnew = n
+
+ xold = np.linspace(0, 1, num=nold, endpoint=False)
+ xnew = np.linspace(0, 1, num=nnew, endpoint=False)
+
+ csum = np.zeros(nnew+1)
+ for i, x in enumerate(xnew):
+ u = np.argmax(xold > x)
+ k = u - 1
+ csum[i] = (np.sum(y[0:k]) +
+ y[k] * (x - xold[k]) / (xold[u] - xold[k]))
+ csum[nnew] = sum(y)
+
+ return np.diff(csum)
+
+
+def sumdown_img(img, n=16):
+ """Summing patches of an image together to reduce its size
+
+ For simplicity, we just pad each side of an image to an
+ integer multiple of n and then down sample but summing up the
+ intensity of the subcells. This is not perfect but is good enough
+ for a proof of concept.
+ """
+ nx = img.shape[0]
+ ny = img.shape[1]
+ mx = (nx - 1)//n + 1
+ my = (ny - 1)//n + 1
+ Nx = mx * n
+ Ny = my * n
+ px = (Nx - nx)//2
+ py = (Ny - ny)//2
+
+ img = np.pad(img, ((px, Nx-nx-px), (py, Ny-ny-py)), 'constant')
+ return img.reshape(n, mx, n, my).sum(axis=(1, 3))
+
+
+def onedimize(imgs, n=16, gt=None):
+ """One-dimensionalize an image by sorting in terms of pixel intensity
+
+ Args:
+ imgs: a python array of two-dimensional numpy arrays
+ n: the number of pixel in both dimensions of the output images
+
+ Return:
+ oneds: a python array of one-dimensional numpy arrays
+ mean: the one-dimensionalized mean image
+ """
+ imgs = [sumdown_img(img, n=n) for img in imgs]
+
+ if gt is None:
+ gt = np.dstack(imgs).mean(axis=2)
+ else:
+ gt = sumdown_img(gt, n=n)
+
+ idxs = np.argsort(-gt.reshape(n*n))
+ return [img.reshape(n*n)[idxs] for img in imgs], gt.reshape(n*n)[idxs]
diff --git a/ehtim/array.py b/ehtim/array.py
index f64d4316..8415f22d 100644
--- a/ehtim/array.py
+++ b/ehtim/array.py
@@ -113,7 +113,7 @@ def obsdata(self, ra, dec, rf, bw, tint, tadv, tstart, tstop,
mjd=ehc.MJD_DEFAULT, timetype='UTC', polrep='stokes',
elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH,
no_elevcut_space=False,
- tau=ehc.TAUDEF, fix_theta_GMST=False):
+ tau=ehc.TAUDEF, fix_theta_GMST=False, reorder=True):
"""Generate u,v points and baseline uncertainties.
Args:
@@ -144,7 +144,7 @@ def obsdata(self, ra, dec, rf, bw, tint, tadv, tstart, tstop,
elevmin=elevmin, elevmax=elevmax,
no_elevcut_space=no_elevcut_space,
timetype=timetype, fix_theta_GMST=fix_theta_GMST)
-
+
uniquetimes = np.sort(np.unique(obsarr['time']))
scans = np.array([[time - 0.5 * tadv, time + 0.5 * tadv] for time in uniquetimes])
source = str(ra) + ":" + str(dec)
@@ -152,7 +152,8 @@ def obsdata(self, ra, dec, rf, bw, tint, tadv, tstart, tstop,
source=source, mjd=mjd, timetype=timetype, polrep=polrep,
ampcal=True, phasecal=True, opacitycal=True,
dcal=True, frcal=True,
- scantable=scans)
+ scantable=scans, reorder=reorder)
+
return obs
def make_subarray(self, sites):
diff --git a/ehtim/const_def.py b/ehtim/const_def.py
index 67ed1f28..4984d700 100644
--- a/ehtim/const_def.py
+++ b/ehtim/const_def.py
@@ -79,20 +79,20 @@
DTPOL_STOKES = [('time', 'f8'), ('tint', 'f8'),
('t1', 'U32'), ('t2', 'U32'),
('tau1', 'f8'), ('tau2', 'f8'),
- ('u', 'f8'), ('v', 'f8'),
+ ('u', 'f8'), ('v', 'f8'), ('w', 'f8'),
('vis', 'c16'), ('qvis', 'c16'), ('uvis', 'c16'), ('vvis', 'c16'),
('sigma', 'f8'), ('qsigma', 'f8'), ('usigma', 'f8'), ('vsigma', 'f8')]
DTPOL_CIRC = [('time', 'f8'), ('tint', 'f8'),
('t1', 'U32'), ('t2', 'U32'),
('tau1', 'f8'), ('tau2', 'f8'),
- ('u', 'f8'), ('v', 'f8'),
+ ('u', 'f8'), ('v', 'f8'), ('w', 'f8'),
('rrvis', 'c16'), ('llvis', 'c16'), ('rlvis', 'c16'), ('lrvis', 'c16'),
('rrsigma', 'f8'), ('llsigma', 'f8'), ('rlsigma', 'f8'), ('lrsigma', 'f8')]
DTAMP = [('time', 'f8'), ('tint', 'f8'),
('t1', 'U32'), ('t2', 'U32'),
- ('u', 'f8'), ('v', 'f8'),
+ ('u', 'f8'), ('v', 'f8'),
('amp', 'f8'), ('sigma', 'f8')]
DTBIS = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'),
@@ -100,11 +100,11 @@
('bispec', 'c16'), ('sigmab', 'f8')]
DTCPHASE = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'),
- ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), ('u3', 'f8'), ('v3', 'f8'),
+ ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), ('u3', 'f8'), ('v3', 'f8'),
('cphase', 'f8'), ('sigmacp', 'f8')]
DTCAMP = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'), ('t4', 'U32'),
- ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'),
+ ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'),
('u3', 'f8'), ('v3', 'f8'), ('u4', 'f8'), ('v4', 'f8'),
('camp', 'f8'), ('sigmaca', 'f8')]
@@ -132,7 +132,7 @@
# Observation fields for plotting and retrieving data
FIELDS = ['time', 'time_utc', 'time_gmst',
- 'tint', 'u', 'v', 'uvdist',
+ 'tint', 'u', 'v', 'w', 'uvdist',
't1', 't2', 'tau1', 'tau2',
'el1', 'el2', 'hr_ang1', 'hr_ang2', 'par_ang1', 'par_ang2',
'vis', 'amp', 'phase', 'snr',
@@ -186,6 +186,7 @@ def show_noblock(pause=0.001):
'tint': 'Integration Time',
'u': r'$u$',
'v': r'$v$',
+ 'w': r'$w$',
'uvdist': r'$u-v$ Distance',
't1': 'Site 1',
't2': 'Site 2',
diff --git a/ehtim/image.py b/ehtim/image.py
index 3dbf1933..73ccafaa 100644
--- a/ehtim/image.py
+++ b/ehtim/image.py
@@ -2303,7 +2303,7 @@ def sample_uv(self, uv, polrep_obs='stokes',
def observe_same_nonoise(self, obs, sgrscat=False, ttype="nfft",
cache=False, fft_pad_factor=2,
- zero_empty_pol=True, verbose=True):
+ zero_empty_pol=True, verbose=True, reorder=True):
"""Observe the image on the same baselines as an existing observation without noise.
Args:
@@ -2327,10 +2327,10 @@ def observe_same_nonoise(self, obs, sgrscat=False, ttype="nfft",
if (np.abs(self.rf - obs.rf) / obs.rf > tolerance):
raise Exception("Image frequency is not the same as observation frequency!")
- if (ttype == 'direct' or ttype == 'fast' or ttype == 'nfft'):
+ if (ttype == 'direct' or ttype == 'fast' or ttype == 'nfft' or ttype == 'DFT' or ttype == 'DFT_i'):
if verbose: print("Producing clean visibilities from image with " + ttype + " FT . . . ")
else:
- raise Exception("ttype=%s, options for ttype are 'direct', 'fast', 'nfft'" % ttype)
+ raise Exception("ttype=%s, options for ttype are 'direct', 'fast', 'nfft', 'DFT', 'DFT_i'" % ttype)
# Copy data to be safe
obsdata = copy.deepcopy(obs.data)
@@ -2361,13 +2361,13 @@ def observe_same_nonoise(self, obs, sgrscat=False, ttype="nfft",
source=self.source, mjd=self.mjd, polrep=obs.polrep,
ampcal=True, phasecal=True, opacitycal=True,
dcal=True, frcal=True,
- timetype=obs.timetype, scantable=obs.scans)
+ timetype=obs.timetype, scantable=obs.scans, reorder=reorder)
return obs_no_noise
def observe_same(self, obs_in,
ttype='nfft', fft_pad_factor=2,
- sgrscat=False, add_th_noise=True,
+ sgrscat=False, add_th_noise=True, th_noise_factor=1,
jones=False, inv_jones=False,
opacitycal=True, ampcal=True, phasecal=True,
frcal=True, dcal=True, rlgaincal=True,
@@ -2379,7 +2379,7 @@ def observe_same(self, obs_in,
dterm_offset=ehc.DTERMPDEF,
rlratio_std=0., rlphase_std=0.,
sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None,
- caltable_path=None, seed=False, verbose=True):
+ caltable_path=None, seed=False, reorder=True, verbose=True):
"""Observe the image on the same baselines as an existing observation object and add noise.
Args:
@@ -2443,7 +2443,7 @@ def observe_same(self, obs_in,
obs = self.observe_same_nonoise(obs_in, sgrscat=sgrscat,ttype=ttype,
cache=False, fft_pad_factor=fft_pad_factor,
- zero_empty_pol=True, verbose=verbose)
+ zero_empty_pol=True, reorder=reorder, verbose=verbose)
# Jones Matrix Corruption & Calibration
if jones:
@@ -2467,7 +2467,7 @@ def observe_same(self, obs_in,
source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal,
dcal=dcal, frcal=frcal,
- timetype=obs.timetype, scantable=obs.scans)
+ timetype=obs.timetype, scantable=obs.scans, reorder=reorder)
if inv_jones:
obsdata = simobs.apply_jones_inverse(obs,
@@ -2478,7 +2478,7 @@ def observe_same(self, obs_in,
source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
ampcal=ampcal, phasecal=phasecal,
opacitycal=True, dcal=True, frcal=True,
- timetype=obs.timetype, scantable=obs.scans)
+ timetype=obs.timetype, scantable=obs.scans, reorder=reorder)
# No Jones Matrices, Add noise the old way
# NOTE There is an asymmetry here - in the old way, we don't offer the ability to
@@ -2489,7 +2489,7 @@ def observe_same(self, obs_in,
print('WARNING: the caltable is only saved if you apply noise with a Jones Matrix')
# TODO -- clean up arguments
- obsdata = simobs.add_noise(obs, add_th_noise=add_th_noise,
+ obsdata = simobs.add_noise(obs, add_th_noise=add_th_noise, th_noise_factor=th_noise_factor,
opacitycal=opacitycal, ampcal=ampcal, phasecal=phasecal,
stabilize_scan_phase=stabilize_scan_phase,
stabilize_scan_amp=stabilize_scan_amp,
@@ -2504,7 +2504,7 @@ def observe_same(self, obs_in,
source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
ampcal=ampcal, phasecal=phasecal,
opacitycal=True, dcal=True, frcal=True,
- timetype=obs.timetype, scantable=obs.scans)
+ timetype=obs.timetype, scantable=obs.scans, reorder=reorder)
return obs
@@ -2513,8 +2513,8 @@ def observe(self, array, tint, tadv, tstart, tstop, bw,
elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH,
no_elevcut_space=False,
ttype='nfft', fft_pad_factor=2, fix_theta_GMST=False,
- sgrscat=False, add_th_noise=True,
- jones=False, inv_jones=False,
+ sgrscat=False, add_th_noise=True, th_noise_factor=1,
+ jones=False, inv_jones=False, noise=True,
opacitycal=True, ampcal=True, phasecal=True,
frcal=True, dcal=True, rlgaincal=True,
stabilize_scan_phase=False, stabilize_scan_amp=False,
@@ -2525,7 +2525,7 @@ def observe(self, array, tint, tadv, tstart, tstop, bw,
dterm_offset=ehc.DTERMPDEF,
rlratio_std=0.,rlphase_std=0.,
sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None,
- caltable_path=None, seed=False, verbose=True):
+ caltable_path=None, seed=False, reorder=True, verbose=True):
"""Generate baselines from an array object and observe the image.
Args:
@@ -2614,26 +2614,33 @@ def observe(self, array, tint, tadv, tstart, tstop, bw,
polrep=polrep_obs, tau=tau,
elevmin=elevmin, elevmax=elevmax,
no_elevcut_space=no_elevcut_space,
- timetype=timetype, fix_theta_GMST=fix_theta_GMST)
+ timetype=timetype, fix_theta_GMST=fix_theta_GMST, reorder=reorder)
+
# Observe on the same baselines as the empty observation and add noise
- obs = self.observe_same(obs, ttype=ttype, fft_pad_factor=fft_pad_factor,
- sgrscat=sgrscat, add_th_noise=add_th_noise,
- jones=jones, inv_jones=inv_jones,
- opacitycal=opacitycal, ampcal=ampcal,
- phasecal=phasecal, dcal=dcal,
- frcal=frcal, rlgaincal=rlgaincal,
- stabilize_scan_phase=stabilize_scan_phase,
- stabilize_scan_amp=stabilize_scan_amp,
- neggains=neggains,
- taup=taup,
- gain_offset=gain_offset, gainp=gainp,
- phase_std=phase_std,
- dterm_offset=dterm_offset,
- rlratio_std=rlratio_std,rlphase_std=rlphase_std,
- sigmat=sigmat,phasesigmat=phasesigmat,
- rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat,
- caltable_path=caltable_path, seed=seed, verbose=verbose)
+ if noise:
+ obs = self.observe_same(obs, ttype=ttype, fft_pad_factor=fft_pad_factor,
+ sgrscat=sgrscat, add_th_noise=add_th_noise, th_noise_factor=th_noise_factor,
+ jones=jones, inv_jones=inv_jones,
+ opacitycal=opacitycal, ampcal=ampcal,
+ phasecal=phasecal, dcal=dcal,
+ frcal=frcal, rlgaincal=rlgaincal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ neggains=neggains,
+ taup=taup,
+ gain_offset=gain_offset, gainp=gainp,
+ phase_std=phase_std,
+ dterm_offset=dterm_offset,
+ rlratio_std=rlratio_std,rlphase_std=rlphase_std,
+ sigmat=sigmat,phasesigmat=phasesigmat,
+ rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat,
+ caltable_path=caltable_path, seed=seed, reorder=reorder, verbose=verbose)
+ else:
+ obs = self.observe_same_nonoise(obs, sgrscat=sgrscat, ttype=ttype,
+ cache=False, fft_pad_factor=fft_pad_factor,
+ zero_empty_pol=True, reorder=reorder, verbose=verbose)
+
obs.mjd = mjd
diff --git a/ehtim/io/load.py b/ehtim/io/load.py
index 8e38a4f5..7ed4d8b5 100644
--- a/ehtim/io/load.py
+++ b/ehtim/io/load.py
@@ -1436,7 +1436,7 @@ def load_obs_uvfits(filename, polrep='stokes', flipbl=False,
((
times[i], tints[i],
t1[i], t2[i], tau1[i], tau2[i],
- u[i], v[i],
+ u[i], v[i], 0,
rr[i], ll[i], rl[i], lr[i],
rrsig[i], llsig[i], rlsig[i], lrsig[i]
), dtype=dtpol_out
diff --git a/ehtim/obsdata.py b/ehtim/obsdata.py
index 65ca8318..951f7e07 100644
--- a/ehtim/obsdata.py
+++ b/ehtim/obsdata.py
@@ -103,7 +103,7 @@ class Obsdata(object):
def __init__(self, ra, dec, rf, bw, datatable, tarr, scantable=None,
polrep='stokes', source=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, timetype='UTC',
ampcal=True, phasecal=True, opacitycal=True, dcal=True, frcal=True,
- trial_speedups=False):
+ trial_speedups=False, reorder=True):
"""A polarimetric VLBI observation of visibility amplitudes and phases (in Jy).
Args:
@@ -175,7 +175,11 @@ def __init__(self, ra, dec, rf, bw, datatable, tarr, scantable=None,
self.reorder_tarr_sefd(reorder_baselines=False)
# reorder baselines to uvfits convention
- self.reorder_baselines(trial_speedups=trial_speedups)
+ if reorder:
+ self.reorder_baselines(trial_speedups=trial_speedups) # comment out for Closure Invariants
+ else:
+ self.data = np.array(sorted(self.data, key=lambda x: x['time']))
+
# Get tstart, mjd and tstop
times = self.unpack(['time'])['time']
@@ -615,7 +619,8 @@ def tlist(self, conj=False, t_gather=0., scan_gather=False):
data, lambda x: np.searchsorted(self.scans[:, 0], x['time'])):
datalist.append(np.array([obs for obs in group]))
- return np.array(datalist, dtype=object)
+ # return np.array(datalist, dtype=object)
+ return datalist
def split_obs(self, t_gather=0., scan_gather=False):
@@ -1224,7 +1229,7 @@ def polchisq(self, im, dtype='pvis', ttype='nfft', pol_trans=True, mask=[], **kw
return chisq
- def recompute_uv(self):
+ def recompute_uv(self, with_w=False):
"""Recompute u,v points using observation times and metadata
Args:
@@ -1240,10 +1245,17 @@ def recompute_uv(self):
print("Recomputing U,V Points using MJD %d \n RA %e \n DEC %e \n RF %e GHz"
% (self.mjd, self.ra, self.dec, self.rf / 1.e9))
- (timesout, uout, vout) = obsh.compute_uv_coordinates(arr, site1, site2, times,
- self.mjd, self.ra, self.dec, self.rf,
- timetype=self.timetype,
- elevmin=0, elevmax=90, no_elevcut_space=False)
+ if with_w:
+ (timesout, uout, vout, wout) = obsh.compute_uv_coordinates(arr, site1, site2, times,
+ self.mjd, self.ra, self.dec, self.rf,
+ timetype=self.timetype,
+ elevmin=0, elevmax=90, no_elevcut_space=False,
+ w_term=with_w)
+ else:
+ (timesout, uout, vout) = obsh.compute_uv_coordinates(arr, site1, site2, times,
+ self.mjd, self.ra, self.dec, self.rf,
+ timetype=self.timetype,
+ elevmin=0, elevmax=90, no_elevcut_space=False)
if len(timesout) != len(times):
raise Exception(
@@ -1253,10 +1265,14 @@ def recompute_uv(self):
datatable['u'] = uout
datatable['v'] = vout
+ if with_w:
+ datatable['w'] = wout
+
arglist, argdict = self.obsdata_args()
arglist[DATPOS] = np.array(datatable)
out = Obsdata(*arglist, **argdict)
+
return out
def avg_coherent(self, inttime, scan_avg=False, moving=False):
@@ -2749,6 +2765,97 @@ def errfunc(p):
return gparams
+ def ClosureInvariants(self):
+ """
+ Calculates copolar closure invariants for visibilities assuming an n element
+ interferometer array using method 1.
+
+ Nithyanandan, T., Rajaram, N., Joseph, S. 2022 “Invariants in copolar
+ interferometry: An Abelian gauge theory”, PHYS. REV. D 105, 043019.
+ https://doi.org/10.1103/PhysRevD.105.043019
+
+ Args:
+ vis (np.ndarray): visibility data sampled by the interferometer array
+ n (int): number of antenna as part of the interferometer array
+
+ Returns:
+ ci (np.ndarray): closure invariants
+ """
+ tlist = self.tlist()
+ out_ci = np.array([])
+ for tdata in tlist:
+ num_antenna = len(np.unique(tdata['t1'])) + 1
+ if num_antenna < 3:
+ continue
+ vis = tdata['vis'].reshape(1,1,-1)
+ ant_pairs = np.array([tdata['t1'], tdata['t2']]).T
+ unique_ant = pd.unique(ant_pairs.flatten())
+ ant_pairs = np.array([np.where(i == unique_ant)[0][0] for i in ant_pairs.flatten()]).reshape(-1, 2)
+ ant_pairs = [tuple(i) for i in ant_pairs]
+ reverse_idx = [i for i in range(len(ant_pairs)) if ant_pairs[i][0] > ant_pairs[i][1]]
+ ant_pairs = [ant_pairs[i] if i not in reverse_idx else (ant_pairs[i][1], ant_pairs[i][0]) for i in range(len(ant_pairs))]
+ vis[:,:,reverse_idx] = np.conjugate(vis[:,:,reverse_idx])
+ _, btriads = self.Triads(num_antenna, pairs=ant_pairs)
+ C_oa = vis[:, :, btriads[:, 0]]
+ C_ab = vis[:, :, btriads[:, 1]]
+ C_bo = np.conjugate(vis[:, :, btriads[:, 2]])
+ A_oab = C_oa / np.conjugate(C_ab) * C_bo
+ A_oab = np.dstack((A_oab.real, A_oab.imag))
+ A_max = np.nanmax(np.abs(A_oab), axis=-1, keepdims=True)
+ ci = A_oab / A_max
+ ci = ci.reshape(-1)
+ out_ci = np.concatenate([out_ci, ci], axis=0)
+
+ return out_ci
+
+ def Triads(self, n:int, pairs=None):
+ """
+ Generates arrays of antenna and baseline indicies that form triangular
+ loops pivoted around the 0th antenna. Used to calculate closure invariants
+ whereby specific baseline correlations need to be indexed according
+ to those triangular loops.
+ Baseline array format [ant1, ant2]:
+ [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6] ...
+ [1, 2], [1, 3], [1, 4], [1, 5], [1, 6] ...
+ [2, 3], [2, 4], [2, 5], [2, 6] ...
+ [3, 4], [3, 5], [3, 6] ...
+ [4, 5], [4, 6] ...
+ [5, 6] ...
+
+ Args:
+ n (int): number of antenna in the array
+
+ Returns:
+ atriads (np.ndarray): antenna triangular loop indicies
+ btriads (np.ndarray): baseline triangular loop indicies
+ """
+ ntriads = (n-1)*(n-2)//2
+ ant1 = np.zeros(ntriads, dtype=np.uint8)
+ ant2 = np.arange(1, n, dtype=np.uint8).reshape(n-1, 1) + np.zeros(n-2, dtype=np.uint8).reshape(1, n-2)
+ ant3 = np.arange(2, n, dtype=np.uint8).reshape(1, n-2) + np.zeros(n-1, dtype=np.uint8).reshape(n-1, 1)
+ anti = np.where(ant3 > ant2)
+ ant2, ant3 = ant2[anti], ant3[anti]
+ atriads = np.concatenate([ant1.reshape(-1, 1), ant2.reshape(-1, 1), ant3.reshape(-1, 1)], axis=-1)
+
+ ant_pairs_01 = list(zip(ant1, ant2))
+ ant_pairs_12 = list(zip(ant2, ant3))
+ ant_pairs_20 = list(zip(ant3, ant1))
+
+ t1 = np.arange(n, dtype=int).reshape(n, 1) + np.zeros(n, dtype=int).reshape(1, n)
+ t2 = np.arange(n, dtype=int).reshape(1, n) + np.zeros(n, dtype=int).reshape(n, 1)
+ bli = np.where(t2 > t1)
+ t1, t2 = t1[bli], t2[bli]
+ if pairs == None:
+ bl_pairs = list(zip(t1, t2))
+ else:
+ bl_pairs = pairs
+
+ bl_01 = np.asarray([bl_pairs.index(apair) for apair in ant_pairs_01])
+ bl_12 = np.asarray([bl_pairs.index(apair) for apair in ant_pairs_12])
+ bl_20 = np.asarray([bl_pairs.index(tuple(reversed(apair))) for apair in ant_pairs_20])
+ btriads = np.concatenate([bl_01.reshape(-1, 1), bl_12.reshape(-1, 1), bl_20.reshape(-1, 1)], axis=-1)
+ return atriads, btriads
+
def bispectra(self, vtype='vis', mode='all', count='min',
timetype=False, uv_min=False, snrcut=0.):
"""Return a recarray of the equal time bispectra.
diff --git a/ehtim/observing/obs_helpers.py b/ehtim/observing/obs_helpers.py
index 6dc1dbaa..893b73fa 100644
--- a/ehtim/observing/obs_helpers.py
+++ b/ehtim/observing/obs_helpers.py
@@ -55,7 +55,7 @@
def compute_uv_coordinates(array, site1, site2, time, mjd, ra, dec, rf, timetype='UTC',
elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH, no_elevcut_space=False,
- fix_theta_GMST=False):
+ fix_theta_GMST=False, w_term=False):
"""Compute u,v coordinates for an array at a given time for a source at a given ra,dec,rf
"""
@@ -189,6 +189,7 @@ def compute_uv_coordinates(array, site1, site2, time, mjd, ra, dec, rf, timetype
# u,v coordinates
u = np.dot((coord1 - coord2)/wvl, projU) # u (lambda)
v = np.dot((coord1 - coord2)/wvl, projV) # v (lambda)
+ w = np.dot((coord1 - coord2)/wvl, sourcevec) # w (lambda)
# mask out below elevation cut
mask_elev_1 = elevcut(coord1, sourcevec, elevmin=elevmin, elevmax=elevmax)
@@ -205,8 +206,12 @@ def compute_uv_coordinates(array, site1, site2, time, mjd, ra, dec, rf, timetype
time = time[mask]
u = u[mask]
v = v[mask]
+ w = w[mask]
- # return times and uv points where we have data
+ if w_term:
+ return (time, u, v, w)
+
+ # return times and uv points where we have data
return (time, u, v)
diff --git a/ehtim/observing/obs_simulate.py b/ehtim/observing/obs_simulate.py
index 1d3d2e6a..99028e08 100644
--- a/ehtim/observing/obs_simulate.py
+++ b/ehtim/observing/obs_simulate.py
@@ -148,9 +148,10 @@ def make_uvpoints(array, ra, dec, rf, bw, tint, tadv, tstart, tstop,
ra, dec, rf, timetype=timetype,
elevmin=elevmin, elevmax=elevmax,
no_elevcut_space=no_elevcut_space,
- fix_theta_GMST=fix_theta_GMST)
+ fix_theta_GMST=fix_theta_GMST,
+ w_term=True)
- (timesout, uout, vout) = uvdat
+ (timesout, uout, vout, wout) = uvdat
for k in range(len(timesout)):
outlist.append(np.array((
timesout[k],
@@ -161,6 +162,7 @@ def make_uvpoints(array, ra, dec, rf, bw, tint, tadv, tstart, tstop,
tau2, # Station 1 zenith optical depth
uout[k], # u (lambda)
vout[k], # v (lambda)
+ wout[k], # w (lambda)
0.0, # 1st Visibility (Jy)
0.0, # 2nd Visibility
0.0, # 3rd Visibility
@@ -345,6 +347,38 @@ def sample_vis(im_org, uv, sgrscat=False, polrep_obs='stokes',
obsdata.append(vis)
+ elif ttype == "DFT":
+ xfov, yfov = im.xdim*im.psize/4.84813681109536e-12, im.ydim*im.psize/4.84813681109536e-12
+ for i in range(4):
+ pol = pollist[i]
+ imvec = im._imdict[pol]
+ if imvec is None or len(imvec) == 0:
+ if zero_empty_pol:
+ obsdata.append(np.zeros(len(uv)))
+ else:
+ obsdata.append(None)
+ else:
+ imarr = imvec.reshape(im.ydim, im.xdim)
+ vis = DFT(imarr, uv, xfov=xfov, yfov=yfov)
+ obsdata.append(vis)
+
+ elif ttype == "DFT_i":
+ xfov, yfov = im.xdim*im.psize/4.84813681109536e-12, im.ydim*im.psize/4.84813681109536e-12
+ for i in range(4):
+ pol = pollist[i]
+ imvec = im._imdict[pol]
+ if imvec is None or len(imvec) == 0:
+ if zero_empty_pol:
+ obsdata.append(np.zeros(len(uv)))
+ else:
+ obsdata.append(None)
+ else:
+ uv = np.array([uv[:,1], uv[:,0]]).T # uv swap hack
+ imarr = imvec.reshape(im.ydim, im.xdim)
+ vis = DFT(imarr, uv, xfov=xfov, yfov=yfov)
+ obsdata.append(vis)
+
+
# Get visibilities from DTFT
else:
# Construct Fourier matrix
@@ -375,6 +409,35 @@ def sample_vis(im_org, uv, sgrscat=False, polrep_obs='stokes',
return obsdata
+
+def DFT(data, uv, xfov=225, yfov=225):
+ if data.ndim == 2:
+ data = data[np.newaxis,...]
+ out_shape = (uv.shape[0],)
+ elif data.ndim > 2:
+ data = data.reshape((-1,) + data.shape[-2:])
+ out_shape = data.shape[:-2] + (uv.shape[0],)
+ ny, nx = data.shape[-2:]
+ dx = xfov*4.84813681109536e-12 / nx
+ dy = yfov*4.84813681109536e-12 / ny
+ angx = (np.arange(nx) - nx//2) * dx
+ angy = (np.arange(ny) - ny//2) * dy
+ lvect = np.sin(angx)
+ mvect = np.sin(angy)
+ l, m = np.meshgrid(lvect, mvect)
+ lm = np.concatenate([l.reshape(1,-1), m.reshape(1,-1)], axis=0)
+ imgvect = data.reshape((data.shape[0],-1))
+ x = -2*np.pi*np.dot(uv,lm)[np.newaxis, ...]
+ visr = np.sum(imgvect[:, np.newaxis, :] * np.cos(x, dtype=np.float32), axis=-1)
+ visi = np.sum(imgvect[:, np.newaxis, :] * np.sin(x, dtype=np.float32), axis=-1)
+ if data.ndim == 2:
+ vis = visr.ravel() + 1j*visi.ravel()
+ else:
+ vis = visr.ravel() + 1j*visi.ravel()
+ vis = vis.reshape(out_shape)
+ return vis
+
+
##################################################################################################
# Noise + miscalibration funcitons
##################################################################################################
@@ -1199,7 +1262,7 @@ def apply_jones_inverse(obs, opacitycal=True, dcal=True, frcal=True, verbose=Tru
# The old noise generating function.
-def add_noise(obs, add_th_noise=True, opacitycal=True, ampcal=True, phasecal=True,
+def add_noise(obs, add_th_noise=True, th_noise_factor=1, opacitycal=True, ampcal=True, phasecal=True,
stabilize_scan_amp=False, stabilize_scan_phase=False,
neggains=False,
taup=ehc.GAINPDEF, gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF,
@@ -1297,7 +1360,7 @@ def add_noise(obs, add_th_noise=True, opacitycal=True, ampcal=True, phasecal=Tru
sig_rr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdr'],
obs.tarr[obs.tkey[sites[i][1]]]['sefdr'], tint[i], bw)
for i in range(len(tint))), float)
- sig_ll = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdr'],
+ sig_ll = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdl'],
obs.tarr[obs.tkey[sites[i][1]]]['sefdl'], tint[i], bw)
for i in range(len(tint))), float)
sig_rl = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdr'],
@@ -1401,10 +1464,10 @@ def add_noise(obs, add_th_noise=True, opacitycal=True, ampcal=True, phasecal=Tru
sigma_est4 = sigma_perf4 * gain_true * tau_est
if add_th_noise:
- vis1 = (vis1 + obsh.cerror(sigma_true1))
- vis2 = (vis2 + obsh.cerror(sigma_true2))
- vis3 = (vis3 + obsh.cerror(sigma_true3))
- vis4 = (vis4 + obsh.cerror(sigma_true4))
+ vis1 = (vis1 + th_noise_factor*obsh.cerror(sigma_true1))
+ vis2 = (vis2 + th_noise_factor*obsh.cerror(sigma_true2))
+ vis3 = (vis3 + th_noise_factor*obsh.cerror(sigma_true3))
+ vis4 = (vis4 + th_noise_factor*obsh.cerror(sigma_true4))
# Add the gain error to the true visibilities
vis1 = vis1 * gain_true * tau_est / tau_true
diff --git a/features/__init__.py b/features/__init__.py
new file mode 100644
index 00000000..d7cc51cb
--- /dev/null
+++ b/features/__init__.py
@@ -0,0 +1,11 @@
+"""
+.. module:: ehtim.features
+ :platform: Unix
+ :synopsis: EHT Imaging Utilities: feature extraction functions
+
+.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu)
+
+"""
+from . import rex
+
+from ..const_def import *
diff --git a/features/rex.py b/features/rex.py
new file mode 100644
index 00000000..62fafd31
--- /dev/null
+++ b/features/rex.py
@@ -0,0 +1,1057 @@
+# rex.py
+# ring fitting code for ehtim
+#
+# Copyright (C) 2019 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import os
+import glob
+import matplotlib.pyplot as plt
+import numpy as np
+import astropy.io.fits as fits
+import subprocess
+
+import scipy.interpolate
+import scipy.optimize
+import scipy.stats
+from astropy.stats import median_absolute_deviation
+
+from ehtim.image import load_image
+import ehtim.imaging.dynamical_imaging as di
+import ehtim.parloop as ploop
+import ehtim.const_def as ehc
+
+###################################################################################################
+# Parameters
+###################################################################################################
+
+EP = 1.e-16
+BIG = 1./EP
+
+IMSIZE = 160*ehc.RADPERUAS # 250*ehc.RADPERUAS # FOV of resampled image (muas)
+NPIX = 160 # 128 # pixels in resampled image
+
+NRAYS = 360 # number of angular rays in final profile
+NRS = 100 # number of radial points in final profile
+
+RMAX = 50 # maximum radius in every profile slice (muas)
+RMIN = 5 # radius threshold for averaging inside ring (muas)
+
+RPRIOR_MIN = 15. # 5. # minimum radius for search (muas)
+RPRIOR_MAX = 50. # 60. # maximum radius for search (muas)
+NRAYS_SEARCH = 25 # number of angular rays in search profiles
+NRS_SEARCH = 50 # number of radial points in search profiles
+THRESH = 0.05 # thresholding level for the images in the search
+# BLUR_VALUE_MIN=2 # blur to this value for initial centroid search (uas)
+FOVP_SEARCH = 0.1 # fractional FOV around image center for brute force search
+NSEARCH = 10 # number of points in each dimension for brute force search
+NORMFLUX = 1 # normalized image flux for outputted profiles (Jy)
+
+POSTPROCDIR = '.' # default postprocessing directory
+
+###################################################################################################
+# Profiles class
+###################################################################################################
+
+
+class Profiles(object):
+
+ def __init__(self, im, x0, y0, profs, thetas, rmin=RMIN, rmax=RMAX,
+ interptype='cubic',flux_norm=NORMFLUX,
+ profsQ=[], profsU=[]):
+
+ self.x0 = x0
+ self.y0 = y0
+ self.im = im
+ self.rmin = rmin
+ self.rmax = rmax
+
+ # store the center image
+ deltay = -(im.fovy()/2. - y0*ehc.RADPERUAS)/im.psize
+ deltax = (im.fovx()/2. - x0*ehc.RADPERUAS)/im.psize
+ self.im_center = im.shift([int(np.round(deltay)), int(np.round(deltax))])
+
+ # total flux and normalization
+ self.flux = im.total_flux()
+ self.parea = (im.psize/ehc.RADPERUAS)**2
+
+ # factor to convert to normalized brightness temperature (total flux of 1 Jy)
+ self.flux_norm = flux_norm
+ self.normfactor = self.flux_norm / im.total_flux()
+
+ # image array and profiles
+ factor = 3.254e13/(im.rf**2 * im.psize**2) # factor to convert to brightness temperature
+ self.imarr = im.imvec.reshape(im.ydim, im.xdim)[::-1] * factor # in Tb
+
+ self.xs = np.arange(im.xdim)*im.psize/ehc.RADPERUAS
+ self.ys = np.arange(im.ydim)*im.psize/ehc.RADPERUAS
+ self.interp = scipy.interpolate.interp2d(self.ys, self.xs, self.imarr, kind=interptype)
+
+ self.profiles = np.array(profs)
+ self.profilesQ = np.array(profsQ)
+ self.profilesU = np.array(profsU)
+ self.profilesP = np.sqrt(self.profilesQ**2 + self.profilesU**2)
+ self.thetas = np.array(thetas)
+ self.nang = len(thetas)
+ self.nrs = len(self.profiles[0])
+ self.nthetas = len(self.thetas)
+ self.rs = np.linspace(0, self.rmax, self.nrs)
+ self.dr = self.rs[-1] - self.rs[-2]
+ self.pks = []
+ self.pk_vals = []
+ self.diameters = []
+
+ for prof in self.profiles:
+ pk, vpk = self.calc_pkrad_from_prof(prof)
+
+ self.pks.append(pk)
+ self.pk_vals.append(vpk)
+ self.diameters.append(2*np.abs(pk))
+
+ self.pks = np.array(self.pks)
+ self.pk_vals = np.array(self.pk_vals)
+ self.diameters = np.array(self.diameters)
+
+ # ring size
+ self.RingSize1 = (np.mean(self.diameters), np.std(self.diameters))
+ self.RingSize1_med = (np.median(self.diameters), median_absolute_deviation(self.diameters))
+
+ def calc_pkrad_from_prof(self, prof):
+ """calculate peak radius and value with linear interpolation"""
+ args = np.argsort(prof)
+ pkpos = args[-1]
+ pk = self.rs[pkpos]
+ vpk = prof[pkpos]
+ if pkpos > 0 and pkpos < self.nrs-1:
+ vals = [prof[pkpos-1], prof[pkpos], prof[pkpos+1]]
+ pk, vpk = quad_interp_radius(pk, self.dr, vals)
+ return (pk, vpk)
+
+ def calc_meanprof_and_stats(self):
+
+ # calculate mean profile
+ self.meanprof = np.mean(self.profiles, axis=0)
+ args = np.argsort(self.meanprof)
+
+ self.pkloc = args[-1]
+ self.pkrad = self.rs[self.pkloc]
+ self.meanpk = self.meanprof[self.pkloc]
+
+ # absolute peak in angle and radius
+ profile_peak_loc = np.unravel_index(np.argmax(self.profiles), self.profiles.shape)
+ self.abspk_loc_rad = profile_peak_loc[1]
+ self.abspk_rad = self.rs[self.abspk_loc_rad]
+ self.abspk_loc_ang = profile_peak_loc[0]
+ self.abspk_ang = self.thetas[self.abspk_loc_ang]
+
+ # find inside mean flux
+ inner_loc = np.argmin((self.rs-self.rmin)**2)
+ self.in_level = np.mean(self.meanprof[0:inner_loc]) # profile avg inside ring
+
+ # find outside mean flux
+ outer_loc = np.argmin((self.rs-(self.rmax-self.rmin))**2)
+ self.out_level = np.mean(self.meanprof[outer_loc:]) # profile avg outside ring
+
+ # find mean profile FWHM with spline interpolation
+ meanprof_zeroed = self.meanprof - self.out_level
+ (lh_meanprof, rh_meanprof) = self.calc_width(meanprof_zeroed)
+ lhloc_meanprof = np.argmin((self.rs-lh_meanprof)**2)
+ rhloc_meanprof = np.argmin((self.rs-rh_meanprof)**2)
+
+ self.lh = lh_meanprof
+ self.rh = rh_meanprof
+ self.lhloc = lhloc_meanprof
+ self.rhloc = rhloc_meanprof
+
+ # ring diameter and width from the mean profile
+ meanprof_diameter = 2*self.pkrad
+ meanprof_width = np.abs(rh_meanprof - lh_meanprof)
+ self.RingSize2 = (meanprof_diameter, meanprof_width)
+
+ # find ring width with all angular profiles
+ ringwidths = []
+ for i in range(self.nang):
+ rprof = self.profiles[i]
+ # TODO zero min profile before taking width???
+ # rprof_zeroed = rprof - np.max((np.min(rprof), 0))
+ (lh, rh) = self.calc_width(rprof)
+ width = rh-lh
+ if width <= 0 or width >= 2*meanprof_width:
+ continue # AC ?? ok to exclude huge widths???
+ ringwidths.append(width)
+
+ self.RingWidth = (np.mean(ringwidths), np.std(ringwidths))
+
+ # ring angle 1: mean and std deviation of individual profiles
+ ringangles = []
+ ringasyms = []
+ for i in range(self.lhloc, self.rhloc+1):
+ angprof = self.profiles.T[i]
+ if i == self.lhloc:
+ prof_mean_r = angprof.reshape(1, len(self.profiles.T[i]))
+ else:
+ prof_mean_r = np.vstack((prof_mean_r, angprof))
+
+ angle_asym = self.calc_ringangle_asymmetry(angprof)
+ ringangles.append(angle_asym[0])
+ ringasyms.append(angle_asym[1])
+
+ self.RingAngle1 = (scipy.stats.circmean(ringangles), scipy.stats.circstd(ringangles))
+
+ # ring angle 2: ring angle function on avg profile
+ prof_mean_r = np.mean(np.array(prof_mean_r), axis=0)
+ self.meanprof_theta = prof_mean_r
+ ringangle2 = self.calc_ringangle_asymmetry(prof_mean_r)
+ self.RingAngle2 = (ringangle2[0], ringangle2[-1])
+
+ # contrast 1: maximum profile value / mean of inner region
+ # self.RingContrast1 = np.max(self.profiles.T[self.pkloc]) / self.in_level
+ # self.RingContrast1 = np.max(self.profiles) / self.in_level
+ self.RingContrast1 = np.max(self.profiles[:, self.lhloc:self.rhloc+1]) / self.in_level
+
+ # contrast 1: mean profile max value / mean of inner region
+ self.RingContrast2 = self.meanpk / self.in_level
+
+ # asymmetry 1: m1 mode of angular profile
+ self.RingAsym1 = (np.mean(ringasyms), np.std(ringasyms))
+
+ # asymmetry 2: integrated flux in bottom half of ring vs top half of ring
+ mask_inner = self.im.copy()
+ mask_outer = self.im.copy()
+ immask = self.im.copy()
+
+ x0_c = self.im.fovx()/2. - self.x0*ehc.RADPERUAS
+ y0_c = self.y0*ehc.RADPERUAS - self.im.fovy()/2.
+
+ # mask annulus
+ rad_inner = (self.RingSize1[0]/2. - self.RingWidth[0]/2.)*ehc.RADPERUAS
+ rad_outer = (self.RingSize1[0]/2. + self.RingWidth[0]/2.)*ehc.RADPERUAS
+
+ mask_inner.imvec *= 0
+ mask_outer.imvec *= 0
+ mask_inner = mask_inner.add_gauss(1, [2*rad_inner, 2*rad_inner, 0, x0_c, y0_c])
+ mask_inner = mask_inner.mask(cutoff=.5)
+ mask_outer = mask_outer.add_gauss(1, [2*rad_outer, 2*rad_outer, 0, x0_c, y0_c])
+ mask_outer = mask_outer.mask(cutoff=.5)
+
+ maskvec_annulus = np.logical_xor(
+ mask_inner.imvec.astype(bool), mask_outer.imvec.astype(bool))
+
+ # mask angle
+ xlist = np.arange(0, -self.im.xdim, -1)*self.im.psize + \
+ (self.im.psize*self.im.xdim)/2.0 - self.im.psize/2.0
+ ylist = np.arange(0, -self.im.ydim, -1)*self.im.psize + \
+ (self.im.psize*self.im.ydim)/2.0 - self.im.psize/2.0
+
+ cangle = self.RingAngle1[0]
+
+ def anglemask(x, y):
+ ang = np.mod(-np.arctan2(y-y0_c, x-x0_c)+np.pi/2., 2*np.pi)
+ # return ang
+ if np.mod(np.abs(ang-cangle), 2*np.pi) > 0.5*np.pi:
+ return False
+ else:
+ return True
+
+ maskvec_ang = np.array([[anglemask(i, j) for i in xlist]
+ for j in ylist]).flatten().astype(bool)
+
+ # combine masks and get the bright and dim flux
+ maskvec_brighthalf = maskvec_annulus * maskvec_ang
+ brightflux = np.sum(immask.imvec[(maskvec_brighthalf)])
+
+ maskvec_dimhalf = maskvec_annulus * ~maskvec_ang
+ dimflux = np.sum(immask.imvec[(maskvec_dimhalf)])
+ self.RingFlux = brightflux + dimflux
+ self.RingAsym2 = ((brightflux-dimflux)/(brightflux+dimflux), brightflux/dimflux)
+
+ # Polarization brightness ratio
+ # AC TODO FOR PAPER VIII ANALYSIS
+ self.RingAsymPol = (0., 0.)
+ if len(self.profilesP)>0 and len(self.im.qvec) > 0 and len(self.im.uvec) > 0:
+ pvec = np.sqrt(self.im.qvec**2 + self.im.uvec**2)
+ pvec_C = (self.im.qvec + 1j*self.im.uvec)
+
+ ringanglesPol = []
+ # ringasymsPol = []
+ for i in range(self.lhloc, self.rhloc+1):
+ angprof = self.profilesP.T[i]
+ # simple maximum AC TODO
+ ringanglesPol.append(self.thetas[np.argmax(angprof)])
+
+ # weighted avg
+ # angle_asym = self.calc_ringangle_asymmetry(angprof)
+ # ringanglesPol.append(angle_asym[0])
+ # ringasymsPol.append(angle_asym[1])
+
+ self.RingAnglePol = (scipy.stats.circmean(ringanglesPol),
+ scipy.stats.circstd(ringanglesPol))
+
+ cangle = self.RingAnglePol[0]
+
+ def anglemask_pol(x, y):
+ ang = np.mod(-np.arctan2(y-y0_c, x-x0_c)+np.pi/2., 2*np.pi)
+ # return ang
+ if np.mod(np.abs(ang-cangle), 2*np.pi) > 0.5*np.pi:
+ return False
+ else:
+ return True
+
+ maskvec_ang = np.array([[anglemask_pol(i, j) for i in xlist]
+ for j in ylist]).flatten().astype(bool)
+
+ # combine masks and get the bright and dim pol flux
+ maskvec_brighthalf = maskvec_annulus * maskvec_ang
+ maskvec_dimhalf = maskvec_annulus * ~maskvec_ang
+ # maskvec_brighthalf = maskvec_ang
+ # maskvec_dimhalf = ~maskvec_ang
+
+ # calculate polarized asymmetry / birghtness ratio
+ brightflux_pol_C = np.abs(np.sum(pvec_C[(maskvec_brighthalf)]))
+ dimflux_pol_C = np.abs(np.sum(pvec_C[(maskvec_dimhalf)]))
+ brightflux_pol = np.sum(pvec[(maskvec_brighthalf)])
+ dimflux_pol = np.sum(pvec[(maskvec_dimhalf)])
+ self.RingAsymPol = ((brightflux_pol_C/dimflux_pol_C),
+ brightflux_pol/dimflux_pol)
+
+ # calculate dynamic range
+ mask = self.im.copy()
+ immask = self.im.copy()
+
+ x0_c = mask.fovx()/2. - self.x0*ehc.RADPERUAS
+ y0_c = self.y0*ehc.RADPERUAS - mask.fovy()/2.
+ rad = self.RingSize1[0]*ehc.RADPERUAS
+
+ mask.imvec *= 0
+ mask = mask.add_gauss(1, [2*rad, 2*rad, 0, x0_c, y0_c])
+ mask = mask.mask(cutoff=.5)
+ maskvec = mask.imvec.astype(bool) + (immask.imvec < EP*self.flux)
+ offsource_vec = immask.imvec[~(maskvec)]
+
+ self.impeak = np.max(self.im.imvec)
+ self.std_offsource = np.std(offsource_vec) + EP
+ self.mean_offsource = np.mean(offsource_vec) + EP
+ self.dynamic_range = self.impeak / self.std_offsource
+
+ def calc_width(self, prof):
+ pkrad, maxval = self.calc_pkrad_from_prof(prof)
+ spline = scipy.interpolate.UnivariateSpline(self.rs, prof-0.5*maxval, s=0)
+ roots = spline.roots() # find the roots
+
+ if len(roots) == 0:
+ return(self.rs[0], self.rs[-1])
+
+ lh = self.rs[0]
+ rh = self.rs[-1]
+ for root in np.sort(roots):
+ if root < pkrad:
+ lh = root
+ else:
+ rh = root
+ break
+
+ return (lh, rh)
+
+ def calc_ringangle_asymmetry(self, prof):
+ dtheta = self.thetas[-1]-self.thetas[-2]
+ prof = prof / np.sum(prof*dtheta) # normalize
+ x = np.sum(prof * np.exp(1j*self.thetas) * dtheta)
+ ang = np.mod(np.angle(x), 2*np.pi)
+ asym = np.abs(x)
+ std = np.sqrt(-2*np.log(np.abs(x)))
+ return (ang, asym, std)
+
+ def plot_img(self, postprocdir=POSTPROCDIR, save_png=False):
+ plt.figure()
+
+ fovx = self.im.fovx()/ehc.RADPERUAS
+ fovy = self.im.fovy()/ehc.RADPERUAS
+ xsplot = self.xs - fovx/2.
+ ysplot = self.ys - fovy/2.
+ x0plot = self.x0 - fovx/2.
+ y0plot = self.y0 - fovy/2.
+
+ plt.pcolormesh(xsplot,ysplot,self.imarr,cmap='afmhot')
+ plt.contour(xsplot,ysplot, self.imarr, colors='k',levels=5)
+
+ plt.xlabel(r"-RA ($\mu$as)")
+ plt.ylabel(r"Dec ($\mu$as)")
+
+ plt.plot(x0plot,y0plot, 'r*', markersize=20)
+
+ thetas = np.linspace(0, 2*np.pi, 100)
+ plt.plot(x0plot+ np.cos(thetas) * self.RingSize1[0]/2,
+ y0plot + np.sin(thetas) * self.RingSize1[0]/2,
+ 'r-', markersize=1)
+
+ #plt.axes().set_aspect('equal', 'datalim')
+
+ if save_png:
+ dirname = os.path.basename(os.path.dirname(self.imname))
+ if not os.path.exists(postprocdir + '/' + dirname):
+ subprocess.call(['mkdir', postprocdir + '/' + dirname])
+
+ basename = os.path.basename(self.imname)
+ fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_contour.png'
+
+ plt.savefig(fname)
+ plt.close()
+ else:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ def plot_unwrapped(self, postprocdir=POSTPROCDIR, save_png=False,
+ xlabel=True, ylabel=True, xticklabel=True, yticklabel=True,
+ ax=False, imrange=[], show=True, cfun='jet', linecolor='r', labelsize=14):
+
+ # line colors
+ angcolor = np.array([100, 149, 237])/256.
+ pkcolor = np.array([219, 0., 219])/256.
+ # pkcolor = np.array([0,255,0])/256.
+
+ imarr = np.array(self.profiles).T/1.e9
+ if ax is False:
+ plt.figure()
+ ax = plt.gca()
+
+ if imrange:
+ plt.imshow(imarr, cmap=plt.get_cmap(cfun), origin='lower',
+ vmin=imrange[0], vmax=imrange[1], interpolation='gaussian')
+ else:
+ plt.imshow(imarr, cmap=plt.get_cmap(cfun), origin='lower', interpolation='gaussian')
+
+ uas_to_pix = self.nrs/np.max(self.rs) # convert radius to pixels
+ rad_to_pix = self.nang/(2*np.pi) # convert az. angle to pixels
+
+ # horizontal lines -- radius
+ pkloc = self.RingSize1[0]/2. * uas_to_pix
+ lhloc = (self.RingSize1[0] - self.RingSize1[1])/2. * uas_to_pix
+ rhloc = (self.RingSize1[0] + self.RingSize1[1]) / 2. * uas_to_pix
+
+ plt.axhline(y=pkloc, color=linecolor, linewidth=1)
+ plt.axhline(y=lhloc, color=linecolor, linewidth=1, linestyle=':')
+ plt.axhline(y=rhloc, color=linecolor, linewidth=1, linestyle=':')
+
+ # horizontal lines -- width
+ # add radius and half width sigma in quadrature
+ bandloc_sigma = np.sqrt((self.RingWidth[1]/2)**2 + (self.RingSize1[1]/2)**2)
+
+ rhloc = (self.RingSize1[0]/2. + self.RingWidth[0]/2.) * uas_to_pix
+ rhloc2 = (self.RingSize1[0]/2. + self.RingWidth[0]/2. + bandloc_sigma) * uas_to_pix
+ rhloc3 = (self.RingSize1[0]/2. + self.RingWidth[0]/2. - bandloc_sigma) * uas_to_pix
+
+ lhloc = (self.RingSize1[0]/2. - self.RingWidth[0]/2.) * uas_to_pix
+ lhloc2 = (self.RingSize1[0]/2. - self.RingWidth[0]/2. + bandloc_sigma) * uas_to_pix
+ lhloc3 = (self.RingSize1[0]/2. - self.RingWidth[0]/2. - bandloc_sigma) * uas_to_pix
+
+ plt.axhline(y=lhloc, color=linecolor, linewidth=1, linestyle='--')
+ plt.axhline(y=lhloc2, color=linecolor, linewidth=1, linestyle=':')
+ plt.axhline(y=lhloc3, color=linecolor, linewidth=1, linestyle=':')
+ plt.axhline(y=rhloc, color=linecolor, linewidth=1, linestyle='--')
+ plt.axhline(y=rhloc2, color=linecolor, linewidth=1, linestyle=':')
+ plt.axhline(y=rhloc3, color=linecolor, linewidth=1, linestyle=':')
+
+ # position angle line
+ pkloc = self.RingAngle1[0] * rad_to_pix
+ lhloc = (self.RingAngle1[0] + self.RingAngle1[1]) * rad_to_pix
+ rhloc = (self.RingAngle1[0] - self.RingAngle1[1]) * rad_to_pix
+
+ plt.axvline(x=pkloc, color=angcolor, linewidth=1)
+ plt.axvline(x=lhloc, color=angcolor, linewidth=1, linestyle=':')
+ plt.axvline(x=rhloc, color=angcolor, linewidth=1, linestyle=':')
+
+ # bright peak point
+ plt.plot([self.abspk_loc_ang], [self.abspk_loc_rad], marker='x', mew=2, ms=6, color=pkcolor)
+
+ # labels
+ if xlabel:
+ plt.xlabel(r"$\theta $ ($^\circ$)", size=labelsize)
+ if ylabel:
+ plt.ylabel(r"$r$ ($\mu$as)", size=labelsize)
+
+ xticklabels = np.arange(0, 360, 60)
+ xticks = (360/imarr.shape[1])*xticklabels
+
+ yticks = np.floor(np.arange(0, imarr.shape[0], imarr.shape[0]/5)).astype(int)
+ yticklabels = ["%0.0f" % r for r in self.rs[yticks]]
+
+ if not xticklabel:
+ xticklabels = []
+ if not yticklabel:
+ yticklabels = []
+
+ plt.xticks(xticks, xticklabels)
+ plt.yticks(yticks, yticklabels)
+ plt.tick_params(axis='both', which='major', length=6)
+
+ if save_png:
+ dirname = os.path.basename(os.path.dirname(self.imname))
+ if not os.path.exists(postprocdir + '/' + dirname):
+ subprocess.call(['mkdir', postprocdir + '/' + dirname])
+
+ basename = os.path.basename(self.imname)
+ fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_unwrapped.png'
+ plt.savefig(fname)
+ plt.close()
+ elif show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ def save_unwrapped(self, fname):
+
+ imarr = np.array(self.profiles).T
+
+ header = fits.Header()
+ header['CTYPE1'] = 'RA---SIN'
+ header['CTYPE2'] = 'DEC--SIN'
+ header['CDELT1'] = 2*np.pi/float(len(self.profiles))
+ header['CDELT2'] = np.max(self.rs)/float(len(self.rs))
+ header['BUNIT'] = 'K'
+ hdu = fits.PrimaryHDU(imarr, header=header)
+ hdulist = [hdu]
+ hdulist = fits.HDUList(hdulist)
+ hdulist.writeto(fname, overwrite=True)
+
+ def plot_profs(self, postprocdir=POSTPROCDIR, save_png=False, colors=ehc.SCOLORS):
+ plt.figure()
+ plt.xlabel(r"distance from center ($\mu$as)")
+ plt.ylabel(r"$T_{\rm b}$")
+ #plt.ylim([0, 1])
+ plt.xlim([-10, 60])
+ plt.title('All Profiles')
+ for j in range(len(self.profiles)):
+ plt.plot(self.rs, self.profiles[j], color=colors[j % len(colors)],
+ linestyle='-', linewidth=1)
+ if save_png:
+ dirname = os.path.basename(os.path.dirname(self.imname))
+ if not os.path.exists(postprocdir + '/' + dirname):
+ subprocess.call(['mkdir', postprocdir + '/' + dirname])
+
+ basename = os.path.basename(self.imname)
+ fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_profiles.png'
+
+ plt.savefig(fname)
+ plt.close()
+ else:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ def plot_prof_band(self, postprocdir=POSTPROCDIR, save_png=False,
+ color='b', fontsize=14, show=True, axis=None, xlabel=True, ylabel=False):
+ """2-sided plot of radial profiles, cut across orthogonal to position angle"""
+ if axis is None:
+ plt.figure()
+ ax = plt.gca()
+ else:
+ ax = axis
+
+ if xlabel:
+ plt.xlabel(r"$r$ ($\mu$as)", size=fontsize)
+
+ yticks = [0, 2, 4, 6, 8, 10]
+ yticklabels = []
+ if ylabel:
+ plt.ylabel(r'Brightness Temperature ($10^9$ K)', size=fontsize)
+ yticklabels = yticks
+
+ plt.yticks(yticks, yticklabels)
+
+ #plt.ylim([0, 11])
+ plt.xlim([-55, 55])
+
+ # cut the ring in half orthagonal to the position angle
+ cutloc1 = np.argmin(np.abs(self.thetas-np.mod(self.RingAngle1[0] - np.pi/2., 2*np.pi)))
+ cutloc2 = np.argmin(np.abs(self.thetas-np.mod(self.RingAngle1[0] + np.pi/2., 2*np.pi)))
+
+ if cutloc1 < cutloc2:
+ prof_half_1 = self.profiles[cutloc1:cutloc2+1]
+ prof_half_2 = np.vstack((self.profiles[cutloc2+1:], self.profiles[0:cutloc1]))
+ else:
+ prof_half_1 = np.vstack((self.profiles[cutloc1:], self.profiles[0:cutloc2+1]))
+ prof_half_2 = self.profiles[cutloc2+1:cutloc1]
+
+ # plot left half
+ radii = -np.flip(self.rs)
+ tho_m = np.flip(np.median(np.array(prof_half_1), axis=0))
+ tho_l = np.flip(np.percentile(np.array(prof_half_1), 0, axis=0))
+ tho_u = np.flip(np.percentile(np.array(prof_half_1), 100, axis=0))
+ tho_l1 = np.flip(np.percentile(np.array(prof_half_1), 25, axis=0))
+ tho_u1 = np.flip(np.percentile(np.array(prof_half_1), 75, axis=0))
+
+ ax.plot(radii, tho_m/1.e9, ls='-', linewidth=2, color=color)
+ ax.fill_between(radii, tho_l/1.e9, tho_u/1.e9, alpha=.2, edgecolor=None, facecolor=color)
+ ax.fill_between(radii, tho_l1/1.e9, tho_u1/1.e9, alpha=.4, edgecolor=None, facecolor=color)
+
+ # plot rights half
+ radii = self.rs
+ tho_m = np.median(np.array(prof_half_2), axis=0)
+ tho_l = np.percentile(np.array(prof_half_2), 0, axis=0)
+ tho_u = np.percentile(np.array(prof_half_2), 100, axis=0)
+ tho_l1 = np.percentile(np.array(prof_half_2), 25, axis=0)
+ tho_u1 = np.percentile(np.array(prof_half_2), 75, axis=0)
+
+ ax.plot(radii, tho_m/1.e9, ls='-', linewidth=2, color=color)
+ ax.fill_between(radii, tho_l/1.e9, tho_u/1.e9, alpha=.2, edgecolor=None, facecolor=color)
+ ax.fill_between(radii, tho_l1/1.e9, tho_u1/1.e9, alpha=.4, edgecolor=None, facecolor=color)
+
+ if save_png:
+ dirname = os.path.basename(os.path.dirname(self.imname))
+ if not os.path.exists(postprocdir + '/' + dirname):
+ subprocess.call(['mkdir', postprocdir + '/' + dirname])
+
+ basename = os.path.basename(self.imname)
+ fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_band_profile.png'
+
+ plt.savefig(fname)
+ plt.close()
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ def plot_meanprof(self, postprocdir=POSTPROCDIR, save_png=False, color='k'):
+ fig = plt.figure()
+ plt.plot(self.rs, self.meanprof,
+ color=color, linestyle='-', linewidth=1)
+ plt.plot((self.lh, self.rh), (0.5*self.meanpk, 0.5*self.meanpk),
+ color=color, linestyle='--', linewidth=1)
+ plt.xlabel(r"distance from center ($\mu$as)")
+ plt.ylabel(r"Flux (mJy/$\mu$as$^2$)")
+ #plt.ylim([0, 1])
+ plt.xlim([-10, 60])
+ plt.title('Mean Profile')
+ if save_png:
+ dirname = os.path.basename(os.path.dirname(self.imname))
+ if not os.path.exists(postprocdir + '/' + dirname):
+ subprocess.call(['mkdir', postprocdir + '/' + dirname])
+
+ basename = os.path.basename(self.imname)
+ fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_meanprofile.png'
+
+ plt.savefig(fname)
+ plt.close()
+ else:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ def plot_meanprof_theta(self, postprocdir=POSTPROCDIR, save_png=False, color='k'):
+ fig = plt.figure()
+ plt.plot(self.thetas/ehc.DEGREE, self.meanprof_theta,
+ color=color, linestyle='-', linewidth=1)
+
+ ang1 = self.RingAngle1[0]/ehc.DEGREE
+ std1 = self.RingAngle1[1]/ehc.DEGREE
+ up = np.mod(ang1+std1, 360)
+ down = np.mod(ang1-std1, 360)
+ plt.axvline(x=ang1, color='b', linewidth=1)
+ plt.axvline(x=up, color='b', linewidth=1, linestyle='--')
+ plt.axvline(x=down, color='b', linewidth=1, linestyle='--')
+
+ ang2 = self.RingAngle2[0]/ehc.DEGREE
+ std2 = self.RingAngle2[1]/ehc.DEGREE
+ up = np.mod(ang2+std2, 360)
+ down = np.mod(ang2-std2, 360)
+ plt.axvline(x=ang2, color='r', linewidth=1)
+ plt.axvline(x=up, color='r', linewidth=1, linestyle='--')
+ plt.axvline(x=down, color='r', linewidth=1, linestyle='--')
+
+ plt.xlabel(r"Angle E of N ($^{\circ}$)")
+ plt.ylabel("Normalized Flux")
+ plt.title('Mean Angular Profile')
+ if save_png:
+ dirname = os.path.basename(os.path.dirname(self.imname))
+ if not os.path.exists(postprocdir + '/' + dirname):
+ subprocess.call(['mkdir', postprocdir + '/' + dirname])
+
+ basename = os.path.basename(self.imname)
+ fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_meanangprofile.png'
+
+ plt.savefig(fname)
+ plt.close()
+ else:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+###################################################################################################
+# Other functions
+###################################################################################################
+
+
+def quad_interp_radius(r_max, dr, val_list):
+ v_L = val_list[0]
+ v_max = val_list[1]
+ v_R = val_list[2]
+
+ rpk = r_max + dr*(v_L - v_R) / (2 * (v_L + v_R - 2*v_max))
+
+ vpk = 8*v_max*(v_L + v_R) - (v_L - v_R)**2 - 16*v_max**2
+ vpk /= (8*(v_L + v_R - 2*v_max))
+
+ return (rpk, vpk)
+
+
+def compute_ring_profile(im, x0, y0, title="",
+ nrays=NRAYS, nrs=NRS, rmin=RMIN, rmax=RMAX,
+ flux_norm=NORMFLUX, pol_profs=False,
+ interptype='cubic'):
+ """compute a ring profile given a center location
+ """
+
+ rs = np.linspace(0, rmax, nrs)
+ thetas = np.linspace(0, 2*np.pi, nrays)
+
+ factor = 3.254e13/(im.rf**2 * im.psize**2) # convert to brightness temperature
+ imarr = im.imvec.reshape(im.ydim, im.xdim)[::-1] * factor # in brightness temperature K
+ xs = np.arange(im.xdim)*im.psize/ehc.RADPERUAS
+ ys = np.arange(im.ydim)*im.psize/ehc.RADPERUAS
+
+ # TODO: test fiducial images with linear?
+ interp = scipy.interpolate.interp2d(ys, xs, imarr, kind=interptype)
+
+ def ringVals(theta):
+ xxs = x0 - rs*np.sin(theta)
+ yys = y0 + rs*np.cos(theta)
+
+ vals = [interp(xxs[i], yys[i])[0] for i in np.arange(len(rs))]
+ return vals
+
+ profs = []
+ for j in range(nrays):
+ vals = ringVals(thetas[j])
+ profs.append(vals)
+
+ # polarization profiles
+ profsQ = []
+ profsU = []
+ if len(im.qvec) > 0 and len(im.uvec > 0) and pol_profs:
+ qarr = im.qvec.reshape(im.ydim, im.xdim)[::-1] * factor # in brightness temperature K
+ uarr = im.uvec.reshape(im.ydim, im.xdim)[::-1] * factor # in brightness temperature K
+ interpQ = scipy.interpolate.interp2d(ys, xs, qarr, kind=interptype)
+ interpU = scipy.interpolate.interp2d(ys, xs, uarr, kind=interptype)
+
+ def ringValsQ(theta):
+ xxs = x0 - rs*np.sin(theta)
+ yys = y0 + rs*np.cos(theta)
+
+ vals = [interpQ(xxs[i], yys[i])[0] for i in np.arange(len(rs))]
+ return vals
+
+ def ringValsU(theta):
+ xxs = x0 - rs*np.sin(theta)
+ yys = y0 + rs*np.cos(theta)
+
+ vals = [interpU(xxs[i], yys[i])[0] for i in np.arange(len(rs))]
+ return vals
+
+ for j in range(nrays):
+ valsQ = ringValsQ(thetas[j])
+ profsQ.append(valsQ)
+ valsU = ringValsU(thetas[j])
+ profsU.append(valsU)
+
+
+ profiles = Profiles(im, x0, y0, profs, thetas, rmin=rmin, rmax=rmax, interptype=interptype,
+ flux_norm=flux_norm,profsQ=profsQ, profsU=profsU)
+
+ return profiles
+
+
+def findCenter(im,
+ rmin=RMIN, rmax=RMAX,
+ rmin_search=RPRIOR_MIN, rmax_search=RPRIOR_MAX,
+ nrays_search=NRAYS_SEARCH, nrs_search=NRS_SEARCH,
+ fov_search=FOVP_SEARCH, n_search=NSEARCH, flux_norm=NORMFLUX):
+ """Find the ring center by looking at profiles over a given range
+ """
+
+ print("nrays", nrays_search, "nrs", nrs_search, "fov", fov_search, "n", n_search)
+
+ rs = np.linspace(0, rmax_search, nrs_search)
+ dr = rs[-1] - rs[-2]
+ thetas = np.linspace(0, 2*np.pi, nrays_search)
+ factor = 3.254e13/(im.rf**2 * im.psize**2) # convert to brightness temperature
+ imarr = im.imvec.reshape(im.ydim, im.xdim)[::-1] * factor # in brightness temperature K
+ xs = np.arange(im.xdim)*im.psize/ehc.RADPERUAS
+ ys = np.arange(im.ydim)*im.psize/ehc.RADPERUAS
+
+ # TODO: test fiducial images with linear?
+ #iminterp = scipy.interpolate.interp2d(ys, xs, imarr, kind='cubic')
+ iminterp = scipy.interpolate.interp2d(ys, xs, imarr, kind='linear')
+ #iminterp = scipy.interpolate.RegularGridInterpolator((ys,xs),imarr)
+ def objFunc(pos):
+ (x0, y0) = pos
+ diameters = []
+ for j in range(nrays_search):
+ xxs = x0 - rs*np.sin(thetas[j])
+ yys = y0 + rs*np.cos(thetas[j])
+ prof = np.array([iminterp(xxs[i], yys[i])[0] for i in np.arange(len(rs))])
+
+ args = np.argsort(prof)
+ pkpos = args[-1]
+ pk = rs[pkpos]
+ vpk = prof[pkpos]
+ if pkpos > 0 and pkpos < nrs_search-1:
+ vals = [prof[pkpos-1], prof[pkpos], prof[pkpos+1]]
+ pk, vpk = quad_interp_radius(pk, dr, vals)
+
+ diameters.append(2*np.abs(pk))
+
+ # ring size
+ mean,std = (np.mean(diameters), np.std(diameters))
+
+ if mean < rmin_search or mean > rmax_search:
+ return np.inf
+ else:
+ J = np.abs(std/mean)
+ return J
+
+ fovx = im.fovx()/ehc.RADPERUAS
+ fovy = im.fovy()/ehc.RADPERUAS
+
+ # brute force search + fmin finisher to find
+ fovmin_x = (.5-fov_search) * fovx
+ fovmax_x = (.5+fov_search) * fovx
+ fovmin_y = (.5-fov_search) * fovy
+ fovmax_y = (.5+fov_search) * fovy
+ res = scipy.optimize.brute(objFunc, ranges=((fovmin_x, fovmax_x), (fovmin_y, fovmax_y)),
+ Ns=n_search)
+
+ return res
+
+def FindProfile(im, blur=0, pol_profs=False,
+ imsize=IMSIZE, npix=NPIX, rmin=RMIN, rmax=RMAX, nrays=NRAYS, nrs=NRS,
+ rmin_search=RPRIOR_MIN, rmax_search=RPRIOR_MAX,
+ nrays_search=NRAYS_SEARCH, nrs_search=NRS_SEARCH,
+ thresh_search=THRESH, fov_search=FOVP_SEARCH, n_search=NSEARCH,
+ flux_norm=NORMFLUX,center=False):
+ """find the best ring profile for an image object and save results
+ """
+
+ im_raw = im.copy()
+ # blur image if requested
+ if blur > 0:
+ im_raw = im_raw.blur_circ(blur*ehc.RADPERUAS, blur*ehc.RADPERUAS)
+
+ # center image and regrid to uniform pixel size and fox
+ if center:
+ im = di.center_core(im_raw) # TODO -- why isn't this working?
+ else:
+ im = im_raw
+
+ im_search = im.regrid_image(imsize, npix)
+ im = im.regrid_image(imsize, npix)
+
+ # blur image if requested
+ # if blur > 0:
+ # im_search = im_search.blur_circ(blur*ehc.RADPERUAS)
+ # im = im.blur_circ(blur*ehc.RADPERUAS)
+
+ # blur and threshold image FOR SEARCH ONLY
+ # if blur==0:
+ # im_search = im.blur_circ(BLUR_VALUE_MIN*ehc.RADPERUAS)
+ # else:
+ # im_search = im.copy()
+
+ # threshold the search image to 5% of the maximum
+ im_search.imvec[im_search.imvec < thresh_search*np.max(im_search.imvec)] = 0
+
+ # find center
+ res = findCenter(im_search, rmin=rmin, rmax=rmax,
+ rmin_search=rmin_search, rmax_search=rmax_search,
+ nrays_search=nrays_search, nrs_search=nrs_search,
+ fov_search=fov_search, n_search=n_search, flux_norm=flux_norm)
+
+ # compute profiles using the original (regridded, flux centroid centered) image
+ print("compute profile")
+ pp = compute_ring_profile(im, res[0], res[1], nrs=nrs, nrays=nrays,
+ rmin=rmin, rmax=rmax, flux_norm=flux_norm,
+ pol_profs=pol_profs)
+ pp.calc_meanprof_and_stats()
+
+ return pp
+
+def FindProfileSingle(imname, postprocdir,
+ save_files=False, blur=0, aipscc=False, tag='',
+ rerun=True, return_pp=True,
+ imsize=IMSIZE, npix=NPIX, rmin=RMIN, rmax=RMAX, nrays=NRAYS, nrs=NRS,
+ rmin_search=RPRIOR_MIN, rmax_search=RPRIOR_MAX,
+ nrays_search=NRAYS_SEARCH, nrs_search=NRS_SEARCH,
+ thresh_search=THRESH, fov_search=FOVP_SEARCH, n_search=NSEARCH,
+ flux_norm=NORMFLUX,center=False):
+ """find the best ring profile for an image fits file and save results
+ """
+
+ dirname = os.path.basename(os.path.dirname(imname))
+ basename = os.path.basename(imname)
+ txtname = postprocdir + '/' + dirname + '/' + basename[:-5] + tag + '.txt'
+ if rerun is False and os.path.exists(txtname):
+ return -1
+
+ # print("nrays",nrays_search,"nrs",nrs_search,"fov",fov_search)
+ with ploop.HiddenPrints():
+
+ # load image
+ im_raw = load_image(imname, aipscc=aipscc)
+
+ # calculate profile
+ pp = FindProfile(im, blur=blur,
+ imsize=imsize, npix=npix, rmin=rmin, rmax=rmax, nrays=nrays, nrs=nrs,
+ rmin_search=rmin_search, rmax_search=rmax_search,
+ nrays_search=nrays_search, nrs_search=nrs_search,
+ thresh_search=thresh_search, fov_search=fov_search, n_search=n_search,
+ flux_norm=flux_norm,center=center)
+
+
+ # save files
+ if save_files:
+ print("save files")
+ dirname = os.path.basename(os.path.dirname(imname))
+ if not os.path.exists(postprocdir + '/' + dirname):
+ subprocess.call(['mkdir', postprocdir + '/' + dirname])
+
+ basename = os.path.splitext(os.path.basename(imname))[0]
+ txtname = postprocdir + '/' + dirname + '/' + basename + tag + '.txt'
+
+ if os.path.exists(txtname):
+ os.remove(txtname)
+
+ f = open(txtname, 'a')
+ f.write('ring_x0 ' + str(res[0]) + '\n')
+ f.write('ring_y0 ' + str(res[1]) + '\n')
+
+ f.write('ring_diameter ' + str(pp.RingSize1[0]) + '\n')
+ f.write('ring_diameter_sigma ' + str(pp.RingSize1[1]) + '\n')
+
+ f.write('meanprof_ring_diameter ' + str(pp.RingSize2[0]) + '\n')
+ f.write('meanprof_ring_diameter_sigma ' + str(pp.RingSize2[1]) + '\n')
+
+ f.write('ring_orientation: ' + str(pp.RingAngle1[0]) + '\n')
+ f.write('ring_orientation_sigma: ' + str(pp.RingAngle1[1]) + '\n')
+
+ f.write('meanprof_ring_orientation: ' + str(pp.RingAngle2[0]) + '\n')
+ f.write('meanprof_ring_orientation_sigma: ' + str(pp.RingAngle2[1]) + '\n')
+
+ f.write('ring_width: ' + str(pp.RingWidth[0]) + '\n')
+ f.write('ring_width_sigma: ' + str(pp.RingWidth[1]) + '\n')
+
+ f.write('total_flux ' + str(pp.flux) + '\n')
+ f.write('total_ring_flux ' + str(pp.RingFlux) + '\n')
+
+ f.write('ring_asym_1 ' + str(pp.RingAsym1[0]) + '\n')
+ f.write('ring_asym_1_sigma ' + str(pp.RingAsym1[1]) + '\n')
+ f.write('ring_asym_2 ' + str(pp.RingAsym2[0]) + '\n')
+ f.write('ring_brighthalf_over_dimhalf ' + str(pp.RingAsym2[1]) + '\n')
+
+ f.write('in_flux_mean_ring ' + str(pp.in_level) + '\n')
+ f.write('out_flux_mean_ring ' + str(pp.out_level) + '\n')
+ f.write('max_flux_mean_ring ' + str(pp.meanpk) + '\n')
+
+ f.write('max_ring_contrast: ' + str(pp.RingContrast1) + '\n')
+ f.write('mean_ring_contrast: ' + str(pp.RingContrast2) + '\n')
+ f.write('dynamic_range ' + str(pp.dynamic_range) + '\n')
+
+ f.write('norm_factor ' + str(pp.normfactor) + '\n')
+
+ f.write('ring_diameter_med ' + str(pp.RingSize1_med[0]) + '\n')
+ f.write('ring_diameter_medabsdev ' + str(pp.RingSize1_med[1]) + '\n')
+
+ f.write('ring_angle_pol ' + str(pp.RingAnglePol[0]) + '\n')
+ f.write('ring_angle_pol_sigma ' + str(pp.RingAnglePol[1]) + '\n')
+
+ f.write('ring_pol_ratio_p ' + str(pp.RingAsymPol[0]) + '\n')
+ f.write('ring_pol_ratio_m ' + str(pp.RingAsymPol[1]) + '\n')
+
+ f.close()
+
+ # save unwrapped and centered fits image
+ # fitsname = postprocdir + '/' + dirname + '/' + basename[:-5] + tag + '.fits'
+ # fitsname_centered = postprocdir + '/' + dirname + \
+ # '/' + basename[:-5] + tag + '_cent.fits'
+
+ # pp.save_unwrapped(fitsname)
+ # pp.im_center.save_fits(fitsname_centered)
+
+ # save radial profile
+ # radprof_name = postprocdir + '/' + dirname + '/' + basename[:-5] + tag + '_radprof.txt
+ # data=np.hstack((pp.rs.reshape(pp.nrs,1),
+ # pp.meanprof.reshape(pp.nrs,1),
+ # pp.normfactor * pp.meanprof.reshape(pp.nrs,1)))
+ # np.savetxt(radprof_name, data)
+
+ # save angular profile
+ # angprof_name = postprocdir + '/' + dirname + '/' + basename[:-5] + tag+'_angprof.txt'
+ # data=np.hstack((pp.thetas.reshape(pp.nthetas,1),
+ # pp.meanprof_theta.reshape(pp.nthetas,1),
+ # pp.normfactor * pp.meanprof_theta.reshape(pp.nthetas,1)))
+ # np.savetxt(angprof_name, data)
+
+ # pp.plot_unwrapped(save_png=True)
+ # pp.plot_img(save_png=True)
+ # pp.plot_meanprof(save_png=True)
+ # pp.plot_meanprof_theta(save_png=True)
+ # plt.close('all')
+
+ if return_pp:
+ return pp
+ else:
+ del pp
+ return
+
+
+def FindProfiles(foldername, postprocdir, processes=-1,
+ save_files=False, blur=0,
+ aipscc=False, tag='', rerun=True, return_pp=True,
+ imsize=IMSIZE, npix=NPIX, rmin=RMIN, rmax=RMAX, nrays=NRAYS, nrs=NRS,
+ rmin_search=RPRIOR_MIN, rmax_search=RPRIOR_MAX,
+ nrays_search=NRAYS_SEARCH, nrs_search=NRS_SEARCH,
+ thresh_search=THRESH, fov_search=FOVP_SEARCH, n_search=NSEARCH,
+ flux_norm=NORMFLUX
+ ):
+ """find profiles for all image fits files in a directory
+ """
+
+ foldername = os.path.abspath(foldername)
+ imlist = np.array(glob.glob(foldername + '/*.fits'))
+ ext = '.fits'
+
+ # Look for hdf5 files for EHT library runs
+ if len(imlist) == 0:
+ imlist = np.array(glob.glob(foldername + '/*.h5'))
+ ext = '.h5'
+
+ if len(imlist) == 0:
+ print("\nfound no image files in ", foldername)
+ return []
+
+ print("\nfound ", len(imlist), " ", ext, " files in ", foldername)
+
+ imlist = np.sort(imlist)
+ arglist = [[imlist[i], postprocdir,
+ save_files, blur,
+ aipscc, tag, rerun, return_pp,
+ imsize, npix, rmin, rmax, nrays, nrs,
+ rmin_search, rmax_search, nrays_search, nrs_search,
+ thresh_search, fov_search, n_search, flux_norm]
+ for i in range(len(imlist))]
+
+ parloop = ploop.Parloop(FindProfileSingle)
+ pplist = parloop.run_loop(arglist, processes)
+ return pplist
diff --git a/image.py b/image.py
new file mode 100644
index 00000000..73ccafaa
--- /dev/null
+++ b/image.py
@@ -0,0 +1,4320 @@
+# image.py
+# an interferometric image class
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import sys
+import copy
+import math
+import numpy as np
+import numpy.matlib as matlib
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import scipy.optimize as opt
+import scipy.signal
+import scipy.ndimage.filters as filt
+import scipy.interpolate
+from scipy import ndimage as ndi
+
+
+try:
+ from skimage.feature import canny
+ from skimage.transform import hough_circle, hough_circle_peaks
+except ImportError:
+ print("Warning: scikit-image not installed! Cannot use hough transform")
+
+import ehtim.observing.obs_simulate as simobs
+import ehtim.observing.pulses as pulses
+import ehtim.io.save
+import ehtim.io.load
+import ehtim.const_def as ehc
+import ehtim.observing.obs_helpers as obsh
+
+# TODO : add time to all images
+# TODO : add arbitrary center location
+
+###################################################################################################
+# Image object
+###################################################################################################
+
+
+class Image(object):
+
+ """A polarimetric image (in units of Jy/pixel).
+
+ Attributes:
+ pulse (function): The function convolved with the pixel values for continuous image.
+ psize (float): The pixel dimension in radians
+ xdim (int): The number of pixels along the x dimension
+ ydim (int): The number of pixels along the y dimension
+ mjd (int): The integer MJD of the image
+ time (float): The observing time of the image (UTC hours)
+ source (str): The astrophysical source name
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ rf (float): The image frequency in Hz
+
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, or RR,LL,LR,RL for Circular
+ _imdict (dict): The dictionary with the polarimetric images
+ _mflist (list): List of spectral index images (and higher order terms)
+ """
+
+ def __init__(self, image, psize, ra, dec, pa=0.0,
+ polrep='stokes', pol_prim=None,
+ rf=ehc.RF_DEFAULT, pulse=ehc.PULSE_DEFAULT, source=ehc.SOURCE_DEFAULT,
+ mjd=ehc.MJD_DEFAULT, time=0.):
+ """A polarimetric image (in units of Jy/pixel).
+
+ Args:
+ image (numpy.array): The 2D intensity values in a Jy/pixel array
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+
+ psize (float): The pixel dimension in radians
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ pa (float): logical positional angle of the image
+ rf (float): The image frequency in Hz
+ pulse (function): The function convolved with the pixel values for continuous image.
+ source (str): The source name
+ mjd (int): The integer MJD of the image
+ time (float): The observing time of the image (UTC hours)
+
+ Returns:
+ (Image): the Image object
+ """
+
+ if len(image.shape) != 2:
+ raise Exception("image must be a 2D numpy array")
+ if polrep not in ['stokes', 'circ']:
+ raise Exception("only 'stokes' and 'circ' are supported polreps!")
+
+ # Save the image vector
+ imvec = image.flatten()
+
+ if polrep == 'stokes':
+ if pol_prim is None:
+ pol_prim = 'I'
+ if pol_prim == 'I':
+ self._imdict = {'I': imvec, 'Q': np.array([]), 'U': np.array([]), 'V': np.array([])}
+ elif pol_prim == 'V':
+ self._imdict = {'I': np.array([]), 'Q': np.array([]), 'U': np.array([]), 'V': imvec}
+ elif pol_prim == 'Q':
+ self._imdict = {'I': np.array([]), 'Q': imvec, 'U': np.array([]), 'V': np.array([])}
+ elif pol_prim == 'U':
+ self._imdict = {'I': np.array([]), 'Q': np.array([]), 'U': imvec, 'V': np.array([])}
+ else:
+ raise Exception("for polrep=='stokes', pol_prim must be 'I','Q','U', or 'V'!")
+
+ elif polrep == 'circ':
+ if pol_prim is None:
+ print("polrep is 'circ' and no pol_prim specified! Setting pol_prim='RR'")
+ pol_prim = 'RR'
+ if pol_prim == 'RR':
+ self._imdict = {'RR': imvec, 'LL': np.array([]), 'RL': np.array([]), 'LR': np.array([])}
+ elif pol_prim == 'LL':
+ self._imdict = {'RR': np.array([]), 'LL': imvec, 'RL': np.array([]), 'LR': np.array([])}
+ else:
+ raise Exception("for polrep=='circ', pol_prim must be 'RR' or 'LL'!")
+ else:
+ raise Exception("polrep must be 'circ' or 'stokes'!")
+
+ # multifrequency spectral index, curvature arrays
+ # TODO -- higher orders?
+ # TODO -- don't initialize to zero?
+ avec = np.array([]) # np.zeros(imvec.shape)
+ bvec = np.array([]) # np.zeros(imvec.shape)
+ self._mflist = [avec, bvec]
+
+ # Save the image dimension data
+ self.pol_prim = pol_prim
+ self.polrep = polrep
+ self.pulse = pulse
+ self.psize = float(psize)
+ self.xdim = image.shape[1]
+ self.ydim = image.shape[0]
+
+ # Save the image metadata
+ self.ra = float(ra)
+ self.dec = float(dec)
+ self.pa = float(pa)
+ self.rf = float(rf)
+ self.source = str(source)
+ self.mjd = int(mjd)
+
+ # Cached FFT of the image
+ self.cached_fft = {}
+
+ if time > 24:
+ self.mjd += int((time - time % 24) / 24)
+ self.time = float(time % 24)
+ else:
+ self.time = time
+
+ @property
+ def imvec(self):
+ imvec = self._imdict[self.pol_prim]
+ return imvec
+
+ @imvec.setter
+ def imvec(self, vec):
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("imvec size is not consistent with xdim*ydim!")
+
+ self._imdict[self.pol_prim] = vec
+
+ @property
+ def specvec(self):
+ specvec = self._mflist[0]
+ return specvec
+
+ @specvec.setter
+ def specvec(self, vec):
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+ self._mflist[0] = vec
+
+ @property
+ def curvvec(self):
+ curvvec = self._mflist[1]
+ return curvvec
+
+ @curvvec.setter
+ def curvvec(self, vec):
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+ self._mflist[1] = vec
+
+ @property
+ def ivec(self):
+# if self.polrep != 'stokes':
+# raise Exception("ivec is not defined unless self.polrep=='stokes'")
+
+ ivec = np.array([])
+ if self.polrep == 'stokes':
+ ivec = self._imdict['I']
+ elif self.polrep == 'circ':
+ if len(self.rrvec) != 0 and len(self.llvec) != 0:
+ ivec = 0.5 * (self.rrvec + self.llvec)
+
+ return ivec
+
+ @ivec.setter
+ def ivec(self, vec):
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+ if self.polrep != 'stokes':
+ raise Exception("ivec cannot be set unless self.polrep=='stokes'")
+
+ self._imdict['I'] = vec
+
+ @property
+ def qvec(self):
+# if self.polrep != 'stokes':
+# raise Exception("qvec is not defined unless self.polrep=='stokes'")
+
+ qvec = np.array([])
+ if self.polrep == 'stokes':
+ qvec = self._imdict['Q']
+ elif self.polrep == 'circ':
+ if len(self.rlvec) != 0 and len(self.lrvec) != 0:
+ qvec = np.real(0.5 * (self.lrvec + self.rlvec))
+
+ return qvec
+
+ @qvec.setter
+ def qvec(self, vec):
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+ if self.polrep != 'stokes':
+ raise Exception("ivec cannot be set unless self.polrep=='stokes'")
+
+ self._imdict['Q'] = vec
+
+ @property
+ def uvec(self):
+# if self.polrep != 'stokes':
+# raise Exception("qvec is not defined unless self.polrep=='stokes'")
+
+ uvec = np.array([])
+ if self.polrep == 'stokes':
+ uvec = self._imdict['U']
+ elif self.polrep == 'circ':
+ if len(self.rlvec) != 0 and len(self.lrvec) != 0:
+ uvec = np.real(0.5j * (self.lrvec - self.rlvec))
+
+ return uvec
+
+ @uvec.setter
+ def uvec(self, vec):
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+ if self.polrep != 'stokes':
+ raise Exception("uvec cannot be set unless self.polrep=='stokes'")
+
+ self._imdict['U'] = vec
+
+ @property
+ def vvec(self):
+# if self.polrep != 'stokes':
+# raise Exception("vvec is not defined unless self.polrep=='stokes'")
+
+ vvec = np.array([])
+ if self.polrep == 'stokes':
+ vvec = self._imdict['V']
+ elif self.polrep == 'circ':
+ if len(self.rrvec) != 0 and len(self.llvec) != 0:
+ vvec = 0.5 * (self.rrvec - self.llvec)
+
+ return vvec
+
+ @vvec.setter
+ def vvec(self, vec):
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+ if self.polrep != 'stokes':
+ raise Exception("vvec cannot be set unless self.polrep=='stokes'")
+
+ self._imdict['V'] = vec
+
+ @property
+ def rrvec(self):
+# if self.polrep != 'circ':
+# raise Exception("rrvec is not defined unless self.polrep=='circ'")
+
+ rrvec = np.array([])
+ if self.polrep == 'circ':
+ rrvec = self._imdict['RR']
+ elif self.polrep == 'stokes':
+ if len(self.ivec) != 0 and len(self.vvec) != 0:
+ rrvec = (self.ivec + self.vvec)
+
+ return rrvec
+
+ @rrvec.setter
+ def rrvec(self, vec):
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+ if self.polrep != 'circ':
+ raise Exception("rrvec cannot be set unless self.polrep=='circ'")
+
+ self._imdict['RR'] = vec
+
+ @property
+ def llvec(self):
+# if self.polrep != 'circ':
+# raise Exception("llvec is not defined unless self.polrep=='circ'")
+
+ llvec = np.array([])
+ if self.polrep == 'circ':
+ llvec = self._imdict['LL']
+ elif self.polrep == 'stokes':
+ if len(self.ivec) != 0 and len(self.vvec) != 0:
+ llvec = (self.ivec - self.vvec)
+
+ return llvec
+
+ @llvec.setter
+ def llvec(self, vec):
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+ if self.polrep != 'circ':
+ raise Exception("llvec cannot be set unless self.polrep=='circ'")
+
+ self._imdict['LL'] = vec
+
+ @property
+ def rlvec(self):
+# if self.polrep != 'circ':
+# raise Exception("rlvec is not defined unless self.polrep=='circ'")
+
+ rlvec = np.array([])
+ if self.polrep == 'circ':
+ rlvec = self._imdict['RL']
+ elif self.polrep == 'stokes':
+ if len(self.qvec) != 0 and len(self.uvec) != 0:
+ rlvec = (self.qvec + 1j * self.uvec)
+
+ return rlvec
+
+ @rlvec.setter
+ def rlvec(self, vec):
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+ if self.polrep != 'circ':
+ raise Exception("rlvec cannot be set unless self.polrep=='circ'")
+
+ self._imdict['RL'] = vec
+
+ @property
+ def lrvec(self):
+ """Return the imvec of LR"""
+# if self.polrep != 'circ':
+# raise Exception("lrvec is not defined unless self.polrep=='circ'")
+
+ lrvec = np.array([])
+ if self.polrep == 'circ':
+ lrvec = self._imdict['LR']
+ elif self.polrep == 'stokes':
+ if len(self.qvec) != 0 and len(self.uvec) != 0:
+ lrvec = (self.qvec - 1j * self.uvec)
+
+
+ return lrvec
+
+ @lrvec.setter
+ def lrvec(self, vec):
+ """Set the imvec"""
+
+ if len(vec) != self.xdim * self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+ if self.polrep != 'circ':
+ raise Exception("lrvec cannot be set unless self.polrep=='circ'")
+
+ self._imdict['LR'] = vec
+
+ @property
+ def pvec(self):
+ """Return the polarization magnitude for each pixel"""
+ if self.polrep == 'circ':
+ pvec = np.abs(self.rlvec)
+ elif self.polrep == 'stokes':
+ pvec = np.abs(self.qvec + 1j * self.uvec)
+
+ return pvec
+
+ @property
+ def mvec(self):
+ """Return the fractional polarization for each pixel"""
+ if self.polrep == 'circ':
+ mvec = 2 * np.abs(self.rlvec) / (self.rrvec + self.llvec)
+ elif self.polrep == 'stokes':
+ mvec = np.abs(self.qvec + 1j * self.uvec) / self.ivec
+
+ return mvec
+
+ @property
+ def chivec(self):
+ """Return the fractional polarization angle for each pixel"""
+ if self.polrep == 'circ':
+ chivec = 0.5 * np.angle(self.rlvec / (self.rrvec + self.llvec))
+ elif self.polrep == 'stokes':
+ chivec = 0.5 * np.angle((self.qvec + 1j * self.uvec) / self.ivec)
+
+ return chivec
+
+ @property
+ def evpavec(self):
+ """Return the fractional polarization angle for each pixel"""
+
+ return self.chivec
+
+ @property
+ def evec(self):
+ """Return the E mode image vector"""
+ if self.polrep == 'circ':
+ qvec = np.real(0.5 * (self.lrvec + self.rlvec))
+ uvec = np.real(0.5j * (self.lrvec - self.rlvec))
+ elif self.polrep == 'stokes':
+ qvec = self.qvec
+ uvec = self.uvec
+
+ qarr = qvec.reshape((self.ydim, self.xdim))
+ uarr = uvec.reshape((self.ydim, self.xdim))
+ qarr_fft = np.fft.fftshift(np.fft.fft2(qarr))
+ uarr_fft = np.fft.fftshift(np.fft.fft2(uarr))
+
+ # TODO -- check conventions for u,v angle
+ s, t = np.meshgrid(np.flip(np.fft.fftshift(np.fft.fftfreq(self.xdim, d=1.0 / self.xdim))),
+ np.flip(np.fft.fftshift(np.fft.fftfreq(self.ydim, d=1.0 / self.ydim))))
+ s = s + .5 # .5 offset to reference to pixel center
+ t = t + .5 # .5 offset to reference to pixel center
+ uvangle = np.arctan2(s, t)
+
+ # TODO -- these conventions for e,b are from kaminokowski aara 54:227-69 sec 4.1
+ # TODO -- check!
+ cos2arr = np.round(np.cos(2 * uvangle), decimals=10)
+ sin2arr = np.round(np.sin(2 * uvangle), decimals=10)
+ earr_fft = (cos2arr * qarr_fft + sin2arr * uarr_fft)
+
+ earr = np.fft.ifft2(np.fft.ifftshift(earr_fft))
+ return np.real(earr.flatten())
+
+ @property
+ def bvec(self):
+ """Return the B mode image vector"""
+
+ if self.polrep == 'circ':
+ qvec = np.real(0.5 * (self.lrvec + self.rlvec))
+ uvec = np.real(0.5j * (self.lrvec - self.rlvec))
+ elif self.polrep == 'stokes':
+ qvec = self.qvec
+ uvec = self.uvec
+
+ # TODO -- check conventions for u,v angle
+ qarr = qvec.reshape((self.ydim, self.xdim))
+ uarr = uvec.reshape((self.ydim, self.xdim))
+ qarr_fft = np.fft.fftshift(np.fft.fft2(qarr))
+ uarr_fft = np.fft.fftshift(np.fft.fft2(uarr))
+
+ # TODO -- are these conventions for u,v right?
+ s, t = np.meshgrid(np.flip(np.fft.fftshift(np.fft.fftfreq(self.xdim, d=1.0 / self.xdim))),
+ np.flip(np.fft.fftshift(np.fft.fftfreq(self.ydim, d=1.0 / self.ydim))))
+ s = s + .5 # .5 offset to reference to pixel center
+ t = t + .5 # .5 offset to reference to pixel center
+ uvangle = np.arctan2(s, t)
+
+ # TODO -- check!
+ cos2arr = np.round(np.cos(2 * uvangle), decimals=10)
+ sin2arr = np.round(np.sin(2 * uvangle), decimals=10)
+ barr_fft = (-sin2arr * qarr_fft + cos2arr * uarr_fft)
+
+ barr = np.fft.ifft2(np.fft.ifftshift(barr_fft))
+ return np.real(barr.flatten())
+
+ def get_polvec(self, pol):
+ """Get the imvec corresponding to the chosen polarization
+ """
+ if self.polrep == 'stokes' and pol is None:
+ pol = 'I'
+ elif self.polrep == 'circ' and pol is None:
+ pol = 'RR'
+
+ if pol.lower() == 'i':
+ outvec = self.ivec
+ elif pol.lower() == 'q':
+ outvec = self.qvec
+ elif pol.lower() == 'u':
+ outvec = self.uvec
+ elif pol.lower() == 'v':
+ outvec = self.vvec
+ elif pol.lower() == 'rr':
+ outvec = self.rrvec
+ elif pol.lower() == 'll':
+ outvec = self.llvec
+ elif pol.lower() == 'lr':
+ outvec = self.lrvec
+ elif pol.lower() == 'rl':
+ outvec = self.rlvec
+ elif pol.lower() == 'p':
+ outvec = self.pvec
+ elif pol.lower() == 'm':
+ outvec = self.mvec
+ elif pol.lower() == 'chi' or pol.lower() =='evpa':
+ outvec = self.chivec
+ elif pol.lower() == 'e':
+ outvec = self.evec
+ elif pol.lower() == 'b':
+ outvec = self.bvec
+ else:
+ raise Exception("Requested polvec type not recognized!")
+ return outvec
+
+ def image_args(self):
+ """Copy arguments for making a new Image into a list and dictonary
+ """
+
+ arglist = [self.imarr(), self.psize, self.ra, self.dec]
+ argdict = {'rf': self.rf, 'pa': self.pa,
+ 'polrep': self.polrep, 'pol_prim': self.pol_prim,
+ 'pulse': self.pulse, 'source': self.source,
+ 'mjd': self.mjd, 'time': self.time}
+
+ return (arglist, argdict)
+
+ def copy(self):
+ """Return a copy of the Image object.
+
+ Args:
+
+ Returns:
+ (Image): copy of the Image.
+ """
+
+ # Make new image with primary polarization
+ arglist, argdict = self.image_args()
+ newim = Image(*arglist, **argdict)
+
+ # Copy over all polarization images
+ newim.copy_pol_images(self)
+
+ # Copy over spectral index information
+ newim._mflist = copy.deepcopy(self._mflist)
+
+ return newim
+
+ def copy_pol_images(self, old_image):
+ """Copy polarization images from old_image over to self.
+
+ Args:
+ old_image (Image): image object to copy from
+
+ """
+
+ for pol in list(self._imdict.keys()):
+
+ if (pol == self.pol_prim):
+ continue
+
+ polvec = old_image._imdict[pol]
+ if len(polvec):
+ self.add_pol_image(polvec.reshape(self.ydim, self.xdim), pol)
+
+ def add_pol_image(self, image, pol):
+ """Add another image polarization.
+
+ Args:
+ image (list): 2D image frame (possibly complex) in a Jy/pixel array
+ pol (str): The image type: 'I','Q','U','V' for stokes, 'RR','LL','RL','LR' for circ
+ """
+
+ if pol == self.pol_prim:
+ raise Exception("new pol in add_pol_image is the same as pol_prim!")
+ if image.shape != (self.ydim, self.xdim):
+ raise Exception("add_pol_image image shapes incompatible with primary image!")
+ if not (pol in list(self._imdict.keys())):
+ raise Exception("for polrep==%s, pol in add_pol_image in " %
+ self.polrep + ",".join(list(self._imdict.keys())))
+
+ if self.polrep == 'stokes':
+ if pol == 'I':
+ self.ivec = image.flatten()
+ elif pol == 'Q':
+ self.qvec = image.flatten()
+ elif pol == 'U':
+ self.uvec = image.flatten()
+ elif pol == 'V':
+ self.vvec = image.flatten()
+
+ elif self.polrep == 'circ':
+ if pol == 'RR':
+ self.rrvec = image.flatten()
+ elif pol == 'LL':
+ self.llvec = image.flatten()
+ elif pol == 'RL':
+ self.rlvec = image.flatten()
+ elif pol == 'LR':
+ self.lrvec = image.flatten()
+
+ return
+
+ # TODO deprecated -- replace with generic add_pol_image
+ def add_qu(self, qimage, uimage):
+ """Add Stokes Q and U images. self.polrep must be 'stokes'
+
+ Args:
+ qimage (numpy.array): The 2D Stokes Q values in Jy/pixel array
+ uimage (numpy.array): The 2D Stokes U values in Jy/pixel array
+
+ Returns:
+ """
+
+ if self.polrep != 'stokes':
+ raise Exception("polrep must be 'stokes' for add_qu() !")
+ self.add_pol_image(qimage, 'Q')
+ self.add_pol_image(uimage, 'U')
+
+ return
+
+ # TODO deprecated -- replace with generic add_pol_image
+ def add_v(self, vimage):
+ """Add Stokes V image. self.polrep must be 'stokes'
+
+ Args:
+ vimage (numpy.array): The 2D Stokes Q values in Jy/pixel array
+ """
+
+ if self.polrep != 'stokes':
+ raise Exception("polrep must be 'stokes' for add_v() !")
+ self.add_pol_image(vimage, 'V')
+
+ return
+
+ def switch_polrep(self, polrep_out='stokes', pol_prim_out=None):
+ """Return a new image with the polarization representation changed
+ Args:
+ polrep_out (str): the polrep of the output data
+ pol_prim_out (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for circ
+
+ Returns:
+ (Image): new Image object with potentially different polrep
+ """
+
+ if polrep_out not in ['stokes', 'circ']:
+ raise Exception("polrep_out must be either 'stokes' or 'circ'")
+ if pol_prim_out is None:
+ if polrep_out == 'stokes':
+ pol_prim_out = 'I'
+ elif polrep_out == 'circ':
+ pol_prim_out = 'RR'
+
+ # Simply copy if the polrep is unchanged
+ if polrep_out == self.polrep and pol_prim_out == self.pol_prim:
+ return self.copy()
+
+ # Assemble a dictionary of new polarization vectors
+ if polrep_out == 'stokes':
+ if self.polrep == 'stokes':
+ imdict = {'I': self.ivec, 'Q': self.qvec, 'U': self.uvec, 'V': self.vvec}
+ else:
+ if len(self.rrvec) == 0 or len(self.llvec) == 0:
+ ivec = np.array([])
+ vvec = np.array([])
+ else:
+ ivec = 0.5 * (self.rrvec + self.llvec)
+ vvec = 0.5 * (self.rrvec - self.llvec)
+
+ if len(self.rlvec) == 0 or len(self.lrvec) == 0:
+ qvec = np.array([])
+ uvec = np.array([])
+ else:
+ qvec = np.real(0.5 * (self.lrvec + self.rlvec))
+ uvec = np.real(0.5j * (self.lrvec - self.rlvec))
+
+ imdict = {'I': ivec, 'Q': qvec, 'U': uvec, 'V': vvec}
+
+ elif polrep_out == 'circ':
+ if self.polrep == 'circ':
+ imdict = {'RR': self.rrvec, 'LL': self.llvec, 'RL': self.rlvec, 'LR': self.lrvec}
+ else:
+ if len(self.ivec) == 0 or len(self.vvec) == 0:
+ rrvec = np.array([])
+ llvec = np.array([])
+ else:
+ rrvec = (self.ivec + self.vvec)
+ llvec = (self.ivec - self.vvec)
+
+ if len(self.qvec) == 0 or len(self.uvec) == 0:
+ rlvec = np.array([])
+ lrvec = np.array([])
+ else:
+ rlvec = (self.qvec + 1j * self.uvec)
+ lrvec = (self.qvec - 1j * self.uvec)
+
+ imdict = {'RR': rrvec, 'LL': llvec, 'RL': rlvec, 'LR': lrvec}
+
+ # Assemble the new image
+ imvec = imdict[pol_prim_out]
+ if len(imvec) == 0:
+ raise Exception("for switch_polrep to %s with pol_prim_out=%s, \n" %
+ (polrep_out, pol_prim_out) + "output image is not defined")
+
+ arglist, argdict = self.image_args()
+ arglist[0] = imvec.reshape(self.ydim, self.xdim)
+ argdict['polrep'] = polrep_out
+ argdict['pol_prim'] = pol_prim_out
+ newim = Image(*arglist, **argdict)
+
+ # Add in any other polarizations
+ for pol in list(imdict.keys()):
+ if pol == newim.pol_prim:
+ continue
+ polvec = imdict[pol]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim)
+ newim.add_pol_image(polarr, pol)
+
+ # Add in spectral index
+ newim._mflist = copy.deepcopy(self._mflist)
+
+ return newim
+
+ def orth_chi(self):
+ """Rotate the EVPA 90 degrees
+
+ Args:
+
+ Returns:
+ (Image): image with rotated EVPA
+ """
+ im = self.copy()
+ if im.polrep == 'stokes':
+ im.qvec *= -1
+ im.uvec *= -1
+ elif im.polrep == 'circ':
+ im.lrvec *= -1# np.conjugate(im.rlvec)
+ im.rlvec *= -1#np.conjugate(im.rlvec)
+ #im.lrvec = np.conjugate(im.rlvec)
+ #im.rlvec = np.conjugate(im.rlvec)
+
+ return im
+
+ def get_image_mf(self, nu):
+ """Get image at a given frequency given the spectral information in self._mflist
+
+ Args:
+ nu (float): frequency in Hz
+
+ Returns:
+ (Image): image at the desired frequency
+ """
+ # TODO -- what to do about polarization? Faraday rotation?
+
+ nuref = self.rf
+ log_nufrac = np.log(nu / nuref)
+ log_imvec = np.log(self.imvec)
+
+ for n, mfvec in enumerate(self._mflist):
+ if len(mfvec):
+ log_imvec += mfvec * (log_nufrac**(n + 1))
+ imvec = np.exp(log_imvec)
+
+ arglist, argdict = self.image_args()
+ arglist[0] = imvec.reshape(self.ydim, self.xdim)
+ argdict['rf'] = nu
+ outim = Image(*arglist, **argdict)
+
+ # Copy over all polarization images -- unchanged for now
+ outim.copy_pol_images(self)
+
+ # DON'T copy over spectral index information for now
+ # outim._mflist = copy.deepcopy(self._mflist)
+
+ return outim
+
+ def imarr(self, pol=None):
+ """Return the 2D image array of a given pol parameter.
+
+ Args:
+ pol (str): I,Q,U or V for Stokes, or RR,LL,LR,RL for Circ
+
+ Returns:
+ (numpy.array): 2D image array of dimension (ydim, xdim)
+ """
+
+ if pol is None:
+ pol = self.pol_prim
+
+ imvec = self.get_polvec(pol)
+ if len(imvec):
+ imarr = imvec.reshape(self.ydim, self.xdim)
+ else:
+ imarr = np.array([])
+ return imarr
+
+# imarr = np.array([])
+# if self.polrep == 'stokes':
+# if pol == "I" and len(self.ivec):
+# imarr = self.ivec.reshape(self.ydim, self.xdim)
+# elif pol == "Q" and len(self.qvec):
+# imarr = self.qvec.reshape(self.ydim, self.xdim)
+# elif pol == "U" and len(self.uvec):
+# imarr = self.uvec.reshape(self.ydim, self.xdim)
+# elif pol == "V" and len(self.vvec):
+# imarr = self.vvec.reshape(self.ydim, self.xdim)
+# elif self.polrep == 'circ':
+# if pol == "RR" and len(self.rrvec):
+# imarr = self.rrvec.reshape(self.ydim, self.xdim)
+# elif pol == "LL" and len(self.llvec):
+# imarr = self.llvec.reshape(self.ydim, self.xdim)
+# elif pol == "RL" and len(self.rlvec):
+# imarr = self.rlvec.reshape(self.ydim, self.xdim)
+# elif pol == "LR" and len(self.lrvec):
+# imarr = self.lrvec.reshape(self.ydim, self.xdim)
+
+ return imarr
+
+ def sourcevec(self):
+ """Return the source position vector in geocentric coordinates at 0h GMST.
+
+ Args:
+
+ Returns:
+ (numpy.array): normal vector pointing to source in geocentric coordinates (m)
+ """
+
+ sourcevec = np.array([np.cos(self.dec * ehc.DEGREE), 0, np.sin(self.dec * ehc.DEGREE)])
+ return sourcevec
+
+ def fovx(self):
+ """Return the image fov in x direction in radians.
+
+ Args:
+
+ Returns:
+ (float) : image fov in x direction (radian)
+ """
+
+ return self.psize * self.xdim
+
+ def fovy(self):
+ """Returns the image fov in y direction in radians.
+
+ Args:
+
+ Returns:
+ (float) : image fov in y direction (radian)
+ """
+
+ return self.psize * self.ydim
+
+ def total_flux(self):
+ """Return the total flux of the image in Jy.
+
+ Args:
+
+ Returns:
+ (float) : image total flux (Jy)
+ """
+ if self.polrep == 'stokes':
+ flux = np.sum(self.ivec)
+ elif self.polrep == 'circ':
+ flux = 0.5 * (np.sum(self.rrvec) + np.sum(self.llvec))
+
+ return flux
+
+ def lin_polfrac(self):
+ """Return the total fractional linear polarized flux
+
+ Args:
+
+ Returns:
+ (float) : image fractional linear polarized flux
+ """
+ if self.polrep == 'stokes':
+ frac = np.abs(np.sum(self.qvec + 1j * self.uvec)) / np.abs(np.sum(self.ivec))
+ elif self.polrep == 'circ':
+ frac = 2 * np.abs(np.sum(self.rlvec)) / np.abs(np.sum(self.rrvec + self.llvec))
+
+ return frac
+
+ def evpa(self):
+ """Return the total evpa
+
+ Args:
+
+ Returns:
+ (float) : image average evpa (E of N) in radian
+ """
+ if self.polrep == 'stokes':
+ frac = 0.5 * np.angle(np.sum(self.qvec + 1j * self.uvec))
+ elif self.polrep == 'circ':
+ frac = np.angle(np.sum(self.rlvec))
+
+ return frac
+
+ def circ_polfrac(self):
+ """Return the total fractional circular polarized flux
+
+ Args:
+
+ Returns:
+ (float) : image fractional circular polarized flux
+ """
+ if self.polrep == 'stokes':
+ frac = np.sum(self.vvec) / np.abs(np.sum(self.ivec))
+ elif self.polrep == 'circ':
+ frac = np.sum(self.rrvec - self.llvec) / np.abs(np.sum(self.rrvec + self.llvec))
+
+ return frac
+
+ def center(self, pol=None):
+ """Center the image based on the coordinates of the centroid().
+ A non-integer shift is used, which wraps the image when rotating.
+
+ Args:
+ pol (str): The polarization for which to find the image centroid
+
+ Returns:
+ (np.array): centroid positions (x0,y0) in radians
+ """
+
+ return self.shift_fft(-self.centroid(pol=pol))
+
+ def centroid(self, pol=None):
+ """Compute the location of the image centroid (corresponding to the polarization pol)
+
+ Args:
+ pol (str): The polarization for which to find the image centroid
+
+ Returns:
+ (np.array): centroid positions (x0,y0) in radians
+ """
+
+ if pol is None:
+ pol = self.pol_prim
+ imvec = self.get_polvec(pol)
+ pdim = self.psize
+
+# if not (pol in list(self._imdict.keys())):
+# raise Exception("for polrep==%s, pol must be in " %
+# self.polrep + ",".join(list(self._imdict.keys())))
+# imvec = self._imdict[pol]
+
+ if len(imvec):
+ xlist = np.arange(0, -self.xdim, -1) * pdim + (pdim * self.xdim) / 2.0 - pdim / 2.0
+ ylist = np.arange(0, -self.ydim, -1) * pdim + (pdim * self.ydim) / 2.0 - pdim / 2.0
+ x0 = np.sum(np.outer(0.0 * ylist + 1.0, xlist).ravel() * imvec) / np.sum(imvec)
+ y0 = np.sum(np.outer(ylist, 0.0 * xlist + 1.0).ravel() * imvec) / np.sum(imvec)
+ centroid = np.array([x0, y0])
+ else:
+ raise Exception("No %s image found!" % pol)
+
+ return centroid
+
+ def pad(self, fovx, fovy):
+ """Pad an image to new fov_x by fov_y in radian.
+ Args:
+ fovx (float): new fov in x dimension (rad)
+ fovy (float): new fov in y dimension (rad)
+
+ Returns:
+ im_pad (Image): padded image
+ """
+
+ # Find pad widths
+ fovoldx = self.fovx()
+ fovoldy = self.fovy()
+ padx = int(0.5 * (fovx - fovoldx) / self.psize)
+ pady = int(0.5 * (fovy - fovoldy) / self.psize)
+
+ # Pad main image vector
+ imarr = self.imvec.reshape(self.ydim, self.xdim)
+ imarr = np.pad(imarr, ((pady, pady), (padx, padx)), 'constant')
+
+ # Make new image
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr
+ outim = Image(*arglist, **argdict)
+
+ # Pad all polarizations and copy over
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polvec = self._imdict[pol]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim)
+ polarr = np.pad(polarr, ((pady, pady), (padx, padx)), 'constant')
+ outim.add_pol_image(polarr, pol)
+
+ # Add in spectral index
+ mflist_out = []
+ for mfvec in self._mflist:
+ if len(mfvec):
+ mfarr = mfvec.reshape(self.ydim, self.xdim)
+ mfarr = np.pad(mfarr, ((pady, pady), (padx, padx)), 'constant')
+ mfvec_out = mfarr.flatten()
+ else:
+ mfvec_out = np.array([])
+ mflist_out.append(mfvec_out)
+ outim._mflist = mflist_out
+
+ return outim
+
+ def resample_square(self, xdim_new, ker_size=5):
+ """Exactly resample a square image to new dimensions using the pulse function.
+
+ Args:
+ xdim_new (int): new pixel dimension
+ ker_size (int): kernel size for resampling
+
+ Returns:
+ im_resampled (Image): resampled image
+ """
+
+ if self.xdim != self.ydim:
+ raise Exception("Image must be square to use Image.resample_square!")
+ if self.pulse == pulses.deltaPulse2D:
+ raise Exception("Image.resample_squre does not work with delta pulses!")
+
+ ydim_new = xdim_new
+ fov = self.xdim * self.psize
+ psize_new = float(fov) / float(xdim_new)
+
+ # Define an interpolation function using the pulse
+ ij = np.array([[[i * self.psize + (self.psize * self.xdim) / 2.0 - self.psize / 2.0,
+ j * self.psize + (self.psize * self.ydim) / 2.0 - self.psize / 2.0]
+ for i in np.arange(0, -self.xdim, -1)]
+ for j in np.arange(0, -self.ydim, -1)]).reshape((self.xdim * self.ydim, 2))
+
+ def im_new_val(imvec, x_idx, y_idx):
+ x = x_idx * psize_new + (psize_new * xdim_new) / 2.0 - psize_new / 2.0
+ y = y_idx * psize_new + (psize_new * ydim_new) / 2.0 - psize_new / 2.0
+ mask = (((x - ker_size * self.psize / 2.0) < ij[:, 0]) *
+ (ij[:, 0] < (x + ker_size * self.psize / 2.0)) *
+ ((y - ker_size * self.psize / 2.0) < ij[:, 1]) *
+ (ij[:, 1] < (y + ker_size * self.psize / 2.0))
+ ).flatten()
+ interp = np.sum([imvec[n] * self.pulse(x - ij[n, 0], y - ij[n, 1], self.psize, dom="I")
+ for n in np.arange(len(imvec))[mask]])
+ return interp
+
+ def im_new(imvec):
+ imarr_new = np.array([[im_new_val(imvec, x_idx, y_idx)
+ for x_idx in np.arange(0, -xdim_new, -1)]
+ for y_idx in np.arange(0, -ydim_new, -1)])
+ return imarr_new
+
+ # Compute new primary image vector
+ imarr_new = im_new(self.imvec)
+
+ # Normalize
+ scaling = np.sum(self.imvec) / np.sum(imarr_new)
+ imarr_new *= scaling
+
+ # Make new image
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr_new
+ arglist[1] = psize_new
+ outim = Image(*arglist, **argdict)
+
+ # Interpolate all polarizations and copy over
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polvec = self._imdict[pol]
+ if len(polvec):
+ polarr_new = im_new(polvec)
+ polarr_new *= scaling
+ outim.add_pol_image(polarr_new, pol)
+
+ # Interpolate spectral index and copy over
+ mflist_out = []
+ for mfvec in self._mflist:
+ print("WARNING: resample_squre not debugged for spectral index resampling!")
+ if len(mfvec):
+ mfarr = im_new(mfvec)
+ mfvec_out = mfarr.flatten()
+ else:
+ mfvec_out = np.array([])
+ mflist_out.append(mfvec_out)
+ outim._mflist = mflist_out
+
+ return outim
+
+ def regrid_image(self, targetfov, npix, interp='linear'):
+ """Resample the image to new (square) dimensions.
+
+ Args:
+ targetfov (float): new field of view (radian)
+ npix (int): new pixel dimension
+ interp ('linear', 'cubic', 'quintic'): type of interpolation. default is linear
+
+ Returns:
+ (Image): resampled image
+ """
+
+ psize_new = float(targetfov) / float(npix)
+ fov_x = self.fovx()
+ fov_y = self.fovy()
+
+ # define an interpolation function
+ x = np.linspace(-fov_x / 2, fov_x / 2, self.xdim)
+ y = np.linspace(-fov_y / 2, fov_y / 2, self.ydim)
+
+ xtarget = np.linspace(-targetfov / 2, targetfov / 2, npix)
+ ytarget = np.linspace(-targetfov / 2, targetfov / 2, npix)
+
+ def interp_imvec(imvec, specind=False):
+ if np.any(np.imag(imvec) != 0):
+ return interp_imvec(np.real(imvec)) + 1j * interp_imvec(np.imag(imvec))
+
+ interpfunc = scipy.interpolate.interp2d(y, x, np.reshape(imvec, (self.ydim, self.xdim)),
+ kind=interp)
+ tmpimg = interpfunc(ytarget, xtarget)
+ tmpimg[np.abs(xtarget) > fov_x / 2., :] = 0.0
+ tmpimg[:, np.abs(ytarget) > fov_y / 2.] = 0.0
+
+ if not specind: # adjust pixel size if not a spectral index map
+ tmpimg = tmpimg * (psize_new)**2 / self.psize**2
+ return tmpimg
+
+ # Make new image
+ imarr_new = interp_imvec(self.imvec)
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr_new
+ arglist[1] = psize_new
+
+ outim = Image(*arglist, **argdict)
+
+ # Interpolate all polarizations and copy over
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polvec = self._imdict[pol]
+ if len(polvec):
+ polarr_new = interp_imvec(polvec)
+ outim.add_pol_image(polarr_new, pol)
+
+ # Interpolate spectral index and copy over
+ mflist_out = []
+ for mfvec in self._mflist:
+ if len(mfvec):
+ mfarr = interp_imvec(mfvec, specind=True)
+ mfvec_out = mfarr.flatten()
+ else:
+ mfvec_out = np.array([])
+ mflist_out.append(mfvec_out)
+ outim._mflist = mflist_out
+
+ return outim
+
+ def rotate(self, angle, interp='cubic'):
+ """Rotate the image counterclockwise by the specified angle.
+
+ Args:
+ angle (float): CCW angle to rotate the image (radian)
+ interp ('linear', 'cubic', 'quintic'): type of interpolation. default is cubic
+ Returns:
+ (Image): resampled image
+ """
+
+ order = 3
+ if interp == 'linear':
+ order = 1
+ elif interp == 'cubic':
+ order = 3
+ elif interp == 'quintic':
+ order = 5
+
+ # Define an interpolation function
+ def rot_imvec(imvec):
+ if np.any(np.imag(imvec) != 0):
+ return rot_imvec(np.real(imvec)) + 1j * rot_imvec(np.imag(imvec))
+ imarr_rot = scipy.ndimage.interpolation.rotate(imvec.reshape((self.ydim, self.xdim)),
+ angle * 180.0 / np.pi, reshape=False,
+ order=order, mode='constant',
+ cval=0.0, prefilter=True)
+
+ return imarr_rot
+
+ # pol_prim needs to be RR,LL,I,or V for a simple rotation to work!
+ if(not (self.pol_prim in ['RR', 'LL', 'I', 'V'])):
+ raise Exception("im.pol_prim must be a scalar ('I','V','RR','LL') for simple rotation!")
+
+ # Make new image
+ imarr_rot = rot_imvec(self.imvec)
+
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr_rot
+ outim = Image(*arglist, **argdict)
+
+ # Rotate all polarizations and copy over
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polvec = self._imdict[pol]
+ if len(polvec):
+ polarr_rot = rot_imvec(polvec)
+ if pol == 'RL':
+ polarr_rot *= np.exp(1j * 2 * angle)
+ elif pol == 'LR':
+ polarr_rot *= np.exp(-1j * 2 * angle)
+ elif pol == 'Q':
+ polarr_rot = polarr_rot + 1j * rot_imvec(self._imdict['U'])
+ polarr_rot = np.real(np.exp(1j * 2 * angle) * polarr_rot)
+ elif pol == 'U':
+ polarr_rot = rot_imvec(self._imdict['Q']) + 1j * polarr_rot
+ polarr_rot = np.imag(np.exp(1j * 2 * angle) * polarr_rot)
+
+ outim.add_pol_image(polarr_rot, pol)
+
+ # Rotate spectral index and copy over
+ mflist_out = []
+ for mfvec in self._mflist:
+ if len(mfvec):
+ mfarr = rot_imvec(mfvec)
+ mfvec_out = mfarr.flatten()
+ else:
+ mfvec_out = np.array([])
+ mflist_out.append(mfvec_out)
+ outim._mflist = mflist_out
+
+ return outim
+
+ def shift(self, shiftidx):
+ """Shift the image by a given number of pixels.
+
+ Args:
+ shiftidx (list): pixel offsets [x_offset, y_offset] for the image shift
+
+ Returns:
+ (Image): shifted images
+ """
+
+ # Define shifting function
+ def shift_imvec(imvec):
+ im_shift = np.roll(imvec.reshape(self.ydim, self.xdim), shiftidx[0], axis=0)
+ im_shift = np.roll(im_shift, shiftidx[1], axis=1)
+ return im_shift
+
+ # Make new image
+ imarr_shift = shift_imvec(self.imvec)
+
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr_shift
+ outim = Image(*arglist, **argdict)
+
+ # Shift all polarizations and copy over
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polvec = self._imdict[pol]
+ if len(polvec):
+ polarr_shift = shift_imvec(polvec)
+ outim.add_pol_image(polarr_shift, pol)
+
+ # Shift spectral index and copy over
+ mflist_out = []
+ for mfvec in self._mflist:
+ if len(mfvec):
+ mfarr = shift_imvec(mfvec)
+ mfvec_out = mfarr.flatten()
+ else:
+ mfvec_out = np.array([])
+ mflist_out.append(mfvec_out)
+ outim._mflist = mflist_out
+
+ return outim
+
+ def shift_fft(self, shift):
+ """Shift the image by a given vector in radians.
+ This allows non-integer pixel shifts, via FFT.
+
+ Args:
+ shift (list): offsets [x_offset, y_offset] for the image shift in radians
+
+ Returns:
+ (Image): shifted image
+ """
+
+ Nx = self.xdim
+ Ny = self.ydim
+
+ [dx_pixels, dy_pixels] = np.array(shift) / self.psize
+
+ s, t = np.meshgrid(np.fft.fftfreq(Nx, d=1.0 / Nx), np.fft.fftfreq(Ny, d=1.0 / Ny))
+ rotate = np.exp(2.0 * np.pi * 1j * (s * dx_pixels + t * dy_pixels) / float(Nx))
+
+ imarr = self.imvec.reshape((Ny, Nx))
+ imarr_rotate = np.real(np.fft.ifft2(np.fft.fft2(imarr) * rotate))
+
+ # make new Image
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr_rotate
+ outim = Image(*arglist, **argdict)
+
+ # Shift all polarizations and copy over
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polvec = self._imdict[pol]
+ if len(polvec):
+ imarr = polvec.reshape((Ny, Nx))
+ imarr_rotate = np.real(np.fft.ifft2(np.fft.fft2(imarr) * rotate))
+ outim.add_pol_image(imarr_rotate, pol)
+
+ # Shift spectral index and copy over
+ mflist_out = []
+ for mfvec in self._mflist:
+ if len(mfvec):
+ mfarr = mfvec.reshape((Ny, Nx))
+ mfarr = np.real(np.fft.ifft2(np.fft.fft2(mfarr) * rotate))
+ mfvec_out = mfarr.flatten()
+ else:
+ mfvec_out = np.array([])
+ mflist_out.append(mfvec_out)
+ outim._mflist = mflist_out
+
+ return outim
+
+ def blur_gauss(self, beamparams, frac=1., frac_pol=0):
+ """Blur image with a Gaussian beam w/ beamparams [fwhm_max, fwhm_min, theta] in radians.
+
+ Args:
+ beamparams (list): [fwhm_maj, fwhm_min, theta, x, y] in radians
+ frac (float): fractional beam size for blurring the main image
+ frac_pol (float): fractional beam size for blurring the other polarizations
+
+ Returns:
+ (Image): output image
+ """
+
+ if frac <= 0.0 or beamparams[0] <= 0:
+ return self.copy()
+
+ # Make a Gaussian image
+ xlist = np.arange(0, -self.xdim, -1) * self.psize + \
+ (self.psize * self.xdim) / 2.0 - self.psize / 2.0
+ ylist = np.arange(0, -self.ydim, -1) * self.psize + \
+ (self.psize * self.ydim) / 2.0 - self.psize / 2.0
+ sigma_maj = beamparams[0] / (2. * np.sqrt(2. * np.log(2.)))
+ sigma_min = beamparams[1] / (2. * np.sqrt(2. * np.log(2.)))
+ cth = np.cos(beamparams[2])
+ sth = np.sin(beamparams[2])
+
+ def gaussim(blurfrac):
+ gauss = np.array([[np.exp(-(j * cth + i * sth)**2 / (2 * (blurfrac * sigma_maj)**2) -
+ (i * cth - j * sth)**2 / (2 * (blurfrac * sigma_min)**2))
+ for i in xlist]
+ for j in ylist])
+ gauss = gauss[0:self.ydim, 0:self.xdim]
+ gauss = gauss / np.sum(gauss) # normalize to 1
+ return gauss
+
+ gauss = gaussim(frac)
+ if frac_pol:
+ gausspol = gaussim(frac_pol)
+
+ # Define a convolution function
+ def blur(imarr, gauss):
+ imarr_blur = scipy.signal.fftconvolve(gauss, imarr, mode='same')
+ return imarr_blur
+
+ # Convolve the primary image
+ imarr = (self.imvec).reshape(self.ydim, self.xdim).astype('float64')
+ imarr_blur = blur(imarr, gauss)
+
+ # Make new image object
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr_blur
+ outim = Image(*arglist, **argdict)
+
+ # Blur all polarizations and copy over
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polvec = self._imdict[pol]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim).astype('float64')
+ if frac_pol:
+ polarr = blur(polarr, gausspol)
+ outim.add_pol_image(polarr, pol)
+
+ # Blur spectral index and copy over
+ mflist_out = []
+ for mfvec in self._mflist:
+ if len(mfvec):
+ mfarr = mfvec.reshape(self.ydim, self.xdim).astype('float64')
+ mfarr = blur(mfarr, gauss)
+ mfvec_out = mfarr.flatten()
+ else:
+ mfvec_out = np.array([])
+ mflist_out.append(mfvec_out)
+ outim._mflist = mflist_out
+
+ return outim
+
+ def blur_circ(self, fwhm_i, fwhm_pol=0, filttype='gauss'):
+ """Apply a circular gaussian filter to the image, with FWHM in radians.
+
+ Args:
+ fwhm_i (float): circular beam size for Stokes I blurring in radian
+ fwhm_pol (float): circular beam size for Stokes Q,U,V blurring in radian
+ filttype (str): "gauss" or "butter"
+
+ Returns:
+ (Image): output image
+ """
+
+ sigma = fwhm_i / (2. * np.sqrt(2. * np.log(2.)))
+ sigmap = sigma / self.psize
+ fwhmp = fwhm_i / self.psize
+ fwhmp_pol = fwhm_pol / self.psize
+
+ # Define a convolution function
+ def blur_gauss(imarr, fwhm):
+ sigma = fwhmp / (2. * np.sqrt(2. * np.log(2.)))
+ if np.any(np.imag(imarr) != 0):
+ return blur(np.real(imarr), sigma) + 1j * blur(np.imag(imarr), sigma)
+ imarr_blur = filt.gaussian_filter(imarr, (sigma, sigma))
+ return imarr_blur
+
+ def blur_butter(imarr, size):
+
+ #bfilt = scipy.signal.butter(2,freq,btype='low',output='sos')
+ #if np.any(np.imag(imarr) != 0):
+ # return blur(np.real(imarr), sigma) + 1j * blur(np.imag(imarr), sigma)
+
+ #imarr_blur = scipy.signal.sosfilt(bfilt, imarr, axis=0)
+ #imarr_blur = scipy.signal.sosfilt(bfilt, imarr_blur, axis=1)
+
+ if size==0:
+ return imarr
+
+ cutoff = 1/size
+ Nx = self.xdim
+ Ny = self.ydim
+
+ s, t = np.meshgrid(np.fft.fftfreq(Nx, d=1.0 ), np.fft.fftfreq(Ny, d=1.0 ))
+ #s, t = np.meshgrid(np.fft.fftfreq(Nx, d=1.0 / Nx), np.fft.fftfreq(Ny, d=1.0 / Ny))
+ r = np.sqrt(s**2 + t**2)
+
+ bfilt = 1./np.sqrt(1 + (r/cutoff)**4)
+
+ imarr = self.imvec.reshape((Ny, Nx))
+ imarr_filt = np.real(np.fft.ifft2(np.fft.fft2(imarr) * bfilt))
+ return imarr_filt
+
+
+ if filttype=='gauss':
+ blur = blur_gauss
+ elif filttype=='butter':
+ blur = blur_butter
+ else:
+ raise Exception("filttype not recognized in blur_circ!")
+
+ # Blur the primary image
+ imarr = self.imvec.reshape(self.ydim, self.xdim)
+ imarr_blur = blur(imarr, fwhmp)
+
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr_blur
+ outim = Image(*arglist, **argdict)
+
+ # Blur spectral index and copy over
+ mflist_out = []
+ for mfvec in self._mflist:
+ if len(mfvec):
+ mfarr = mfvec.reshape(self.ydim, self.xdim)
+ mfarr = blur(mfarr, fwhmp)
+ mfvec_out = mfarr.flatten()
+ else:
+ mfvec_out = np.array([])
+ mflist_out.append(mfvec_out)
+ outim._mflist = mflist_out
+
+ # Blur all polarizations and copy overi
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polvec = self._imdict[pol]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim)
+ if fwhm_pol:
+ #print("Blurring polarization")
+ polarr = blur(polarr, fwhmp_pol)
+ outim.add_pol_image(polarr, pol)
+
+
+ return outim
+
+ def blur_mf(self, freqs, fwhm, fit_order=1, filttype='gauss'):
+ """Blur image correctly across multiple frequencies
+ WARNING: does not currently do polarization correctly!
+
+ Args:
+ freqs (float): Frequencies to include in the blurring & spectral index fit
+ fwhm (float): circular beam size
+ fit_order (int): how many orders to fit spectrum: 1 or 2
+ filttype (str): "gauss" or "butter"
+
+ Returns:
+ (Image): output image
+
+ """
+ if fit_order not in [1,2]:
+ raise Exception("fit_order must be 1 or 2 in blur_mf!")
+
+ reffreq = self.rf
+
+ # remove any zeros in the images
+ imlist = [self.get_image_mf(rf).blur_circ(kernel, filttype=filttype) for rf in freqs]
+ for image in imlist:
+ image.imvec[image.imvec<=0] = np.min(image.imvec[image.imvec!=0])
+
+ xfit = np.log(np.array(freqs)/reffreq)
+ yfit = np.log(np.array([im.imvec for im in imlist]))
+
+ if fit_order == 2:
+ coeffs = np.polyfit(xfit,yfit,2)
+ beta = coeffs[0]
+ alpha = coeffs[1]
+ elif fit_order == 1:
+ coeffs = np.polyfit(xfit,yfit,1)
+ alpha = coeffs[0]
+ beta = 0*alpha
+ else:
+ alpha = 0*yfit
+ beta = 0*yfit
+
+ outim = self.blur_circ(kernel, filttype=filttype)
+ outim.specvec = alpha
+ outim.curvvec = beta
+ return outim
+
+ def grad(self, gradtype='abs'):
+ """Return the gradient image
+
+ Args:
+ gradtype (str): 'x','y',or 'abs' for the image gradient dimension
+
+ Returns:
+ Image : an image object containing the gradient image
+ """
+
+ # Define the desired gradient function
+ def gradim(imvec):
+ if np.any(np.imag(imvec) != 0):
+ return gradim(np.real(imvec)) + 1j * gradim(np.imag(imvec))
+
+ imarr = imvec.reshape(self.ydim, self.xdim)
+
+ #sx = ndi.sobel(imarr, axis=0, mode='constant')
+ #sy = ndi.sobel(imarr, axis=1, mode='constant')
+ sx = ndi.sobel(imarr, axis=0, mode='nearest')
+ sy = ndi.sobel(imarr, axis=1, mode='nearest')
+
+ # TODO: are these in the right order??
+ if gradtype == 'x':
+ gradarr = sx
+ if gradtype == 'y':
+ gradarr = sy
+ else:
+ gradarr = np.hypot(sx, sy)
+ return gradarr
+
+ # Find the gradient for the primary image
+ gradarr = gradim(self.imvec)
+
+ arglist, argdict = self.image_args()
+ arglist[0] = gradarr
+ outim = Image(*arglist, **argdict)
+
+ # Find the gradient for all polarizations and copy over
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polvec = self._imdict[pol]
+ if len(polvec):
+ gradarr = gradim(polvec)
+ outim.add_pol_image(gradarr, pol)
+
+ # Find the spectral index gradients and copy over
+ mflist_out = []
+ for mfvec in self._mflist:
+ if len(mfvec):
+ mfarr = gradim(mfvec)
+ mfvec_out = mfarr.flatten()
+ else:
+ mfvec_out = np.array([])
+ mflist_out.append(mfvec_out)
+ outim._mflist = mflist_out
+
+ return outim
+
+ def mask(self, cutoff=0.05, beamparams=None, frac=0.0):
+ """Produce an image mask that shows all pixels above the specified cutoff frac of the max
+ Works off the primary image
+
+ Args:
+ cutoff (float): mask pixels with intensities greater than cuttoff * max
+ beamparams (list): either [fwhm_maj, fwhm_min, pos_ang] or a single fwhm
+ frac (float): the fraction of nominal beam to blur with
+
+ Returns:
+ (Image): output mask image
+
+ """
+
+ # Blur the image
+ if beamparams is not None:
+ try:
+ len(beamparams)
+ except TypeError:
+ beamparams = [beamparams, beamparams, 0]
+ if len(beamparams) == 3:
+ mask = self.blur_gauss(beamparams, frac)
+ else:
+ raise Exception("beamparams should be a length 3 array [maj, min, posang]!")
+ else:
+ mask = self.copy()
+
+ # Mask pixels outside the desired intensity range
+ maxval = np.max(mask.imvec)
+ minval = np.min(mask.imvec)
+ intensityrange = maxval - minval
+ thresh = intensityrange * cutoff + minval
+ maskvec = (mask.imvec > thresh).astype(int)
+
+ # make the primary image
+ maskarr = maskvec.reshape(mask.ydim, mask.xdim)
+
+ arglist, argdict = self.image_args()
+ arglist[0] = maskarr
+ mask = Image(*arglist, **argdict)
+
+ # Replace all polarization imvecs with mask
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ mask.add_pol_image(maskarr, pol)
+
+ # No spectral index information in mask
+
+ return mask
+
+ # TODO make this work with a mask image of different dimensions & fov
+ def apply_mask(self, mask_im, fill_val=0.):
+ """Apply a mask to the image
+
+ Args:
+ mask_im (Image): a mask image with the same dimensions as the Image
+ fill_val (float): masked pixels of all polarizations are set to this value
+
+ Returns:
+ (Image): the masked image
+
+ """
+ if ((self.psize != mask_im.psize) or
+ (self.xdim != mask_im.xdim) or (self.ydim != mask_im.ydim)):
+ raise Exception("mask image does not match dimensions of the current image!")
+
+ # Get the mask vector
+ maskvec = mask_im.imvec.astype(bool)
+ maskvec[maskvec <= 0] = 0
+ maskvec[maskvec > 0] = 1
+
+ # Mask the primary image
+ imvec = self.imvec
+ imvec[~maskvec] = fill_val
+ imarr = imvec.reshape(self.ydim, self.xdim)
+
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr
+ outim = Image(*arglist, **argdict)
+
+ # Apply mask to all polarizations and copy over
+ for pol in list(self._imdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polvec = self._imdict[pol]
+ if len(polvec):
+ polvec[~maskvec] = fill_val
+ polarr = polvec.reshape(self.ydim, self.xdim)
+ outim.add_pol_image(polarr, pol)
+
+ # Apply mask to spectral index and copy over
+ mflist_out = []
+ for mfvec in self._mflist:
+ if len(mfvec):
+ mfvec_out = copy.deepcopy(mfvec)
+ mfvec_out[~maskvec] = 0.
+ else:
+ mfvec_out = np.array([])
+ mflist_out.append(mfvec_out)
+ outim._mflist = mflist_out
+
+ return outim
+
+ def threshold(self, cutoff=0.05, beamparams=None, frac=0.0, fill_val=None):
+ """Apply a hard threshold to the primary polarization image.
+ Leave other polarizations untouched.
+
+ Args:
+ cutoff (float): Mask pixels with intensities greater than cuttoff * max
+ beamparams (list): either [fwhm_maj, fwhm_min, pos_ang] or a single fwhm
+ frac (float): the fraction of nominal beam to blur with
+ fill_val (float): masked pixels are set to this value.
+ If fill_val==None, they are set to the min unmasked value
+
+ Returns:
+ (Image): output mask image
+ """
+
+ if fill_val is None or fill_val is False:
+ maxval = np.max(self.imvec)
+ minval = np.min(self.imvec)
+ intensityrange = maxval - minval
+ fill_val = (intensityrange * cutoff + minval)
+
+ mask = self.mask(cutoff=cutoff, beamparams=beamparams, frac=frac)
+ out = self.apply_mask(mask, fill_val=fill_val)
+ return out
+
+ def add_flat(self, flux, pol=None):
+ """Add a flat background flux to the main polarization image.
+
+ Args:
+ flux (float): total flux to add to image
+ pol (str): the polarization to add the flux to. None defaults to pol_prim.
+ Returns:
+ (Image): output image
+ """
+
+ if pol is None:
+ pol = self.pol_prim
+ if not (pol in list(self._imdict.keys())):
+ raise Exception("for polrep==%s, pol must be in " %
+ self.polrep + ",".join(list(self._imdict.keys())))
+ if not len(self._imdict[pol]):
+ raise Exception("no image for pol %s" % pol)
+
+ # Make a flat image array
+ flatarr = ((flux / float(len(self.imvec))) * np.ones(len(self.imvec)))
+ flatarr = flatarr.reshape(self.ydim, self.xdim)
+
+ # Add to the main image and create the new image object
+ imarr = self.imvec.reshape(self.ydim, self.xdim).copy()
+ if pol == self.pol_prim:
+ imarr += flatarr
+
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr
+ outim = Image(*arglist, **argdict)
+
+ # Copy over the rest of the polarizations
+ for pol2 in list(self._imdict.keys()):
+ if pol2 == self.pol_prim:
+ continue
+ polvec = self._imdict[pol2]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim).copy()
+ if pol2 == pol:
+ polarr += flatarr
+ outim.add_pol_image(polarr, pol2)
+
+ # Copy the spectral index (unchanged)
+ outim._mflist = copy.deepcopy(self._mflist)
+
+ return outim
+
+ def add_tophat(self, flux, radius, pol=None):
+ """Add centered tophat flux to the Stokes I image inside a given radius.
+
+ Args:
+ flux (float): total flux to add to image
+ radius (float): radius of top hat flux in radians
+ pol (str): the polarization to add the flux to. None defaults to pol_prim
+
+ Returns:
+ (Image): output image
+ """
+
+ if pol is None:
+ pol = self.pol_prim
+ if not (pol in list(self._imdict.keys())):
+ raise Exception("for polrep==%s, pol must be in " %
+ self.polrep + ",".join(list(self._imdict.keys())))
+ if not len(self._imdict[pol]):
+ raise Exception("no image for pol %s" % pol)
+
+ # Make a tophat image array
+ xlist = np.arange(0, -self.xdim, -1) * self.psize + \
+ (self.psize * self.xdim) / 2.0 - self.psize / 2.0
+ ylist = np.arange(0, -self.ydim, -1) * self.psize + \
+ (self.psize * self.ydim) / 2.0 - self.psize / 2.0
+
+ hatarr = np.array([[1.0 if np.sqrt(i**2 + j**2) <= radius else 0.
+ for i in xlist]
+ for j in ylist])
+
+ hatarr = hatarr[0:self.ydim, 0:self.xdim]
+ hatarr *= flux / np.sum(hatarr)
+
+ # Add to the main image and create the new image object
+ imarr = self.imvec.reshape(self.ydim, self.xdim).copy()
+ if pol == self.pol_prim:
+ imarr += hatarr
+
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr
+ outim = Image(*arglist, **argdict)
+
+ # Copy over the rest of the polarizations
+ for pol2 in list(self._imdict.keys()):
+ if pol2 == self.pol_prim:
+ continue
+ polvec = self._imdict[pol2]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim).copy()
+ if pol2 == pol:
+ polarr += hatarr
+ outim.add_pol_image(polarr, pol2)
+
+ # Copy the spectral index (unchanged)
+ outim._mflist = copy.deepcopy(self._mflist)
+
+ return outim
+
+ def add_gauss(self, flux, beamparams, pol=None):
+ """Add a gaussian to an image.
+
+ Args:
+ flux (float): the total flux contained in the Gaussian in Jy
+ beamparams (list): [fwhm_maj, fwhm_min, theta, x, y], all in radians
+ pol (str): the polarization to add the flux to. None defaults to pol_prim.
+
+ Returns:
+ (Image): output image
+ """
+
+ if pol is None:
+ pol = self.pol_prim
+ if not (pol in list(self._imdict.keys())):
+ raise Exception("for polrep==%s, pol must be in " %
+ self.polrep + ",".join(list(self._imdict.keys())))
+ if not len(self._imdict[pol]):
+ raise Exception("no image for pol %s" % pol)
+
+ # Make a Gaussian image
+ try:
+ x = beamparams[3]
+ y = beamparams[4]
+ except IndexError:
+ x = y = 0.0
+
+ sigma_maj = beamparams[0] / (2. * np.sqrt(2. * np.log(2.)))
+ sigma_min = beamparams[1] / (2. * np.sqrt(2. * np.log(2.)))
+ cth = np.cos(beamparams[2])
+ sth = np.sin(beamparams[2])
+ xlist = np.arange(0, -self.xdim, -1) * self.psize + \
+ (self.psize * self.xdim) / 2.0 - self.psize / 2.0
+ ylist = np.arange(0, -self.ydim, -1) * self.psize + \
+ (self.psize * self.ydim) / 2.0 - self.psize / 2.0
+
+ def gaussian(x2, y2):
+ gauss = np.exp(-((y2) * cth + (x2) * sth)**2 / (2 * sigma_maj**2) +
+ -((x2) * cth - (y2) * sth)**2 / (2 * sigma_min**2))
+ return gauss
+
+ gaussarr = np.array([[gaussian(i - x, j - y) for i in xlist] for j in ylist])
+ gaussarr = gaussarr[0:self.ydim, 0:self.xdim]
+ gaussarr *= flux / np.sum(gaussarr)
+
+ # TODO: if we want to add a gaussian to V, we might also want to make sure we add it to I
+ # Add to the main image and create the new image object
+ imarr = self.imvec.reshape(self.ydim, self.xdim).copy()
+ if pol == self.pol_prim:
+ imarr += gaussarr
+
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr
+ outim = Image(*arglist, **argdict)
+
+ # Copy over the rest of the polarizations
+ for pol2 in list(self._imdict.keys()):
+ if pol2 == self.pol_prim:
+ continue
+ polvec = self._imdict[pol2]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim).copy()
+ if pol2 == pol:
+ polarr += gaussarr
+ outim.add_pol_image(polarr, pol2)
+
+ # Copy the spectral index (unchanged)
+ outim._mflist = copy.deepcopy(self._mflist)
+
+ return outim
+
+ def add_crescent(self, flux, Rp, Rn, a, b, x=0, y=0, pol=None):
+ """Add a crescent to an image; see Kamruddin & Dexter (2013).
+
+ Args:
+ flux (float): the total flux contained in the crescent in Jy
+ Rp (float): the larger radius in radians
+ Rn (float): the smaller radius in radians
+ a (float): the relative x offset of smaller disk in radians
+ b (float): the relative y offset of smaller disk in radians
+ x (float): the center x coordinate of the larger disk in radians
+ y (float): the center y coordinate of the larger disk in radians
+ pol (str): the polarization to add the flux to. None defaults to pol_prim.
+
+ Returns:
+ (Image): output image add_gaus
+ """
+
+ if pol is None:
+ pol = self.pol_prim
+ if not (pol in list(self._imdict.keys())):
+ raise Exception("for polrep==%s, pol must be in " %
+ self.polrep + ",".join(list(self._imdict.keys())))
+ if not len(self._imdict[pol]):
+ raise Exception("no image for pol %s" % pol)
+
+ # Make a crescent image
+ xlist = np.arange(0, -self.xdim, -1) * self.psize + \
+ (self.psize * self.xdim) / 2.0 - self.psize / 2.0
+ ylist = np.arange(0, -self.ydim, -1) * self.psize + \
+ (self.psize * self.ydim) / 2.0 - self.psize / 2.0
+
+ def crescent(x2, y2):
+ if (x2 - a)**2 + (y2 - b)**2 > Rn**2 and x2**2 + y2**2 < Rp**2:
+ return 1.0
+ else:
+ return 0.0
+
+ crescarr = np.array([[crescent(i - x, j - y) for i in xlist] for j in ylist])
+ crescarr = crescarr[0:self.ydim, 0:self.xdim]
+ crescarr *= flux / np.sum(crescarr)
+
+ # Add to the main image and create the new image object
+ imarr = self.imvec.reshape(self.ydim, self.xdim).copy()
+ if pol == self.pol_prim:
+ imarr += crescarr
+
+ arglist, argdict = self.image_args()
+ arglist[0] = imarr
+ outim = Image(*arglist, **argdict)
+
+ # Copy over the rest of the polarizations
+ for pol2 in list(self._imdict.keys()):
+ if pol2 == self.pol_prim:
+ continue
+ polvec = self._imdict[pol2]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim).copy()
+ if pol2 == pol:
+ polarr += crescarr
+ outim.add_pol_image(polarr, pol2)
+
+ # Copy the spectral index (unchanged)
+ outim._mflist = copy.deepcopy(self._mflist)
+
+ return outim
+
+ def add_ring_m1(self, I0, I1, r0, phi, sigma, x=0, y=0, pol=None):
+ """Add a ring to an image with an m=1 mode
+
+ Args:
+ I0 (float):
+ I1 (float):
+ r0 (float): the radius
+ phi (float): angle of m1 mode
+ sigma (float): the blurring size
+ x (float): the center x coordinate of the larger disk in radians
+ y (float): the center y coordinate of the larger disk in radians
+ pol (str): the polarization to add the flux to. None defaults to pol_prim.
+ Returns:
+ (Image): output image add_gaus
+ """
+
+ if pol is None:
+ pol = self.pol_prim
+ if not (pol in list(self._imdict.keys())):
+ raise Exception("for polrep==%s, pol must be in " %
+ self.polrep + ",".join(list(self._imdict.keys())))
+ if not len(self._imdict[pol]):
+ raise Exception("no image for pol %s" % pol)
+
+ # Make a ring image
+ flux = I0 - 0.5 * I1
+ phi = phi + np.pi
+ psize = self.psize
+ xlist = np.arange(0, -self.xdim, -1) * self.psize + \
+ (self.psize * self.xdim) / 2.0 - self.psize / 2.0
+ ylist = np.arange(0, -self.ydim, -1) * self.psize + \
+ (self.psize * self.ydim) / 2.0 - self.psize / 2.0
+
+ def ringm1(x2, y2):
+ if (x2**2 + y2**2) > (r0 - psize)**2 and (x2**2 + y2**2) < (r0 + psize)**2:
+ theta = np.arctan2(y2, x2)
+ flux = (I0 - 0.5 * I1 * (1 + np.cos(theta - phi))) / (2 * np.pi * r0)
+ return flux
+ else:
+ return 0.0
+
+ ringarr = np.array([[ringm1(i - x, j - y)
+ for i in xlist]
+ for j in ylist])
+ ringarr = ringarr[0:self.ydim, 0:self.xdim]
+
+ arglist, argdict = self.image_args()
+ arglist[0] = ringarr
+ outim = Image(*arglist, **argdict)
+
+ outim = outim.blur_circ(sigma)
+ outim.imvec *= flux / (outim.total_flux())
+ ringarr = outim.imvec.reshape(self.ydim, self.xdim)
+
+ # Add to the main image and create the new image object
+ imarr = self.imvec.reshape(self.ydim, self.xdim).copy()
+ if pol == self.pol_prim:
+ imarr += ringarr
+
+ arglist[0] = imarr
+ outim = Image(*arglist, **argdict)
+
+ # Copy over the rest of the polarizations
+ for pol2 in list(self._imdict.keys()):
+ if pol2 == self.pol_prim:
+ continue
+ polvec = self._imdict[pol2]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim).copy()
+ if pol2 == pol:
+ polarr += ringarr
+ outim.add_pol_image(polarr, pol2)
+
+ # Copy the spectral index (unchanged)
+ outim._mflist = copy.deepcopy(self._mflist)
+
+ return outim
+
+ def add_const_pol(self, mag, angle, cmag=0, csign=1):
+ """Return an with constant fractional linear and circular polarization
+
+ Args:
+ mag (float): constant polarization fraction to add to the image
+ angle (float): constant EVPA
+ cmag (float): constant circular polarization fraction to add to the image
+ cmag (int): constant circular polarization sign +/- 1
+
+ Returns:
+ (Image): output image
+ """
+
+ if not (0 <= mag < 1):
+ raise Exception("fractional polarization magnitude must be between 0 and 1!")
+
+ if not (0 <= cmag < 1):
+ raise Exception("circular polarization magnitude must be between 0 and 1!")
+
+ if self.polrep == 'stokes':
+ im_stokes = self
+ elif self.polrep == 'circ':
+ im_stokes = self.switch_polrep(polrep_out='stokes')
+ ivec = im_stokes.ivec.copy()
+ qvec = obsh.qimage(ivec, mag * np.ones(len(ivec)), angle * np.ones(len(ivec)))
+ uvec = obsh.uimage(ivec, mag * np.ones(len(ivec)), angle * np.ones(len(ivec)))
+ vvec = cmag * np.sign(csign) * ivec
+
+ # create the new stokes image object
+ iarr = ivec.reshape(self.ydim, self.xdim).copy()
+
+ arglist, argdict = self.image_args()
+ arglist[0] = iarr
+ argdict['polrep'] = 'stokes'
+ argdict['pol_prim'] = 'I'
+ outim = Image(*arglist, **argdict)
+
+ # Copy over the rest of the polarizations
+ imdict = {'I': ivec, 'Q': qvec, 'U': uvec, 'V': vvec}
+ for pol in list(imdict.keys()):
+ if pol == 'I':
+ continue
+ polvec = imdict[pol]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim).copy()
+ outim.add_pol_image(polarr, pol)
+
+ # Copy the spectral index (unchanged)
+ outim._mflist = copy.deepcopy(self._mflist)
+
+ return outim
+
+ def add_random_pol(self, mag, corr, cmag=0., ccorr=0., seed=0):
+ """Return an image random linear and circular polarizations with certain correlation lengths
+
+ Args:
+ mag (float): linear polarization fraction
+ corr (float): EVPA correlation length (radians)
+ cmag (float): circular polarization fraction
+ ccorr (float): CP correlation length (radians)
+ seed (int): Seed for random number generation
+
+ Returns:
+ (Image): output image
+ """
+
+ import ehtim.scattering.stochastic_optics as so
+
+ if not (0 <= mag < 1):
+ raise Exception("fractional polarization magnitude must be between 0 and 1!")
+
+ if not (0 <= cmag < 1):
+ raise Exception("circular polarization magnitude must be between 0 and 1!")
+
+ if self.polrep == 'stokes':
+ im_stokes = self
+ elif self.polrep == 'circ':
+ im_stokes = self.switch_polrep(polrep_out='stokes')
+ ivec = im_stokes.ivec.copy()
+
+ # create the new stokes image object
+ iarr = ivec.reshape(self.ydim, self.xdim).copy()
+
+ arglist, argdict = self.image_args()
+ arglist[0] = iarr
+ argdict['polrep'] = 'stokes'
+ argdict['pol_prim'] = 'I'
+ outim = Image(*arglist, **argdict)
+
+ # Make a random phase screen using the scattering tools
+ # Use this screen to define the EVPA
+ dist = 1.0 * 3.086e21
+ rdiff = np.abs(corr) * dist / 1e3
+ theta_mas = 0.37 * 1.0 / rdiff * 1000. * 3600. * 180. / np.pi
+ sm = so.ScatteringModel(scatt_alpha=1.67, observer_screen_distance=dist,
+ source_screen_distance=1.e5 * dist,
+ theta_maj_mas_ref=theta_mas, theta_min_mas_ref=theta_mas,
+ r_in=rdiff * 2, r_out=1e30)
+ ep = so.MakeEpsilonScreen(self.xdim, self.ydim, rngseed=seed)
+ ps = np.array(sm.MakePhaseScreen(ep, outim, obs_frequency_Hz=29.979e9).imvec)
+ ps = ps / 1000**(1.66 / 2)
+ qvec = ivec * mag * np.sin(ps)
+ uvec = ivec * mag * np.cos(ps)
+
+ # Make a random phase screen using the scattering tools
+ # Use this screen to define the CP magnitude
+ if cmag != 0.0 and ccorr > 0.0:
+ dist = 1.0 * 3.086e21
+ rdiff = np.abs(ccorr) * dist / 1e3
+ theta_mas = 0.37 * 1.0 / rdiff * 1000. * 3600. * 180. / np.pi
+ sm = so.ScatteringModel(scatt_alpha=1.67, observer_screen_distance=dist,
+ source_screen_distance=1.e5 * dist,
+ theta_maj_mas_ref=theta_mas, theta_min_mas_ref=theta_mas,
+ r_in=rdiff * 2, r_out=1e30)
+ ep = so.MakeEpsilonScreen(self.xdim, self.ydim, rngseed=seed * 2)
+ ps = np.array(sm.MakePhaseScreen(ep, outim, obs_frequency_Hz=29.979e9).imvec)
+ ps = ps / 1000**(1.66 / 2)
+ vvec = ivec * cmag * np.sin(ps)
+ else:
+ vvec = ivec * cmag
+
+ # Copy over the rest of the polarizations
+ imdict = {'I': ivec, 'Q': qvec, 'U': uvec, 'V': vvec}
+ for pol in list(imdict.keys()):
+ if pol == 'I':
+ continue
+ polvec = imdict[pol]
+ if len(polvec):
+ polarr = polvec.reshape(self.ydim, self.xdim).copy()
+ outim.add_pol_image(polarr, pol)
+
+ # Copy the spectral index (unchanged)
+ outim._mflist = copy.deepcopy(self._mflist)
+
+ return outim
+
+ def add_const_mf(self, alpha, beta=0.):
+ """Add a constant spectral index and curvature term
+
+ Args:
+ alpha (float): spectral index (with no - sign)
+ beta (float): curvature
+
+ Returns:
+ (Image): output image with constant mf information added
+ """
+
+ avec = alpha * np.ones(len(self.imvec))
+ bvec = beta * np.ones(len(self.imvec))
+
+ # create the new image object
+ outim = self.copy()
+ outim._mflist = [avec, bvec]
+
+ return outim
+
+ def add_zblterm(self, obs, uv_min, zblval=None, new_fov=False,
+ gauss_sz=False, gauss_sz_factor=0.75, debias=True):
+ """Add a large Gaussian term to account for missing flux in the zero baseline.
+
+ Args:
+ obs : an Obsdata object to determine min non-zero baseline and 0-bl flux
+ uv_min (float): The cutoff in Glambada used to determine what is a 0-bl
+ new_fov (rad): The size of the padded image once the Gaussian is added
+ (if False it will be set to 3 x the gaussian fwhm)
+ gauss_sz (rad): The size of the Gaussian added to add flux to the 0-bl.
+ (if False it is computed from the min non-zero baseline)
+ gauss_sz_factor (float): The fraction of the min non-zero baseline
+ used to caluclate the Gaussian FWHM.
+ debias (bool): True if you use debiased amplitudes to caluclate the 0-bl flux in Jy
+
+ Returns:
+ (Image): a padded image with a large Gaussian component
+ """
+
+ if gauss_sz is False:
+ obs_flag = obs.flag_uvdist(uv_min=uv_min)
+ minuvdist = np.min(np.sqrt(obs_flag.data['u']**2 + obs_flag.data['v']**2))
+ gauss_sz_sigma = (1 / (gauss_sz_factor * minuvdist))
+ gauss_sz = gauss_sz_sigma * 2.355 # convert from stdev to fwhm
+
+ factor = 5.0
+ if new_fov is False:
+ im_fov = np.max((self.xdim * self.psize, self.ydim * self.psize))
+ new_fov = np.max((factor * (gauss_sz / 2.355), im_fov))
+
+ if new_fov < factor * (gauss_sz / 2.355):
+ print('WARNING! The specified new fov may not be large enough')
+
+ # calculate the amount of flux to include in the Gaussian
+ obs_zerobl = obs.flag_uvdist(uv_max=uv_min)
+ obs_zerobl.add_amp(debias=debias)
+ orig_totflux = np.sum(obs_zerobl.amp['amp'] * (1 / obs_zerobl.amp['sigma']**2))
+ orig_totflux /= np.sum(1 / obs_zerobl.amp['sigma']**2)
+
+ if zblval is None:
+ addedflux = orig_totflux - np.sum(self.imvec)
+ else:
+ addedflux = orig_totflux - zblval
+
+ print('Adding a ' + str(addedflux) + ' Jy circular Gaussian of FWHM size ' +
+ str(gauss_sz / ehc.RADPERUAS) + ' uas')
+
+ im_new = self.copy()
+ im_new = im_new.pad(new_fov, new_fov)
+ im_new = im_new.add_gauss(addedflux, (gauss_sz, gauss_sz, 0, 0, 0))
+ return im_new
+
+ def sample_uv(self, uv, polrep_obs='stokes',
+ sgrscat=False, ttype='nfft',
+ cache=False, fft_pad_factor=2,
+ zero_empty_pol=True, verbose=True):
+ """Sample the image on the selected uv points without creating an Obsdata object.
+
+ Args:
+ uv (ndarray): an array of uv points
+ polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel
+
+ ttype (str): "fast" or "nfft" or "direct"
+ cache (bool): Use cached fft for 'fast' mode -- deprecated, use nfft instead!
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+ zero_empty_pol (bool): if True, returns zero vec if the polarization doesn't exist.
+ Otherwise return None
+ verbose (bool): Boolean value controls output prints.
+
+ Returns:
+ (list): a list of [I,Q,U,V] visibilities
+ """
+
+ if polrep_obs not in ['stokes', 'circ']:
+ raise Exception("polrep_obs must be either 'stokes' or 'circ'")
+
+ data = simobs.sample_vis(self, uv, polrep_obs=polrep_obs, sgrscat=sgrscat,
+ ttype=ttype, cache=cache, fft_pad_factor=fft_pad_factor,
+ zero_empty_pol=zero_empty_pol, verbose=verbose)
+ return data
+
+ def observe_same_nonoise(self, obs, sgrscat=False, ttype="nfft",
+ cache=False, fft_pad_factor=2,
+ zero_empty_pol=True, verbose=True, reorder=True):
+ """Observe the image on the same baselines as an existing observation without noise.
+
+ Args:
+ obs (Obsdata): the existing observation
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel
+ ttype (str): "fast" or "nfft" or "direct"
+ cache (bool): Use cached fft for 'fast' mode -- deprecated, use nfft instead!
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+ zero_empty_pol (bool): if True, returns zero vec if the polarization doesn't exist.
+ Otherwise return None
+ verbose (bool): Boolean value controls output prints.
+
+ Returns:
+ (Obsdata): an observation object with no noise
+ """
+
+ # Check for agreement in coordinates and frequency
+ tolerance = 1e-8
+ if (np.abs(self.ra - obs.ra) > tolerance) or (np.abs(self.dec - obs.dec) > tolerance):
+ raise Exception("Image coordinates are not the same as observtion coordinates!")
+ if (np.abs(self.rf - obs.rf) / obs.rf > tolerance):
+ raise Exception("Image frequency is not the same as observation frequency!")
+
+ if (ttype == 'direct' or ttype == 'fast' or ttype == 'nfft' or ttype == 'DFT' or ttype == 'DFT_i'):
+ if verbose: print("Producing clean visibilities from image with " + ttype + " FT . . . ")
+ else:
+ raise Exception("ttype=%s, options for ttype are 'direct', 'fast', 'nfft', 'DFT', 'DFT_i'" % ttype)
+
+ # Copy data to be safe
+ obsdata = copy.deepcopy(obs.data)
+
+ # Extract uv datasample
+ uv = obsh.recarr_to_ndarr(obsdata[['u', 'v']], 'f8')
+ data = simobs.sample_vis(self, uv, sgrscat=sgrscat, polrep_obs=obs.polrep,
+ ttype=ttype, cache=cache, fft_pad_factor=fft_pad_factor,
+ zero_empty_pol=zero_empty_pol, verbose=verbose)
+
+ # put visibilities into the obsdata
+ if obs.polrep == 'stokes':
+ obsdata['vis'] = data[0]
+ if not(data[1] is None):
+ obsdata['qvis'] = data[1]
+ obsdata['uvis'] = data[2]
+ obsdata['vvis'] = data[3]
+
+ elif obs.polrep == 'circ':
+ obsdata['rrvis'] = data[0]
+ if not(data[1] is None):
+ obsdata['llvis'] = data[1]
+ if not(data[2] is None):
+ obsdata['rlvis'] = data[2]
+ obsdata['lrvis'] = data[3]
+
+ obs_no_noise = ehtim.obsdata.Obsdata(self.ra, self.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=self.source, mjd=self.mjd, polrep=obs.polrep,
+ ampcal=True, phasecal=True, opacitycal=True,
+ dcal=True, frcal=True,
+ timetype=obs.timetype, scantable=obs.scans, reorder=reorder)
+
+ return obs_no_noise
+
+ def observe_same(self, obs_in,
+ ttype='nfft', fft_pad_factor=2,
+ sgrscat=False, add_th_noise=True, th_noise_factor=1,
+ jones=False, inv_jones=False,
+ opacitycal=True, ampcal=True, phasecal=True,
+ frcal=True, dcal=True, rlgaincal=True,
+ stabilize_scan_phase=False, stabilize_scan_amp=False,
+ neggains=False,
+ taup=ehc.GAINPDEF,
+ gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF,
+ phase_std=-1,
+ dterm_offset=ehc.DTERMPDEF,
+ rlratio_std=0., rlphase_std=0.,
+ sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None,
+ caltable_path=None, seed=False, reorder=True, verbose=True):
+ """Observe the image on the same baselines as an existing observation object and add noise.
+
+ Args:
+ obs_in (Obsdata): the existing observation
+ ttype (str): "fast" or "nfft" or "direct"
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel
+ add_th_noise (bool): if True, baseline-dependent thermal noise is added
+
+ jones (bool): if True, uses Jones matrix to apply mis-calibration effects
+ inv_jones (bool): if True, applies estimated inverse Jones matrix
+ (not including random terms) to a priori calibrate data
+
+ opacitycal (bool): if False, time-dependent gaussian errors are added to opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to station gains
+ phasecal (bool): if False, time-dependent station-based random phases are added
+ frcal (bool): if False, feed rotation angle terms are added to Jones matrices.
+ dcal (bool): if False, time-dependent gaussian errors added to D-terms.
+ rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ neggains (bool): if True, force the applied gains to be <1
+
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gainp (float): the fractional std. dev. of the random error on the gains
+ or a dict giving one std. dev. per site
+
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ phase_std (float): std. dev. of LCP phase,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+ dterm_offset (float): the base std. dev. of random additive error at all sites,
+ or a dict giving one std. dev. per site
+
+ rlratio_std (float): the fractional std. dev. of the R/L gain offset
+ or a dict giving one std. dev. per site
+ rlphase_std (float): std. dev. of R/L phase offset,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+
+ sigmat (float): temporal std for a Gaussian Process used to generate gains.
+ If sigmat=None then an iid gain noise is applied.
+ phasesigmat (float): temporal std for a Gaussian Process used to generate phases.
+ If phasesigmat=None then an iid gain noise is applied.
+ rlgsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L gain ratios.
+ If rlgsigmat=None then an iid gain noise is applied.
+ rlpsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L phase diff.
+ If rlpsigmat=None then an iid gain noise is applied.
+
+ caltable_path (string): If not None, path and prefix for saving the applied caltable
+ seed (int): seeds the random component of the noise terms. DO NOT set to 0!
+ verbose (bool): print updates and warnings
+ Returns:
+ (Obsdata): an observation object
+ """
+
+ if seed:
+ np.random.seed(seed=seed)
+
+ obs = self.observe_same_nonoise(obs_in, sgrscat=sgrscat,ttype=ttype,
+ cache=False, fft_pad_factor=fft_pad_factor,
+ zero_empty_pol=True, reorder=reorder, verbose=verbose)
+
+ # Jones Matrix Corruption & Calibration
+ if jones:
+ obsdata = simobs.add_jones_and_noise(obs, add_th_noise=add_th_noise,
+ opacitycal=opacitycal, ampcal=ampcal,
+ phasecal=phasecal, frcal=frcal, dcal=dcal,
+ rlgaincal=rlgaincal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ neggains=neggains,
+ taup=taup,
+ gain_offset=gain_offset, gainp=gainp,
+ phase_std=phase_std,
+ dterm_offset=dterm_offset,
+ rlratio_std=rlratio_std, rlphase_std=rlphase_std,
+ sigmat=sigmat, phasesigmat=phasesigmat,
+ rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat,
+ caltable_path=caltable_path, seed=seed,verbose=verbose)
+
+ obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
+ ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal,
+ dcal=dcal, frcal=frcal,
+ timetype=obs.timetype, scantable=obs.scans, reorder=reorder)
+
+ if inv_jones:
+ obsdata = simobs.apply_jones_inverse(obs,
+ opacitycal=opacitycal, dcal=dcal, frcal=frcal,
+ verbose=verbose)
+
+ obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
+ ampcal=ampcal, phasecal=phasecal,
+ opacitycal=True, dcal=True, frcal=True,
+ timetype=obs.timetype, scantable=obs.scans, reorder=reorder)
+
+ # No Jones Matrices, Add noise the old way
+ # NOTE There is an asymmetry here - in the old way, we don't offer the ability to
+ # *not* unscale estimated noise.
+ else:
+
+ if caltable_path:
+ print('WARNING: the caltable is only saved if you apply noise with a Jones Matrix')
+
+ # TODO -- clean up arguments
+ obsdata = simobs.add_noise(obs, add_th_noise=add_th_noise, th_noise_factor=th_noise_factor,
+ opacitycal=opacitycal, ampcal=ampcal, phasecal=phasecal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ neggains=neggains,
+ taup=taup,
+ gain_offset=gain_offset, gainp=gainp,
+ sigmat=sigmat,
+ caltable_path=caltable_path, seed=seed,
+ verbose=verbose)
+
+ obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
+ ampcal=ampcal, phasecal=phasecal,
+ opacitycal=True, dcal=True, frcal=True,
+ timetype=obs.timetype, scantable=obs.scans, reorder=reorder)
+
+ return obs
+
+ def observe(self, array, tint, tadv, tstart, tstop, bw,
+ mjd=None, timetype='UTC', polrep_obs=None,
+ elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH,
+ no_elevcut_space=False,
+ ttype='nfft', fft_pad_factor=2, fix_theta_GMST=False,
+ sgrscat=False, add_th_noise=True, th_noise_factor=1,
+ jones=False, inv_jones=False, noise=True,
+ opacitycal=True, ampcal=True, phasecal=True,
+ frcal=True, dcal=True, rlgaincal=True,
+ stabilize_scan_phase=False, stabilize_scan_amp=False,
+ neggains=False,
+ tau=ehc.TAUDEF, taup=ehc.GAINPDEF,
+ gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF,
+ phase_std=-1,
+ dterm_offset=ehc.DTERMPDEF,
+ rlratio_std=0.,rlphase_std=0.,
+ sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None,
+ caltable_path=None, seed=False, reorder=True, verbose=True):
+ """Generate baselines from an array object and observe the image.
+
+ Args:
+ array (Array): an array object containing sites with which to generate baselines
+ tint (float): the scan integration time in seconds
+ tadv (float): the uniform cadence between scans in seconds
+ tstart (float): the start time of the observation in hours
+ tstop (float): the end time of the observation in hours
+ bw (float): the observing bandwidth in Hz
+
+ mjd (int): the mjd of the observation, if set as different from the image mjd
+ timetype (str): how to interpret tstart and tstop; either 'GMST' or 'UTC'
+ polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation
+ elevmin (float): station minimum elevation in degrees
+ elevmax (float): station maximum elevation in degrees
+ no_elevcut_space (bool): if True, do not apply elevation cut to orbiters
+
+ ttype (str): "fast", "nfft" or "dtft"
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in the FFT
+ fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v
+
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel
+ add_th_noise (bool): if True, baseline-dependent thermal noise is added
+
+ jones (bool): if True, uses Jones matrix to apply mis-calibration effects
+ otherwise uses old formalism without D-terms
+ inv_jones (bool): if True, applies estimated inverse Jones matrix
+ (not including random terms) to calibrate data
+ opacitycal (bool): if False, time-dependent gaussian errors are added to opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to station gains
+ phasecal (bool): if False, time-dependent station-based random phases are added
+ frcal (bool): if False, feed rotation angle terms are added to Jones matrix.
+ dcal (bool): if False, time-dependent gaussian errors added to Jones matrix D-terms.
+ rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ neggains (bool): if True, force the applied gains to be <1
+
+ tau (float): the base opacity at all sites, or a dict giving one opacity per site
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gainp (float): the fractional std. dev. of the random error on the gains
+ or a dict giving one std. dev. per site
+
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ phase_std (float): std. dev. of LCP phase,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+ dterm_offset (float): the base std. dev. of random additive error at all sites,
+ or a dict giving one std. dev. per site
+
+ rlratio_std (float): the fractional std. dev. of the R/L gain offset
+ or a dict giving one std. dev. per site
+ rlphase_std (float): std. dev. of R/L phase offset,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+
+ sigmat (float): temporal std for a Gaussian Process used to generate gains.
+ If sigmat=None then an iid gain noise is applied.
+ phasesigmat (float): temporal std for a Gaussian Process used to generate phases.
+ If phasesigmat=None then an iid gain noise is applied.
+ rlgsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L gain ratios.
+ If rlgsigmat=None then an iid gain noise is applied.
+ rlpsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L phase diff.
+ If rlpsigmat=None then an iid gain noise is applied.
+
+
+ caltable_path (string): If not None, path and prefix for saving the applied caltable
+ seed (int): seeds the random component of the noise terms. DO NOT set to 0!
+
+ verbose (bool): print updates and warnings
+
+ Returns:
+ (Obsdata): an observation object
+ """
+
+ # Generate empty observation
+ if verbose: print("Generating empty observation file . . . ")
+
+ if mjd is None:
+ mjd = self.mjd
+ if polrep_obs is None:
+ polrep_obs = self.polrep
+
+ obs = array.obsdata(self.ra, self.dec, self.rf, bw, tint, tadv, tstart, tstop, mjd=mjd,
+ polrep=polrep_obs, tau=tau,
+ elevmin=elevmin, elevmax=elevmax,
+ no_elevcut_space=no_elevcut_space,
+ timetype=timetype, fix_theta_GMST=fix_theta_GMST, reorder=reorder)
+
+
+ # Observe on the same baselines as the empty observation and add noise
+ if noise:
+ obs = self.observe_same(obs, ttype=ttype, fft_pad_factor=fft_pad_factor,
+ sgrscat=sgrscat, add_th_noise=add_th_noise, th_noise_factor=th_noise_factor,
+ jones=jones, inv_jones=inv_jones,
+ opacitycal=opacitycal, ampcal=ampcal,
+ phasecal=phasecal, dcal=dcal,
+ frcal=frcal, rlgaincal=rlgaincal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ neggains=neggains,
+ taup=taup,
+ gain_offset=gain_offset, gainp=gainp,
+ phase_std=phase_std,
+ dterm_offset=dterm_offset,
+ rlratio_std=rlratio_std,rlphase_std=rlphase_std,
+ sigmat=sigmat,phasesigmat=phasesigmat,
+ rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat,
+ caltable_path=caltable_path, seed=seed, reorder=reorder, verbose=verbose)
+ else:
+ obs = self.observe_same_nonoise(obs, sgrscat=sgrscat, ttype=ttype,
+ cache=False, fft_pad_factor=fft_pad_factor,
+ zero_empty_pol=True, reorder=reorder, verbose=verbose)
+
+
+ obs.mjd = mjd
+
+ return obs
+
+ def observe_vex(self, vex, source, t_int=0.0, tight_tadv=False,
+ polrep_obs=None, ttype='nfft', fft_pad_factor=2,
+ fix_theta_GMST=False,
+ sgrscat=False, add_th_noise=True,
+ jones=False, inv_jones=False,
+ opacitycal=True, ampcal=True, phasecal=True,
+ frcal=True, dcal=True, rlgaincal=True,
+ stabilize_scan_phase=False, stabilize_scan_amp=False,
+ neggains=False,
+ tau=ehc.TAUDEF, taup=ehc.GAINPDEF,
+ gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF,
+ phase_std=-1,
+ dterm_offset=ehc.DTERMPDEF,
+ rlratio_std=0.,rlphase_std=0.,
+ sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None,
+ caltable_path=None, seed=False, verbose=True):
+ """Generate baselines from a vex file and observes the image.
+
+ Args:
+ vex (Vex): an vex object containing sites and scan information
+ source (str): the source to observe
+
+ t_int (float): if not zero, overrides the vex scan lengths
+ tight_tadv (float): if True, advance right after each integration,
+ otherwise advance after 2x the scan length
+
+ polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation
+ ttype (str): "fast" or "nfft" or "dtft"
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+ fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v
+
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel
+ add_th_noise (bool): if True, baseline-dependent thermal noise is added
+ jones (bool): if True, uses Jones matrix to apply mis-calibration effects
+ otherwise uses old formalism without D-terms
+ inv_jones (bool): if True, applies estimated inverse Jones matrix
+ (not including random terms) to calibrate data
+ opacitycal (bool): if False, time-dependent gaussian errors are added to opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to station gains
+ phasecal (bool): if False, time-dependent station-based random phases are added
+ frcal (bool): if False, feed rotation angle terms are added to Jones matrix.
+ dcal (bool): if False, time-dependent gaussian errors added to Jones matrix D-terms.
+ rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ neggains (bool): if True, force the applied gains to be <1
+
+ tau (float): the base opacity at all sites,
+ or a dict giving one opacity per site
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gainp (float): the fractional std. dev. of the random error on the gains
+ or a dict giving one std. dev. per site
+
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ phase_std (float): std. dev. of LCP phase,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+ dterm_offset (float): the base std. dev. of random additive error at all sites,
+ or a dict giving one std. dev. per site
+
+ rlratio_std (float): the fractional std. dev. of the R/L gain offset
+ or a dict giving one std. dev. per site
+ rlphase_std (float): std. dev. of R/L phase offset,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+
+ sigmat (float): temporal std for a Gaussian Process used to generate gains.
+ If sigmat=None then an iid gain noise is applied.
+ phasesigmat (float): temporal std for a Gaussian Process used to generate phases.
+ If phasesigmat=None then an iid gain noise is applied.
+ rlgsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L gain ratios.
+ If rlgsigmat=None then an iid gain noise is applied.
+ rlpsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L phase diff.
+ If rlpsigmat=None then an iid gain noise is applied.
+
+ caltable_path (string): If not None, path and prefix for saving the applied caltable
+ seed (int): seeds the random component of the noise terms. DO NOT set to 0!
+ verbose (bool): print updates and warnings
+
+ Returns:
+ (Obsdata): an observation object
+
+ """
+
+ if polrep_obs is None:
+ polrep_obs = self.polrep
+
+ t_int_flag = False
+ if t_int == 0.0:
+ t_int_flag = True
+
+ # Loop over all scans and assemble a list of scan observations
+ obs_List = []
+ for i_scan in range(len(vex.sched)):
+
+ if t_int_flag:
+ t_int = vex.sched[i_scan]['scan'][0]['scan_sec']
+ if tight_tadv:
+ t_adv = t_int
+ else:
+ t_adv = 2.0 * vex.sched[i_scan]['scan'][0]['scan_sec']
+
+ # If this scan doesn't observe the source, advance
+ if vex.sched[i_scan]['source'] != source:
+ continue
+
+ # What subarray is observing now?
+ scankeys = list(vex.sched[i_scan]['scan'].keys())
+ subarray = vex.array.make_subarray([vex.sched[i_scan]['scan'][key]['site']
+ for key in scankeys])
+
+ # Observe with the subarray over the scan interval
+ t_start = vex.sched[i_scan]['start_hr']
+ t_stop = t_start + vex.sched[i_scan]['scan'][0]['scan_sec']/3600.0 - ehc.EP
+
+ obs = self.observe(subarray, t_int, t_adv, t_start, t_stop, vex.bw_hz,
+ mjd=vex.sched[i_scan]['mjd_floor'], timetype='UTC',
+ polrep_obs=polrep_obs,
+ elevmin=.01, elevmax=89.99,
+ ttype=ttype, fft_pad_factor=fft_pad_factor,
+ fix_theta_GMST=fix_theta_GMST,
+ sgrscat=sgrscat,
+ add_th_noise=add_th_noise,
+ jones=jones, inv_jones=inv_jones,
+ opacitycal=opacitycal, ampcal=ampcal, phasecal=phasecal,
+ frcal=frcal, dcal=dcal, rlgaincal=rlgaincal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ neggains=neggains,
+ tau=tau, taup=taup,
+ gain_offset=gain_offset, gainp=gainp,
+ phase_std=phase_std,
+ dterm_offset=dterm_offset,
+ rlratio_std=rlratio_std,rlphase_std=rlphase_std,
+ sigmat=sigmat,phasesigmat=phasesigmat,
+ rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat,
+ caltable_path=caltable_path, seed=seed, verbose=verbose)
+
+ obs_List.append(obs)
+
+ # Merge the scans together
+ obs = ehtim.obsdata.merge_obs(obs_List)
+
+ return obs
+
+ def compare_images(self, im_compare, pol=None, psize=None,target_fov=None, blur_frac=0.0,
+ beamparams=[1., 1., 1.], metric=['nxcorr', 'nrmse', 'rssd'],
+ blursmall=False, shift=True):
+ """Compare to another image by computing normalized cross correlation,
+ normalized root mean squared error, or square root of the sum of squared differences.
+ Returns metrics only for the primary polarization imvec!
+
+ Args:
+ im_compare (Image): the image to compare to
+ pol (str): which polarization image to compare. Default is self.pol_prim
+ psize (float): pixel size of comparison image (rad).
+ If None it is the smallest of the input image pizel sizes
+ target_fov (float): fov of the comparison image (rad).
+ If None it is twice the largest fov of the input images
+
+ beamparams (list): the nominal Gaussian beam parameters [fovx, fovy, position angle]
+ blur_frac (float): fractional beam to blur each image to before comparison
+
+ metric (list) : a list of fidelity metrics from ['nxcorr','nrmse','rssd']
+ blursmall (bool) : True to blur the unpadded image rather than the large image.
+ shift (int): manual image shift, otherwise use shift from maximum cross-correlation
+
+ Returns:
+ (tuple): [errormetric, im1_pad, im2_shift]
+ """
+
+ im1 = self.copy()
+ im2 = im_compare.switch_polrep(polrep_out=im1.polrep, pol_prim_out=im1.pol_prim)
+
+ if im1.polrep != im2.polrep:
+ raise Exception("In find_shift, im1 and im2 must have the same polrep!")
+ if im1.pol_prim != im2.pol_prim:
+ raise Exception("In find_shift, im1 and im2 must have the same pol_prim!")
+
+ # Shift the comparison image to maximize normalized cross-corr.
+ [idx, xcorr, im1_pad, im2_pad] = im1.find_shift(im2, psize=psize, target_fov=target_fov,
+ beamparams=beamparams, pol=pol,
+ blur_frac=blur_frac, blursmall=blursmall)
+
+ if not isinstance(shift, bool):
+ idx = shift
+
+ im2_shift = im2_pad.shift(idx)
+
+ # Compute error metrics
+ error = []
+ imvec1 = im1_pad.get_polvec(pol)
+ imvec2 = im2_shift.get_polvec(pol)
+ if 'nxcorr' in metric:
+ error.append(xcorr[idx[0], idx[1]] / (im1_pad.xdim * im1_pad.ydim))
+ if 'nrmse' in metric:
+ error.append(np.sqrt(np.sum((np.abs(imvec1 - imvec2)**2 * im1_pad.psize**2)) /
+ np.sum((imvec1)**2 * im1_pad.psize**2)))
+ if 'rssd' in metric:
+ error.append(np.sqrt(np.sum(np.abs(imvec1 - imvec2)**2) * im1_pad.psize**2))
+
+ return (error, im1_pad, im2_shift)
+
+ def align_images(self, im_list, pol=None, shift=True, final_fov=False, scale='lin',
+ gamma=0.5, dynamic_range=[1.e3]):
+ """Align all the images in im_list to the current image (self)
+ Aligns all images by comparison of the primary pol image.
+
+ Args:
+ im_list (list): list of images to align to the current image
+ shift (list): list of manual image shifts,
+ otherwise use the shift from maximum cross-correlation
+ pol (str): which polarization image to compare. Default is self.pol_prim
+ final_fov (float): fov of the comparison image (rad).
+ If False it is the largestinput image fov
+
+ scale (str) : compare images in 'log','lin',or 'gamma' scale
+ gamma (float): exponent for gamma scale comparison
+ dynamic_range (float): dynamic range for log and gamma scale comparisons
+
+ Returns:
+ (tuple): (im_list_shift, shifts, im0_pad)
+ """
+
+ im0 = self.copy()
+ if not np.all(im0.polrep == np.array([im.polrep for im in im_list])):
+ raise Exception("In align_images, all images must have the same polrep!")
+ if not np.all(im0.pol_prim == np.array([im.pol_prim for im in im_list])):
+ raise Exception("In find_shift, all images must have the same pol_prim!")
+
+ if len(dynamic_range) == 1:
+ dynamic_range = dynamic_range * np.ones(len(im_list) + 1)
+
+ useshift = True
+ if isinstance(shift, bool):
+ useshift = False
+
+ # Find the minimum psize and the maximum field of view
+ psize = im0.psize
+ max_fov = np.max([im0.xdim * im0.psize, im0.ydim * im0.psize])
+ for i in range(0, len(im_list)):
+ psize = np.min([psize, im_list[i].psize])
+ max_fov = np.max([max_fov,
+ im_list[i].xdim * im_list[i].psize,
+ im_list[i].ydim * im_list[i].psize])
+
+ if not final_fov:
+ final_fov = max_fov
+
+ # Shift all images in the list
+ im_list_shift = []
+ shifts = []
+ for i in range(0, len(im_list)):
+ (idx, _, im0_pad_orig, im_pad) = im0.find_shift(im_list[i], target_fov=2 * max_fov,
+ psize=psize, pol=pol,
+ scale=scale, gamma=gamma,
+ dynamic_range=dynamic_range[i + 1])
+
+ if i == 0:
+ npix = int(im0_pad_orig.xdim / 2)
+ im0_pad = im0_pad_orig.regrid_image(final_fov, npix)
+ if useshift:
+ idx = shift[i]
+
+ tmp = im_pad.shift(idx)
+ shifts.append(idx)
+ im_list_shift.append(tmp.regrid_image(final_fov, npix))
+
+ return (im_list_shift, shifts, im0_pad)
+
+ def find_shift(self, im_compare, pol=None, psize=None, target_fov=None,
+ beamparams=[1., 1., 1.], blur_frac=0.0, blursmall=False,
+ scale='lin', gamma=0.5, dynamic_range=1.e3):
+ """Find image shift that maximizes normalized cross correlation with a second image im2.
+ Finds shift only by comparison of the primary pol image.
+
+ Args:
+ im_compare (Image): image with respect with to switch
+ pol (str): which polarization image to compare. Default is self.pol_prim
+ psize (float): pixel size of comparison image (rad).
+ If None it is the smallest of the input image pizel sizes
+ target_fov (float): fov of the comparison image (rad).
+ If None it is twice the largest fov of the input images
+
+ beamparams (list): the nominal Gaussian beam parameters [fovx, fovy, position angle]
+ blur_frac (float): fractional beam to blur each image to before comparison
+ blursmall (bool) : True to blur the unpadded image rather than the large image.
+
+ scale (str) : compare images in 'log','lin',or 'gamma' scale
+ gamma (float): exponent for gamma scale comparison
+ dynamic_range (float): dynamic range for log and gamma scale comparisons
+
+ Returns:
+ (tuple): (errormetric, im1_pad, im2_shift)
+ """
+
+ im1 = self.copy()
+ im2 = im_compare.switch_polrep(polrep_out=im1.polrep, pol_prim_out=im1.pol_prim)
+ if pol=='RL' or pol=='LR':
+ raise Exception("Find_shift currently doesn't work with complex RL or LR imvecs!")
+ if im1.polrep != im2.polrep:
+ raise Exception("In find_shift, im1 and im2 must have the same polrep!")
+ if im1.pol_prim != im2.pol_prim:
+ raise Exception("In find_shift, im1 and im2 must have the same pol_prim!")
+
+ # Find maximum FOV and minimum pixel size for comparison
+ if target_fov is None:
+ max_fov = np.max([im1.fovx(), im1.fovy(), im2.fovx(), im2.fovy()])
+ target_fov = 2 * max_fov
+ if psize is None:
+ psize = np.min([im1.psize, im2.psize])
+
+ npix = int(target_fov / psize)
+
+ # Blur images, then pad
+ if ((blur_frac > 0.0) and (blursmall is True)):
+ im1 = im1.blur_gauss(beamparams, blur_frac, blur_frac)
+ im2 = im2.blur_gauss(beamparams, blur_frac, blur_frac)
+
+ im1_pad = im1.regrid_image(target_fov, npix)
+ im2_pad = im2.regrid_image(target_fov, npix)
+
+ # or, pad images, then blur
+ if ((blur_frac > 0.0) and (blursmall is False)):
+ im1_pad = im1_pad.blur_gauss(beamparams, blur_frac, blur_frac)
+ im2_pad = im2_pad.blur_gauss(beamparams, blur_frac, blur_frac)
+
+ # Rescale the image vectors into log or gamma scale
+ # TODO -- what about negative values? complex values?
+ im1_pad_vec = im1_pad.get_polvec(pol)
+ im2_pad_vec = im2_pad.get_polvec(pol)
+ if scale == 'log':
+ im1_pad_vec[im1_pad_vec < 0.0] = 0.0
+ im1_pad_vec = np.log(im1_pad_vec + np.max(im1_pad_vec) / dynamic_range)
+ im2_pad_vec[im2_pad_vec < 0.0] = 0.0
+ im2_pad_vec = np.log(im2_pad_vec + np.max(im2_pad_vec) / dynamic_range)
+ if scale == 'gamma':
+ im1_pad_vec[im1_pad_vec < 0.0] = 0.0
+ im1_pad_vec = (im1_pad_vec + np.max(im1_pad_vec) / dynamic_range)**(gamma)
+ im2_pad_vec[im2_pad_vec < 0.0] = 0.0
+ im2_pad_vec = (im2_pad_vec + np.max(im2_pad_vec) / dynamic_range)**(gamma)
+
+ # Normalize images and compute cross correlation with FFT
+ im1_norm = (im1_pad_vec.reshape(im1_pad.ydim, im1_pad.xdim) - np.mean(im1_pad_vec))
+ im1_norm /= np.std(im1_pad_vec)
+ im2_norm = (im2_pad_vec.reshape(im2_pad.ydim, im2_pad.xdim) - np.mean(im2_pad_vec))
+ im2_norm /= np.std(im2_pad_vec)
+
+ fft_im1 = np.fft.fft2(im1_norm)
+ fft_im2 = np.fft.fft2(im2_norm)
+
+ xcorr = np.real(np.fft.ifft2(fft_im1 * np.conj(fft_im2)))
+
+ # Find idx of shift that maximized cross-correlation
+ idx = np.unravel_index(xcorr.argmax(), xcorr.shape)
+
+ return [idx, xcorr, im1_pad, im2_pad]
+
+ def hough_ring(self, edgetype='canny', thresh=0.2, num_circles=3, radius_range=None,
+ return_type='rad', display_results=True):
+ """Use a circular hough transform to find a circle in the image
+ Returns metrics only for the primary polarization imvec!
+
+ Args:
+ num_circles (int) : number of circles to return
+ radius_range (tuple): range of radii to search in Hough transform, in radian
+ edgetype (str): edge detection type, 'gradient' or 'canny'
+ thresh(float): fractional threshold for the gradient image
+ display_results (bool): True to display results of the fit
+ return_type (str): 'rad' to return in radian, 'pixel' to return in pixel units
+
+ Returns:
+ list : a list of fitted circles (xpos, ypos, radius, objFunc), in radian
+ """
+
+ if 'skimage' not in sys.modules:
+ raise Exception("scikit-image not installed: cannot use hough_ring!")
+
+ # coordinate values
+ pdim = self.psize
+ xlist = np.arange(0, -self.xdim, -1) * pdim + (pdim * self.xdim) / 2.0 - pdim / 2.0
+ ylist = np.arange(0, -self.ydim, -1) * pdim + (pdim * self.ydim) / 2.0 - pdim / 2.0
+
+ # normalize to range 0, 1
+ im = self.copy()
+ maxval = np.max(im.imvec)
+ meanval = np.mean(im.imvec)
+
+ im_norm = im.imvec / (maxval + .01 * meanval)
+ im_norm = im_norm.astype('float') # is it a problem if it's double??
+ im_norm[np.isnan(im.imvec)] = 0 # mask nans to 0
+ im.imvec = im_norm
+
+ # detect edges
+ if edgetype == 'canny':
+ imarr = im.imvec.reshape(self.ydim, self.xdim)
+ edges = canny(imarr, sigma=0, high_threshold=thresh, low_threshold=0.01)
+ im_edges = self.copy()
+ im_edges.imvec = edges.flatten()
+
+ elif edgetype == 'grad':
+ im_edges = self.grad()
+ if not (thresh is None):
+ thresh_val = thresh * np.max(im_edges.imvec)
+ mask = im_edges.imvec > thresh_val
+ # im_edges.imvec[mask] = 1
+ im_edges.imvec[~mask] = 0
+ edges = im_edges.imvec.reshape(self.ydim, self.xdim)
+ else:
+ im_edges = im.copy()
+ if not (thresh is None):
+ thresh_val = thresh * np.max(im_edges.imvec)
+ mask = im_edges.imvec > thresh_val
+ # im_edges.imvec[mask] = 1f
+ im_edges.imvec[~mask] = 0
+ edges = im_edges.imvec.reshape(self.ydim, self.xdim)
+
+ # define radius range for Hough transform search
+ if radius_range is None:
+ hough_radii = np.arange(int(10 * ehc.RADPERUAS / self.psize),
+ int(50 * ehc.RADPERUAS / self.psize))
+ else:
+ hough_radii = np.linspace(
+ radius_range[0] /
+ self.psize,
+ radius_range[0] /
+ self.psize,
+ 25)
+
+ # perform the hough transform and select the most prominent circles
+ hough_res = hough_circle(edges, hough_radii)
+ accums, cy, cx, radii = hough_circle_peaks(hough_res, hough_radii,
+ total_num_peaks=num_circles)
+ accum_tot = np.sum(accums)
+
+ # print results, plot circles, and return
+ outlist = []
+ if display_results:
+ plt.ion()
+ fig = self.display()
+ ax = fig.gca()
+
+ i = 0
+ colors = ['b', 'r', 'w', 'lime', 'magenta', 'aqua']
+ for accum, center_y, center_x, radius in zip(accums, cy, cx, radii):
+ accum_frac = accum / accum_tot
+ if return_type == 'rad':
+ x_rad = xlist[int(np.round(center_x))]
+ y_rad = ylist[int(np.round(center_y))]
+ r_rad = radius * self.psize
+ outlist.append([x_rad, y_rad, r_rad, accum_frac])
+ else:
+ outlist.append([center_x, center_y, radius, accum_frac])
+ print(accum_frac)
+ print("%i ring diameter: %0.1f microarcsec" % (i, 2 * radius * pdim / ehc.RADPERUAS))
+ if display_results:
+ if i > len(colors):
+ color = colors[-1]
+ else:
+ color = colors[i]
+ circ = mpl.patches.Circle((center_y, center_x), radius, fill=False, color=color)
+ ax.add_patch(circ)
+ i += 1
+
+ return outlist
+
+ def fit_gauss(self, units='rad'):
+ """Determine the Gaussian parameters that short baselines would measure for the source
+ by diagonalizing the image covariance matrix.
+ Returns parameters only for the primary polarization!
+
+ Args:
+ units (string): 'rad' returns values in radians,
+ 'natural' returns FWHM in uas and PA in degrees
+
+ Returns:
+ (tuple) : a tuple (fwhm_maj, fwhm_min, theta) of the fit Gaussian parameters
+ """
+
+ (x1, y1) = self.centroid()
+ pdim = self.psize
+ im = self.imvec
+
+ xlist = np.arange(0, -self.xdim, -1) * pdim + (pdim * self.xdim) / 2.0 - pdim / 2.0
+ ylist = np.arange(0, -self.ydim, -1) * pdim + (pdim * self.ydim) / 2.0 - pdim / 2.0
+
+ x2 = (np.sum(np.outer(0.0 * ylist + 1.0, (xlist - x1)**2).ravel() * im) / np.sum(im))
+ y2 = (np.sum(np.outer((ylist - y1)**2, 0.0 * xlist + 1.0).ravel() * im) / np.sum(im))
+ xy = (np.sum(np.outer(ylist - y1, xlist - x1).ravel() * im) / np.sum(im))
+
+ eig = np.linalg.eigh(np.array(((x2, xy), (xy, y2))))
+ gauss_params = np.array((eig[0][1]**0.5 * (8. * np.log(2.))**0.5,
+ eig[0][0]**0.5 * (8. * np.log(2.))**0.5,
+ np.mod(np.arctan2(eig[1][1][0], eig[1][1][1]) + np.pi, np.pi)))
+ if units == 'natural':
+ gauss_params[0] /= ehc.RADPERUAS
+ gauss_params[1] /= ehc.RADPERUAS
+ gauss_params[2] *= 180. / np.pi
+
+ return gauss_params
+
+ def fit_gauss_empirical(self, paramguess=None):
+ """Determine the Gaussian parameters that short baselines would measure
+ Returns parameters only for the primary polarization!
+
+ Args:
+ paramguess (tuple): Initial guess (fwhm_maj, fwhm_min, theta) of fit parameters
+
+ Returns:
+ (tuple) : a tuple (fwhm_maj, fwhm_min, theta) of the fit Gaussian parameters.
+ """
+
+ # This could be done using moments of the intensity distribution (self.fit_gauss)
+ # but we'll use the visibility approach
+ u_max = 1.0 / (self.psize * self.xdim) / 5.0
+ uv = np.array([[u, v]
+ for u in np.arange(-u_max, u_max * 1.001, u_max / 4.0)
+ for v in np.arange(-u_max, u_max * 1.001, u_max / 4.0)])
+ u = uv[:, 0]
+ v = uv[:, 1]
+ vis = np.dot(obsh.ftmatrix(self.psize, self.xdim, self.ydim, uv, pulse=self.pulse),
+ self.imvec)
+
+ if paramguess is None:
+ paramguess = (self.psize * self.xdim / 4.0, self.psize * self.xdim / 4.0, 0.)
+
+ def errfunc(p):
+ vismodel = obsh.gauss_uv(u, v, self.total_flux(), p, x=0., y=0.)
+ err = np.sum((np.abs(vis) - np.abs(vismodel))**2)
+ return err
+
+ # minimizer params
+ optdict = {'maxiter': 5000, 'maxfev': 5000, 'xtol': paramguess[0] / 1e9, 'ftol': 1e-10}
+ res = opt.minimize(errfunc, paramguess, method='Nelder-Mead', options=optdict)
+
+ # Return in the form [maj, min, PA]
+ x = res.x
+ x[0] = np.abs(x[0])
+ x[1] = np.abs(x[1])
+ x[2] = np.mod(x[2], np.pi)
+ if x[0] < x[1]:
+ maj = x[1]
+ x[1] = x[0]
+ x[0] = maj
+ x[2] = np.mod(x[2] + np.pi / 2.0, np.pi)
+
+ return x
+
+ def contour(self, contour_levels=[0.1, 0.25, 0.5, 0.75],
+ contour_cfun=None, color='w', legend=True, show_im=True,
+ cfun='afmhot', scale='lin', interp='gaussian', gamma=0.5, dynamic_range=1.e3,
+ plotp=False, nvec=20, pcut=0.01, mcut=0.1, label_type='ticks', has_title=True,
+ has_cbar=True, cbar_lims=(), cbar_unit=('Jy', 'pixel'),
+ contour_im=False, power=0, beamcolor='w',
+ export_pdf="", show=True, beamparams=None, cbar_orientation="vertical",
+ scale_lw=1, beam_lw=1, cbar_fontsize=12, axis=None, scale_fontsize=12):
+ """Display the image in a contour plot.
+
+ Args:
+ contour_levels (arr): the fractional contour levels relative to the max flux plotted
+ contour_cfun (pyplot colormap function): the function used to get the RGB colors
+ legend (bool): True to show a legend that says what each contour line corresponds to
+ cfun (str): matplotlib.pyplot color function
+ scale (str): image scaling in ['log','gamma','lin']
+ interp (str): image interpolation 'gauss' or 'lin'
+
+ gamma (float): index for gamma scaling
+ dynamic_range (float): dynamic range for log and gamma scaling
+
+ plotp (bool): True to plot linear polarimetic image
+ nvec (int): number of polarimetric vectors to plot
+ pcut (float): minimum stokes P value for displaying polarimetric vectors
+ as fraction of maximum Stokes I pixel
+ mcut (float): minimum fractional polarization for plotting vectors
+ label_type (string): specifies the type of axes labeling: 'ticks', 'scale', 'none'
+ has_title (bool): True if you want a title on the plot
+ has_cbar (bool): True if you want a colorbar on the plot
+ cbar_lims (tuple): specify the lower and upper limit of the colorbar
+ cbar_unit (tuple of strings): the unit of each pixel for the colorbar:
+ 'Jy', 'm-Jy', '$\mu$Jy'
+
+ export_pdf (str): path to exported PDF with plot
+ show (bool): Display the plot if true
+ show_im (bool): Display the image with the contour plot if True
+
+ Returns:
+ (matplotlib.figure.Figure): figure object with image
+
+ """
+
+ image = self.copy()
+
+ # or some generalized version for image sizes
+ y = np.linspace(0, image.ydim, image.ydim)
+ x = np.linspace(0, image.xdim, image.xdim)
+
+ # make the image grid
+ z = image.imvec.reshape((image.ydim, image.xdim))
+ maxz = max(image.imvec)
+ if axis is None:
+ ax = plt.gca()
+
+ elif axis is not None:
+ ax = axis
+ plt.sca(axis)
+
+ if show_im:
+ if axis is not None:
+ axis = image.display(cfun=cfun, scale=scale, interp=interp, gamma=gamma,
+ dynamic_range=dynamic_range,
+ plotp=plotp, nvec=nvec, pcut=pcut, mcut=mcut,
+ label_type=label_type, has_title=has_title,
+ has_cbar=has_cbar, cbar_lims=cbar_lims,
+ cbar_unit=cbar_unit,
+ beamparams=beamparams,
+ cbar_orientation=cbar_orientation, scale_lw=1, beam_lw=1,
+ cbar_fontsize=cbar_fontsize, axis=axis,
+ scale_fontsize=scale_fontsize, power=power,
+ beamcolor=beamcolor)
+ else:
+ image.display(cfun=cfun, scale=scale, interp=interp, gamma=gamma,
+ dynamic_range=dynamic_range,
+ plotp=plotp, nvec=nvec, pcut=pcut, mcut=mcut, label_type=label_type,
+ has_title=has_title, has_cbar=has_cbar,
+ cbar_lims=cbar_lims, cbar_unit=cbar_unit, beamparams=beamparams,
+ cbar_orientation=cbar_orientation, scale_lw=1, beam_lw=1,
+ cbar_fontsize=cbar_fontsize,
+ axis=None, scale_fontsize=scale_fontsize,
+ power=power, beamcolor=beamcolor)
+ else:
+ if contour_im is False:
+ image.imvec = 0.0 * image.imvec
+ else:
+ image = contour_im.copy()
+
+ if axis is not None:
+ axis = image.display(cfun=cfun, scale=scale, interp=interp, gamma=gamma,
+ dynamic_range=dynamic_range,
+ plotp=plotp, nvec=nvec, pcut=pcut, mcut=mcut,
+ label_type=label_type, has_title=has_title,
+ has_cbar=has_cbar, cbar_lims=cbar_lims, cbar_unit=cbar_unit,
+ beamparams=beamparams,
+ cbar_orientation=cbar_orientation, scale_lw=1, beam_lw=1,
+ cbar_fontsize=cbar_fontsize,
+ axis=axis,
+ scale_fontsize=scale_fontsize, power=power,
+ beamcolor=beamcolor)
+ else:
+ image.display(cfun=cfun, scale=scale, interp=interp, gamma=gamma,
+ dynamic_range=dynamic_range,
+ plotp=plotp, nvec=nvec, pcut=pcut, mcut=mcut, label_type=label_type,
+ has_title=has_title,
+ has_cbar=has_cbar, cbar_lims=cbar_lims, cbar_unit=cbar_unit,
+ beamparams=beamparams,
+ cbar_orientation=cbar_orientation, scale_lw=1, beam_lw=1,
+ cbar_fontsize=cbar_fontsize, axis=None,
+ scale_fontsize=scale_fontsize, power=power, beamcolor=beamcolor)
+
+ if axis is None:
+ ax = plt.gcf()
+ if axis is not None:
+ ax = axis
+
+ if axis is not None:
+ ax = axis
+ plt.sca(axis)
+
+ count = 0.
+
+ for level in contour_levels:
+ if not(contour_cfun is None):
+ rgbval = contour_cfun(count / len(contour_levels))
+ rgbstring = '#%02x%02x%02x' % (rgbval[0] * 256, rgbval[1] * 256, rgbval[2] * 256)
+ else:
+ rgbstring = color
+ cs = plt.contour(x, y, z, levels=[level * maxz], colors=rgbstring, cmap=None)
+ count += 1
+ cs.collections[0].set_label(str(int(level * 100)) + '%')
+ if legend:
+ plt.legend()
+
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ if export_pdf != "":
+ ax.savefig(export_pdf, bbox_inches='tight', pad_inches=0)
+
+ elif axis is not None:
+ return axis
+ return ax
+
+ def display(self, pol=None, cfun=False, interp='gaussian',
+ scale='lin', gamma=0.5, dynamic_range=1.e3,
+ plotp=False, plot_stokes=False, nvec=20,
+ vec_cfun=None,
+ scut=0, pcut=0.1, mcut=0.01, scale_ticks=False,
+ log_offset=False,
+ label_type='ticks', has_title=True, alpha=1,
+ has_cbar=True, only_cbar=False, cbar_lims=(), cbar_unit=('Jy', 'pixel'),
+ export_pdf="", pdf_pad_inches=0.0, show=True, beamparams=None,
+ cbar_orientation="vertical", scinot=False,
+ scale_lw=1, beam_lw=1, cbar_fontsize=12, axis=None,
+ scale_fontsize=12,
+ power=0,
+ beamcolor='w', beampos='right', scalecolor='w',dpi=500):
+ """Display the image.
+
+ Args:
+ pol (str): which polarization image to plot. Default is self.pol_prim
+ pol='spec' will plot spectral index
+ pol='curv' will plot spectral curvature
+ cfun (str): matplotlib.pyplot color function.
+ False changes with 'pol', but is 'afmhot' for most
+ interp (str): image interpolation 'gauss' or 'lin'
+
+ scale (str): image scaling in ['log','gamma','lin']
+ gamma (float): index for gamma scaling
+ dynamic_range (float): dynamic range for log and gamma scaling
+
+ plotp (bool): True to plot linear polarimetic image
+ plot_stokes (bool): True to plot stokes subplots along with plotp
+ nvec (int): number of polarimetric vectors to plot
+ vec_cfun (str): color function for vectors colored by lin pol frac
+
+ scut (float): minimum stokes I value for displaying spectral index
+ pcut (float): minimum stokes I value for displaying polarimetric vectors
+ (fraction of maximum Stokes I)
+ mcut (float): minimum fractional polarization value for displaying vectors
+ label_type (string): specifies the type of axes labeling: 'ticks', 'scale', 'none'
+ has_title (bool): True if you want a title on the plot
+ has_cbar (bool): True if you want a colorbar on the plot
+ cbar_lims (tuple): specify the lower and upper limit of the colorbar
+ cbar_unit (tuple): specifies the unit of the colorbar: e.g.,
+ ('Jy','pixel'),('m-Jy','$\mu$as$^2$'),['Tb']
+ beamparams (list): [fwhm_maj, fwhm_min, theta], set to plot beam contour
+
+ export_pdf (str): path to exported PDF with plot
+ show (bool): Display the plot if true
+ scinot (bool): Display numbers/units in scientific notation
+ scale_lw (float): Linewidth of the scale overlay
+ beam_lw (float): Linewidth of the beam overlay
+ cbar_fontsize (float): Fontsize of the text elements of the colorbar
+ axis (matplotlib.axes.Axes): An axis object
+ scale_fontsize (float): Fontsize of the scale label
+
+ power (float): Passed to colorbar for division of ticks by 1e(power)
+ beamcolor (str): color of the beam overlay
+ scalecolor (str): color of the scale label overlay
+ Returns:
+ (matplotlib.figure.Figure): figure object with image
+
+ """
+
+ if (interp in ['gauss', 'gaussian', 'Gaussian', 'Gauss']):
+ interp = 'gaussian'
+ elif (interp in ['linear','bilinear']):
+ interp = 'bilinear'
+ else:
+ interp = 'none'
+
+ if not(beamparams is None or beamparams is False):
+ if beamparams[0] > self.fovx() or beamparams[1] > self.fovx():
+ raise Exception("beam FWHM must be smaller than fov!")
+
+ if self.polrep == 'stokes' and pol is None:
+ pol = 'I'
+ elif self.polrep == 'circ' and pol is None:
+ pol = 'RR'
+
+ if only_cbar:
+ has_cbar = True
+ label_type = 'none'
+ has_title = False
+
+ if axis is None:
+ f = plt.figure()
+ plt.clf()
+
+ if axis is not None:
+ plt.sca(axis)
+ f = plt.gcf()
+
+ # Get unit scale factor
+ factor = 1.
+ fluxunit = 'Jy'
+ areaunit = 'pixel'
+
+ if cbar_unit[0] in ['m-Jy', 'mJy']:
+ fluxunit = 'mJy'
+ factor *= 1.e3
+ elif cbar_unit[0] in ['muJy', r'$\mu$-Jy', r'$\mu$Jy']:
+ fluxunit = r'$\mu$Jy'
+ factor *= 1.e6
+ elif cbar_unit[0] == 'Tb':
+ factor = 3.254e13 / (self.rf**2 * self.psize**2)
+ fluxunit = 'Brightness Temperature (K)'
+ areaunit = ''
+ if power != 0:
+ fluxunit = (r'Brightness Temperature ($10^{{' + str(power) + '}}$ K)')
+ else:
+ fluxunit = 'Brightness Temperature (K)'
+ elif cbar_unit[0] in ['Jy']:
+ fluxunit = 'Jy'
+ factor *= 1.
+ else:
+ factor = 1
+ fluxunit = cbar_unit[0]
+ areaunit = ''
+
+ if len(cbar_unit) == 1 or cbar_unit[0] == 'Tb':
+ factor *= 1.
+
+ elif cbar_unit[1] == 'pixel':
+ factor *= 1.
+ if power != 0:
+ areaunit = areaunit + (r' ($10^{{' + str(power) + '}}$ K)')
+
+ elif cbar_unit[1] in ['$arcseconds$^2$', 'as$^2$', 'as2']:
+ areaunit = 'as$^2$'
+ fovfactor = self.xdim * self.psize * (1 / ehc.RADPERAS)
+ factor *= (1. / fovfactor)**2 / (1. / self.xdim)**2
+ if power != 0:
+ areaunit = areaunit + (r' ($10^{{' + str(power) + '}}$ K)')
+
+ elif cbar_unit[1] in [r'$\m-arcseconds$^2$', 'mas$^2$', 'mas2']:
+ areaunit = 'mas$^2$'
+ fovfactor = self.xdim * self.psize * (1 / ehc.RADPERUAS) / 1000.
+ factor *= (1. / fovfactor)**2 / (1. / self.xdim)**2
+ if power != 0:
+ areaunit = areaunit + (r' ($10^{{' + str(power) + '}}$ K)')
+
+ elif cbar_unit[1] in [r'$\mu$-arcseconds$^2$', r'$\mu$as$^2$', 'muas2']:
+ areaunit = r'$\mu$as$^2$'
+ fovfactor = self.xdim * self.psize * (1 / ehc.RADPERUAS)
+ factor *= (1. / fovfactor)**2 / (1. / self.xdim)**2
+ if power != 0:
+ areaunit = areaunit + (r' ($10^{{' + str(power) + '}}$ K)')
+
+ elif cbar_unit[1] == 'beam':
+ if (beamparams is None or beamparams is False):
+ print("Cannot convert to Jy/beam without beamparams!")
+ else:
+ areaunit = 'beam'
+ beamarea = (2.0 * np.pi * beamparams[0] * beamparams[1] / (8.0 * np.log(2)))
+ factor *= beamarea / (self.psize**2)
+ if power != 0:
+ areaunit = areaunit + (r' ($10^{{' + str(power) + '}}$ K)')
+
+ else:
+ raise ValueError('cbar_unit ' + cbar_unit[1] + ' is not a possible option')
+
+ if not plotp: # Plot a single polarization image
+ cbar_lims_p = ()
+
+ if pol.lower() == 'spec':
+ imvec = self.specvec.copy()
+
+ # mask out low total intensity values
+ mask = self.imvec < (scut * np.max(self.imvec))
+ imvec[mask] = np.nan
+
+ unit = r'$\alpha$'
+ factor = 1
+ cbar_lims_p = [-5, 5]
+ cfun_p = 'seismic'
+ elif pol.lower() == 'curv':
+ imvec = self.curvvec.copy()
+
+ # mask out low total intensity values
+ mask = self.imvec < (scut * np.max(self.imvec))
+ imvec[mask] = np.nan
+
+ unit = r'$\beta$'
+ factor = 1
+ cbar_lims_p = [-5, 5]
+ cfun_p = 'seismic'
+ elif pol.lower() == 'm':
+ imvec = self.mvec.copy()
+ unit = r'$\|\breve{m}|$'
+ factor = 1
+ cbar_lims_p = [0, 1]
+ cfun_p = 'cool'
+ elif pol.lower() == 'p':
+ imvec = self.mvec * self.ivec
+ unit = r'$\|P|$'
+ cfun_p = 'afmhot'
+ elif pol.lower() == 'chi' or pol.lower() == 'evpa':
+ imvec = self.chivec.copy() / ehc.DEGREE
+ unit = r'$\chi (^\circ)$'
+ factor = 1
+ cbar_lims_p = [0, 180]
+ cfun_p = 'hsv'
+ elif pol.lower() == 'e':
+ imvec = self.evec.copy()
+ unit = r'$E$-mode'
+ cfun_p = 'Spectral'
+ elif pol.lower() == 'b':
+ imvec = self.bvec.copy()
+ unit = r'$B$-mode'
+ cfun_p = 'Spectral'
+ else:
+ pol = pol.upper()
+ if pol == 'V':
+ cfun_p = 'bwr'
+ else:
+ cfun_p = 'afmhot'
+ try:
+ imvec = np.array(self._imdict[pol]).reshape(-1) / (10.**power)
+ except KeyError:
+ try:
+ if self.polrep == 'stokes':
+ im2 = self.switch_polrep('circ')
+ elif self.polrep == 'circ':
+ im2 = self.switch_polrep('stokes')
+ imvec = np.array(im2._imdict[pol]).reshape(-1) / (10.**power)
+ except KeyError:
+ raise Exception("Cannot make pol %s image in display()!" % pol)
+
+ unit = fluxunit
+ if areaunit != '':
+ unit += ' / ' + areaunit
+
+ if np.any(np.imag(imvec)):
+ print('casting complex image to abs value')
+ imvec = np.real(imvec)
+
+ imvec = imvec * factor
+ imarr = imvec.reshape(self.ydim, self.xdim)
+
+ if scale == 'log':
+ if (imarr < 0.0).any():
+ print('clipping values less than 0 in display')
+ imarr[imarr < 0.0] = 0.0
+ if log_offset:
+ imarr = np.log10(imarr + log_offset / dynamic_range)
+ else:
+ imarr = np.log10(imarr + np.max(imarr) / dynamic_range)
+ unit = r'$\log_{10}$(' + unit + ')'
+
+ if scale == 'gamma':
+ if (imarr < 0.0).any():
+ print('clipping values less than 0 in display')
+ imarr[imarr < 0.0] = 0.0
+ imarr = (imarr + np.max(imarr) / dynamic_range)**(gamma)
+ unit = '(' + unit + ')^' + str(gamma)
+
+ if not cbar_lims and cbar_lims_p:
+ cbar_lims = cbar_lims_p
+
+ if cbar_lims:
+ cbar_lims[0] = cbar_lims[0] / (10.**power)
+ cbar_lims[1] = cbar_lims[1] / (10.**power)
+ imarr[imarr > cbar_lims[1]] = cbar_lims[1]
+ imarr[imarr < cbar_lims[0]] = cbar_lims[0]
+
+ if has_title:
+ plt.title("%s %.2f GHz %s" % (self.source, self.rf / 1e9, pol), fontsize=16)
+
+ if not cfun:
+ cfun = cfun_p
+ cmap = plt.get_cmap(cfun).copy()
+ cmap.set_bad(color='whitesmoke')
+
+ if cbar_lims:
+ im = plt.imshow(imarr, alpha=alpha, cmap=cmap, interpolation=interp,
+ vmin=cbar_lims[0], vmax=cbar_lims[1])
+ else:
+ im = plt.imshow(imarr, alpha=alpha, cmap=cmap, interpolation=interp)
+
+ if not(beamparams is None or beamparams is False):
+ if beampos=='left':
+ beamparams = [beamparams[0], beamparams[1], beamparams[2],
+ +.4 * self.fovx(), -.4 * self.fovy()]
+ else:
+ beamparams = [beamparams[0], beamparams[1], beamparams[2],
+ -.35 * self.fovx(), -.35 * self.fovy()]
+ beamimage = self.copy()
+ beamimage.imvec *= 0
+ beamimage = beamimage.add_gauss(1, beamparams)
+ halflevel = 0.5 * np.max(beamimage.imvec)
+ beamimarr = (beamimage.imvec).reshape(beamimage.ydim, beamimage.xdim)
+ plt.contour(beamimarr, levels=[halflevel], colors=beamcolor, linewidths=beam_lw)
+
+ if has_cbar:
+ if only_cbar:
+ im.set_visible(False)
+ cb = plt.colorbar(im, fraction=0.046, pad=0.04, orientation=cbar_orientation)
+ cb.set_label(unit, fontsize=float(cbar_fontsize))
+
+ if cbar_fontsize != 12:
+ cb.set_label(unit, fontsize=float(cbar_fontsize) / 1.5)
+ cb.ax.tick_params(labelsize=cbar_fontsize)
+
+ if cbar_lims:
+ plt.clim(cbar_lims[0], cbar_lims[1])
+ if scinot:
+ cb.formatter.set_powerlimits((0, 0))
+ cb.update_ticks()
+
+ else: # plot polarization with ticks!
+
+ im_stokes = self.switch_polrep(polrep_out='stokes')
+ imvec = np.array(im_stokes.imvec).reshape(-1) / (10**power)
+ qvec = np.array(im_stokes.qvec).reshape(-1) / (10**power)
+ uvec = np.array(im_stokes.uvec).reshape(-1) / (10**power)
+ vvec = np.array(im_stokes.vvec).reshape(-1) / (10**power)
+
+ if len(imvec) == 0:
+ imvec = np.zeros(im_stokes.ydim * im_stokes.xdim)
+ if len(qvec) == 0:
+ qvec = np.zeros(im_stokes.ydim * im_stokes.xdim)
+ if len(uvec) == 0:
+ uvec = np.zeros(im_stokes.ydim * im_stokes.xdim)
+ if len(vvec) == 0:
+ vvec = np.zeros(im_stokes.ydim * im_stokes.xdim)
+
+ imvec *= factor
+ qvec *= factor
+ uvec *= factor
+ vvec *= factor
+
+ imarr = (imvec).reshape(im_stokes.ydim, im_stokes.xdim)
+ qarr = (qvec).reshape(im_stokes.ydim, im_stokes.xdim)
+ uarr = (uvec).reshape(im_stokes.ydim, im_stokes.xdim)
+ varr = (vvec).reshape(im_stokes.ydim, im_stokes.xdim)
+
+ unit = fluxunit
+ if areaunit != '':
+ unit = fluxunit + ' / ' + areaunit
+
+ # only the stokes I image gets transformed! TODO
+ imarr2 = imarr.copy()
+ if scale == 'log':
+ if (imarr2 < 0.0).any():
+ print('clipping values less than 0 in display')
+ imarr2[imarr2 < 0.0] = 0.0
+ imarr2 = np.log10(imarr2 + np.max(imarr2) / dynamic_range)
+ unit = r'$\log_{10}$(' + unit + ')'
+
+ if scale == 'gamma':
+ if (imarr2 < 0.0).any():
+ print('clipping values less than 0 in display')
+ imarr2[imarr2 < 0.0] = 0.0
+ imarr2 = (imarr2 + np.max(imarr2) / dynamic_range)**(gamma)
+ unit = '(' + unit + ')^gamma'
+
+ if cbar_lims:
+ cbar_lims[0] = cbar_lims[0] / (10.**power)
+ cbar_lims[1] = cbar_lims[1] / (10.**power)
+ imarr2[imarr2 > cbar_lims[1]] = cbar_lims[1]
+ imarr2[imarr2 < cbar_lims[0]] = cbar_lims[0]
+
+ # polarization ticks
+ m = (np.abs(qvec + 1j * uvec) / imvec).reshape(self.ydim, self.xdim)
+
+ thin = self.xdim // nvec
+ maska = (imvec).reshape(self.ydim, self.xdim) > pcut * np.max(imvec)
+ maskb = (np.abs(qvec + 1j * uvec) / imvec).reshape(self.ydim, self.xdim) > mcut
+ mask = maska * maskb
+ mask2 = mask[::thin, ::thin]
+ x = (np.array([[i for i in range(self.xdim)]
+ for j in range(self.ydim)])[::thin, ::thin])
+ x = x[mask2]
+ y = (np.array([[j for i in range(self.xdim)]
+ for j in range(self.ydim)])[::thin, ::thin])
+ y = y[mask2]
+ a = (-np.sin(np.angle(qvec + 1j * uvec) /
+ 2).reshape(self.ydim, self.xdim)[::thin, ::thin])
+ a = a[mask2]
+ b = (np.cos(np.angle(qvec + 1j * uvec) /
+ 2).reshape(self.ydim, self.xdim)[::thin, ::thin])
+ b = b[mask2]
+
+ m = (np.abs(qvec + 1j * uvec) / imvec).reshape(self.ydim, self.xdim)
+ p = (np.abs(qvec + 1j * uvec)).reshape(self.ydim, self.xdim)
+ m[np.logical_not(mask)] = np.nan
+ p[np.logical_not(mask)] = np.nan
+ qarr[np.logical_not(mask)] = np.nan
+ uarr[np.logical_not(mask)] = np.nan
+
+ voi = (vvec / imvec).reshape(self.ydim, self.xdim)
+ voi[np.logical_not(mask)] = np.nan
+
+ if scale_ticks:
+ pticks = ((np.abs(qvec + 1j * uvec)).reshape(self.ydim, self.xdim))[::thin, ::thin][mask2]
+ pscale = (pticks - np.min(pticks))/(np.max(pticks) - np.min(pticks))
+ a *= pscale
+ b *= pscale
+
+ # Little pol plots
+ if plot_stokes:
+
+ maxval = 1.1 * np.max((np.max(np.abs(uarr)),
+ np.max(np.abs(qarr)), np.max(np.abs(varr))))
+
+ # P Plot
+ ax = plt.subplot2grid((2, 5), (0, 0))
+ im = plt.imshow(p, cmap=plt.get_cmap('bwr'), interpolation=interp,
+ vmin=-maxval, vmax=maxval)
+ plt.contour(imarr, colors='k', linewidths=.25)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ if has_title:
+ plt.title('P')
+ if has_cbar:
+ cbaxes = plt.gcf().add_axes([0.1, 0.2, 0.01, 0.6])
+ cbar = plt.colorbar(im, fraction=0.046, pad=0.04, cax=cbaxes,
+ label=unit, orientation='vertical')
+ cbar.ax.tick_params(labelsize=cbar_fontsize)
+ cbaxes.yaxis.set_ticks_position('left')
+ cbaxes.yaxis.set_label_position('left')
+ if cbar_lims:
+ plt.clim(-maxval, maxval)
+
+ cmap = plt.get_cmap('bwr')
+ cmap.set_bad('whitesmoke')
+ # V Plot
+ ax = plt.subplot2grid((2, 5), (0, 1))
+ plt.imshow(varr, cmap=cmap, interpolation=interp,
+ vmin=-maxval, vmax=maxval)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ if has_title:
+ plt.title('V')
+
+ # Q Plot
+ ax = plt.subplot2grid((2, 5), (1, 0))
+ plt.imshow(qarr, cmap=cmap, interpolation=interp,
+ vmin=-maxval, vmax=maxval)
+ plt.contour(imarr, colors='k', linewidths=.25)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ if has_title:
+ plt.title('Q')
+
+ # U Plot
+ ax = plt.subplot2grid((2, 5), (1, 1))
+ plt.imshow(uarr, cmap=cmap, interpolation=interp,
+ vmin=-maxval, vmax=maxval)
+ plt.contour(imarr, colors='k', linewidths=.25)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ if has_title:
+ plt.title('U')
+
+ # V/I plot
+ ax = plt.subplot2grid((2, 5), (0, 2))
+ cmap = plt.get_cmap('seismic')
+ cmap.set_bad('whitesmoke')
+
+ im = plt.imshow(voi, cmap=cmap, interpolation=interp,
+ vmin=-1, vmax=1)
+ if has_title:
+ plt.title('V/I')
+ plt.contour(imarr, colors='k', linewidths=.25)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ if has_cbar:
+ cbaxes = plt.gcf().add_axes([0.125, 0.1, 0.425, 0.01])
+ cbar = plt.colorbar(im, fraction=0.046, pad=0.04, cax=cbaxes,
+ label='|m|', orientation='horizontal')
+ cbar.ax.tick_params(labelsize=cbar_fontsize)
+ cbaxes.yaxis.set_ticks_position('right')
+ cbaxes.yaxis.set_label_position('right')
+
+ if cbar_lims:
+ plt.clim(-1, 1)
+
+ # m plot
+ ax = plt.subplot2grid((2, 5), (1, 2))
+ plt.imshow(m, cmap=plt.get_cmap('seismic'), interpolation=interp, vmin=-1, vmax=1)
+ ax.set_xticks([])
+ ax.set_yticks([])
+ if has_title:
+ plt.title('m')
+ plt.contour(imarr, colors='k', linewidths=.25)
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.01 * self.xdim, units='x', pivot='mid', color='k', angles='uv',
+ scale=1.0 / thin)
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.005 * self.xdim, units='x', pivot='mid', color='w', angles='uv',
+ scale=1.1 / thin)
+
+ # Big Stokes I plot --axis
+ ax = plt.subplot2grid((2, 5), (0, 3), rowspan=2, colspan=2)
+ else:
+ ax = plt.gca()
+
+ if not cfun:
+ cfun = 'afmhot'
+ cmap = plt.get_cmap(cfun)
+ cmap.set_bad(color='whitesmoke')
+
+ # Big Stokes I plot
+ if cbar_lims:
+ im = plt.imshow(imarr2, cmap=cmap, interpolation=interp,
+ vmin=cbar_lims[0], vmax=cbar_lims[1])
+ else:
+ im = plt.imshow(imarr2, cmap, interpolation=interp)
+
+ if vec_cfun is None:
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.01 * self.xdim, units='x', pivot='mid', color='k', angles='uv',
+ scale=1.0 / thin)
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.005 * self.xdim, units='x', pivot='mid', color='w', angles='uv',
+ scale=1.1 / thin)
+ else:
+ mthin = (
+ np.abs(
+ qvec +
+ 1j *
+ uvec) /
+ imvec).reshape(
+ self.ydim,
+ self.xdim)[
+ ::thin,
+ ::thin]
+ mthin = mthin[mask2]
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.01 * self.xdim, units='x', pivot='mid', color='w', angles='uv',
+ scale=1.0 / thin)
+ plt.quiver(x, y, a, b, mthin,
+ norm=mpl.colors.Normalize(vmin=0, vmax=1.), cmap=vec_cfun,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.007 * self.xdim, units='x', pivot='mid', angles='uv',
+ scale=1.1 / thin)
+
+ if not(beamparams is None or beamparams is False):
+ beamparams = [beamparams[0], beamparams[1], beamparams[2],
+ -.35 * self.fovx(), -.35 * self.fovy()]
+ beamimage = self.copy()
+ beamimage.imvec *= 0
+ beamimage = beamimage.add_gauss(1, beamparams)
+ halflevel = 0.5 * np.max(beamimage.imvec)
+ beamimarr = (beamimage.imvec).reshape(beamimage.ydim, beamimage.xdim)
+ plt.contour(beamimarr, levels=[halflevel], colors=beamcolor, linewidths=beam_lw)
+
+ if has_cbar:
+
+ cbar = plt.colorbar(im, fraction=0.046, pad=0.04,
+ label=unit, orientation=cbar_orientation)
+ cbar.ax.tick_params(labelsize=cbar_fontsize)
+ if cbar_lims:
+ plt.clim(cbar_lims[0], cbar_lims[1])
+ if has_title:
+ plt.title("%s %.1f GHz : m=%.1f%% , v=%.1f%%" % (self.source, self.rf / 1e9,
+ self.lin_polfrac() * 100,
+ self.circ_polfrac() * 100),
+ fontsize=12)
+ f.subplots_adjust(hspace=.1, wspace=0.3)
+
+ # Label the plot
+ ax = plt.gca()
+ if label_type == 'ticks':
+ xticks = obsh.ticks(self.xdim, self.psize / ehc.RADPERAS / 1e-6)
+ yticks = obsh.ticks(self.ydim, self.psize / ehc.RADPERAS / 1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel(r'Relative RA ($\mu$as)')
+ plt.ylabel(r'Relative Dec ($\mu$as)')
+
+ elif label_type == 'scale':
+ plt.axis('off')
+ fov_uas = self.xdim * self.psize / ehc.RADPERUAS # get the fov in uas
+ roughfactor = 1. / 3. # make the bar about 1/3 the fov
+ fov_scale = int(math.ceil(fov_uas * roughfactor / 10.0)) * 10
+ start = self.xdim * roughfactor / 3.0 # select the start location
+ end = start + fov_scale / fov_uas * self.xdim # determine the end location
+ plt.plot([start, end], [self.ydim - start - 5, self.ydim - start - 5],
+ color=scalecolor, lw=scale_lw) # plot a line
+ plt.text(x=(start + end) / 2.0, y=self.ydim - start + self.ydim / 30,
+ s=str(fov_scale) + r" $\mu$as", color=scalecolor,
+ ha="center", va="center", fontsize=scale_fontsize)
+ ax = plt.gca()
+ if axis is None:
+ ax.axes.get_xaxis().set_visible(False)
+ ax.axes.get_yaxis().set_visible(False)
+
+ elif label_type == 'none' or label_type is None:
+ plt.axis('off')
+ ax = plt.gca()
+ if axis is None:
+ ax.axes.get_xaxis().set_visible(False)
+ ax.axes.get_yaxis().set_visible(False)
+
+ # Show or save to file
+ if axis is not None:
+ return axis
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ if export_pdf != "":
+ f.savefig(export_pdf, bbox_inches='tight', pad_inches=pdf_pad_inches, dpi=dpi)
+
+ return f
+
+ def overlay_display(self, im_list, color_coding=np.array([[1, 0, 1], [0, 1, 0]]),
+ export_pdf="", show=True, f=False,
+ shift=[0, 0], final_fov=False, interp='gaussian',
+ scale='lin', gamma=0.5, dynamic_range=[1.e3], rescale=True):
+ """Overlay primary polarization images of a list of images to compare structures.
+
+ Args:
+ im_list (list): list of images to align to the current image
+ color_coding (numpy.array): Color coding of each image in the composite
+
+ f (matplotlib.pyplot.figure): Figure to overlay on top of
+ export_pdf (str): path to exported PDF with plot
+ show (bool): Display the plot if true
+
+ shift (list): list of manual image shifts,
+ otherwise use the shift from maximum cross-correlation
+ final_fov (float): fov of the comparison image (rad).
+ If False it is the largestinput image fov
+
+ scale (str) : compare images in 'log','lin',or 'gamma' scale
+ gamma (float): exponent for gamma scale comparison
+ dynamic_range (float): dynamic range for log and gamma scale comparisons
+
+ Returns:
+ (matplotlib.figure.Figure): figure object with image
+
+ """
+
+ if not f:
+ f = plt.figure()
+ plt.clf()
+
+ if len(dynamic_range) == 1:
+ dynamic_range = dynamic_range * np.ones(len(im_list) + 1)
+
+ if not isinstance(shift, np.ndarray) and not isinstance(shift, bool):
+ shift = matlib.repmat(shift, len(im_list), 1)
+
+ psize = self.psize
+ max_fov = np.max([self.xdim * self.psize, self.ydim * self.psize])
+ for i in range(0, len(im_list)):
+ psize = np.min([psize, im_list[i].psize])
+ max_fov = np.max([max_fov, im_list[i].xdim * im_list[i].psize,
+ im_list[i].ydim * im_list[i].psize])
+
+ if not final_fov:
+ final_fov = max_fov
+
+ (im_list_shift, shifts, im0_pad) = self.align_images(im_list, shift=shift,
+ final_fov=final_fov,
+ scale=scale, gamma=gamma,
+ dynamic_range=dynamic_range)
+
+ # unit = 'Jy/pixel'
+ if scale == 'log':
+ # unit = 'log(Jy/pixel)'
+ log_offset = np.max(im0_pad.imvec) / dynamic_range[0]
+ im0_pad.imvec = np.log10(im0_pad.imvec + log_offset)
+ for i in range(0, len(im_list)):
+ log_offset = np.max(im_list_shift[i].imvec) / dynamic_range[i + 1]
+ im_list_shift[i].imvec = np.log10(im_list_shift[i].imvec + log_offset)
+
+ if scale == 'gamma':
+ # unit = '(Jy/pixel)^gamma'
+ log_offset = np.max(im0_pad.imvec) / dynamic_range[0]
+ im0_pad.imvec = (im0_pad.imvec + log_offset)**(gamma)
+ for i in range(0, len(im_list)):
+ log_offset = np.max(im_list_shift[i].imvec) / dynamic_range[i + 1]
+ im_list_shift[i].imvec = (im_list_shift[i].imvec + log_offset)**(gamma)
+
+ composite_img = np.zeros((im0_pad.ydim, im0_pad.xdim, 3))
+ for i in range(-1, len(im_list)):
+
+ if i == -1:
+ immtx = im0_pad.imvec.reshape(im0_pad.ydim, im0_pad.xdim)
+ else:
+ immtx = im_list_shift[i].imvec.reshape(im0_pad.ydim, im0_pad.xdim)
+
+ if rescale:
+ immtx = immtx - np.min(np.min(immtx))
+ immtx = immtx / np.max(np.max(immtx))
+
+ for c in range(0, 3):
+ composite_img[:, :, c] = composite_img[:, :, c] + (color_coding[i + 1, c] * immtx)
+
+ if rescale is False:
+ composite_img = composite_img - np.min(np.min(np.min(composite_img)))
+ composite_img = composite_img / np.max(np.max(np.max(composite_img)))
+
+ plt.subplot(111)
+ plt.title('%s MJD %i %.2f GHz' % (self.source, self.mjd, self.rf / 1e9), fontsize=20)
+ plt.imshow(composite_img, interpolation=interp)
+ xticks = obsh.ticks(im0_pad.xdim, im0_pad.psize / ehc.RADPERAS / 1e-6)
+ yticks = obsh.ticks(im0_pad.ydim, im0_pad.psize / ehc.RADPERAS / 1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel(r'Relative RA ($\mu$as)')
+ plt.ylabel(r'Relative Dec ($\mu$as)')
+
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+ if export_pdf != "":
+ f.savefig(export_pdf, bbox_inches='tight')
+
+ return (f, shift)
+
+ def save_txt(self, fname):
+ """Save image data to text file.
+
+ Args:
+ fname (str): path to output text file
+
+ Returns:
+ """
+
+ ehtim.io.save.save_im_txt(self, fname)
+ return
+
+ def save_fits(self, fname):
+ """Save image data to a fits file.
+
+ Args:
+ fname (str): path to output fits file
+
+ Returns:
+ """
+ ehtim.io.save.save_im_fits(self, fname)
+ return
+
+
+###################################################################################################
+# Image creation functions
+###################################################################################################
+
+def make_square(obs, npix, fov, pulse=ehc.PULSE_DEFAULT, polrep='stokes', pol_prim=None):
+ """Make an empty square image.
+
+ Args:
+ obs (Obsdata): an obsdata object with the image metadata
+ npix (int): the pixel size of each axis
+ fov (float): the field of view of each axis in radians
+ pulse (function): the function convolved with the pixel values for continuous image
+
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, or RR,LL,LR,RL for Circular
+ Returns:
+ (Image): an image object
+ """
+
+ outim = make_empty(npix, fov, obs.ra, obs.dec, rf=obs.rf, source=obs.source,
+ polrep=polrep, pol_prim=pol_prim, pulse=pulse,
+ mjd=obs.mjd, time=obs.tstart)
+
+ return outim
+
+
+def make_empty(npix, fov, ra, dec, rf=ehc.RF_DEFAULT, source=ehc.SOURCE_DEFAULT,
+ polrep='stokes', pol_prim=None, pulse=ehc.PULSE_DEFAULT,
+ mjd=ehc.MJD_DEFAULT, time=0.):
+ """Make an empty square image.
+
+ Args:
+ npix (int): the pixel size of each axis
+ fov (float): the field of view of each axis in radians
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ rf (float): The image frequency in Hz
+
+ source (str): The source name
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+ pulse (function): The function convolved with the pixel values for continuous image.
+
+ mjd (int): The integer MJD of the image
+ time (float): The observing time of the image (UTC hours)
+
+ Returns:
+ (Image): an image object
+ """
+
+ pdim = fov / float(npix)
+ npix = int(npix)
+ imarr = np.zeros((npix, npix))
+ outim = Image(imarr, pdim, ra, dec,
+ polrep=polrep, pol_prim=pol_prim,
+ rf=rf, source=source, mjd=mjd, time=time, pulse=pulse)
+ return outim
+
+
+def load_image(image, display=False, aipscc=False):
+ """Read in an image from a text, .fits, .h5, or ehtim.image.Image object
+
+ Args:
+ image (str/Image): path to input file
+ display (boolean): determine whether to display the image default
+ aipscc (boolean): if True, then AIPS CC table will be loaded instead
+ of the original brightness distribution.
+ Returns:
+ (Image): loaded image object
+ (boolean): False if the image cannot be read
+ """
+
+ is_unicode = False
+ try:
+ if isinstance(image, basestring):
+ is_unicode = True
+ except NameError: # python 3
+ pass
+ if isinstance(image, str) or is_unicode:
+ if image.endswith('.fits'):
+ im = ehtim.io.load.load_im_fits(image, aipscc=aipscc)
+ elif image.endswith('.txt'):
+ im = ehtim.io.load.load_im_txt(image)
+ elif image.endswith('.h5'):
+ im = ehtim.io.load.load_im_hdf5(image)
+ else:
+ print("Image format is not recognized. Was expecting .fits, .txt, or Image.")
+ print(" Got <.{0}>. Returning False.".format(image.split('.')[-1]))
+ return False
+
+ elif isinstance(image, ehtim.image.Image):
+ im = image
+
+ else:
+ print("Image format is not recognized. Was expecting .fits, .txt, or Image.")
+ print(" Got {0}. Returning False.".format(type(image)))
+ return False
+
+ if display:
+ im.display()
+
+ return im
+
+
+def load_txt(fname, polrep='stokes', pol_prim=None, pulse=ehc.PULSE_DEFAULT, zero_pol=True):
+ """Read in an image from a text file.
+
+ Args:
+ fname (str): path to input text file
+ pulse (function): The function convolved with the pixel values for continuous image.
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+ zero_pol (bool): If True, loads any missing polarizations as zeros
+
+ Returns:
+ (Image): loaded image object
+ """
+
+ return ehtim.io.load.load_im_txt(fname, pulse=pulse, polrep=polrep,
+ pol_prim=pol_prim, zero_pol=True)
+
+
+def load_fits(fname, aipscc=False, pulse=ehc.PULSE_DEFAULT,
+ polrep='stokes', pol_prim=None, zero_pol=False):
+ """Read in an image from a FITS file.
+
+ Args:
+ fname (str): path to input fits file
+ aipscc (bool): if True, then AIPS CC table will be loaded
+ pulse (function): The function convolved with the pixel values for continuous image.
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+ zero_pol (bool): If True, loads any missing polarizations as zeros
+
+ Returns:
+ (Image): loaded image object
+ """
+
+ return ehtim.io.load.load_im_fits(fname, aipscc=aipscc, pulse=pulse,
+ polrep=polrep, pol_prim=pol_prim, zero_pol=zero_pol)
+
+def avg_imlist(imlist):
+ """Average a list of images.
+
+ Args:
+ imlist (list): list of image objects
+
+ Returns:
+ (Image): average image object
+ """
+
+ imavg = imlist[0]
+
+ if np.any(np.array([im.polrep for im in imlist]) != imavg.polrep):
+ raise Exception("im.polrep in all images are not the same in avg_imlist!")
+ if np.any(np.array([im.source for im in imlist]) != imavg.source):
+ raise Exception("im.source in all images are not the same in avg_imlist!")
+ if np.any(np.array([im.rf for im in imlist]) != imavg.rf):
+ raise Exception("im.rf in all images are not the same in avg_imlist!")
+
+ keys = imavg._imdict.keys()
+
+ for im in imlist[1:]:
+ for key in keys:
+ imavg._imdict[key] += im._imdict[key]
+
+ for key in keys:
+ imavg._imdict[key] /= float(len(imlist))
+
+
+ return imavg
+
+def get_specim(imlist, reffreq, fit_order=2):
+ """get the spectral index/curvature from a list of images"""
+ freqs = [im.rf for im in imlist]
+
+ # remove any zeros in the images
+ for im in imlist:
+ im.imvec[im.imvec<=0] = np.min(im.imvec[im.imvec!=0])
+
+ # fit
+ xfit = np.log(np.array(freqs)/reffreq)
+ yfit = np.log(np.array([im.imvec for im in imlist]))
+
+ if fit_order == 2:
+ coeffs = np.polyfit(xfit,yfit,2)
+ beta = coeffs[0]
+ alpha = coeffs[1]
+ imvec = np.exp(coeffs[2])
+ elif fit_order == 1:
+ coeffs = np.polyfit(xfit,yfit,1)
+ alpha = coeffs[0]
+ beta = 0*alpha
+ imvec = np.exp(coeffs[1])
+ else:
+ raise Exception()
+
+ outim = imlist[0].copy() #TODO no polarization
+ outim.imvec = imvec
+ outim.rf = reffreq
+ outim.specvec = alpha
+ outim.curvvec = beta
+
+ return outim
+
+
+def blur_mf(im,freqs,kernel,fit_order=2):
+ """blur multifrequncy images with the same beam"""
+ reffreq = im.rf
+
+ # remove any zeros in the images
+
+
+ imlist = [im.get_image_mf(rf).blur_circ(kernel) for rf in freqs]
+ for image in imlist:
+ image.imvec[image.imvec<=0] = np.min(image.imvec[image.imvec!=0])
+
+ xfit = np.log(np.array(freqs)/reffreq)
+ yfit = np.log(np.array([im.imvec for im in imlist]))
+
+ if fit_order == 2:
+ coeffs = np.polyfit(xfit,yfit,2)
+ beta = coeffs[0]
+ alpha = coeffs[1]
+ elif fit_order == 1:
+ coeffs = np.polyfit(xfit,yfit,1)
+ alpha = coeffs[0]
+ beta = 0*alpha
+ else:
+ alpha = 0*yfit
+ beta = 0*yfit
+
+ outim = im.blur_circ(kernel)
+ outim.specvec = alpha
+ outim.curvvec = beta
+ return outim
diff --git a/imager.py b/imager.py
new file mode 100644
index 00000000..bae00748
--- /dev/null
+++ b/imager.py
@@ -0,0 +1,2010 @@
+# imager.py
+# a general interferometric imager class
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import copy
+import time
+import numpy as np
+import scipy.optimize as opt
+
+import ehtim.scattering as so
+import ehtim.imaging.imager_utils as imutils
+import ehtim.imaging.pol_imager_utils as polutils
+import ehtim.imaging.multifreq_imager_utils as mfutils
+import ehtim.image
+import ehtim.const_def as ehc
+
+MAXIT = 200 # number of iterations
+NHIST = 50 # number of steps to store for hessian approx
+MAXLS = 40 # maximum number of line search steps in BFGS-B
+STOP = 1e-6 # convergence criterion
+EPS = 1e-8
+
+DATATERMS = ['vis', 'bs', 'amp', 'cphase', 'cphase_diag', 'camp', 'logcamp', 'logcamp_diag']
+REGULARIZERS = ['gs', 'tv', 'tvlog','tv2', 'tv2log', 'l1', 'l1w', 'lA', 'patch',
+ 'flux', 'cm', 'simple', 'compact', 'compact2', 'rgauss']
+REGULARIZERS_SPECIND = ['l2_alpha', 'tv_alpha']
+REGULARIZERS_CURV = ['l2_beta', 'tv_beta']
+
+
+
+DATATERMS_POL = ['pvis', 'm', 'pbs','vvis']
+REGULARIZERS_POL = ['msimple', 'hw', 'ptv','l1v','l2v','vtv','vtv2','vflux']
+
+GRIDDER_P_RAD_DEFAULT = 2
+GRIDDER_CONV_FUNC_DEFAULT = 'gaussian'
+FFT_PAD_DEFAULT = 2
+FFT_INTERP_DEFAULT = 3
+
+REG_DEFAULT = {'simple': 1}
+DAT_DEFAULT = {'vis': 100}
+
+POL_TRANS = True # this means we solve for polarization in the m, chi basis
+#POL_WHICH_SOLVE = (0, 1, 1) # this means that pol imaging solves for m & chi (not I), for now
+ # not used, now determined by 'pol_next'
+MF_WHICH_SOLVE = (1, 1, 0) # this means that mf imaging solves for I0 and alpha (not beta), for now
+ # DEFAULT ONLY: object now uses self.mf_which_solve
+
+REGPARAMS_DEFAULT = {'major':50*ehc.RADPERUAS,
+ 'minor':50*ehc.RADPERUAS,
+ 'PA':0.,
+ 'alpha_A':1.0,
+ 'epsilon_tv':0.0}
+
+POLARIZATION_MODES = ['P','QU','IP','IQU','V','IV','IQUV','IPV'] # TODO: treatment of V may be inconsistent
+
+###################################################################################################
+# Imager object
+###################################################################################################
+
+
+class Imager(object):
+ """A general interferometric imager.
+ """
+
+ def __init__(self, obs_in, init_im,
+ prior_im=None, flux=None, data_term=DAT_DEFAULT, reg_term=REG_DEFAULT, **kwargs):
+
+ self.logstr = ""
+ self._obs_list = []
+ self._init_list = []
+ self._prior_list = []
+ self._out_list = []
+ self._out_list_epsilon = []
+ self._out_list_scattered = []
+ self._reg_term_list = []
+ self._dat_term_list = []
+ self._clipfloor_list = []
+ self._maxset_list = []
+ self._pol_list = []
+ self._maxit_list = []
+ self._stop_list = []
+ self._flux_list = []
+ self._pflux_list = []
+ self._vflux_list = []
+ self._snrcut_list = []
+ self._debias_list = []
+ self._systematic_noise_list = []
+ self._systematic_cphase_noise_list = []
+ self._transform_list = []
+ self._weighting_list = []
+
+ # Regularizer/data terms for the next imaging iteration
+ self.reg_term_next = reg_term # e.g. [('simple',1), ('l1',10), ('flux',500), ('cm',500)]
+ self.dat_term_next = data_term # e.g. [('amp', 1000), ('cphase',100)]
+
+ # Observations, frequencies
+ self.reffreq = init_im.rf
+ if isinstance(obs_in, list):
+ self._obslist_next = obs_in
+ self.obslist_next = obs_in
+ else:
+ self._obslist_next = [obs_in]
+ self.obslist_next = [obs_in]
+
+ # Init, prior, flux
+ self.init_next = init_im
+
+ if prior_im is None:
+ self.prior_next = self.init_next
+ else:
+ self.prior_next = prior_im
+
+ if flux is None:
+ self.flux_next = self.prior_next.total_flux()
+ else:
+ self.flux_next = flux
+
+ # set polarimetric flux values equal to Stokes I flux by default
+ # used in regularizer normalization
+ self.pflux_next = kwargs.get('pflux', flux)
+ self.vflux_next = kwargs.get('vflux', flux)
+
+ # Polarization
+ self.pol_next = kwargs.get('pol', self.init_next.pol_prim)
+
+ # Weighting/debiasing/snr cut/systematic noise
+ self.debias_next = kwargs.get('debias', True)
+ snrcut = kwargs.get('snrcut', 0.)
+ self.snrcut_next = {key: 0. for key in set(DATATERMS+DATATERMS_POL)}
+
+ if type(snrcut) is dict:
+ for key in snrcut.keys():
+ self.snrcut_next[key] = snrcut[key]
+ else:
+ for key in self.snrcut_next.keys():
+ self.snrcut_next[key] = snrcut
+
+ self.systematic_noise_next = kwargs.get('systematic_noise', 0.)
+ self.systematic_cphase_noise_next = kwargs.get('systematic_cphase_noise', 0.)
+ self.weighting_next = kwargs.get('weighting', 'natural')
+
+ # Maximal/minimal closure set
+ self.maxset_next = kwargs.get('maxset', False)
+
+ # Clippping
+ self.clipfloor_next = kwargs.get('clipfloor', 0.)
+ self.maxit_next = kwargs.get('maxit', MAXIT)
+ self.stop_next = kwargs.get('stop', STOP)
+ self.transform_next = kwargs.get('transform', ['log','mcv'])
+ self.transform_next = np.array([self.transform_next]).flatten() #so we can handle multiple transforms
+
+ # Normalize or not?
+ self.norm_init = kwargs.get('norm_init', True)
+ self.norm_reg = kwargs.get('norm_reg', False)
+ self.beam_size = self.obslist_next[0].res()
+ self.regparams = {k: kwargs.get(k, REGPARAMS_DEFAULT[k]) for k in REGPARAMS_DEFAULT.keys()}
+
+ self.chisq_transform = False
+ self.chisq_offset_gradient = 0.0
+
+ # FFT parameters
+ self._ttype = kwargs.get('ttype', 'nfft')
+ self._fft_gridder_prad = kwargs.get('fft_gridder_prad', GRIDDER_P_RAD_DEFAULT)
+ self._fft_conv_func = kwargs.get('fft_conv_func', GRIDDER_CONV_FUNC_DEFAULT)
+ self._fft_pad_factor = kwargs.get('fft_pad_factor', FFT_PAD_DEFAULT)
+ self._fft_interp_order = kwargs.get('fft_interp_order', FFT_INTERP_DEFAULT)
+
+ # UV minimum for closure phases
+ self.cp_uv_min = kwargs.get('cp_uv_min', False)
+
+ # Parameters related to scattering
+ self.epsilon_list_next = []
+ self.scattering_model = kwargs.get('scattering_model', None)
+ self._sqrtQ = None
+ self._ea_ker = None
+ self._ea_ker_gradient_x = None
+ self._ea_ker_gradient_y = None
+ self._alpha_phi_list = []
+ self.alpha_phi_next = kwargs.get('alpha_phi', 1e4)
+
+ # Imager history
+ self._change_imgr_params = True
+ self.nruns = 0
+
+ # multifrequency
+ self.mf_next = False
+ self.reg_all_freq_mf = kwargs.get('reg_all_freq_mf',False)
+ self.mf_which_solve = kwargs.get('mf_which_solve',MF_WHICH_SOLVE)
+
+ # Set embedding matrices and prepare imager
+ self.check_params()
+ self.check_limits()
+ self.init_imager()
+
+ @property
+ def obslist_next(self):
+ return self._obslist_next
+
+ @obslist_next.setter
+ def obslist_next(self, obslist):
+ if not isinstance(obslist, list):
+ raise Exception("obslist_next must be a list!")
+ self._obslist_next = obslist
+ self.freq_list = [obs.rf for obs in self.obslist_next]
+ #self.reffreq = self.freq_list[0] #Changed so that reffreq is determined by initial image/prior rf
+ self._logfreqratio_list = [np.log(nu/self.reffreq) for nu in self.freq_list]
+
+ @property
+ def obs_next(self):
+ """the next Obsdata to be used in imaging
+ """
+ return self.obslist_next[0]
+
+ @obs_next.setter
+ def obs_next(self, obs):
+ """the next Obsdata to be used in imaging
+ """
+ self.obslist_next = [obs]
+
+ def make_image(self, pol=None, grads=True, mf=False, **kwargs):
+ """Make an image using current imager settings.
+
+ Args:
+ pol (str): which polarization to image
+ grads (bool): whether or not to use image gradients
+ mf (bool): whether or not to do multifrequency (spectral index only for now)
+
+ Returns:
+ (Image): output image
+
+ """
+
+ self.mf_next = mf
+ self.reg_all_freq_mf = kwargs.get('reg_all_freq_mf', self.reg_all_freq_mf)
+ self.mf_which_solve = kwargs.get('mf_which_solve', self.mf_which_solve)
+
+ if pol is None:
+ pol_prim = self.pol_next
+ else:
+ self.pol_next = pol
+ pol_prim = pol
+
+ print("==============================")
+ print("Imager run %i " % (int(self.nruns)+1))
+
+ # For polarimetric imaging, switch polrep to Stokes
+ if self.pol_next in POLARIZATION_MODES:
+ print("Imaging Polarization: switching to Stokes!")
+ self.prior_next = self.prior_next.switch_polrep(polrep_out='stokes', pol_prim_out='I')
+ self.init_next = self.init_next.switch_polrep(polrep_out='stokes', pol_prim_out='I')
+ pol_prim = 'I'
+
+ # Checks and initialize
+ self.check_params()
+ self.check_limits()
+ self.init_imager()
+
+ # Print initial stats
+ self._nit = 0
+ self._show_updates = kwargs.get('show_updates', True)
+ self._update_interval = kwargs.get('update_interval', 1)
+
+ # Plot initial image
+ self.plotcur(self._xinit, **kwargs)
+
+ # Minimize
+ optdict = {'maxiter': self.maxit_next,
+ 'ftol': self.stop_next, 'gtol': self.stop_next,
+ 'maxcor': NHIST, 'maxls': MAXLS}
+ def callback_func(xcur):
+ self.plotcur(xcur, **kwargs)
+
+ print("Imaging . . .")
+ tstart = time.time()
+ if grads:
+ res = opt.minimize(self.objfunc, self._xinit, method='L-BFGS-B', jac=self.objgrad,
+ options=optdict, callback=callback_func)
+ else:
+ res = opt.minimize(self.objfunc, self._xinit, method='L-BFGS-B',
+ options=optdict, callback=callback_func)
+ tstop = time.time()
+
+ # Format output
+ out = res.x[:]
+ self.tmpout = res.x
+
+ if self.pol_next in POLARIZATION_MODES: # polarization
+ if self.pol_next == 'P':
+ out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (0,1,1))
+ if 'mcv' in self.transform_next:
+ out = polutils.mcv(out)
+
+ elif self.pol_next == 'IP' or self.pol_next == 'IQU':
+ out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (1,1,1))
+ if 'mcv' in self.transform_next:
+ out = polutils.mcv(out)
+ if 'log' in self.transform_next:
+ out[0] = np.exp(out[0])
+
+ elif self.pol_next == 'V':
+ out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (0,0,0,1))
+ if 'mcv' in self.transform_next:
+ out = polutils.mcv(out)
+
+ elif self.pol_next == 'IV':
+ out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (1,0,0,1))
+ if 'mcv' in self.transform_next:
+ out = polutils.mcv(out)
+ if 'log' in self.transform_next:
+ out[0] = np.exp(out[0])
+
+ elif self.pol_next == 'IQUV':
+ out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (1,1,1,1))
+ if 'mcv' in self.transform_next:
+ out = polutils.mcv(out)
+ if 'log' in self.transform_next:
+ out[0] = np.exp(out[0])
+
+ elif self.mf_next: # multi-frequency
+ out = mfutils.unpack_mftuple(out, self._xtuple, self._nimage, self.mf_which_solve)
+ if 'log' in self.transform_next:
+ out[0] = np.exp(out[0])
+
+ elif 'log' in self.transform_next: # simple single-frequency
+ out = np.exp(out)
+
+ # Print final stats
+ outstr = ""
+ chi2_term_dict = self.make_chisq_dict(out)
+ for dname in sorted(self.dat_term_next.keys()):
+ for i, obs in enumerate(self.obslist_next):
+ if len(self.obslist_next)==1:
+ dname_key = dname
+ else:
+ dname_key = dname + ('_%i' % i)
+ outstr += "chi2_%s : %0.2f " % (dname_key, chi2_term_dict[dname_key])
+
+ try:
+ print("time: %f s" % (tstop - tstart))
+ print("J: %f" % res.fun)
+ print(outstr)
+ if isinstance(res.message,str): print(res.message)
+ else: print(res.message.decode())
+ except: # TODO -- issues for some users with res.message
+ pass
+
+ print("==============================")
+
+ # Embed image
+ if self.pol_next in POLARIZATION_MODES: # polarization
+ if np.any(np.invert(self._embed_mask)):
+ out = polutils.embed_pol(out, self._embed_mask)
+ iimage_out = out[0]
+ qimage_out = polutils.make_q_image(out, POL_TRANS)
+ uimage_out = polutils.make_u_image(out, POL_TRANS)
+ vimage_out = polutils.make_v_image(out, POL_TRANS)
+
+ elif self.mf_next: # multi-frequency
+ if np.any(np.invert(self._embed_mask)):
+ out = mfutils.embed_mf(out, self._embed_mask)
+ iimage_out = out[0]
+ specind_out = out[1]
+ curv_out = out[2]
+
+ else: # simple single-pol
+ if np.any(np.invert(self._embed_mask)):
+ out = imutils.embed(out, self._embed_mask)
+ iimage_out = out
+
+ # Return image
+ arglist, argdict = self.prior_next.image_args()
+ arglist[0] = iimage_out.reshape(self.prior_next.ydim, self.prior_next.xdim)
+ argdict['pol_prim'] = pol_prim
+ outim = ehtim.image.Image(*arglist, **argdict)
+
+ # Copy over other polarizations
+ for pol2 in list(outim._imdict.keys()):
+
+ # Is it the base image?
+ if pol2 == outim.pol_prim:
+ continue
+
+ # Did we solve for polarimeric image or are we copying over old pols?
+ if self.pol_next in POLARIZATION_MODES and pol2 == 'Q':
+ polvec = qimage_out
+ elif self.pol_next in POLARIZATION_MODES and pol2 == 'U':
+ polvec = uimage_out
+ elif self.pol_next in POLARIZATION_MODES and pol2 == 'V':
+ polvec = vimage_out
+ else:
+ polvec = self.init_next._imdict[pol2]
+
+ if len(polvec):
+ polarr = polvec.reshape(outim.ydim, outim.xdim)
+ outim.add_pol_image(polarr, pol2)
+
+ # Copy over spectral index information
+ outim._mflist = copy.deepcopy(self.init_next._mflist)
+ if self.mf_next:
+ outim._mflist[0] = specind_out
+ outim._mflist[1] = curv_out
+
+ # Append to history
+ logstr = str(self.nruns) + ": make_image(pol=%s)" % pol
+ self._append_image_history(outim, logstr)
+ self.nruns += 1
+
+ # Return Image object
+ return outim
+
+ def converge(self, niter, blur_frac, pol, grads=True, **kwargs):
+
+ blur = blur_frac * self.obs_next.res()
+ for repeat in range(niter-1):
+ init = self.out_last()
+ init = init.blur_circ(blur, blur)
+ self.init_next = init
+ self.make_image(pol=pol, grads=grads, **kwargs)
+
+
+ def make_image_I(self, grads=True, niter=1, blur_frac=1, **kwargs):
+ """Make Stokes I image using current imager settings.
+ """
+ pol = 'I'
+ self.make_image(pol=pol, grads=grads, **kwargs)
+ self.converge(niter, blur_frac, pol, grads, **kwargs)
+
+ return self.out_last()
+
+
+ def make_image_P(self, grads=True, niter=1, blur_frac=1, **kwargs):
+ """Make Stokes P polarimetric image using current imager settings.
+ """
+ pol = 'P'
+ self.make_image(pol=pol, grads=grads, **kwargs)
+ self.converge(niter, blur_frac, pol, grads, **kwargs)
+
+ return self.out_last()
+
+
+ def make_image_IP(self, grads=True, niter=1, blur_frac=1, **kwargs):
+ """Make Stokes I and P polarimetric image simultaneously using current imager settings.
+ """
+ pol = 'IP'
+ self.make_image(pol=pol, grads=grads, **kwargs)
+ self.converge(niter, blur_frac, pol, grads, **kwargs)
+
+ return self.out_last()
+
+ def make_image_V(self, grads=True, niter=1, blur_frac=1, **kwargs):
+ """Make Stokes I image using current imager settings.
+ """
+ pol = 'V'
+ self.make_image(pol=pol, grads=grads, **kwargs)
+ self.converge(niter, blur_frac, pol, grads, **kwargs)
+
+ return self.out_last()
+
+ def make_image_IV(self, grads=True, niter=1, blur_frac=1, **kwargs):
+ """Make Stokes I image using current imager settings.
+ """
+ pol = 'IV'
+ self.make_image(pol=pol, grads=grads, **kwargs)
+ self.converge(niter, blur_frac, pol, grads, **kwargs)
+
+ return self.out_last()
+
+ def set_embed(self):
+ """Set embedding matrix.
+ """
+ self._embed_mask = self.prior_next.imvec > self.clipfloor_next
+ if not np.any(self._embed_mask):
+ raise Exception("clipfloor_next too large: all prior pixels have been clipped!")
+
+ xmax = self.prior_next.xdim//2
+ ymax = self.prior_next.ydim//2
+
+ if self.prior_next.xdim % 2: xmin=-xmax-1
+ else: xmin=-xmax
+
+ if self.prior_next.ydim % 2: ymin=-ymax-1
+ else: ymin=-ymax
+
+ coord = np.array([[[x, y]
+ for x in np.arange(xmax, xmin, -1)]
+ for y in np.arange(ymax, ymin, -1)])
+
+ coord = coord.reshape(self.prior_next.ydim * self.prior_next.xdim, 2)
+ coord = coord * self.prior_next.psize
+
+ self._coord_matrix = coord[self._embed_mask]
+
+ return
+
+ def check_params(self):
+ """Check parameter consistency.
+ """
+ if ((self.prior_next.psize != self.init_next.psize) or
+ (self.prior_next.xdim != self.init_next.xdim) or
+ (self.prior_next.ydim != self.init_next.ydim)):
+ raise Exception("Initial image does not match dimensions of the prior image!")
+
+ if ((self.prior_next.rf != self.init_next.rf)):
+ raise Exception("Initial image does not have same frequency as prior image!")
+
+ if (self.prior_next.polrep != self.init_next.polrep):
+ raise Exception(
+ "Initial image polrep does not match prior polrep!")
+
+ if (self.prior_next.polrep == 'circ' and not(self.pol_next in ['RR', 'LL'])):
+ raise Exception("Initial image polrep is 'circ': pol_next must be 'RR' or 'LL'")
+
+ if (self.prior_next.polrep == 'stokes' and not(self.pol_next in ['I', 'Q', 'U', 'V', 'P','IP','IQU','IV','IQUV'])):
+ raise Exception(
+ "Initial image polrep is 'stokes': pol_next must be in 'I', 'Q', 'U', 'V', 'P','IP','IQU','IV','IQUV'!")
+
+ # TODO single-polarization imaging. should we still support?
+ if ('log' in self.transform_next and self.pol_next in ['Q', 'U', 'V']):
+ raise Exception("Cannot image Stokes Q, U, V with log image transformation!")
+
+ if(self.pol_next in ['Q', 'U', 'V'] and
+ ('gs' in self.reg_term_next.keys() or 'simple' in self.reg_term_next.keys())):
+ raise Exception(
+ "'simple' and 'gs' methods do not work with Stokes Q, U, or V images!")
+
+ if self._ttype not in ['fast', 'direct', 'nfft']:
+ raise Exception("Possible ttype values are 'fast', 'direct','nfft'!")
+
+ # Catch errors in multifrequency imaging setup
+ if self.mf_next and len(set(self.freq_list)) < 2:
+ raise Exception(
+ "must have observations at at least two frequencies for multifrequency imaging!")
+
+ # Catch errors for polarimetric imaging setup
+ if self.pol_next in POLARIZATION_MODES:
+ if 'mcv' not in self.transform_next:
+ raise Exception("Polarimetric imaging needs 'mcv' transform!")
+ if (self._ttype not in ["direct", "nfft"]):
+ raise Exception("FFT not yet implemented in polarimetric imaging -- use NFFT!")
+ if 'I' in self.pol_next:
+ rlist = REGULARIZERS + REGULARIZERS_POL
+ dlist = DATATERMS + DATATERMS_POL
+ else:
+ rlist = REGULARIZERS_POL
+ dlist = DATATERMS_POL
+ else:
+ rlist = REGULARIZERS + REGULARIZERS_SPECIND + REGULARIZERS_CURV
+ dlist = DATATERMS
+
+ # catch errors in general imaging setup
+ dt_here = False
+ dt_type = True
+ for term in sorted(self.dat_term_next.keys()):
+ if (term is not None) and (term is not False):
+ dt_here = True
+ if not ((term in dlist) or (term is False)):
+ dt_type = False
+
+ st_here = False
+ st_type = True
+ for term in sorted(self.reg_term_next.keys()):
+ if (term is not None) and (term is not False):
+ st_here = True
+ if not ((term in rlist) or (term is False)):
+ st_type = False
+
+ if not dt_here:
+ raise Exception("Must have at least one data term!")
+ if not st_here:
+ raise Exception("Must have at least one regularizer term!")
+ if not dt_type:
+ raise Exception("Invalid data term: valid data terms are: " + ','.join(dlist))
+ if not st_type:
+ raise Exception("Invalid regularizer: valid regularizers are: " + ','.join(rlist))
+
+
+ # Determine if we need to recompute the saved imager parameters on the next imager run
+ if self.nruns == 0:
+ return
+
+ if self.pol_next != self.pol_last():
+ print("changed polarization!")
+ self._change_imgr_params = True
+ return
+
+ if self.obslist_next != self.obslist_last():
+ print("changed observation!")
+ self._change_imgr_params = True
+ return
+
+ if len(self.reg_term_next) != len(self.reg_terms_last()):
+ print("changed number of regularizer terms!")
+ self._change_imgr_params = True
+ return
+
+ if len(self.dat_term_next) != len(self.dat_terms_last()):
+ print("changed number of data terms!")
+ self._change_imgr_params = True
+ return
+
+ for term in sorted(self.dat_term_next.keys()):
+ if term not in self.dat_terms_last().keys():
+ print("added %s to data terms" % term)
+ self._change_imgr_params = True
+ return
+
+ for term in sorted(self.reg_term_next.keys()):
+ if term not in self.reg_terms_last().keys():
+ print("added %s to regularizers!" % term)
+ self._change_imgr_params = True
+ return
+
+ if ((self.prior_next.psize != self.prior_last().psize) or
+ (self.prior_next.xdim != self.prior_last().xdim) or
+ (self.prior_next.ydim != self.prior_last().ydim)):
+ print("changed prior dimensions!")
+ self._change_imgr_params = True
+
+ if self.debias_next != self.debias_last():
+ print("changed debiasing!")
+ self._change_imgr_params = True
+ return
+ if self.snrcut_next != self.snrcut_last():
+ print("changed snrcut!")
+ self._change_imgr_params = True
+ return
+ if self.weighting_next != self.weighting_last():
+ print("changed data weighting!")
+ self._change_imgr_params = True
+ return
+ if self.systematic_noise_next != self.systematic_noise_last():
+ print("changed systematic noise!")
+ self._change_imgr_params = True
+ return
+ if self.systematic_cphase_noise_next != self.systematic_cphase_noise_last():
+ print("changed systematic cphase noise!")
+ self._change_imgr_params = True
+ return
+
+ def check_limits(self):
+ """Check image parameter consistency with observation.
+ """
+ uvmax = 1.0/self.prior_next.psize
+ uvmin = 1.0/(self.prior_next.psize*np.max((self.prior_next.xdim, self.prior_next.ydim)))
+ uvdists = self.obs_next.unpack('uvdist')['uvdist']
+ maxbl = np.max(uvdists)
+ minbl = np.max(uvdists[uvdists > 0])
+
+ if uvmax < maxbl:
+ print("Warning! Pixel size is larger than smallest spatial wavelength!")
+ if uvmin > minbl:
+ print("Warning! Field of View is smaller than largest nonzero spatial wavelength!")
+
+ if self.pol_next in ['I', 'RR', 'LL']:
+ maxamp = np.max(np.abs(self.obs_next.unpack('amp')['amp']))
+ if self.flux_next > 1.2*maxamp:
+ print("Warning! Specified flux is > 120% of maximum visibility amplitude!")
+ if self.flux_next < .8*maxamp:
+ print("Warning! Specified flux is < 80% of maximum visibility amplitude!")
+
+ def reg_terms_last(self):
+ """Return last used regularizer terms.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._reg_term_list[-1]
+
+ def dat_terms_last(self):
+ """Return last used data terms.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._dat_term_list[-1]
+
+ def obslist_last(self):
+ """Return last used observation.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._obs_list[-1]
+
+ def obs_last(self):
+ """Return last used observation.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._obs_list[-1][0]
+
+ def prior_last(self):
+ """Return last used prior image.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._prior_list[-1]
+
+ def out_last(self):
+ """Return last result.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._out_list[-1]
+
+ def out_scattered_last(self):
+ """Return last result with scattering.
+ """
+ if self.nruns == 0 or len(self._out_list_scattered) == 0:
+ print("No stochastic optics imager runs yet!")
+ return
+ return self._out_list_scattered[-1]
+
+ def out_epsilon_last(self):
+ """Return last result with scattering.
+ """
+ if self.nruns == 0 or len(self._out_list_epsilon) == 0:
+ print("No stochastic optics imager runs yet!")
+ return
+ return self._out_list_epsilon[-1]
+
+ def init_last(self):
+ """Return last initial image.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._init_list[-1]
+
+ def flux_last(self):
+ """Return last total flux constraint.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._flux_list[-1]
+
+ def pflux_last(self):
+ """Return last total linear polarimetric flux constraint.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._pflux_list[-1]
+
+ def vflux_last(self):
+ """Return last total circular polarimetric flux constraint.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._vflux_list[-1]
+
+ def clipfloor_last(self):
+ """Return last clip floor.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._clipfloor_list[-1]
+
+ def pol_last(self):
+ """Return last polarization imaged.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._pol_list[-1]
+
+ def maxit_last(self):
+ """Return last max_iterations value.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._maxit_list[-1]
+
+ def debias_last(self):
+ """Return last debias value.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._debias_list[-1]
+
+ def snrcut_last(self):
+ """Return last snrcut value.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._snrcut_list[-1]
+
+ def weighting_last(self):
+ """Return last weighting value.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._weighting_list[-1]
+
+ def systematic_noise_last(self):
+ """Return last systematic_noise value.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._systematic_noise_list[-1]
+
+ def systematic_cphase_noise_last(self):
+ """Return last closure phase systematic noise value (in degree).
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._systematic_cphase_noise_list[-1]
+
+ def stop_last(self):
+ """Return last convergence value.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._stop_list[-1]
+
+ def transform_last(self):
+ """Return last image transform used.
+ """
+ if self.nruns == 0:
+ print("No imager runs yet!")
+ return
+ return self._transform_list[-1]
+
+ def init_imager(self):
+ """Set up Stokes I imager.
+ """
+ # Set embedding
+ self.set_embed()
+
+ # Set prior & initial image vectors for polarimetric imaging
+ if self.pol_next in POLARIZATION_MODES:
+
+ # initial I image
+ if self.norm_init and ('I' in self.pol_next):
+ self._nprior = (self.flux_next * self.prior_next.imvec /
+ np.sum((self.prior_next.imvec)[self._embed_mask]))[self._embed_mask]
+ iinit = (self.flux_next * self.init_next.imvec /
+ np.sum((self.init_next.imvec)[self._embed_mask]))[self._embed_mask]
+ else:
+ self._nprior = self.prior_next.imvec[self._embed_mask]
+ iinit = self.init_next.imvec[self._embed_mask]
+
+ self._nimage = len(iinit)
+
+ # Initialize m & phi & v
+ if (len(self.init_next.qvec) and
+ (np.any(self.init_next.qvec != 0) or np.any(self.init_next.uvec != 0))):
+
+ init1 = np.abs(self.init_next.qvec + 1j*self.init_next.uvec) / self.init_next.imvec
+ init1 = init1[self._embed_mask]
+ init2 = (np.arctan2(self.init_next.uvec, self.init_next.qvec) / 2.0)
+ init2 = init2[self._embed_mask]
+ else:
+ # !AC TODO get the actual zero baseline polarization fraction from the data?
+ print("No polarimetric image in init_next!")
+ print("--initializing with 20% pol and random orientation!")
+ init1 = 0.2 * (np.ones(self._nimage) + 1e-2 * np.random.rand(self._nimage))
+ init2 = np.zeros(self._nimage) + 1e-2 * np.random.rand(self._nimage)
+
+ # Initialize v
+ if 'V' in self.pol_next:
+ if len(self.init_next.vvec) and (np.any(self.init_next.vvec != 0)):
+ init3 = self.init_next.vvec / self.init_next.imvec
+ init3 = init3[self._embed_mask]
+ else:
+ # !AC TODO get the actual zero baseline polarization fraction from the data?
+ print("No V polarimetric image in init_next!")
+ print("--initializing with random vector")
+ #init3 = 0.05 * np.random.randn(self._nimage)
+ init3 = 0.01 * (np.ones(self._nimage) + 1e-2 * np.random.rand(self._nimage))
+ self._inittuple = np.array((iinit, init1, init2, init3))
+ else:
+ self._inittuple = np.array((iinit, init1, init2))
+
+ # Change of variables
+ if 'mcv' in self.transform_next:
+ self._xtuple = polutils.mcv_r(self._inittuple)
+ else:
+ raise Exception("Polarimetric imaging only works with mcv transform!")
+
+ # Only apply log transformation to Stokes I if simultaneous imaging
+ if ('log' in self.transform_next) and ('I' in self.pol_next):
+ self._xtuple[0] = np.log(self._xtuple[0])
+
+ # Determine pol_which_solve
+ if self.pol_next in ['P','QU']:
+ self._pol_which_solve = (0,1,1)
+ elif self.pol_next in ['IP','IQU']:
+ self._pol_which_solve = (1,1,1)
+ elif self.pol_next in ['V']:
+ self._pol_which_solve = (0,0,0,1)
+ elif self.pol_next in ['IV']:
+ self._pol_which_solve = (1,0,0,1)
+ elif self.pol_next in ['IQUV']:
+ self._pol_which_solve = (1,1,1,1)
+ else:
+ raise Exception("Do not know correct pol_which_solve for self.pol_next=%s!"%self.pol_next)
+
+ # Pack into single vector
+ self._xinit = polutils.pack_poltuple(self._xtuple, self._pol_which_solve)
+
+ # Set prior & initial image vectors for multifrequency imaging
+ elif self.mf_next:
+
+ self.reffreq = self.init_next.rf # set reference frequency to same as prior
+ # reset logfreqratios in case reference frequency changed
+ self._logfreqratio_list = [np.log(nu/self.reffreq) for nu in self.freq_list]
+
+ if self.norm_init:
+ nprior_I = (self.flux_next * self.prior_next.imvec /
+ np.sum((self.prior_next.imvec)[self._embed_mask]))[self._embed_mask]
+ ninit_I = (self.flux_next * self.init_next.imvec /
+ np.sum((self.init_next.imvec)[self._embed_mask]))[self._embed_mask]
+ else:
+ nprior_I = self.prior_next.imvec[self._embed_mask]
+ ninit_I = self.init_next.imvec[self._embed_mask]
+
+ if len(self.init_next.specvec):
+ ninit_a = self.init_next.specvec[self._embed_mask]
+ else:
+ ninit_a = np.zeros(self._nimage)[self._embed_mask]
+ if len(self.prior_next.specvec):
+ nprior_a = self.prior_next.specvec[self._embed_mask]
+ else:
+ nprior_a = np.zeros(self._nimage)[self._embed_mask]
+
+ if len(self.init_next.curvvec):
+ ninit_b = self.init_next.curvvec[self._embed_mask]
+ else:
+ ninit_b = np.zeros(self._nimage)[self._embed_mask]
+ if len(self.prior_next.curvvec):
+ nprior_b = self.init_next.curvvec[self._embed_mask]
+ else:
+ nprior_b = np.zeros(self._nimage)[self._embed_mask]
+
+ self._nimage = len(ninit_I)
+
+ self.inittuple = np.array((ninit_I, ninit_a, ninit_b))
+ self.priortuple = np.array((nprior_I, nprior_a, nprior_b))
+
+ # Change of variables
+ if 'log' in self.transform_next:
+ self._xtuple = np.array((np.log(ninit_I), ninit_a, ninit_b))
+ else:
+ self._xtuple = self.inittuple
+
+ # Pack into single vector
+ self._xinit = mfutils.pack_mftuple(self._xtuple, self.mf_which_solve)
+
+ # Set prior & initial image vectors for single stokes or RR, LL imaging
+ else:
+
+ if self.norm_init:
+ self._nprior = (self.flux_next * self.prior_next.imvec /
+ np.sum((self.prior_next.imvec)[self._embed_mask]))[self._embed_mask]
+ ninit = (self.flux_next * self.init_next.imvec /
+ np.sum((self.init_next.imvec)[self._embed_mask]))[self._embed_mask]
+ else:
+ self._nprior = self.prior_next.imvec[self._embed_mask]
+ ninit = self.init_next.imvec[self._embed_mask]
+
+ self._nimage = len(ninit)
+ # Change of variables
+ if 'log' in self.transform_next:
+ self._xinit = np.log(ninit)
+ else:
+ self._xinit = ninit
+
+ # Make data term tuples
+ if self._change_imgr_params:
+ if self.nruns == 0:
+ print("Initializing imager data products . . .")
+ if self.nruns > 0:
+ print("Recomputing imager data products . . .")
+
+ self._data_tuples = {}
+
+ # Loop over all data term types
+ for dname in sorted(self.dat_term_next.keys()):
+
+ # Loop over all observations in the list
+ for i, obs in enumerate(self.obslist_next):
+ # Each entry in the dterm dictionary past the first has an appended number
+ if len(self.obslist_next)==1:
+ dname_key = dname
+ else:
+ dname_key = dname + ('_%i' % i)
+
+ # Polarimetric data products
+ if dname in DATATERMS_POL:
+ tup = polutils.polchisqdata(obs, self.prior_next, self._embed_mask, dname,
+ ttype=self._ttype,
+ fft_pad_factor=self._fft_pad_factor,
+ conv_func=self._fft_conv_func,
+ p_rad=self._fft_gridder_prad)
+
+ # Single polarization data products
+ elif dname in DATATERMS:
+ if self.pol_next in POLARIZATION_MODES:
+ if not 'I' in self.pol_next:
+ raise Exception("cannot use dterm %s with pol=%s"%(dname,self.pol_next))
+ pol_next = 'I'
+ else:
+ pol_next = self.pol_next
+
+ tup = imutils.chisqdata(obs, self.prior_next, self._embed_mask, dname,
+ pol=pol_next, maxset=self.maxset_next,
+ debias=self.debias_next,
+ snrcut=self.snrcut_next[dname],
+ weighting=self.weighting_next,
+ systematic_noise=self.systematic_noise_next,
+ systematic_cphase_noise=self.systematic_cphase_noise_next,
+ ttype=self._ttype, order=self._fft_interp_order,
+ fft_pad_factor=self._fft_pad_factor,
+ conv_func=self._fft_conv_func,
+ p_rad=self._fft_gridder_prad,
+ cp_uv_min=self.cp_uv_min)
+ else:
+ raise Exception("data term %s not recognized!" % dname)
+
+ self._data_tuples[dname_key] = tup
+
+ self._change_imgr_params = False
+
+ return
+
+ def init_imager_scattering(self):
+ """Set up scattering imager.
+ """
+ N = self.prior_next.xdim
+
+ if self.scattering_model is None:
+ self.scattering_model = so.ScatteringModel()
+
+ # First some preliminary definitions
+ wavelength = ehc.C/self.obs_next.rf*100.0 # Observing wavelength [cm]
+ N = self.prior_next.xdim
+
+ # Field of view, in cm, at the scattering screen
+ FOV = self.prior_next.psize * N * self.scattering_model.observer_screen_distance
+
+ # The ensemble-average convolution kernel and its gradients
+ self._ea_ker = self.scattering_model.Ensemble_Average_Kernel(
+ self.prior_next, wavelength_cm=wavelength)
+ ea_ker_gradient = so.Wrapped_Gradient(self._ea_ker/(FOV/N))
+ self._ea_ker_gradient_x = -ea_ker_gradient[1]
+ self._ea_ker_gradient_y = -ea_ker_gradient[0]
+
+ # The power spectrum
+ # Note: rotation is not currently implemented;
+ # the gradients would need to be modified slightly
+ self._sqrtQ = np.real(self.scattering_model.sqrtQ_Matrix(self.prior_next, t_hr=0.0))
+
+ # Generate the initial image+screen vector.
+ # By default, the screen is re-initialized to zero each time.
+ if len(self.epsilon_list_next) == 0:
+ self._xinit = np.concatenate((self._xinit, np.zeros(N**2-1)))
+ else:
+ self._xinit = np.concatenate((self._xinit, self.epsilon_list_next))
+
+ def make_chisq_dict(self, imcur):
+ """Make a dictionary of current chi^2 term values
+ i indexes the observation number in self.obslist_next
+ """
+
+ chi2_dict = {}
+ for dname in sorted(self.dat_term_next.keys()):
+ # Loop over all observations in the list
+ for i, obs in enumerate(self.obslist_next):
+ if len(self.obslist_next)==1:
+ dname_key = dname
+ else:
+ dname_key = dname + ('_%i' % i)
+
+ (data, sigma, A) = self._data_tuples[dname_key]
+
+ if dname in DATATERMS_POL:
+ chi2 = polutils.polchisq(imcur, A, data, sigma, dname,
+ ttype=self._ttype, mask=self._embed_mask,
+ pol_trans=POL_TRANS)
+
+ elif dname in DATATERMS:
+ if self.mf_next: # multifrequency
+ logfreqratio = self._logfreqratio_list[i]
+ imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio)
+ elif self.pol_next in POLARIZATION_MODES: # polarization
+ imcur_nu = imcur[0]
+ else: # normal imaging
+ imcur_nu = imcur
+
+ chi2 = imutils.chisq(imcur_nu, A, data, sigma, dname,
+ ttype=self._ttype, mask=self._embed_mask)
+
+ else:
+ raise Exception("data term %s not recognized!" % dname)
+
+ chi2_dict[dname_key] = chi2
+
+ return chi2_dict
+
+ def make_chisqgrad_dict(self, imcur, i=0):
+ """Make a dictionary of current chi^2 term gradient values
+ i indexes the observation number in self.obslist_next
+ """
+ chi2grad_dict = {}
+ for dname in sorted(self.dat_term_next.keys()):
+ # Loop over all observations in the list
+ for i, obs in enumerate(self.obslist_next):
+ if len(self.obslist_next)==1:
+ dname_key = dname
+ else:
+ dname_key = dname + ('_%i' % i)
+
+ (data, sigma, A) = self._data_tuples[dname_key]
+
+ # Polarimetric data products
+ if dname in DATATERMS_POL:
+ chi2grad = polutils.polchisqgrad(imcur, A, data, sigma, dname,
+ ttype=self._ttype, mask=self._embed_mask,
+ pol_solve=self._pol_which_solve,
+ pol_trans=POL_TRANS)
+
+ # Single polarization data products
+ elif dname in DATATERMS:
+ if self.mf_next: # multifrequency
+ logfreqratio = self._logfreqratio_list[i]
+ imref = imcur[0]
+ imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio)
+ elif self.pol_next in POLARIZATION_MODES: # polarization
+ imcur_nu = imcur[0]
+ else: # normal imaging
+ imcur_nu = imcur
+
+ chi2grad = imutils.chisqgrad(imcur_nu, A, data, sigma, dname,
+ ttype=self._ttype, mask=self._embed_mask)
+
+ # If multifrequency imaging,
+ # transform the image gradients for all the solved quantities
+ if self.mf_next:
+ logfreqratio = self._logfreqratio_list[i]
+ chi2grad = mfutils.mf_all_grads_chain(chi2grad, imcur_nu, imref, logfreqratio)
+
+ # If imaging polarization simultaneously, bundle the gradient properly
+ if self.pol_next in POLARIZATION_MODES:
+ if 'V' in self.pol_next:
+ chi2grad = np.array((chi2grad, np.zeros(self._nimage), np.zeros(self._nimage), np.zeros(self._nimage)))
+ else:
+ chi2grad = np.array((chi2grad, np.zeros(self._nimage), np.zeros(self._nimage)))
+
+ else:
+ raise Exception("data term %s not recognized!" % dname)
+
+ chi2grad_dict[dname_key] = np.array(chi2grad)
+
+ return chi2grad_dict
+
+ def make_reg_dict(self, imcur):
+ """Make a dictionary of current regularizer values
+ """
+ reg_dict = {}
+
+ for regname in sorted(self.reg_term_next.keys()):
+
+ # Polarimetric regularizer
+ if regname in REGULARIZERS_POL:
+ reg = polutils.polregularizer(imcur, self._embed_mask,
+ self.flux_next, self.pflux_next, self.vflux_next,
+ self.prior_next.xdim, self.prior_next.ydim,
+ self.prior_next.psize, regname,
+ norm_reg=self.norm_reg, beam_size=self.beam_size,
+ pol_trans=POL_TRANS)
+
+ # Multifrequency regularizers
+ elif self.mf_next:
+
+ # Image regularizer(s)
+ if regname in REGULARIZERS:
+ # new option to regularize ALL the images in multifrequency imaging
+ # TODO total fluxes not right?
+ if self.reg_all_freq_mf:
+ for i in range(len(self.obslist_next)):
+ regname_key = regname + ('_%i' % i)
+ logfreqratio = self._logfreqratio_list[i]
+ imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio)
+ prior_nu = mfutils.imvec_at_freq(self.priortuple, logfreqratio)
+ imref =imcur[0]
+
+ reg = imutils.regularizer(imcur_nu, prior_nu, self._embed_mask,
+ self.flux_next, self.prior_next.xdim,
+ self.prior_next.ydim, self.prior_next.psize,
+ regname,
+ norm_reg=self.norm_reg, beam_size=self.beam_size,
+ **self.regparams)
+ reg_dict[regname_key] = reg
+
+ # normally we only regularize reference frequency image
+ else:
+ reg = imutils.regularizer(imcur[0], self.priortuple[0], self._embed_mask,
+ self.flux_next, self.prior_next.xdim,
+ self.prior_next.ydim, self.prior_next.psize,
+ regname,
+ norm_reg=self.norm_reg, beam_size=self.beam_size,
+ **self.regparams)
+
+ # Spectral index regularizer(s)
+ elif regname in REGULARIZERS_SPECIND:
+ reg = mfutils.regularizer_mf(imcur[1], self.priortuple[1], self._embed_mask,
+ self.flux_next, self.prior_next.xdim,
+ self.prior_next.ydim, self.prior_next.psize,
+ regname,
+ norm_reg=self.norm_reg, beam_size=self.beam_size,
+ **self.regparams)
+
+ # Curvature index regularizer(s)
+ elif regname in REGULARIZERS_CURV:
+ reg = mfutils.regularizer_mf(imcur[2], self.priortuple[2], self._embed_mask,
+ self.flux_next, self.prior_next.xdim,
+ self.prior_next.ydim, self.prior_next.psize,
+ regname,
+ norm_reg=self.norm_reg, beam_size=self.beam_size,
+ **self.regparams)
+
+ # Normal, single polarization, single-frequency regularizer
+ elif regname in REGULARIZERS:
+ if self.pol_next in POLARIZATION_MODES:
+ imcur0 = imcur[0]
+ else:
+ imcur0 = imcur
+
+ reg = imutils.regularizer(imcur0, self._nprior, self._embed_mask,
+ self.flux_next, self.prior_next.xdim,
+ self.prior_next.ydim, self.prior_next.psize,
+ regname,
+ norm_reg=self.norm_reg, beam_size=self.beam_size,
+ **self.regparams)
+ else:
+ raise Exception("regularizer term %s not recognized!" % regname)
+
+ # multifrequency regularizer terms are already in the dictionary
+ # if we regularize all images with self.reg_all_freq_mf
+ if not(self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS)):
+ reg_dict[regname] = reg
+
+ return reg_dict
+
+ def make_reggrad_dict(self, imcur):
+ """Make a dictionary of current regularizer gradient values
+ """
+
+ reggrad_dict = {}
+
+
+ for regname in sorted(self.reg_term_next.keys()):
+
+ # Polarimetric regularizer
+ if regname in REGULARIZERS_POL:
+ reg = polutils.polregularizergrad(imcur, self._embed_mask,
+ self.flux_next, self.pflux_next, self.vflux_next,
+ self.prior_next.xdim, self.prior_next.ydim,
+ self.prior_next.psize, regname,
+ norm_reg=self.norm_reg, beam_size=self.beam_size,
+ pol_solve=self._pol_which_solve,
+ pol_trans=POL_TRANS)
+
+ # Multifrequency regularizer
+ elif self.mf_next:
+
+ # Image regularizer(s)
+ if regname in REGULARIZERS:
+ # new option to regularize ALL the images in multifrequency imaging
+ # TODO total fluxes not right?
+ if self.reg_all_freq_mf:
+ for i in range(len(self.obslist_next)):
+ regname_key = regname + ('_%i' % i)
+ logfreqratio = self._logfreqratio_list[i]
+ imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio)
+ prior_nu = mfutils.imvec_at_freq(self.priortuple, logfreqratio)
+ imref =imcur[0]
+
+ reg = imutils.regularizergrad(imcur_nu, prior_nu,
+ self._embed_mask, self.flux_next,
+ self.prior_next.xdim, self.prior_next.ydim,
+ self.prior_next.psize, regname,
+ norm_reg=self.norm_reg,
+ beam_size=self.beam_size,
+ **self.regparams)
+
+
+ reg = mfutils.mf_all_grads_chain(reg, imcur_nu, imref, logfreqratio)
+ reg_dict[regname_key] = reg
+
+ # normally we only regularize the reference frequency image
+ else:
+ reg = imutils.regularizergrad(imcur[0], self.priortuple[0],
+ self._embed_mask, self.flux_next,
+ self.prior_next.xdim, self.prior_next.ydim,
+ self.prior_next.psize, regname,
+ norm_reg=self.norm_reg,
+ beam_size=self.beam_size,
+ **self.regparams)
+ reg = np.array((reg, np.zeros(self._nimage), np.zeros(self._nimage)))
+
+
+ # Spectral index regularizer(s)
+ elif regname in REGULARIZERS_SPECIND:
+ reg = mfutils.regularizergrad_mf(imcur[1], self.priortuple[1],
+ self._embed_mask, self.flux_next,
+ self.prior_next.xdim, self.prior_next.ydim,
+ self.prior_next.psize, regname,
+ norm_reg=self.norm_reg,
+ beam_size=self.beam_size,
+ **self.regparams)
+ reg = np.array((np.zeros(self._nimage), reg, np.zeros(self._nimage)))
+
+ # Curvature index regularizer(s)
+ elif regname in REGULARIZERS_CURV:
+ reg = mfutils.regularizergrad_mf(imcur[2], self.priortuple[2],
+ self._embed_mask, self.flux_next,
+ self.prior_next.xdim, self.prior_next.ydim,
+ self.prior_next.psize, regname,
+ norm_reg=self.norm_reg,
+ beam_size=self.beam_size,
+ **self.regparams)
+ reg = np.array((np.zeros(self._nimage), np.zeros(self._nimage), reg))
+
+ # Normal, single polarization, single-frequency regularizer
+ elif regname in REGULARIZERS:
+ if self.pol_next in POLARIZATION_MODES:
+ imcur0 = imcur[0]
+ else:
+ imcur0 = imcur
+ reg = imutils.regularizergrad(imcur0, self._nprior, self._embed_mask, self.flux_next,
+ self.prior_next.xdim, self.prior_next.ydim,
+ self.prior_next.psize,
+ regname,
+ norm_reg=self.norm_reg, beam_size=self.beam_size,
+ **self.regparams)
+ if self.pol_next in POLARIZATION_MODES:
+ if 'V' in self.pol_next:
+ reg = np.array((reg, np.zeros(self._nimage), np.zeros(self._nimage), np.zeros(self._nimage)))
+ else:
+ reg = np.array((reg, np.zeros(self._nimage), np.zeros(self._nimage)))
+ else:
+ raise Exception("regularizer term %s not recognized!" % regname)
+
+ # multifrequency regularizer gradient terms are already in the dictionary
+ # if we regularize all images with self.reg_all_freq_mf
+ if not(self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS)):
+ reggrad_dict[regname] = reg
+
+ return reggrad_dict
+
+ def objfunc(self, imvec):
+ """Current objective function.
+ """
+
+ # Unpack polarimetric/multifrequency vector into an array
+ if self.pol_next in POLARIZATION_MODES:
+ imcur = polutils.unpack_poltuple(imvec, self._xtuple, self._nimage, self._pol_which_solve)
+ elif self.mf_next:
+ imcur = mfutils.unpack_mftuple(imvec, self._xtuple, self._nimage, self.mf_which_solve)
+ else:
+ imcur = imvec
+
+ # Image change of variables
+ if self.pol_next in POLARIZATION_MODES and 'mcv' in self.transform_next:
+ imcur = polutils.mcv(imcur)
+
+ if 'log' in self.transform_next:
+ if self.pol_next in POLARIZATION_MODES:
+ imcur[0] = np.exp(imcur[0])
+ elif self.mf_next:
+ imcur[0] = np.exp(imcur[0])
+ else:
+ imcur = np.exp(imcur)
+
+ # Data terms
+ datterm = 0.
+ chi2_term_dict = self.make_chisq_dict(imcur)
+ for dname in sorted(self.dat_term_next.keys()):
+ hyperparameter = self.dat_term_next[dname]
+
+ for i, obs in enumerate(self.obslist_next):
+ if len(self.obslist_next)==1:
+ dname_key = dname
+ else:
+ dname_key = dname + ('_%i' % i)
+
+ chi2 = chi2_term_dict[dname_key]
+
+ if self.chisq_transform:
+ datterm += hyperparameter * (chi2 + 1./chi2 - 1.)
+ else:
+ datterm += hyperparameter * (chi2 - 1.)
+
+ # Regularizer terms
+ regterm = 0
+ reg_term_dict = self.make_reg_dict(imcur)
+ for regname in sorted(self.reg_term_next.keys()):
+ hyperparameter = self.reg_term_next[regname]
+ # multifrequency imaging, regularize every frequency
+ if self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS):
+ for i in range(len(self.obslist_next)):
+ regname_key = regname + ('_%i' % i)
+ regularizer = reg_term_dict[regname_key]
+ regterm += hyperparameter * regularizer
+
+ # but normally just one regularizer
+ else:
+ regularizer = reg_term_dict[regname]
+ regterm += hyperparameter * regularizer
+
+ # Total cost
+ cost = datterm + regterm
+
+ return cost
+
+ def objgrad(self, imvec):
+ """Current objective function gradient.
+ """
+
+ # Unpack polarimetric/multifrequency vector into an array
+ if self.pol_next in POLARIZATION_MODES:
+ imcur = polutils.unpack_poltuple(imvec, self._xtuple, self._nimage, self._pol_which_solve)
+ elif self.mf_next:
+ imcur = mfutils.unpack_mftuple(imvec, self._xtuple, self._nimage, self.mf_which_solve)
+ else:
+ imcur = imvec
+
+ # Image change of variables
+ if 'mcv' in self.transform_next:
+ if self.pol_next in POLARIZATION_MODES:
+ cvcur = imcur.copy()
+ imcur = polutils.mcv(imcur)
+
+ if 'log' in self.transform_next:
+ if self.pol_next in POLARIZATION_MODES:
+ imcur[0] = np.exp(imcur[0])
+ elif self.mf_next:
+ imcur[0] = np.exp(imcur[0])
+ else:
+ imcur = np.exp(imcur)
+
+ # Data terms
+ datterm = 0.
+ chi2_term_dict = self.make_chisqgrad_dict(imcur)
+ if self.chisq_transform:
+ chi2_value_dict = self.make_chisq_dict(imcur)
+ for dname in sorted(self.dat_term_next.keys()):
+ hyperparameter = self.dat_term_next[dname]
+
+ for i, obs in enumerate(self.obslist_next):
+ if len(self.obslist_next)==1:
+ dname_key = dname
+ else:
+ dname_key = dname + ('_%i' % i)
+
+ chi2_grad = chi2_term_dict[dname_key]
+
+ if self.chisq_transform:
+ chi2_val = chi2_value_dict[dname]
+ datterm += hyperparameter * chi2_grad * (1. - 1./(chi2_val**2))
+ else:
+ datterm += hyperparameter * (chi2_grad + self.chisq_offset_gradient)
+
+ # Regularizer terms
+ regterm = 0
+ reg_term_dict = self.make_reggrad_dict(imcur)
+ for regname in sorted(self.reg_term_next.keys()):
+ hyperparameter = self.reg_term_next[regname]
+
+ # multifrequency imaging, regularize every frequency
+ if self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS):
+ for i in range(len(self.obslist_next)):
+ regname_key = regname + ('_%i' % i)
+ regularizer = reg_term_dict[regname_key]
+ regterm += hyperparameter * regularizer
+
+ # but normally just one regularizer
+ else:
+ regularizer_grad = reg_term_dict[regname]
+ regterm += hyperparameter * regularizer_grad
+
+ # Total gradient
+ grad = datterm + regterm
+
+ # Chain rule term for change of variables
+ if 'mcv' in self.transform_next:
+ if self.pol_next in POLARIZATION_MODES:
+ grad *= polutils.mchain(cvcur)
+
+ if 'log' in self.transform_next:
+ if self.pol_next in POLARIZATION_MODES:
+ grad[0] *= imcur[0]
+ elif self.mf_next:
+ grad[0] *= imcur[0]
+ else:
+ grad *= imcur
+
+ # Repack gradient for polarimetric imaging
+ if self.pol_next in POLARIZATION_MODES:
+ grad = polutils.pack_poltuple(grad, self._pol_which_solve)
+
+ # repack gradient for multifrequency imaging
+ elif self.mf_next:
+ grad = mfutils.pack_mftuple(grad, self.mf_which_solve)
+
+ return grad
+
+ def plotcur(self, imvec, **kwargs):
+ """Plot current image.
+ """
+
+ if self._show_updates:
+ if self._nit % self._update_interval == 0:
+ if self.pol_next in POLARIZATION_MODES:
+
+ imcur = polutils.unpack_poltuple(imvec, self._xtuple, self._nimage, self._pol_which_solve)
+ elif self.mf_next:
+ imcur = mfutils.unpack_mftuple(
+ imvec, self._xtuple, self._nimage, self.mf_which_solve)
+ else:
+ imcur = imvec
+
+ # Image change of variables
+
+ if 'mcv' in self.transform_next:
+ if self.pol_next in POLARIZATION_MODES:
+ imcur = polutils.mcv(imcur)
+
+ if 'log' in self.transform_next:
+ if self.pol_next in POLARIZATION_MODES:
+ imcur[0] = np.exp(imcur[0])
+ elif self.mf_next:
+ imcur[0] = np.exp(imcur[0])
+ else:
+ imcur = np.exp(imcur)
+
+ # Get chi^2 and regularizer
+ chi2_term_dict = self.make_chisq_dict(imcur)
+ reg_term_dict = self.make_reg_dict(imcur)
+
+ # Format print string
+ outstr = "------------------------------------------------------------------"
+ outstr += "\n%4d | " % self._nit
+ for dname in sorted(self.dat_term_next.keys()):
+ for i, obs in enumerate(self.obslist_next):
+ if len(self.obslist_next)==1:
+ dname_key = dname
+ else:
+ dname_key = dname + ('_%i' % i)
+ outstr += "chi2_%s : %0.2f " % (dname_key, chi2_term_dict[dname_key])
+ outstr += "\n "
+ for dname in sorted(self.dat_term_next.keys()):
+ for i, obs in enumerate(self.obslist_next):
+ if len(self.obslist_next)==1:
+ dname_key = dname
+ else:
+ dname_key = dname + ('_%i' % i)
+ dval = chi2_term_dict[dname_key]*self.dat_term_next[dname]
+ outstr += "%s : %0.1f " % (dname_key, dval)
+
+ outstr += "\n "
+ for regname in sorted(self.reg_term_next.keys()):
+
+ if self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS):
+ for i in range(len(self.obslist_next)):
+ regname_key = regname + ('_%i' % i)
+ rval = reg_term_dict[regname_key]*self.reg_term_next[regname]
+ outstr += "%s : %0.1f " % (regname_key, rval)
+ else:
+ rval = reg_term_dict[regname]*self.reg_term_next[regname]
+ outstr += "%s : %0.1f " % (regname, rval)
+
+ # Embed and plot the image
+ if self.pol_next in POLARIZATION_MODES:
+ if np.any(np.invert(self._embed_mask)):
+ imcur = polutils.embed_pol(imcur, self._embed_mask)
+ polutils.plot_m(imcur, self.prior_next, self._nit, chi2_term_dict, **kwargs)
+
+ else:
+ if self.mf_next:
+ implot = imcur[0]
+ else:
+ implot = imcur
+ if np.any(np.invert(self._embed_mask)):
+ implot = imutils.embed(implot, self._embed_mask)
+
+ imutils.plot_i(implot, self.prior_next, self._nit,
+ chi2_term_dict, pol=self.pol_next, **kwargs)
+
+ if self._nit == 0:
+ print()
+ print(outstr)
+
+ self._nit += 1
+
+ def objfunc_scattering(self, minvec):
+ """Current stochastic optics objective function.
+ """
+ N = self.prior_next.xdim
+
+ imvec = minvec[:N**2]
+ EpsilonList = minvec[N**2:]
+ if 'log' in self.transform_next:
+ imvec = np.exp(imvec)
+
+ IM = ehtim.image.Image(imvec.reshape(N, N), self.prior_next.psize,
+ self.prior_next.ra, self.prior_next.dec,
+ self.prior_next.pa, rf=self.obs_next.rf,
+ source=self.prior_next.source, mjd=self.prior_next.mjd)
+
+ # The scattered image vector
+ screen = so.MakeEpsilonScreenFromList(EpsilonList, N)
+ scatt_im = self.scattering_model.Scatter(IM, Epsilon_Screen=screen,
+ ea_ker=self._ea_ker, sqrtQ=self._sqrtQ,
+ Linearized_Approximation=True)
+ scatt_im = scatt_im.imvec
+
+ # Calculate the chi^2 using the scattered image
+ datterm = 0.
+ chi2_term_dict = self.make_chisq_dict(scatt_im)
+ for dname in sorted(self.dat_term_next.keys()):
+ datterm += self.dat_term_next[dname] * (chi2_term_dict[dname] - 1.)
+
+ # Calculate the entropy using the unscattered image
+ regterm = 0
+ reg_term_dict = self.make_reg_dict(imvec)
+
+ # Make dict also for scattered image
+ reg_term_dict_scatt = self.make_reg_dict(scatt_im)
+
+ for regname in sorted(self.reg_term_next.keys()):
+ if regname == 'rgauss':
+ # Get gradient of the scattered image vector
+ regterm += self.reg_term_next[regname] * reg_term_dict_scatt[regname]
+
+ else:
+ regterm += self.reg_term_next[regname] * reg_term_dict[regname]
+
+ # Scattering screen regularization term
+ chisq_epsilon = sum(EpsilonList*EpsilonList)/((N*N-1.0)/2.0)
+ regterm_scattering = self.alpha_phi_next * (chisq_epsilon - 1.0)
+
+ return datterm + regterm + regterm_scattering
+
+ def objgrad_scattering(self, minvec):
+ """Current stochastic optics objective function gradient
+ """
+ wavelength = ehc.C/self.obs_next.rf*100.0 # Observing wavelength [cm]
+ wavelengthbar = wavelength/(2.0*np.pi) # lambda/(2pi) [cm]
+ N = self.prior_next.xdim
+
+ # Field of view, in cm, at the scattering screen
+ FOV = self.prior_next.psize * N * self.scattering_model.observer_screen_distance
+ rF = self.scattering_model.rF(wavelength)
+
+ imvec = minvec[:N**2]
+ EpsilonList = minvec[N**2:]
+ if 'log' in self.transform_next:
+ imvec = np.exp(imvec)
+
+ IM = ehtim.image.Image(imvec.reshape(N, N), self.prior_next.psize,
+ self.prior_next.ra, self.prior_next.dec,
+ self.prior_next.pa, rf=self.obs_next.rf,
+ source=self.prior_next.source, mjd=self.prior_next.mjd)
+
+ # The scattered image vector
+ screen = so.MakeEpsilonScreenFromList(EpsilonList, N)
+ scatt_im = self.scattering_model.Scatter(IM, Epsilon_Screen=screen,
+ ea_ker=self._ea_ker, sqrtQ=self._sqrtQ,
+ Linearized_Approximation=True)
+ scatt_im = scatt_im.imvec
+
+ EA_Image = self.scattering_model.Ensemble_Average_Blur(IM, ker=self._ea_ker)
+ EA_Gradient = so.Wrapped_Gradient((EA_Image.imvec/(FOV/N)).reshape(N, N))
+
+ # The gradient signs don't actually matter, but let's make them match intuition
+ # (i.e., right to left, bottom to top)
+ EA_Gradient_x = -EA_Gradient[1]
+ EA_Gradient_y = -EA_Gradient[0]
+
+ Epsilon_Screen = so.MakeEpsilonScreenFromList(EpsilonList, N)
+ phi_scr = self.scattering_model.MakePhaseScreen(Epsilon_Screen, IM,
+ obs_frequency_Hz=self.obs_next.rf,
+ sqrtQ_init=self._sqrtQ)
+ phi = phi_scr.imvec.reshape((N, N))
+ phi_Gradient = so.Wrapped_Gradient(phi/(FOV/N))
+ phi_Gradient_x = -phi_Gradient[1]
+ phi_Gradient_y = -phi_Gradient[0]
+
+ # Entropy gradient; wrt unscattered image so unchanged by scattering
+ regterm = 0
+ reg_term_dict = self.make_reggrad_dict(imvec)
+
+ # Make dict also for scattered image
+ reg_term_dict_scatt = self.make_reggrad_dict(scatt_im)
+
+ for regname in sorted(self.reg_term_next.keys()):
+ # We need an exception if the regularizer is 'rgauss'
+ if regname == 'rgauss':
+ # Get gradient of the scattered image vector
+ gaussterm = self.reg_term_next[regname] * reg_term_dict_scatt[regname]
+ dgauss_dIa = gaussterm.reshape((N, N))
+
+ # Now the chain rule factor to get the gauss gradient wrt the unscattered image
+ gx = so.Wrapped_Convolve(
+ self._ea_ker_gradient_x[::-1, ::-1], phi_Gradient_x * (dgauss_dIa))
+ gx = (rF**2.0 * gx).flatten()
+
+ gy = so.Wrapped_Convolve(
+ self._ea_ker_gradient_y[::-1, ::-1], phi_Gradient_y * (dgauss_dIa))
+ gy = (rF**2.0 * gy).flatten()
+
+ # Now we add the gradient for the unscattered image
+ regterm += so.Wrapped_Convolve(self._ea_ker[::-1, ::-1],
+ (dgauss_dIa)).flatten() + gx + gy
+
+ else:
+ regterm += self.reg_term_next[regname] * reg_term_dict[regname]
+
+ # Chi^2 gradient wrt the unscattered image
+ # First, the chi^2 gradient wrt to the scattered image
+ datterm = 0.
+ chi2_term_dict = self.make_chisqgrad_dict(scatt_im)
+ for dname in sorted(self.dat_term_next.keys()):
+ datterm += self.dat_term_next[dname] * (chi2_term_dict[dname])
+ dchisq_dIa = datterm.reshape((N, N))
+
+ # Now the chain rule factor to get the chi^2 gradient wrt the unscattered image
+ gx = so.Wrapped_Convolve(self._ea_ker_gradient_x[::-1, ::-1], phi_Gradient_x * (dchisq_dIa))
+ gx = (rF**2.0 * gx).flatten()
+
+ gy = so.Wrapped_Convolve(self._ea_ker_gradient_y[::-1, ::-1], phi_Gradient_y * (dchisq_dIa))
+ gy = (rF**2.0 * gy).flatten()
+
+ chisq_grad_im = so.Wrapped_Convolve(
+ self._ea_ker[::-1, ::-1], (dchisq_dIa)).flatten() + gx + gy
+
+ # Gradient of the data chi^2 wrt to the epsilon screen
+ # Preliminary Definitions
+ chisq_grad_epsilon = np.zeros(N**2-1)
+ i_grad = 0
+ ell_mat = np.zeros((N, N))
+ m_mat = np.zeros((N, N))
+ for ell in range(0, N):
+ for m in range(0, N):
+ ell_mat[ell, m] = ell
+ m_mat[ell, m] = m
+
+ # Real part; top row
+ for t in range(1, (N+1)//2):
+ s = 0
+ grad_term = (wavelengthbar/FOV*self._sqrtQ[s][t] *
+ 2.0*np.cos(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
+ grad_term = so.Wrapped_Gradient(grad_term)
+ grad_term_x = -grad_term[1]
+ grad_term_y = -grad_term[0]
+
+ cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y)
+ chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term)
+
+ i_grad = i_grad + 1
+
+ # Real part; remainder
+ for s in range(1, (N+1)//2):
+ for t in range(N):
+ grad_term = (wavelengthbar/FOV*self._sqrtQ[s][t] *
+ 2.0*np.cos(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
+ grad_term = so.Wrapped_Gradient(grad_term)
+ grad_term_x = -grad_term[1]
+ grad_term_y = -grad_term[0]
+
+ cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y)
+ chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term)
+
+ i_grad = i_grad + 1
+
+ # Imaginary part; top row
+ for t in range(1, (N+1)//2):
+ s = 0
+ grad_term = (-wavelengthbar/FOV*self._sqrtQ[s][t] *
+ 2.0*np.sin(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
+ grad_term = so.Wrapped_Gradient(grad_term)
+ grad_term_x = -grad_term[1]
+ grad_term_y = -grad_term[0]
+
+ cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y)
+ chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term)
+
+ i_grad = i_grad + 1
+
+ # Imaginary part; remainder
+ for s in range(1, (N+1)//2):
+ for t in range(N):
+ grad_term = (-wavelengthbar/FOV*self._sqrtQ[s][t] *
+ 2.0*np.sin(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
+ grad_term = so.Wrapped_Gradient(grad_term)
+ grad_term_x = -grad_term[1]
+ grad_term_y = -grad_term[0]
+
+ cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y)
+ chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term)
+ i_grad = i_grad + 1
+
+ # Gradient of the chi^2 regularization term for the epsilon screen
+ chisq_epsilon_grad = self.alpha_phi_next * 2.0*EpsilonList/((N*N-1)/2.0)
+
+ # Chain rule term for change of variables
+ if 'log' in self.transform_next:
+ regterm *= imvec
+ chisq_grad_im *= imvec
+
+ out = np.concatenate(((regterm + chisq_grad_im), (chisq_grad_epsilon + chisq_epsilon_grad)))
+ return out
+
+ def plotcur_scattering(self, minvec):
+ """Plot current stochastic optics image/screen
+ """
+ if self._show_updates:
+ if self._nit % self._update_interval == 0:
+ N = self.prior_next.xdim
+
+ imvec = minvec[:N**2]
+ EpsilonList = minvec[N**2:]
+ if 'log' in self.transform_next:
+ imvec = np.exp(imvec)
+
+ IM = ehtim.image.Image(imvec.reshape(N, N), self.prior_next.psize,
+ self.prior_next.ra, self.prior_next.dec,
+ self.prior_next.pa, rf=self.obs_next.rf,
+ source=self.prior_next.source, mjd=self.prior_next.mjd)
+
+ # The scattered image vector
+ screen = so.MakeEpsilonScreenFromList(EpsilonList, N)
+ scatt_im = self.scattering_model.Scatter(IM, Epsilon_Screen=screen,
+ ea_ker=self._ea_ker, sqrtQ=self._sqrtQ,
+ Linearized_Approximation=True)
+ scatt_im = scatt_im.imvec
+
+ # Calculate the chi^2 using the scattered image
+ datterm = 0.
+ chi2_term_dict = self.make_chisq_dict(scatt_im)
+ for dname in sorted(self.dat_term_next.keys()):
+ datterm += self.dat_term_next[dname] * (chi2_term_dict[dname] - 1.)
+
+ # Calculate the entropy using the unscattered image
+ regterm = 0
+ reg_term_dict = self.make_reg_dict(imvec)
+ for regname in sorted(self.reg_term_next.keys()):
+ regterm += self.reg_term_next[regname] * reg_term_dict[regname]
+
+ # Scattering screen regularization term
+ chisq_epsilon = sum(EpsilonList*EpsilonList)/((N*N-1.0)/2.0)
+ # regterm_scattering = self.alpha_phi_next * (chisq_epsilon - 1.0)
+
+ outstr = "i: %d " % self._nit
+
+ for dname in sorted(self.dat_term_next.keys()):
+ outstr += "%s : %0.2f " % (dname, chi2_term_dict[dname])
+ for regname in sorted(self.reg_term_next.keys()):
+ outstr += "%s : %0.2f " % (regname, reg_term_dict[regname])
+ outstr += "Epsilon chi^2 : %0.2f " % (chisq_epsilon)
+ outstr += "Max |Epsilon| : %0.2f " % (max(abs(EpsilonList)))
+ print(outstr)
+
+ self._nit += 1
+
+ def make_image_I_stochastic_optics(self, grads=True, **kwargs):
+ """Reconstructs an image of total flux density
+ using the stochastic optics scattering mitigation technique.
+
+ Uses the scattering model in Imager.scattering_model.
+ If none has been specified, defaults to standard model for Sgr A*.
+ Returns the estimated unscattered image.
+
+ Args:
+ grads (bool): Flag for whether or not to use analytic gradients.
+ show_updates (bool): Flag for whether or not to show updates
+ Returns:
+ out (Image): The estimated *unscattered* image.
+ """
+
+ N = self.prior_next.xdim
+
+ # Checks and initialize
+ self.check_params()
+ self.check_limits()
+ self.init_imager()
+ self.init_imager_scattering()
+ self._nit = 0
+
+ # Print stats
+ self._show_updates = kwargs.get('show_updates', True)
+ self._update_interval = kwargs.get('update_interval', 1)
+ self.plotcur_scattering(self._xinit)
+
+ # Minimize
+ optdict = {'maxiter': self.maxit_next, 'ftol': self.stop_next, 'maxcor': NHIST}
+ tstart = time.time()
+ if grads:
+ res = opt.minimize(self.objfunc_scattering, self._xinit, method='L-BFGS-B',
+ jac=self.objgrad_scattering, options=optdict,
+ callback=self.plotcur_scattering)
+ else:
+ res = opt.minimize(self.objfunc_scattering, self._xinit, method='L-BFGS-B',
+ options=optdict, callback=self.plotcur_scattering)
+ tstop = time.time()
+
+ # Format output
+ out = res.x[:N**2]
+ if 'log' in self.transform_next:
+ out = np.exp(out)
+ if np.any(np.invert(self._embed_mask)):
+ raise Exception("Embedding is not currently implemented!")
+ out = imutils.embed(out, self._embed_mask)
+
+ outim = ehtim.image.Image(out.reshape(N, N), self.prior_next.psize,
+ self.prior_next.ra, self.prior_next.dec, self.prior_next.pa,
+ rf=self.prior_next.rf, source=self.prior_next.source,
+ mjd=self.prior_next.mjd, pulse=self.prior_next.pulse)
+ outep = res.x[N**2:]
+ screen = so.MakeEpsilonScreenFromList(outep, N)
+ outscatt = self.scattering_model.Scatter(outim,
+ Epsilon_Screen=screen,
+ ea_ker=self._ea_ker, sqrtQ=self._sqrtQ,
+ Linearized_Approximation=True)
+
+ # Preserving image complex polarization fractions
+ if len(self.prior_next.qvec):
+ qvec = self.prior_next.qvec * out / self.prior_next.imvec
+ uvec = self.prior_next.uvec * out / self.prior_next.imvec
+ outim.add_qu(qvec.reshape(N, N),
+ uvec.reshape(N, N))
+
+ # Print stats
+ print("time: %f s" % (tstop - tstart))
+ print("J: %f" % res.fun)
+ print(res.message)
+
+ # Append to history
+ logstr = str(self.nruns) + ": make_image_I_stochastic_optics()"
+ self._append_image_history(outim, logstr)
+ self._out_list_epsilon.append(res.x[N**2:])
+ self._out_list_scattered.append(outscatt)
+
+ self.nruns += 1
+
+ # Return Image object
+ return outim
+
+ def _append_image_history(self, outim, logstr):
+ self.logstr += (logstr + "\n")
+ self._obs_list.append(self.obslist_next)
+ self._init_list.append(self.init_next)
+ self._prior_list.append(self.prior_next)
+ self._debias_list.append(self.debias_next)
+ self._weighting_list.append(self.weighting_next)
+ self._systematic_noise_list.append(self.systematic_noise_next)
+ self._systematic_cphase_noise_list.append(self.systematic_cphase_noise_next)
+ self._snrcut_list.append(self.snrcut_next)
+ self._flux_list.append(self.flux_next)
+ self._pflux_list.append(self.pflux_next)
+ self._vflux_list.append(self.vflux_next)
+ self._pol_list.append(self.pol_next)
+ self._clipfloor_list.append(self.clipfloor_next)
+ self._maxset_list.append(self.clipfloor_next)
+ self._maxit_list.append(self.maxit_next)
+ self._stop_list.append(self.stop_next)
+ self._transform_list.append(self.transform_next)
+ self._reg_term_list.append(self.reg_term_next)
+ self._dat_term_list.append(self.dat_term_next)
+ self._alpha_phi_list.append(self.alpha_phi_next)
+
+ self._out_list.append(outim)
+ return
diff --git a/imaging/__init__.py b/imaging/__init__.py
new file mode 100644
index 00000000..9cd8c613
--- /dev/null
+++ b/imaging/__init__.py
@@ -0,0 +1,9 @@
+"""
+.. module:: ehtim.imaging
+ :platform: Unix
+ :synopsis: EHT Imaging Utilities: imaging functions
+
+.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu)
+
+"""
+from ..const_def import *
diff --git a/imaging/clean.py b/imaging/clean.py
new file mode 100644
index 00000000..5ed177ae
--- /dev/null
+++ b/imaging/clean.py
@@ -0,0 +1,1241 @@
+# clean.py
+# Clean-like imagers
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+from builtins import range
+
+import string
+import time
+import numpy as np
+import scipy.optimize as opt
+import scipy.ndimage as nd
+import scipy.ndimage.filters as filt
+import matplotlib.pyplot as plt
+try:
+ from pynfft.nfft import NFFT
+except ImportError:
+ pass
+ #print("Warning: No NFFT installed!")
+import numpy.polynomial.polynomial as p
+
+import ehtim.image as image
+
+from ehtim.const_def import *
+from ehtim.observing.obs_helpers import *
+from ehtim.imaging.imager_utils import *
+
+##################################################################################################
+# Constants & Definitions
+##################################################################################################
+
+
+NHIST = 50 # number of steps to store for hessian approx
+MAXIT = 100.
+
+DATATERMS = ['vis', 'bs', 'amp', 'cphase', 'camp', 'logcamp']
+REGULARIZERS = ['gs', 'tv', 'tv2','l1', 'patch', 'simple', 'compact', 'compact2']
+
+NFFT_KERSIZE_DEFAULT = 20
+GRIDDER_P_RAD_DEFAULT = 2
+GRIDDER_CONV_FUNC_DEFAULT = 'gaussian'
+FFT_PAD_DEFAULT = 2
+FFT_INTERP_DEFAULT = 3
+
+nit = 0 # global variable to track the iteration number in the plotting callback
+
+##################################################################################################
+# Imagers
+##################################################################################################
+def plot_i(Image, nit, chi2, fig=1, cmap='afmhot'):
+ """Plot the total intensity image at each iteration
+ """
+
+ plt.ion()
+ plt.figure(fig)
+ plt.pause(0.00001)
+ plt.clf()
+
+ plt.imshow(Image.imvec.reshape(Image.ydim,Image.xdim), cmap=plt.get_cmap(cmap), interpolation='gaussian')
+ xticks = ticks(Image.xdim, Image.psize/RADPERAS/1e-6)
+ yticks = ticks(Image.ydim, Image.psize/RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ plt.title("step: %i $\chi^2$: %f " % (nit, chi2), fontsize=20)
+
+def dd_clean_vis(Obsdata, InitIm, niter=1, clipfloor=-1, ttype="direct", loop_gain=1, method='min_chisq', weighting='uniform',
+ fft_pad_factor=FFT_PAD_DEFAULT, p_rad=NFFT_KERSIZE_DEFAULT, show_updates=False):
+
+ # limit imager range to prior values > clipfloor
+ embed_mask = InitIm.imvec >= clipfloor
+
+ # get data
+ vis = Obsdata.data['vis']
+ sigma = Obsdata.data['sigma']
+ uv = np.hstack((Obsdata.data['u'].reshape(-1,1), Obsdata.data['v'].reshape(-1,1)))
+
+ # necessary nfft infos
+ npad = int(fft_pad_factor * np.max((InitIm.xdim, InitIm.ydim)))
+ nfft_info = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv)
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ # Weights
+ weights_nat = 1./(sigma**2)
+
+ if weighting=='uniform':
+ weights = np.ones(len(weights_nat))
+ elif weighting=='natural':
+ weights=weights_nat
+ else:
+ raise Exception("weighting must be 'uniform' or 'natural'")
+ weights_norm = np.sum(weights)
+
+ # Coordinate matrix
+ coord = InitIm.psize * np.array([[[x,y] for x in np.arange(InitIm.xdim//2,-InitIm.xdim//2,-1)]
+ for y in np.arange(InitIm.ydim//2,-InitIm.ydim//2,-1)])
+ coord = coord.reshape(InitIm.ydim*InitIm.xdim, 2)
+ coord = coord[embed_mask]
+
+ # Initial imvec and visibilities
+ # TODO currently always initialized to zero!!
+ OutputIm = InitIm.copy()
+ DeltasIm = InitIm.copy()
+ ChisqIm = InitIm.copy()
+
+ res = Obsdata.res()
+ beamparams = Obsdata.fit_beam()
+
+ imvec_init = 0*InitIm.imvec[embed_mask]
+ vis_init = np.zeros(len(vis),dtype='complex128')
+
+ imvec_current = imvec_init
+ vis_current = vis_init
+
+ chisq_init = np.sum(weights*np.abs(vis-vis_init)**2)
+ rchisq_init = np.sum(weights_nat*np.abs(vis-vis_init)**2)/(2*len(weights_nat))
+ chisq_current = chisq_init
+ rchisq_current = rchisq_init
+
+ # clean loop
+ print("\n")
+ for it in range(niter):
+
+ resid_current = vis - vis_current
+
+ plan.f = weights * resid_current
+ plan.adjoint()
+ out = np.real((plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim))
+ deltas_all = out/ weights_norm
+ deltas = deltas_all[embed_mask]
+
+ chisq_map_all = chisq_current - (deltas_all**2)*weights_norm
+ chisq_map = chisq_map_all[embed_mask]
+
+ # Visibility space clean
+ if method=='min_chisq':
+ component_loc_idx = np.argmin(chisq_map)
+
+ # Image space clean
+ elif method=='max_delta':
+ component_loc_idx = np.argmax(deltas)
+
+ else:
+ raise Exception("method should be 'min_chisq' or 'max_delta'!")
+
+ # display images of delta and chisq
+ if show_updates:
+ DeltasIm.imvec = deltas_all
+ plot_i(DeltasIm, it, chisq_current,fig=0, cmap='afmhot')
+
+ ChisqIm.imvec = -chisq_map_all
+ plot_i(ChisqIm, it, chisq_current,fig=1, cmap='cool')
+
+ # clean component location
+ component_loc_x = coord[component_loc_idx][0]
+ component_loc_y = coord[component_loc_idx][1]
+ component_strength = loop_gain*deltas[component_loc_idx]
+
+ # update vis and imvec
+ imvec_current[component_loc_idx] += component_strength
+
+ #TODO how to incorporate pulse function?
+ vis_current += component_strength*np.exp(2*np.pi*1j*(uv[:,0]*component_loc_x + uv[:,1]*component_loc_y))
+
+ # update chi^2 and output image
+ chisq_current = np.sum(weights*np.abs(vis-vis_current)**2)
+ rchisq_current = np.sum(weights_nat*np.abs(vis-vis_current)**2)/(2*len(weights_nat))
+
+ print(it+1,component_strength, chisq_current, rchisq_current, component_loc_x/RADPERUAS, component_loc_y/RADPERUAS)
+
+ OutputIm.imvec = imvec_current
+ if show_updates:
+ OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False)
+ OutputImBlur = OutputIm.blur_gauss(beamparams)
+ plot_i(OutputImBlur, it, rchisq_current, fig=2)
+
+ OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False)
+ return OutputIm
+
+
+#solve full 5th order polynomial
+def dd_clean_bispec_full(Obsdata, InitIm, niter=1, clipfloor=-1, loop_gain=.1,
+ weighting='uniform', bscount="min",show_updates=True,
+ fft_pad_factor=FFT_PAD_DEFAULT, p_rad=NFFT_KERSIZE_DEFAULT):
+
+
+ # limit imager range to prior values > clipfloor
+ embed_mask = InitIm.imvec >= clipfloor
+
+ # get data
+ biarr = Obsdata.bispectra(mode="all", count=bscount)
+ uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1)))
+ uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1)))
+ uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1)))
+ bs = biarr['bispec']
+ sigma = biarr['sigmab']
+
+ # necessary nfft infos
+ npad = int(fft_pad_factor * np.max((InitIm.xdim, InitIm.ydim)))
+ nfft_info1 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1)
+ nfft_info2 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2)
+ nfft_info3 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3)
+
+ nfft_info11 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv1)
+ nfft_info22 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv2)
+ nfft_info33 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv3)
+
+ nfft_info12 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1-uv2)
+ nfft_info23 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2-uv3)
+ nfft_info31 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3-uv1)
+
+ # TODO do we use pulse factors?
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ plan11 = nfft_info11.plan
+ pulsefac11 = nfft_info11.pulsefac
+ plan22 = nfft_info22.plan
+ pulsefac22 = nfft_info22.pulsefac
+ plan33 = nfft_info33.plan
+ pulsefac33 = nfft_info33.pulsefac
+
+ plan12 = nfft_info12.plan
+ pulsefac12 = nfft_info12.pulsefac
+ plan23 = nfft_info23.plan
+ pulsefac23 = nfft_info23.pulsefac
+ plan31 = nfft_info31.plan
+ pulsefac31 = nfft_info31.pulsefac
+
+ # Weights
+ weights_nat = 1./(sigma**2)
+ if weighting=='uniform':
+ weights = np.ones(len(weights_nat))
+ elif weighting=='natural':
+ weights=weights_nat
+ else:
+ raise Exception("weighting must be 'uniform' or 'natural'")
+ weights_norm = np.sum(weights)
+
+ # Coordinate matrix
+ # TODO what if the image is odd?
+ coord = InitIm.psize * np.array([[[x,y] for x in np.arange(InitIm.xdim//2,-InitIm.xdim//2,-1)]
+ for y in np.arange(InitIm.ydim//2,-InitIm.ydim//2,-1)])
+ coord = coord.reshape(InitIm.ydim*InitIm.xdim, 2)
+ coord = coord[embed_mask]
+
+ # Initial imvec and visibilities
+ # TODO currently initialized to zero!!
+ OutputIm = InitIm.copy()
+ DeltasIm = InitIm.copy()
+ ChisqIm = InitIm.copy()
+
+ res = Obsdata.res()
+ beamparams = Obsdata.fit_beam()
+
+ imvec_init = 0*InitIm.imvec[embed_mask]
+ vis1_init = np.zeros(len(bs), dtype='complex128')
+ vis2_init = np.zeros(len(bs), dtype='complex128')
+ vis3_init = np.zeros(len(bs), dtype='complex128')
+ bs_init = vis1_init*vis2_init*vis3_init
+ chisq_init = np.sum(weights*np.abs(bs - bs_init)**2)
+ rchisq_init = np.sum(weights_nat*np.abs(bs - bs_init)**2)/(2*len(weights_nat))
+
+ imvec_current = imvec_init
+ vis1_current = vis1_init
+ vis2_current = vis2_init
+ vis3_current = vis3_init
+ bs_current = bs_init
+ chisq_current = chisq_init
+ rchisq_current = rchisq_init
+
+ # clean loop
+ print("\n")
+ for it in range(niter):
+ t = time.time()
+ # compute delta at each location
+ resid_current = bs - bs_current
+ vis12_current = vis1_current*vis2_current
+ vis13_current = vis1_current*vis3_current
+ vis23_current = vis2_current*vis3_current
+
+ # center the first component automatically
+ # since initial image is empty, must go to higher order in delta in solution
+ # TODO generalize to non-empty initial image!
+ if it==0:
+
+ A = np.sum(weights*np.real(resid_current))
+ B = np.sum(weights)
+
+ component_strength = np.cbrt(A/B)
+ component_loc_idx = (InitIm.ydim//2)*InitIm.xdim + InitIm.xdim//2 #TODO is this right for odd images??+/- 1??
+
+ else:
+ # First calculate P (1st order)
+ plan1.f = weights * resid_current * vis23_current.conj()
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * resid_current * vis13_current.conj()
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * resid_current * vis12_current.conj()
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ P = -2 * (out1 + out2 + out3)
+
+ # Then calculate Q (2nd order)
+ plan12.f = weights * vis13_current*vis23_current.conj()
+ plan12.adjoint()
+ out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim))
+ plan23.f = weights * vis12_current*vis13_current.conj()
+ plan23.adjoint()
+ out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim))
+ plan31.f = weights * vis23_current*vis12_current.conj()
+ plan31.adjoint()
+ out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim))
+
+ plan1.f = weights * resid_current.conj() * vis1_current
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * resid_current.conj() * vis2_current
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * resid_current.conj() * vis3_current
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ Q0 = np.sum(weights*(np.abs(vis12_current)**2 + np.abs(vis23_current)**2 + np.abs(vis13_current)**2))
+ Q1 = 2 * (out12 + out23 + out31)
+ Q2 = -2 * (out1 + out2 + out3)
+ Q = 2*(Q0 + Q1 + Q2)
+
+ #Calculate R (3rd order)
+ plan11.f = weights * vis1_current.conj() * vis23_current
+ plan11.adjoint()
+ out11 = np.real((plan11.f_hat.copy().T).reshape(nfft_info11.xdim*nfft_info11.ydim))
+ plan22.f = weights * vis2_current.conj() * vis13_current
+ plan22.adjoint()
+ out22 = np.real((plan22.f_hat.copy().T).reshape(nfft_info22.xdim*nfft_info22.ydim))
+ plan33.f = weights * vis3_current.conj() * vis12_current
+ plan33.adjoint()
+ out33 = np.real((plan33.f_hat.copy().T).reshape(nfft_info33.xdim*nfft_info33.ydim))
+
+ plan1.f = weights * vis1_current * (np.abs(vis2_current)**2 + np.abs(vis3_current)**2)
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * vis2_current * (np.abs(vis1_current)**2 + np.abs(vis3_current)**2)
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * vis3_current * (np.abs(vis1_current)**2 + np.abs(vis2_current)**2)
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ R0 = -np.sum(weights*np.real(resid_current))
+ R1 = np.real(out11 + out22 + out33)
+ R2 = np.real(out1 + out2 + out3)
+ R = 6*(R0 + R1 + R2)
+
+ # Now find S (4th order)
+ plan12.f = weights * vis1_current*vis2_current.conj()
+ plan12.adjoint()
+ out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim))
+ plan23.f = weights * vis2_current*vis3_current.conj()
+ plan23.adjoint()
+ out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim))
+ plan31.f = weights * vis3_current*vis1_current.conj()
+ plan31.adjoint()
+ out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim))
+
+ plan1.f = weights * vis23_current.conj()
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * vis13_current.conj()
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * vis12_current.conj()
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ S0 = np.sum(weights*(np.abs(vis1_current)**2 + np.abs(vis2_current)**2 + np.abs(vis3_current)**2))
+ S1 = 2 * (out12 + out23 + out31)
+ S2 = 2 * (out1 + out2 + out3)
+ S = 4*(S0 + S1 + S2)
+
+
+ # T (5th order)
+ plan1.f = weights * vis1_current
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * vis2_current
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * vis3_current
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ T = 10*(out1 + out2 + out3)
+
+ # Finally U (6th order)
+ U = 6*weights_norm * np.ones(T.shape)
+
+ # Find Component
+ deltas = np.zeros(len(P))
+ chisq_map = np.zeros(len(P))
+ for i in range(len(P)):
+ polynomial_params = np.array([P[i], Q[i], R[i], S[i], T[i], U[i]])
+ allroots = p.polyroots(polynomial_params)
+
+ # test roots to see which minimizes chi^2
+ newchisq = chisq_current
+ delta = 0
+ for j in range(len(allroots)):
+ root = allroots[j]
+ if np.imag(root)!=0: continue
+
+ trialchisq = chisq_current + P[i]*root + 0.5*Q[i]*root**2 + (1./3.)*R[i]*root**3 + 0.25*S[i]*root**4 + 0.2*T[i]*root**5 + (1./6.)*U[i]*root**6
+ if trialchisq < newchisq:
+ delta = root
+ newchisq = trialchisq
+
+ deltas[i] = delta
+ chisq_map[i] = newchisq
+
+ #plot deltas and chi^2 map
+ if show_updates:
+ DeltasIm.imvec = deltas
+ plot_i(DeltasIm, it, chisq_current,fig=0, cmap='afmhot')
+
+ ChisqIm.imvec = -chisq_map
+ plot_i(ChisqIm, it, chisq_current,fig=1, cmap='cool')
+
+
+ component_loc_idx = np.argmin(chisq_map[embed_mask])
+ component_strength = loop_gain*(deltas[embed_mask])[component_loc_idx]
+
+ # clean component location
+ component_loc_x = coord[component_loc_idx][0]
+ component_loc_y = coord[component_loc_idx][1]
+
+ # update imvec, vis, bispec
+ imvec_current[component_loc_idx] += component_strength
+
+ #TODO how to incorporate pulse function?
+ vis1_current += component_strength*np.exp(2*np.pi*1j*(uv1[:,0]*component_loc_x + uv1[:,1]*component_loc_y))
+ vis2_current += component_strength*np.exp(2*np.pi*1j*(uv2[:,0]*component_loc_x + uv2[:,1]*component_loc_y))
+ vis3_current += component_strength*np.exp(2*np.pi*1j*(uv3[:,0]*component_loc_x + uv3[:,1]*component_loc_y))
+ bs_current = vis1_current * vis2_current * vis3_current
+
+ # update chi^2 and output image
+ chisq_current = np.sum(weights*np.abs(bs - bs_current)**2)
+ rchisq_current = np.sum(weights_nat*np.abs(bs - bs_current)**2)/(2*len(weights_nat))
+
+ print("it %i: %f (%.2f , %.2f) %.4f" % (it+1, component_strength, component_loc_x/RADPERUAS, component_loc_y/RADPERUAS, chisq_current))
+
+ OutputIm.imvec = imvec_current
+ if show_updates:
+ OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False)
+ OutputImBlur = OutputIm.blur_gauss(beamparams)
+ plot_i(OutputImBlur, it, rchisq_current,fig=2)
+
+ OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False)
+ return OutputIm
+
+
+
+
+#solve full 5th order polynomial
+#weight imaginary term differently
+def dd_clean_bispec_imweight(Obsdata, InitIm, niter=1, clipfloor=-1, ttype="direct", loop_gain=.1, loop_gain_init=1,
+ weighting='uniform', bscount="min", imweight=1, show_updates=True,
+ fft_pad_factor=FFT_PAD_DEFAULT, p_rad=NFFT_KERSIZE_DEFAULT):
+
+
+ imag_weight=imweight
+ # limit imager range to prior values > clipfloor
+ embed_mask = InitIm.imvec >= clipfloor
+
+ # get data
+ biarr = Obsdata.bispectra(mode="all", count=bscount)
+ uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1)))
+ uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1)))
+ uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1)))
+ bs = biarr['bispec']
+ sigma = biarr['sigmab']
+
+ # necessary nfft infos
+ npad = int(fft_pad_factor * np.max((InitIm.xdim, InitIm.ydim)))
+ nfft_info1 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1)
+ nfft_info2 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2)
+ nfft_info3 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3)
+
+ nfft_info11 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv1)
+ nfft_info22 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv2)
+ nfft_info33 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv3)
+
+ nfft_info12 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1-uv2)
+ nfft_info23 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2-uv3)
+ nfft_info31 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3-uv1)
+
+ # TODO do we use pulse factors?
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ plan11 = nfft_info11.plan
+ pulsefac11 = nfft_info11.pulsefac
+ plan22 = nfft_info22.plan
+ pulsefac22 = nfft_info22.pulsefac
+ plan33 = nfft_info33.plan
+ pulsefac33 = nfft_info33.pulsefac
+
+ plan12 = nfft_info12.plan
+ pulsefac12 = nfft_info12.pulsefac
+ plan23 = nfft_info23.plan
+ pulsefac23 = nfft_info23.pulsefac
+ plan31 = nfft_info31.plan
+ pulsefac31 = nfft_info31.pulsefac
+
+ # Weights
+ weights_nat = 1./(sigma**2)
+ if weighting=='uniform':
+ weights = np.ones(len(weights_nat))
+ elif weighting=='natural':
+ weights=weights_nat
+ else:
+ raise Exception("weighting must be 'uniform' or 'natural'")
+ weights_norm = np.sum(weights)
+
+ # Coordinate matrix
+ # TODO do we need to make sure this corresponds exactly with what NFFT is doing?
+ # TODO what if the image is odd?
+ coord = InitIm.psize * np.array([[[x,y] for x in np.arange(InitIm.xdim//2,-InitIm.xdim//2,-1)]
+ for y in np.arange(InitIm.ydim//2,-InitIm.ydim//2,-1)])
+ coord = coord.reshape(InitIm.ydim*InitIm.xdim, 2)
+ coord = coord[embed_mask]
+
+ # Initial imvec and visibilities
+ # TODO currently initialized to zero!!
+ OutputIm = InitIm.copy()
+ DeltasIm = InitIm.copy()
+ ChisqIm = InitIm.copy()
+
+ res = Obsdata.res()
+ beamparams = Obsdata.fit_beam()
+
+ imvec_init = 0*InitIm.imvec[embed_mask]
+
+ vis1_init = np.zeros(len(bs), dtype='complex128')
+ vis2_init = np.zeros(len(bs), dtype='complex128')
+ vis3_init = np.zeros(len(bs), dtype='complex128')
+ bs_init = vis1_init*vis2_init*vis3_init
+ chisq_init = np.sum(weights*(np.real(bs - bs_init)**2 + imweight*np.imag(bs-bs_init)**2))
+ rchisq_init = np.sum(weights_nat*np.abs(bs - bs_init)**2)/(2*len(weights_nat))
+
+ imvec_current = imvec_init
+ vis1_current = vis1_init
+ vis2_current = vis2_init
+ vis3_current = vis3_init
+ bs_current = bs_init
+ chisq_current = chisq_init
+ rchisq_current = rchisq_init
+
+ # clean loop
+ print("\n")
+ for it in range(niter):
+ t = time.time()
+ # compute delta at each location
+ resid_current = bs - bs_current
+ vis12_current = vis1_current*vis2_current
+ vis13_current = vis1_current*vis3_current
+ vis23_current = vis2_current*vis3_current
+
+ # center the first component automatically
+ # since initial image is empty, must go to higher order in delta in solution
+ # TODO generalize to non-empty initial image!
+ if it==0:
+
+ A = np.sum(weights*np.real(resid_current))
+ B = np.sum(weights)
+
+ component_strength = loop_gain_init*np.cbrt(A/B)
+ component_loc_idx = (InitIm.ydim//2)*InitIm.xdim + InitIm.xdim//2 #TODO is this right for odd images??+/- 1??
+
+ else:
+ # First calculate P (1st order)
+ plan1.f = weights * np.real(resid_current.conj() * vis23_current)
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * np.real(resid_current.conj() * vis13_current)
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * np.real(resid_current.conj() * vis12_current)
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ PRe = -2 * (out1 + out2 + out3)
+
+ plan1.f = weights * np.imag(resid_current.conj() * vis23_current)
+ plan1.adjoint()
+ out1 = np.imag((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * np.imag(resid_current.conj() * vis13_current)
+ plan2.adjoint()
+ out2 = np.imag((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * np.imag(resid_current.conj() * vis12_current)
+ plan3.adjoint()
+ out3 = np.imag((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ PIm = -2 * (out1 + out2 + out3)
+
+ P = PRe + imag_weight*PIm
+
+ # Then calculate Q (2nd order)
+ plan12.f = weights * np.real(vis13_current.conj()*vis23_current)
+ plan12.adjoint()
+ out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim))
+ plan23.f = weights * np.real(vis12_current.conj()*vis13_current)
+ plan23.adjoint()
+ out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim))
+ plan31.f = weights * np.real(vis23_current.conj()*vis12_current)
+ plan31.adjoint()
+ out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim))
+
+ plan1.f = weights * np.real(resid_current * vis1_current.conj())
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * np.real(resid_current * vis2_current.conj())
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * np.real(resid_current * vis3_current.conj())
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ Q0Re = np.sum(weights*(np.real(vis12_current)**2 + np.real(vis23_current)**2 + np.real(vis13_current)**2))
+ Q1Re = 2 * (out12 + out23 + out31)
+ Q2Re = -2*(out1 + out2 + out3)
+ QRe = 2*(Q0Re + Q1Re + Q2Re)
+
+ plan12.f = weights * np.imag(vis13_current.conj()*vis23_current)
+ plan12.adjoint()
+ out12 = np.imag((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim))
+ plan23.f = weights * np.imag(vis12_current.conj()*vis13_current)
+ plan23.adjoint()
+ out23 = np.imag((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim))
+ plan31.f = weights * np.imag(vis23_current.conj()*vis12_current)
+ plan31.adjoint()
+ out31 = np.imag((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim))
+
+ plan1.f = weights * np.imag(resid_current * vis1_current.conj())
+ plan1.adjoint()
+ out1 = np.imag((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * np.imag(resid_current * vis2_current.conj())
+ plan2.adjoint()
+ out2 = np.imag((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * np.imag(resid_current * vis3_current.conj())
+ plan3.adjoint()
+ out3 = np.imag((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ Q0Im = np.sum(weights*(np.imag(vis12_current)**2 + np.imag(vis23_current)**2 + np.imag(vis13_current)**2))
+ Q1Im = 2 * (out12 + out23 + out31)
+ Q2Im = -2*(out1 + out2 + out3)
+ QIm = 2*(Q0Im + Q1Im + Q2Im)
+
+ Q = QRe + imag_weight*QIm
+
+ #Calculate R (3rd order)
+ plan11.f = weights * np.real(vis1_current * vis23_current.conj())
+ plan11.adjoint()
+ out11 = np.real((plan11.f_hat.copy().T).reshape(nfft_info11.xdim*nfft_info11.ydim))
+ plan22.f = weights * np.real(vis2_current * vis13_current.conj())
+ plan22.adjoint()
+ out22 = np.real((plan22.f_hat.copy().T).reshape(nfft_info22.xdim*nfft_info22.ydim))
+ plan33.f = weights * np.real(vis3_current * vis12_current.conj())
+ plan33.adjoint()
+ out33 = np.real((plan33.f_hat.copy().T).reshape(nfft_info33.xdim*nfft_info33.ydim))
+
+ plan1.f = weights * np.real(vis1_current.conj()) * (np.abs(vis2_current)**2 + np.abs(vis3_current)**2)
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * np.real(vis2_current.conj()) * (np.abs(vis1_current)**2 + np.abs(vis3_current)**2)
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * np.real(vis3_current.conj()) * (np.abs(vis1_current)**2 + np.abs(vis2_current)**2)
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ R0Re = -np.sum(weights*np.real(resid_current))
+ R1Re = out11 + out22 + out33
+ R2Re = out1 + out2 + out3
+ RRe = 6*(R0Re + R1Re + R2Re)
+
+ plan11.f = weights * np.imag(vis1_current * vis23_current.conj())
+ plan11.adjoint()
+ out11 = np.imag((plan11.f_hat.copy().T).reshape(nfft_info11.xdim*nfft_info11.ydim))
+ plan22.f = weights * np.imag(vis2_current * vis13_current.conj())
+ plan22.adjoint()
+ out22 = np.imag((plan22.f_hat.copy().T).reshape(nfft_info22.xdim*nfft_info22.ydim))
+ plan33.f = weights * np.imag(vis3_current * vis12_current.conj())
+ plan33.adjoint()
+ out33 = np.imag((plan33.f_hat.copy().T).reshape(nfft_info33.xdim*nfft_info33.ydim))
+
+ plan1.f = weights * np.imag(vis1_current.conj()) * (np.abs(vis2_current)**2 + np.abs(vis3_current)**2)
+ plan1.adjoint()
+ out1 = np.imag((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * np.imag(vis2_current.conj()) * (np.abs(vis1_current)**2 + np.abs(vis3_current)**2)
+ plan2.adjoint()
+ out2 = np.imag((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * np.imag(vis3_current.conj()) * (np.abs(vis1_current)**2 + np.abs(vis2_current)**2)
+ plan3.adjoint()
+ out3 = np.imag((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ R0Im = 0
+ R1Im = out11 + out22 + out33
+ R2Im = out1 + out2 + out3
+ RIm = 6*(R0Im + R1Im + R2Im)
+
+ R = RRe + imag_weight*RIm
+
+ # Now find S (4th order)
+ plan12.f = weights * np.real(vis1_current.conj()*vis2_current)
+ plan12.adjoint()
+ out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim))
+ plan23.f = weights * np.real(vis2_current.conj()*vis3_current)
+ plan23.adjoint()
+ out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim))
+ plan31.f = weights * np.real(vis3_current.conj()*vis1_current)
+ plan31.adjoint()
+ out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim))
+
+ plan1.f = weights * vis23_current.conj()
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * vis13_current.conj()
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * vis12_current.conj()
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ S0Re = np.sum(weights*(np.real(vis1_current)**2 + np.real(vis2_current)**2 + np.real(vis3_current)**2))
+ S1Re = 2 * (out12 + out23 + out31)
+ S2Re = 2 * (out1 + out2 + out3)
+ SRe = 4*(S0Re + S1Re + S2Re)
+
+ plan12.f = weights * np.imag(vis1_current.conj()*vis2_current)
+ plan12.adjoint()
+ out12 = np.imag((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim))
+ plan23.f = weights * np.imag(vis2_current.conj()*vis3_current)
+ plan23.adjoint()
+ out23 = np.imag((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim))
+ plan31.f = weights * np.imag(vis3_current.conj()*vis1_current)
+ plan31.adjoint()
+ out31 = np.imag((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim))
+
+ S0Im = np.sum(weights*(np.imag(vis1_current)**2 + np.imag(vis2_current)**2 + np.imag(vis3_current)**2))
+ S1Im = 2 * (out12 + out23 + out31)
+ S2Im = 0
+ SIm = 4*(S0Im + S1Im + S2Im)
+
+ S = SRe + imag_weight*SIm
+
+ # T (5th order)
+ plan1.f = weights * vis1_current
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights * vis2_current
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights * vis3_current
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ TRe = 10*(out1 + out2 + out3)
+ TIm = 0.
+ T = TRe + imag_weight*TIm
+
+ # Finally U (6th order)
+ URe = 6*weights_norm * np.ones(T.shape)
+ UIm = 0.
+ U = URe + imag_weight*UIm
+
+ #Find Component
+ deltas = np.zeros(len(P))
+ chisq_map = np.zeros(len(P))
+ for i in range(len(P)):
+ if embed_mask[i]:
+ polynomial_params = np.array([P[i], Q[i], R[i], S[i], T[i], U[i]])
+ allroots = p.polyroots(polynomial_params)
+
+ # test roots to see which minimizes chi^2
+ newchisq = chisq_current
+ delta = 0
+ for j in range(len(allroots)):
+ root = allroots[j]
+ if np.imag(root)!=0: continue
+
+ trialchisq = chisq_current + P[i]*root + 0.5*Q[i]*root**2 + (1./3.)*R[i]*root**3 + 0.25*S[i]*root**4 + 0.2*T[i]*root**5 + (1./6.)*U[i]*root**6
+ if trialchisq < newchisq:
+ delta = root
+ newchisq = trialchisq
+
+ deltas[i] = delta
+ chisq_map[i] = newchisq
+ else:
+ deltas[i]=0
+ chisq_map[i]=chisq_current
+
+ #print ("step time %i: %f s" % (it+1, time.time() -t))
+ #chisq_map = chisq_current + P*deltas + 0.5*Q*deltas**2 + (1./3.)*R*deltas**3 + 0.25*S*deltas**4 + 0.2*T*deltas**5 + (1./6.)*U*deltas**6
+
+ #Plot delta and chi^2 map
+ if show_updates:
+ DeltasIm.imvec = deltas
+ plot_i(DeltasIm, it, chisq_current,fig=0, cmap='afmhot')
+
+ ChisqIm.imvec = -chisq_map
+ plot_i(ChisqIm, it, chisq_current,fig=1, cmap='cool')
+
+ component_loc_idx = np.argmin(chisq_map[embed_mask])
+ component_strength = loop_gain*(deltas[embed_mask])[component_loc_idx]
+
+ # clean component location
+ component_loc_x = coord[component_loc_idx][0]
+ component_loc_y = coord[component_loc_idx][1]
+
+ # update imvec, vis, bispec
+ imvec_current[component_loc_idx] += component_strength
+
+ #TODO how to incorporate pulse function?
+ vis1_current += component_strength*np.exp(2*np.pi*1j*(uv1[:,0]*component_loc_x + uv1[:,1]*component_loc_y))
+ vis2_current += component_strength*np.exp(2*np.pi*1j*(uv2[:,0]*component_loc_x + uv2[:,1]*component_loc_y))
+ vis3_current += component_strength*np.exp(2*np.pi*1j*(uv3[:,0]*component_loc_x + uv3[:,1]*component_loc_y))
+ bs_current = vis1_current * vis2_current * vis3_current
+
+ # update chi^2 and output image
+ chisq_current = np.sum(weights*np.abs(bs - bs_current)**2)
+ rchisq_current = np.sum(weights_nat*np.abs(bs - bs_current)**2)/(2*len(weights_nat))
+
+ print("it %i: %f (%.2f , %.2f) %.4f" % (it+1, component_strength, component_loc_x/RADPERUAS, component_loc_y/RADPERUAS, chisq_current))
+
+ OutputIm.imvec = imvec_current
+ if show_updates:
+ OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False)
+ OutputImBlur = OutputIm.blur_gauss(beamparams)
+ plot_i(OutputImBlur, it, rchisq_current,fig=2)
+
+ return OutputIm
+
+
+#amplitude and "closure phase" term
+def dd_clean_amp_cphase(Obsdata, InitIm, niter=1, clipfloor=-1, loop_gain=.1, loop_gain_init=1,phaseweight=1,
+ weighting='uniform', bscount="min",no_neg_comps=False,
+ fft_pad_factor=FFT_PAD_DEFAULT, p_rad=NFFT_KERSIZE_DEFAULT, show_updates=True):
+
+
+ # limit imager range to prior values > clipfloor
+ embed_mask = InitIm.imvec >= clipfloor
+
+ # get data
+ amp2 = np.abs(Obsdata.data['vis'])**2 #TODO debias??
+ sigma_amp2 = Obsdata.data['sigma']**2
+ uv = np.hstack((Obsdata.data['u'].reshape(-1,1), Obsdata.data['v'].reshape(-1,1)))
+
+ biarr = Obsdata.bispectra(mode="all", count=bscount)
+ uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1)))
+ uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1)))
+ uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1)))
+ bs = biarr['bispec']
+ sigma_bs = biarr['sigmab']
+
+ # necessary nfft infos
+ npad = int(fft_pad_factor * np.max((InitIm.xdim, InitIm.ydim)))
+
+ nfft_infoA = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv)
+ nfft_infoB = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, 2*uv)
+
+ nfft_info1 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1)
+ nfft_info2 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2)
+ nfft_info3 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3)
+
+ nfft_info11 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv1)
+ nfft_info22 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv2)
+ nfft_info33 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv3)
+
+ nfft_info12 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1-uv2)
+ nfft_info23 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2-uv3)
+ nfft_info31 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3-uv1)
+
+ # TODO do we use pulse factors?
+ planA = nfft_infoA.plan
+ pulsefacA = nfft_infoA.pulsefac
+ planB = nfft_infoB.plan
+ pulsefacB = nfft_infoB.pulsefac
+
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ plan11 = nfft_info11.plan
+ pulsefac11 = nfft_info11.pulsefac
+ plan22 = nfft_info22.plan
+ pulsefac22 = nfft_info22.pulsefac
+ plan33 = nfft_info33.plan
+ pulsefac33 = nfft_info33.pulsefac
+
+ plan12 = nfft_info12.plan
+ pulsefac12 = nfft_info12.pulsefac
+ plan23 = nfft_info23.plan
+ pulsefac23 = nfft_info23.pulsefac
+ plan31 = nfft_info31.plan
+ pulsefac31 = nfft_info31.pulsefac
+
+ # Weights
+ weights_amp2_nat = 1./(sigma_amp2**2)
+ if weighting=='uniform':
+ weights_amp2 = np.ones(len(weights_amp2_nat)) * np.median(weights_amp2_nat) #TODO for scaling weights are all given by median natural weight??
+ elif weighting=='natural':
+ weights_amp2=weights_amp2_nat
+ else:
+ raise Exception("weighting must be 'uniform' or 'natural'")
+ weights_amp2 = weights_amp2 / float(len(weights_amp2_nat))
+ weights_amp2_norm = np.sum(weights_amp2)
+
+ weights_bs_nat = (np.abs(bs)**2) / (sigma_bs**2)
+ if weighting=='uniform':
+ weights_bs = np.ones(len(weights_bs_nat)) * np.median(weights_bs_nat) #TODO for scaling weights are all given by median natural weight??
+ elif weighting=='natural':
+ weights_bs = weights_bs_nat
+ else:
+ raise Exception("weighting must be 'uniform' or 'natural'")
+ weights_bs = weights_bs / np.abs(bs)**2 #weight down by 1/bs^2 only works for uniform??
+ weights_bs = weights_bs / float(len(weights_bs_nat))
+ weights_bs_norm = np.sum(weights_bs)
+
+ # Coordinate matrix
+ # TODO do we need to make sure this corresponds exactly with what NFFT is doing?
+ # TODO what if the image is odd?
+ coord = InitIm.psize * np.array([[[x,y] for x in np.arange(InitIm.xdim//2,-InitIm.xdim//2,-1)]
+ for y in np.arange(InitIm.ydim//2,-InitIm.ydim//2,-1)])
+ coord = coord.reshape(InitIm.ydim*InitIm.xdim, 2)
+ coord = coord[embed_mask]
+
+ # Initial imvec and visibilities
+ # TODO currently initialized to zero!!
+ OutputIm = InitIm.copy()
+ DeltasIm = InitIm.copy()
+ ChisqIm = InitIm.copy()
+
+ res = Obsdata.res()
+ beamparams = Obsdata.fit_beam()
+
+ imvec_init = 0*InitIm.imvec[embed_mask]
+ vis_init = np.zeros(len(amp2), dtype='complex128')
+ chisq_amp2_init = np.sum(weights_amp2*(amp2 - np.abs(vis_init)**2)**2)
+ rchisq_amp2_init = np.sum(weights_amp2_nat*(amp2 - np.abs(vis_init)**2)**2)
+
+ vis1_init = np.zeros(len(bs), dtype='complex128')
+ vis2_init = np.zeros(len(bs), dtype='complex128')
+ vis3_init = np.zeros(len(bs), dtype='complex128')
+ bs_init = vis1_init*vis2_init*vis3_init
+ chisq_bs_init = np.sum(weights_bs*np.abs(bs - bs_init)**2)
+ rchisq_bs_init = np.sum(weights_bs_nat*np.abs(bs - bs_init)**2)
+
+ chisq_init = chisq_amp2_init + phaseweight*chisq_bs_init
+
+ imvec_current = imvec_init.copy()
+ vis_current = vis_init.copy()
+ vis1_current = vis1_init.copy()
+ vis2_current = vis2_init.copy()
+ vis3_current = vis3_init.copy()
+ bs_current = bs_init.copy()
+ chisq_amp2_current = chisq_amp2_init
+ rchisq_amp2_current = rchisq_amp2_init
+ chisq_bs_current = chisq_bs_init
+ rchisq_bs_current = rchisq_bs_init
+ chisq_current = chisq_init
+
+ # clean loop
+ print("\n")
+ for it in range(niter):
+ t = time.time()
+
+ # center the first component automatically
+ # since initial image is empty, must go to higher order in delta in solution
+ # TODO generalize to non-empty initial image!
+ # BASE INITIAL POINT SOURCE ENTIRELY ON VISIBILITY AMPLITUDES
+ if it==0:
+
+ A = np.sum(weights_amp2*amp2)
+ B = weights_amp2_norm
+
+ component_strength = loop_gain_init*np.sqrt(A/B)
+ component_loc_idx = (InitIm.ydim//2)*InitIm.xdim + InitIm.xdim//2 #TODO is this right for odd images??+/- 1??
+
+ else:
+
+ #Amplitude part
+ # First calculate A (1st order)
+ planA.f = weights_amp2 *(amp2_current-amp2)*vis_current
+ planA.adjoint()
+ out = np.real((planA.f_hat.copy().T).reshape(nfft_infoA.xdim*nfft_infoA.ydim))
+ A = 4*out
+
+ # Then calculate B (2nd order)
+ planB.f = weights_amp2 * vis_current*vis_current
+ planB.adjoint()
+ out = np.real((planB.f_hat.copy().T).reshape(nfft_infoB.xdim*nfft_infoB.ydim))
+
+ B0 = np.sum(weights_amp2*(2*amp2_current - amp2))
+ B1 = out
+ B = 4*(B0 + B1)
+
+ #Calculate C (3rd order)
+ planA.f = weights_amp2 * vis_current
+ planA.adjoint()
+ out = np.real((planA.f_hat.copy().T).reshape(nfft_infoA.xdim*nfft_infoA.ydim))
+
+ C = 12*out
+
+ # Now find D (4th order)
+ D = 4*weights_amp2_norm * np.ones(C.shape)
+
+ #"Closure Phase" part
+ resid_current = bs - bs_current
+ vis12_current = vis1_current*vis2_current
+ vis13_current = vis1_current*vis3_current
+ vis23_current = vis2_current*vis3_current
+
+ # First calculate P (1st order)
+ plan1.f = weights_bs * resid_current * vis23_current.conj()
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights_bs * resid_current * vis13_current.conj()
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights_bs * resid_current * vis12_current.conj()
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ P = -2 * (out1 + out2 + out3)
+
+ # Then calculate Q (2nd order)
+ plan12.f = weights_bs * vis13_current*vis23_current.conj()
+ plan12.adjoint()
+ out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim))
+ plan23.f = weights_bs * vis12_current*vis13_current.conj()
+ plan23.adjoint()
+ out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim))
+ plan31.f = weights_bs * vis23_current*vis12_current.conj()
+ plan31.adjoint()
+ out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim))
+
+ plan1.f = weights_bs * resid_current.conj() * vis1_current
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights_bs * resid_current.conj() * vis2_current
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights_bs * resid_current.conj() * vis3_current
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ Q0 = np.sum(weights_bs*(np.abs(vis12_current)**2 + np.abs(vis23_current)**2 + np.abs(vis13_current)**2))
+ Q1 = 2 * (out12 + out23 + out31)
+ Q2 = -2 * (out1 + out2 + out3)
+ Q = 2*(Q0 + Q1 + Q2)
+
+ #Calculate R (3rd order)
+ plan11.f = weights_bs * vis1_current.conj() * vis23_current
+ plan11.adjoint()
+ out11 = np.real((plan11.f_hat.copy().T).reshape(nfft_info11.xdim*nfft_info11.ydim))
+ plan22.f = weights_bs * vis2_current.conj() * vis13_current
+ plan22.adjoint()
+ out22 = np.real((plan22.f_hat.copy().T).reshape(nfft_info22.xdim*nfft_info22.ydim))
+ plan33.f = weights_bs * vis3_current.conj() * vis12_current
+ plan33.adjoint()
+ out33 = np.real((plan33.f_hat.copy().T).reshape(nfft_info33.xdim*nfft_info33.ydim))
+
+ plan1.f = weights_bs * vis1_current * (np.abs(vis2_current)**2 + np.abs(vis3_current)**2)
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights_bs * vis2_current * (np.abs(vis1_current)**2 + np.abs(vis3_current)**2)
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights_bs * vis3_current * (np.abs(vis1_current)**2 + np.abs(vis2_current)**2)
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ R0 = -np.sum(weights_bs*np.real(resid_current))
+ R1 = np.real(out11 + out22 + out33)
+ R2 = np.real(out1 + out2 + out3)
+ R = 6*(R0 + R1 + R2)
+
+ # Now find S (4th order)
+ plan12.f = weights_bs * vis1_current*vis2_current.conj()
+ plan12.adjoint()
+ out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim))
+ plan23.f = weights_bs * vis2_current*vis3_current.conj()
+ plan23.adjoint()
+ out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim))
+ plan31.f = weights_bs * vis3_current*vis1_current.conj()
+ plan31.adjoint()
+ out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim))
+
+ plan1.f = weights_bs * vis23_current.conj()
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights_bs * vis13_current.conj()
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights_bs * vis12_current.conj()
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ S0 = np.sum(weights_bs*(np.abs(vis1_current)**2 + np.abs(vis2_current)**2 + np.abs(vis3_current)**2))
+ S1 = 2 * (out12 + out23 + out31)
+ S2 = 2 * (out1 + out2 + out3)
+ S = 4*(S0 + S1 + S2)
+
+ # T (5th order)
+ plan1.f = weights_bs * vis1_current
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+ plan2.f = weights_bs * vis2_current
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+ plan3.f = weights_bs * vis3_current
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ T = 10*(out1 + out2 + out3)
+
+ # Finally U (6th order)
+ U = 6*weights_bs_norm * np.ones(T.shape)
+
+ # Find Component based on minimizing chi^2
+ deltas = np.zeros(len(A))
+ chisq_map = np.zeros(len(A))
+ for i in range(len(P)):
+ if embed_mask[i]:
+ coeffs = np.array([A[i] + phaseweight*P[i],
+ B[i] + phaseweight*Q[i],
+ C[i] + phaseweight*R[i],
+ D[i] + phaseweight*S[i],
+ phaseweight*T[i],
+ phaseweight*U[i]
+ ])
+ allroots = p.polyroots(coeffs)
+
+ # test roots to see which minimizes chi^2
+ newchisq = chisq_current
+ delta = 0
+ for j in range(len(allroots)):
+ root = allroots[j]
+ if np.imag(root)!=0: continue
+ if (no_neg_comps and root<0):continue
+ trialchisq = chisq_current + coeffs[0]*root + 0.5*coeffs[1]*root**2 + (1./3.)*coeffs[2]*root**3 + 0.25*coeffs[3]*root**4 + 0.2*coeffs[4]*root**5 + (1./6.)*coeffs[5]*root**6
+ if trialchisq < newchisq:
+ delta = root
+ newchisq = trialchisq
+
+ else:
+ delta = 0.
+ newchisq = chisq_current
+ deltas[i] = delta
+ chisq_map[i] = newchisq
+
+ #plot deltas and chi^2 map
+ if show_updates:
+ DeltasIm.imvec = deltas
+ plot_i(DeltasIm, it, chisq_current,fig=0, cmap='afmhot')
+ ChisqIm.imvec = -chisq_map
+ plot_i(ChisqIm, it, chisq_current,fig=1, cmap='cool')
+
+ #chisq_map = chisq_current + P*deltas + 0.5*Q*deltas**2 + (1./3.)*R*deltas**3 + 0.25*S*deltas**4 + 0.2*T*deltas**5 + (1./6.)*U*deltas**6
+ component_loc_idx = np.argmin(chisq_map[embed_mask])
+ component_strength = loop_gain*(deltas[embed_mask])[component_loc_idx]
+
+ # clean component location
+ component_loc_x = coord[component_loc_idx][0]
+ component_loc_y = coord[component_loc_idx][1]
+
+ # update imvec, vis, bispec
+ imvec_current[component_loc_idx] += component_strength
+
+ #TODO how to incorporate pulse function?
+ vis_current += component_strength*np.exp(2*np.pi*1j*(uv[:,0]*component_loc_x + uv[:,1]*component_loc_y))
+ amp2_current = np.abs(vis_current)**2
+
+ vis1_current += component_strength*np.exp(2*np.pi*1j*(uv1[:,0]*component_loc_x + uv1[:,1]*component_loc_y))
+ vis2_current += component_strength*np.exp(2*np.pi*1j*(uv2[:,0]*component_loc_x + uv2[:,1]*component_loc_y))
+ vis3_current += component_strength*np.exp(2*np.pi*1j*(uv3[:,0]*component_loc_x + uv3[:,1]*component_loc_y))
+ bs_current = vis1_current * vis2_current * vis3_current
+
+ # update chi^2 and output image
+ chisq_amp2_current = np.sum(weights_amp2*(amp2 - amp2_current)**2)
+ rchisq_amp2_current = np.sum(weights_amp2_nat*(amp2 - amp2_current)**2)
+
+ chisq_bs_current = np.sum(weights_bs*np.abs(bs - bs_current)**2)
+ rchisq_bs_current = np.sum(weights_bs_nat*np.abs(bs - bs_current)**2)
+
+ chisq_current = chisq_amp2_current + phaseweight*chisq_bs_current
+
+ print("it %i| %.4e (%.1f , %.1f) | %.4e %.4e | %.4e" % (it+1, component_strength, component_loc_x/RADPERUAS, component_loc_y/RADPERUAS, chisq_amp2_current, chisq_bs_current, chisq_current))
+
+ OutputIm.imvec = imvec_current
+ if show_updates:
+ OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False)
+ OutputImBlur = OutputIm.blur_gauss(beamparams)
+ plot_i(OutputImBlur, it, chisq_current, fig=2)
+
+ return OutputIm
diff --git a/imaging/dynamical_imaging.py b/imaging/dynamical_imaging.py
new file mode 100644
index 00000000..a19d917f
--- /dev/null
+++ b/imaging/dynamical_imaging.py
@@ -0,0 +1,2586 @@
+# dynamical_imaging.py
+# imaging movies with interferometric data
+#
+# Copyright (C) 2018 Michael Johnson
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+# Note: This library is still under very active development and is likely to change considerably
+# Contact Michael Johnson (mjohnson@cfa.harvard.edu) with any questions
+# The methods/techniques used in this are described in http://adsabs.harvard.edu/abs/2017ApJ...850..172J
+
+import time
+import numpy as np
+
+import scipy.optimize as opt
+import scipy.ndimage.filters as filt
+import scipy.signal
+
+import matplotlib.pyplot as plt
+
+import itertools as it
+
+from ehtim.const_def import * #Note: C is m/s rather than cm/s.
+from ehtim.observing.obs_helpers import *
+import ehtim.obsdata as obsdata
+import ehtim.image as image
+import ehtim.movie as movie
+from ehtim.imaging.imager_utils import *
+
+import ehtim.scattering as so
+
+from multiprocessing import Pool
+from functools import partial
+
+#imports from the blazarFileDownloader
+import calendar
+import requests
+import os
+from html.parser import HTMLParser
+#from HTMLParser import HTMLParser
+
+Fast_Convolve = True # This option will not wrap the convolution around the image
+
+# These parameters are only global to allow parallelizing the chi^2 calculation without huge memory overhead. It would be nice to do this locally, using the parallel array capabilities.
+# Fourier matrices:
+A1_List = [None,]
+A2_List = [None,]
+A3_List = [None,]
+# Data used:
+data1_List = [None,]
+data2_List = [None,]
+data3_List = [None,]
+# Standard deviation of data used:
+sigma1_List = [None,]
+sigma2_List = [None,]
+sigma3_List = [None,]
+
+##################################################################################################
+# Constants
+##################################################################################################
+#NHIST = 25 # number of steps to store for hessian approx
+nit = 0 # global variable to track the iteration number in the plotting callback
+
+##################################################################################################
+# Tools for AGN images
+##################################################################################################
+
+def get_KLS(im1, im2, shift = [0,0], blur_size_uas=100, dynamic_range=200):
+ # Symmetrized Kullback-Liebler Divergence, with optional blurring and max dynamic range
+ ep = np.max(im1.imvec)/dynamic_range
+ A = im1.blur_circ(blur_size_uas*RADPERUAS).imvec.reshape((im1.ydim,im1.xdim)) + ep
+ B = im2.blur_circ(blur_size_uas*RADPERUAS).imvec.reshape((im2.ydim,im2.xdim)) + ep
+
+ B = np.roll(B, shift, (0,1))
+ return np.sum( (B - A)*np.log(B/A) )
+
+def get_core_position(im, blur_size_uas=100):
+ #Estimate the core position (i.e., the brightest region)
+ #Convolve the input image with the beam, then find the brightest pixel
+ im_blur = im.blur_circ(blur_size_uas*RADPERUAS).imvec
+ return np.array(np.unravel_index(im_blur.argmax(),(im.ydim,im.xdim)))
+
+def center_core(im, blur_size_uas=100):
+ im_rotate = im.copy()
+ core_pos = get_core_position(im,blur_size_uas)
+ center = np.array((int((im.ydim-1)/2),int((im.xdim-1)/2)))
+ print ("Rotating By",center - core_pos)
+ im_rotate.imvec = np.roll(im.imvec.reshape((im.ydim,im.xdim)), center - core_pos, (0,1)).flatten()
+ if len(im.qvec):
+ im_rotate.qvec = np.roll(im.qvec.reshape((im.ydim,im.xdim)), center - core_pos, (0,1)).flatten()
+ if len(im.uvec):
+ im_rotate.uvec = np.roll(im.uvec.reshape((im.ydim,im.xdim)), center - core_pos, (0,1)).flatten()
+ if len(im.vvec):
+ im_rotate.vvec = np.roll(im.vvec.reshape((im.ydim,im.xdim)), center - core_pos, (0,1)).flatten()
+
+ return im_rotate
+
+def align_left(im,min_frac=0.1,opposite_frac_thresh=0.05):
+ #Aligns the core at the middle, assuming that there is no appreciable flux to the left of the core
+ im_rotate = im.copy()
+ center = np.array((int((im.ydim-1)/2),int((im.xdim-1)/2)))
+ projected_flux = np.sum(im.imvec.reshape((im.ydim,im.xdim)),axis=0)
+ thresh = np.max(projected_flux)*min_frac
+ #opposite_thresh = np.max(projected_flux[-int(opposite_frac_thresh*im.xdim):])
+
+ for j in range(im.xdim):
+ if projected_flux[j] > thresh: #or projected_flux[j] > opposite_thresh
+ break
+
+ im_rotate.imvec = np.roll(im.imvec.reshape((im.ydim,im.xdim)), center[0] - j, (1)).flatten()
+ if len(im.qvec):
+ im_rotate.qvec = np.roll(im.qvec.reshape((im.ydim,im.xdim)), center[0] - j, (1)).flatten()
+ if len(im.uvec):
+ im_rotate.uvec = np.roll(im.uvec.reshape((im.ydim,im.xdim)), center[0] - j, (1)).flatten()
+ if len(im.vvec):
+ im_rotate.vvec = np.roll(im.vvec.reshape((im.ydim,im.xdim)), center[0] - j, (1)).flatten()
+
+ return im_rotate
+
+
+##################################################################################################
+# Movie Export Tools
+##################################################################################################
+def export_multipanel_movie(im_List_Set, out='movie.mp4', fps=10, dpi=120, scale='linear', dynamic_range=1000.0, pad_factor=1, verbose=False, xlim = None, ylim = None, titles = [], size=8.0):
+ # Example: di.export_multipanel_movie([im_List,im_List_2],scale='log',xlim=[1000,-1000],ylim=[-3000,500],dynamic_range=[1000,5000], titles = ['43 GHz (BU)','15 GHz (MOJAVE)'])
+ import matplotlib
+ matplotlib.use('agg')
+ import matplotlib.pyplot as plt
+ import matplotlib.animation as animation
+
+ fig = plt.figure()
+
+ mjd_step = (im_List_Set[0][0].mjd - im_List_Set[0][-1].mjd)/len(im_List_Set[0])
+
+ #if len(im_List_Set)%2 == 1:
+ # verticalalignment = 'bottom'
+ #else:
+ # verticalalignment = 'top'
+
+ N_set = len(im_List_Set)
+ extent = [np.array((1,-1,-1,1))]*N_set
+ maxi = np.zeros(N_set)
+ plt_im = [None,]*N_set
+
+ if type(dynamic_range) == float or type(dynamic_range) == int:
+ dynamic_range = np.zeros(N_set) + dynamic_range
+
+ for j in range(N_set):
+ extent[j] = im_List_Set[j][0].psize/RADPERUAS*im_List_Set[j][0].xdim*np.array((1,-1,-1,1)) / 2.
+ maxi[j] = np.max(np.concatenate([im.imvec for im in im_List_Set[j]]))
+
+ def im_data(i_set, n):
+ n_data = (n-n%pad_factor)//pad_factor
+ if scale == 'linear':
+ return im_List_Set[i_set][n_data].imvec.reshape((im_List_Set[i_set][n_data].ydim,im_List_Set[i_set][n_data].xdim))
+ else:
+ return np.log(im_List_Set[i_set][n_data].imvec.reshape((im_List_Set[i_set][n_data].ydim,im_List_Set[i_set][n_data].xdim)) + maxi[i_set]/dynamic_range[i_set])
+
+ for j in range(N_set):
+ ax = plt.subplot(1, N_set, j+1)
+ plt_im[j] = plt.imshow(im_data(j, 0), extent=extent[j], cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+
+ if xlim != None:
+ ax.set_xlim(xlim)
+ if ylim != None:
+ ax.set_ylim(ylim)
+
+ if scale == 'linear':
+ plt_im[j].set_clim([0,maxi[j]])
+ else:
+ plt_im[j].set_clim([np.log(maxi[j]/dynamic_range[j]),np.log(maxi[j])])
+
+ if j == 0:
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+
+ if len(titles) > 0:
+ ax.set_title(titles[j])
+
+ fig.set_size_inches([size,size/len(im_List_Set)])
+ plt.tight_layout()
+
+ def update_img(n):
+ if verbose:
+ print ("processing frame {0} of {1}".format(n, len(im_List)*pad_factor))
+ for j in range(N_set):
+ plt_im[j].set_data(im_data(j, n))
+
+ if mjd_step > 0.1:
+ fig.suptitle('MJD: ' + str(im_List_Set[0][int((n-n%pad_factor)//pad_factor)].mjd), verticalalignment = verticalalignment)
+ else:
+ time = im_List_Set[0][int((n-n%pad_factor)//pad_factor)].time
+ time_str = ("%d:%02d.%02d" % (int(time), (time*60) % 60, (time*3600) % 60))
+ fig.suptitle(time_str)
+
+ return plt_im
+
+ ani = animation.FuncAnimation(fig,update_img,len(im_List_Set[0])*pad_factor,interval=1e3/fps)
+ writer = animation.writers['ffmpeg'](fps=fps, bitrate=1e6)
+ ani.save(out,writer=writer,dpi=dpi)
+
+def export_movie(im_List, out='movie.mp4', fps=10, dpi=120, scale='linear', cbar_unit = 'Jy', gamma=0.5, dynamic_range=1000.0, pad_factor=1, verbose=False):
+ import matplotlib
+ matplotlib.use('agg')
+ import matplotlib.pyplot as plt
+ import matplotlib.animation as animation
+
+ mjd_range = im_List[-1].mjd - im_List[0].mjd
+
+ fig = plt.figure()
+
+ extent = im_List[0].psize/RADPERUAS*im_List[0].xdim*np.array((1,-1,-1,1)) / 2.
+ maxi = np.max(np.concatenate([im.imvec for im in im_List]))
+
+# TODO: fix this
+# if cbar_unit == 'mJy':
+# imvec = imvec * 1.e3
+# qvec = qvec * 1.e3
+# uvec = uvec * 1.e3
+# elif cbar_unit == '$\mu$Jy':
+# imvec = imvec * 1.e6
+# qvec = qvec * 1.e6
+# uvec = uvec * 1.e6
+
+ unit = cbar_unit + '/pixel'
+
+ if scale=='log':
+ unit = 'log(' + cbar_unit + '/pixel)'
+
+ if scale=='gamma':
+ unit = '(' + cbar_unit + '/pixel)^gamma'
+
+ def im_data(n):
+ n_data = (n-n%pad_factor)//pad_factor
+ if scale == 'linear':
+ return im_List[n_data].imvec.reshape((im_List[n_data].ydim,im_List[n_data].xdim))
+ elif scale == 'log':
+ return np.log(im_List[n_data].imvec.reshape((im_List[n_data].ydim,im_List[n_data].xdim)) + maxi/dynamic_range)
+ elif scale == 'gamma':
+ return (im_List[n_data].imvec.reshape((im_List[n_data].ydim,im_List[n_data].xdim)) + maxi/dynamic_range)**(gamma)
+
+ plt_im = plt.imshow(im_data(0), extent=extent, cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+ if scale == 'linear':
+ plt_im.set_clim([0,maxi])
+ elif scale == 'log':
+ plt_im.set_clim([np.log(maxi/dynamic_range),np.log(maxi)])
+ elif scale == 'gamma':
+ plt_im.set_clim([(maxi/dynamic_range)**gamma,(maxi)**(gamma)])
+
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+
+ fig.set_size_inches([5,5])
+ plt.tight_layout()
+
+ def update_img(n):
+ if verbose:
+ print ("processing frame {0} of {1}".format(n, len(im_List)*pad_factor))
+ plt_im.set_data(im_data(n))
+ if mjd_range != 0:
+ fig.suptitle('MJD: ' + str(im_List[int((n-n%pad_factor)//pad_factor)].mjd))
+ else:
+ time = im_List[int((n-n%pad_factor)//pad_factor)].time
+ time_str = ("%d:%02d.%02d" % (int(time), (time*60) % 60, (time*3600) % 60))
+ fig.suptitle(time_str)
+
+ return plt_im
+
+ ani = animation.FuncAnimation(fig,update_img,len(im_List)*pad_factor,interval=1e3/fps)
+ writer = animation.writers['ffmpeg'](fps=fps, bitrate=1e6)
+ ani.save(out,writer=writer,dpi=dpi)
+
+
+##################################################################################################
+# Convenience Functions for Data Processing
+##################################################################################################
+
+def split_obs(obs):
+ """Split single observation into multiple observation files, one per scan
+ """
+
+ print ("Splitting Observation File into " + str(len(obs.tlist())) + " scans")
+ return [ obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, tdata, obs.tarr, source=obs.source, mjd=obs.mjd, ampcal=obs.ampcal, phasecal=obs.phasecal) for tdata in obs.tlist() ]
+
+def merge_obs(obs_List):
+ """Merge a list of observations into a single observation file
+ """
+
+ if len(set([obs.ra for obs in obs_List])) > 1 or len(set([obs.dec for obs in obs_List])) > 1 or len(set([obs.rf for obs in obs_List])) > 1 or len(set([obs.bw for obs in obs_List])) > 1 or len(set([obs.source for obs in obs_List])) > 1:
+ print ("All observations must have the same parameters!")
+ return
+
+ #The important things to merge are the mjd and the data
+ data_merge = np.hstack([obs.data for obs in obs_List])
+
+ return obsdata.Obsdata(obs_List[0].ra, obs_List[0].dec, obs_List[0].rf, obs_List[0].bw, data_merge, obs_List[0].tarr, polrep=obs_List[0].polrep, source=obs_List[0].source, mjd=obs_List[0].mjd, ampcal=obs_List[0].ampcal, phasecal=obs_List[0].phasecal)
+
+def average_im_list(im_List):
+ """Return the average of a list of images
+ """
+ avg_im = im_List[0].copy()
+ avg_im.imvec = np.mean([im.imvec for im in im_List],axis=0)
+ if len(im_List[0].qvec):
+ avg_im.qvec = np.mean([im.qvec for im in im_List],axis=0)
+ if len(im_List[0].uvec):
+ avg_im.uvec = np.mean([im.uvec for im in im_List],axis=0)
+ if len(im_List[0].vvec):
+ avg_im.vvec = np.mean([im.vvec for im in im_List],axis=0)
+
+ return avg_im
+
+def blur_im_list(im_List, fwhm_x, fwhm_t):
+ """Apply a gaussian filter to a list of images, with fwhm_x in radians and fwhm_t in frames. Currently only for Stokes I.
+
+ Args:
+ fwhm_x (float): circular beam size for spatial blurring in radians
+ fwhm_t (float): temporal blurring in frames
+ Returns:
+ (Image): output image list
+ """
+
+ # Blur Stokes I
+ sigma_x = fwhm_x / im_List[0].psize / (2. * np.sqrt(2. * np.log(2.)))
+ sigma_t = fwhm_t / (2. * np.sqrt(2. * np.log(2.)))
+
+ arr = np.array([im.imvec.reshape(im.ydim, im.xdim) for im in im_List])
+ arr = filt.gaussian_filter(arr, (sigma_t, sigma_x, sigma_x))
+
+ ret = []
+ for j in range(len(im_List)):
+ ret.append(image.Image(arr[j], im_List[0].psize, im_List[0].ra, im_List[0].dec, rf=im_List[0].rf, source=im_List[0].source, mjd=im_List[j].mjd))
+
+ return ret
+
+##################################################################################################
+# Convenience Functions for Analytical Work
+##################################################################################################
+
+def Wrapped_Convolve(sig,ker):
+ if np.sum(ker) == 0.0:
+ return sig
+
+ N = sig.shape[0]
+
+ if Fast_Convolve == False:
+ return scipy.signal.fftconvolve(np.pad(sig,((N, N), (N, N)), 'wrap'), np.pad(ker,((N, N), (N, N)), 'constant'),mode='same')[N:(2*N),N:(2*N)]
+ else:
+ return scipy.signal.fftconvolve(sig, ker,mode='same')
+
+def Wrapped_Gradient(M):
+ G = np.gradient(np.pad(M,((1, 1), (1, 1)), 'wrap'))
+ Gx = G[0][1:-1,1:-1]
+ Gy = G[1][1:-1,1:-1]
+ return (Gx, Gy)
+
+def Wrapped_Gradient_Reorder(M):
+ G = np.gradient(np.pad(M,((1, 1), (1, 1)), 'wrap'))
+ Gx = G[0][1:-1,1:-1]
+ Gy = G[1][1:-1,1:-1]
+ return np.transpose(np.array([Gx, Gy]),axes=[1,2,0])
+
+def Wrapped_Divergence( vectorfield ):
+ Gx = np.gradient(np.pad(vectorfield[:,:,0],((1, 1), (1, 1)), 'wrap'), axis=0)[1:-1,1:-1]
+ Gy = np.gradient(np.pad(vectorfield[:,:,1],((1, 1), (1, 1)), 'wrap'), axis=1)[1:-1,1:-1]
+ return Gx+Gy
+
+def Wrapped_Weighted_Divergence( weight, M ): #(weight \cdot \nabla) M
+ grad = Wrapped_Gradient(M)
+ return weight[:,:,0]*grad[0] + weight[:,:,1]*grad[1]
+
+##################################################################################################
+# Dynamic Regularizers and their Gradients
+##################################################################################################
+
+#RdF Regularizer (continuity of total flux density from frame to frame)
+def RdF_clip(Frame_List, embed_mask_List):
+ F_List = [np.sum(Frame_List[j].ravel()[embed_mask_List[j]]) for j in range(len(Frame_List))]
+ return np.sum(np.diff(F_List)**2)
+
+def RdF_gradient_clip(Frame_List, embed_mask_List):
+ N_frame = Frame_List.shape[0]
+ F_List = [np.sum(Frame_List[j].ravel()[embed_mask_List[j]]) for j in range(len(Frame_List))]
+ F_grad_List = [1.0 + 0.0*np.copy(Frame_List[j].ravel()[embed_mask_List[j]]) for j in range(len(Frame_List))]
+ grad = np.copy(F_grad_List)*0.0
+
+ for i in range(1,N_frame):
+ grad[i] = grad[i] + 2.0*(F_List[i] - F_List[i-1])*F_grad_List[i]
+
+ for i in range(N_frame-1):
+ grad[i] = grad[i] + 2.0*(F_List[i] - F_List[i+1])*F_grad_List[i]
+
+ return np.concatenate([grad[i]*(Frame_List[i].ravel()[embed_mask_List[i]]) for i in range(N_frame)])
+
+#RdS Regularizer (continuity of entropy from frame to frame)
+def RdS(Frame_List, Prior_List, embed_mask_List, entropy="simple", norm_reg=True, **kwargs):
+ S_List = [static_regularizer(np.array([Frame_List[j]]), np.array([Prior_List[j]]), np.array([embed_mask_List[j]]), Prior_List[0].total_flux(), Prior_List[0].psize, entropy=entropy, norm_reg=norm_reg, **kwargs) for j in range(len(Frame_List))]
+ return np.sum(np.diff(S_List)**2)
+
+def RdS_gradient(Frame_List, Prior_List, embed_mask_List, entropy="simple", norm_reg=True, **kwargs):
+ #The Jacobian_Factor is already part of the entropy gradient that this function calls
+ N_frame = Frame_List.shape[0]
+ S_List = [static_regularizer(np.array([Frame_List[j]]), np.array([Prior_List[j]]), np.array([embed_mask_List[j]]), Prior_List[0].total_flux(), Prior_List[0].psize, entropy=entropy, norm_reg=norm_reg, **kwargs) for j in range(len(Frame_List))]
+ S_grad_List = np.array([static_regularizer_gradient(np.array([Frame_List[j]]), np.array([Prior_List[j]]), np.array([embed_mask_List[j]]), Prior_List[0].total_flux(), Prior_List[0].psize, entropy=entropy, norm_reg=norm_reg, **kwargs) for j in range(len(Frame_List))])
+
+ grad = np.copy(S_grad_List)*0.0
+
+ for i in range(1,N_frame):
+ grad[i] = grad[i] + 2.0*(S_List[i] - S_List[i-1])*S_grad_List[i]
+
+ for i in range(N_frame-1):
+ grad[i] = grad[i] + 2.0*(S_List[i] - S_List[i+1])*S_grad_List[i]
+
+ return np.concatenate([grad[i] for i in range(N_frame)])
+
+
+######## Rdt, RdI, and Rflow master functions ########
+def Rdt(Frames, ker, metric='SymKL', p=2.0, **kwargs):
+ if metric == 'KL':
+ return Rdt_KL(Frames, ker)
+ elif metric == 'SymKL':
+ return Rdt_SymKL(Frames, ker)
+ elif metric == 'D2':
+ return Rdt_Dp(Frames, ker, p=2.0)
+ elif metric == 'Dp':
+ return Rdt_Dp(Frames, ker, p=p)
+ else:
+ return 0.0
+
+def Rdt_gradient(Frames, ker, metric='SymKL', p=2.0, **kwargs):
+ if metric == 'KL':
+ return Rdt_KL_gradient(Frames, ker)
+ elif metric == 'SymKL':
+ return Rdt_SymKL_gradient(Frames, ker)
+ elif metric == 'D2':
+ return Rdt_Dp_gradient(Frames, ker, p=2.0)
+ elif metric == 'Dp':
+ return Rdt_Dp_gradient(Frames, ker, p=p)
+ else:
+ return 0.0
+
+def RdI(Frames, metric='SymKL', p=2.0, **kwargs):
+ if metric == 'KL':
+ return RdI_KL(Frames)
+ elif metric == 'SymKL':
+ return RdI_SymKL(Frames)
+ elif metric == 'D2':
+ return RdI_Dp(Frames, p=2.0)
+ elif metric == 'Dp':
+ return RdI_Dp(Frames, p=p)
+ else:
+ return 0.0
+
+def RdI_gradient(Frames, metric='SymKL', p=2.0, **kwargs):
+ if metric == 'KL':
+ return RdI_KL_gradient(Frames)
+ elif metric == 'SymKL':
+ return RdI_SymKL_gradient(Frames)
+ elif metric == 'D2':
+ return RdI_Dp_gradient(Frames, p=2.0)
+ elif metric == 'Dp':
+ return RdI_Dp_gradient(Frames, p=p)
+ else:
+ return 0.0
+
+def Rflow(Frames, Flow, metric='D2', p=2.0, **kwargs):
+ if metric == 'KL':
+ return Rflow_KL(Frames, Flow)
+ elif metric == 'SymKL':
+ return Rflow_SymKL(Frames, Flow)
+ elif metric == 'D2':
+ return Rflow_D2(Frames, Flow)
+ elif metric == 'Dp':
+ return Rflow_Dp(Frames, Flow, p=p)
+ else:
+ return 0.0
+
+def Rflow_gradient_I(Frames, Flow, metric='D2', p=2.0, **kwargs):
+ if metric == 'KL':
+ return Rflow_KL_gradient_I(Frames, Flow)
+ elif metric == 'SymKL':
+ return Rflow_SymKL_gradient_I(Frames, Flow)
+ elif metric == 'D2':
+ return Rflow_D2_gradient_I(Frames, Flow)
+ elif metric == 'Dp':
+ return Rflow_Dp_gradient_I(Frames, Flow, p=p)
+ else:
+ return 0.0
+
+def Rflow_gradient_m(Frames, Flow, metric='D2', p=2.0, **kwargs):
+ if metric == 'KL':
+ return Rflow_KL_gradient_m(Frames, Flow)
+ elif metric == 'SymKL':
+ return Rflow_SymKL_gradient_m(Frames, Flow)
+ elif metric == 'D2':
+ return Rflow_Dp_gradient_m(Frames, Flow)
+ elif metric == 'Dp':
+ return Rflow_Dp_gradient_m(Frames, Flow, p=p)
+ else:
+ return 0.0
+
+####################################
+
+#Rdt Regularizer with relative entropy (Kullback-Leibler Divergence)
+def Rdt_KL(Frames, ker):
+ ep=1e-10
+ N_frame = Frames.shape[0]
+ Blur_Frames = np.array([Wrapped_Convolve(f, ker) for f in Frames])
+
+ R = 0.0
+ for i in range(1,N_frame):
+ R += np.sum( Blur_Frames[i]*np.log((Blur_Frames[i]+ep)/(Blur_Frames[i-1]+ep)) )
+
+ return R/N_frame
+
+def Rdt_KL_gradient(Frames, ker):
+ #The Jacobian_Factor accounts for the frames being written as log(frame) in the imaging algorithm
+ ep=1e-10
+ N_frame = Frames.shape[0]
+ Blur_Frames = np.array([Wrapped_Convolve(f, ker) for f in Frames])
+
+ grad = np.copy(Frames)*0.0
+
+ for i in range(1,len(Frames)):
+ grad[i] = grad[i] + np.log((Blur_Frames[i]+ep)/(Blur_Frames[i-1]+ep)) + 1.0
+
+ for i in range(len(Frames)-1):
+ grad[i] = grad[i] - (Blur_Frames[i+1]+ep)/(Blur_Frames[i]+ep)
+
+ return np.array([Wrapped_Convolve(grad[i],ker)/N_frame*Frames[i] for i in range(N_frame)]).flatten()
+
+#Rdt Regularizer with symmetrized relative entropy
+def Rdt_SymKL(Frames, ker):
+ ep=1e-10
+ N_frame = Frames.shape[0]
+ Blur_Frames = np.array([Wrapped_Convolve(f, ker) for f in Frames])
+
+ R = 0.0
+ for i in range(1,N_frame):
+ R += np.sum( (Blur_Frames[i] - Blur_Frames[i-1])*np.log((Blur_Frames[i]+ep)/(Blur_Frames[i-1]+ep)) )
+
+ return 0.5*R/N_frame
+
+def Rdt_SymKL_gradient(Frames, ker):
+ #The Jacobian_Factor accounts for the frames being written as log(frame) in the imaging algorithm
+ ep=1e-10
+ N_frame = Frames.shape[0]
+ Blur_Frames = np.array([Wrapped_Convolve(f, ker) for f in Frames])
+
+ grad = np.copy(Frames)*0.0
+
+ for i in range(1,len(Frames)):
+ grad[i] = grad[i] + 1.0 - (Blur_Frames[i-1]+ep)/(Blur_Frames[i]+ep) + np.log((Blur_Frames[i]+ep)/(Blur_Frames[i-1]+ep))
+
+ for i in range(len(Frames)-1):
+ grad[i] = grad[i] + 1.0 - (Blur_Frames[i+1]+ep)/(Blur_Frames[i]+ep) - np.log((Blur_Frames[i+1]+ep)/(Blur_Frames[i]+ep))
+
+ return 0.5*np.array([Wrapped_Convolve(grad[i],ker)/N_frame*Frames[i] for i in range(N_frame)]).flatten()
+
+#Rdt Regularizer with MSE (or l_p norm)
+def Rdt_Dp(Frames, ker, p=2.0):
+ N_frame = Frames.shape[0]
+ Blur_Frames = np.array([Wrapped_Convolve(f, ker) for f in Frames])
+ return np.sum(np.abs(np.diff(Blur_Frames,axis=0))**p)/N_frame
+
+def Rdt_Dp_gradient(Frames, ker, p=2.0):
+ N_frame = Frames.shape[0]
+
+ grad = np.copy(Frames)*0.0
+
+ if p==2.0:
+ for i in range(1,len(Frames)):
+ grad[i] = grad[i] + 2.0*(Frames[i] - Frames[i-1])
+
+ for i in range(len(Frames)-1):
+ grad[i] = grad[i] + 2.0*(Frames[i] - Frames[i+1])
+
+ return np.array([Wrapped_Convolve(Wrapped_Convolve(grad[i],ker),ker)/N_frame*Frames[i] for i in range(len(Frames))]).flatten()
+ else:
+ for i in range(1,len(Frames)):
+ grad_temp = Wrapped_Convolve(Frames[i] - Frames[i-1],ker)
+ grad[i] = grad[i] + p*(np.abs(grad_temp)**(p-1.0)*np.sign(grad_temp))
+
+ for i in range(len(Frames)-1):
+ grad_temp = Wrapped_Convolve(Frames[i] - Frames[i+1],ker)
+ grad[i] = grad[i] + p*(np.abs(grad_temp)**(p-1.0)*np.sign(grad_temp))
+
+ return np.array([Wrapped_Convolve(grad[i],ker)/N_frame*Frames[i] for i in range(N_frame)]).flatten()
+
+#RdI Regularizer
+def RdI_KL(Frames):
+ N_frame = Frames.shape[0]
+ ep=1e-10
+ avg_Image = np.mean(Frames,axis=0)
+
+ return np.sum(Frames * np.log((Frames + ep)/(avg_Image + ep)))/N_frame
+
+def RdI_KL_gradient(Frames):
+ N_frame = Frames.shape[0]
+ ep=1e-10
+ avg_Image = np.mean(Frames,axis=0)
+ return np.concatenate([(np.log((Frames[i]+ep)/(avg_Image+ep))/N_frame*Frames[i]).ravel() for i in range(N_frame)])
+
+def RdI_SymKL(Frames):
+ N_frame = Frames.shape[0]
+ ep=1e-10
+ avg_Image = np.mean(Frames,axis=0)
+
+ return np.sum(0.5 * (Frames - avg_Image) * np.log((Frames + ep)/(avg_Image + ep)))/N_frame
+
+def RdI_SymKL_gradient(Frames):
+ N_frame = Frames.shape[0]
+ ep=1e-10
+ avg_Image = np.mean(Frames,axis=0)
+ term2 = 1.0/N_frame * np.sum(np.log((Frames + ep)/(avg_Image + ep)), axis=0)
+
+ return np.concatenate([(0.5*( (Frames[i]-avg_Image)/(Frames[i]+ep) + np.log((Frames[i]+ep)/(avg_Image+ep)) - term2)/N_frame*Frames[i]).ravel() for i in range(N_frame)])
+
+def RdI_D2(Frames):
+ N_frame = Frames.shape[0]
+ avg_Image = np.mean(Frames,axis=0)
+ return np.sum((Frames - avg_Image)**2)/N_frame
+
+def RdI_D2_gradient(Frames):
+ N_frame = Frames.shape[0]
+ avg_Image = np.mean(Frames,axis=0)
+ return np.concatenate([(2.0*(Frames[i] - avg_Image)/N_frame*Frames[i]).ravel() for i in range(Frames.shape[0])])
+
+def RdI_Dp(Frames, p=2.0):
+ N_frame = Frames.shape[0]
+ avg_Image = np.mean(Frames,axis=0)
+ return np.sum(np.abs(Frames - avg_Image)**p)/N_frame
+
+def RdI_Dp_gradient(Frames, p=2.0):
+ N_frame = Frames.shape[0]
+ avg_Image = np.mean(Frames,axis=0)
+ term2 = -p/N_frame * np.sum([np.abs(Frames[i] - avg_Image)**(p-1.0)*np.sign(Frames[i] - avg_Image) for i in range(N_frame)],axis=0)
+ return np.concatenate([((p*np.abs(Frames[i] - avg_Image)**(p-1.0)*np.sign(Frames[i] - avg_Image) + term2)/N_frame*Frames[i]).ravel() for i in range(N_frame)])
+
+#Rflow Regularizer
+def Rflow_D2(Frames, Flow):
+ N_frame = Frames.shape[0]
+ val = 0.0
+
+ for j in range(len(Frames)-1):
+ #dI_dt = -Wrapped_Divergence( Frames[j,:,:,None]*Flow ) #this is not the same as the expanded version (for discrete derivatives)
+ dI_dt = -(Wrapped_Weighted_Divergence(Flow,Frames[j]) + Frames[j]*Wrapped_Divergence(Flow))
+ val = val + np.sum( (Frames[j+1] - (Frames[j] + dI_dt))**2 )
+
+ return val/N_frame
+
+def Rflow_D2_gradient_I(Frames, Flow):
+ N_frame = Frames.shape[0]
+ grad = 0.0*np.copy(Frames)
+
+ for j in range(1,len(Frames)):
+ #dI_dt = -Wrapped_Divergence( Frames[j-1,:,:,None]*Flow )
+ dI_dt = -(Wrapped_Weighted_Divergence(Flow,Frames[j-1]) + Frames[j-1]*Wrapped_Divergence(Flow))
+ deltaI = Frames[j] - (Frames[j-1] + dI_dt)
+ grad[j] = grad[j] + 2.0*deltaI
+
+ for j in range(0,len(Frames)-1):
+ #dI_dt = -Wrapped_Divergence( Frames[j,:,:,None]*Flow )
+ dI_dt = -(Wrapped_Weighted_Divergence(Flow,Frames[j]) + Frames[j]*Wrapped_Divergence(Flow))
+ deltaI = Frames[j+1] - (Frames[j] + dI_dt)
+ grad[j] = grad[j] - 2.0*(deltaI + Wrapped_Weighted_Divergence(Flow, deltaI))
+
+ for j in range(len(Frames)):
+ grad[j] = grad[j]*Frames[j]
+
+ return np.concatenate([g.flatten() for g in grad])/N_frame
+
+def Rflow_D2_gradient_m(Frames, Flow):
+ N_frame = Frames.shape[0]
+ grad = 0.0*np.copy(Flow)
+
+ for j in range(len(Frames)-1):
+ #dI_dt = -Wrapped_Divergence( Frames[j,:,:,None]*Flow )
+ dI_dt = -(Wrapped_Weighted_Divergence(Flow,Frames[j]) + Frames[j]*Wrapped_Divergence(Flow))
+ deltaI = Frames[j+1] - (Frames[j] + dI_dt)
+ grad = grad + 2.0 * deltaI[:,:,None] * Wrapped_Gradient_Reorder( Frames[j] )
+ grad = grad - 2.0 * Wrapped_Gradient_Reorder( deltaI*Frames[j] )
+
+ return np.array(grad).flatten()/N_frame
+
+def Rflow_D2_gradient_m_alt(Frames, Flow): #not sure if this is correct
+ N_frame = Frames.shape[0]
+ grad = 0.0*np.copy(Flow)
+
+ for j in range(len(Frames)-1):
+ #dI_dt = -Wrapped_Divergence( Frames[j,:,:,None]*Flow )
+ dI_dt = -(Wrapped_Weighted_Divergence(Flow,Frames[j]) + Frames[j]*Wrapped_Divergence(Flow))
+ deltaI = Frames[j+1] - (Frames[j] + dI_dt)
+ grad = -2.0 * Frames[j][:,:,None] * Wrapped_Gradient_Reorder( deltaI )
+
+ return np.array(grad).flatten()/N_frame
+
+#### Helper functions for the flow ####
+
+def squared_gradient_flow(flow):
+ """Total squared gradient of flow"""
+
+ return np.sum(np.array(Wrapped_Gradient(flow[:,:,0]))**2 + np.array(Wrapped_Gradient(flow[:,:,1]))**2)
+
+def squared_gradient_flow_grad(flow):
+ """Total squared gradient of flow gradient wrt flow"""
+
+ grad_x = -2.0*Wrapped_Divergence(Wrapped_Gradient_Reorder(flow[:,:,0]))
+ grad_y = -2.0*Wrapped_Divergence(Wrapped_Gradient_Reorder(flow[:,:,1]))
+
+ return np.transpose([grad_x.ravel(),grad_y.ravel()]).ravel()
+
+###### Static Regularizer Master Functions #######
+
+def static_regularizer(Frame_List, Prior_List, embed_mask_List, flux, psize, stype="simple", norm_reg=True, **kwargs):
+ N_frame = Frame_List.shape[0]
+ xdim = int(len(Frame_List[0].ravel())**0.5)
+
+ s = np.sum( regularizer(Frame_List[i].ravel()[embed_mask_List[i]], Prior_List[i].ravel()[embed_mask_List[i]], embed_mask_List[i], flux=flux, xdim=xdim, ydim=xdim, psize=psize, stype=stype, norm_reg=norm_reg, **kwargs) for i in range(N_frame))
+
+ return s/N_frame
+
+def static_regularizer_gradient(Frame_List, Prior_List, embed_mask_List, flux, psize, stype="simple", norm_reg=True, **kwargs):
+ # Note: this function includes Jacobian factor to account for the frames being written as log(frame)
+ N_frame = Frame_List.shape[0]
+ xdim = int(len(Frame_List[0].ravel())**0.5)
+
+ s = np.concatenate([regularizergrad((Frame_List[i].ravel())[embed_mask_List[i]], Prior_List[i].ravel()[embed_mask_List[i]], embed_mask_List[i], flux=flux, xdim=xdim, ydim=xdim, psize=psize, stype=stype, norm_reg=norm_reg, **kwargs)*(Frame_List[i].ravel())[embed_mask_List[i]] for i in range(N_frame)])
+
+ return s/N_frame
+
+##################################################################################################
+# Other Regularization Functions
+##################################################################################################
+
+def centroid(Frame_List, coord):
+ return np.sum(np.sum(im.ravel() * coord[:,0])**2 + np.sum(im.ravel() * coord[:,1])**2 for im in Frame_List)/len(Frame_List)
+
+def centroid_gradient(Frame_List, coord): #Includes Jacobian factor to account for the frames being written as log(frame)
+ return 2.0 * np.concatenate([(np.sum(im.ravel() * coord[:,0])*coord[:,0] + np.sum(im.ravel() * coord[:,1])*coord[:,1])*im.ravel() for im in Frame_List])/len(Frame_List)
+
+def movie_flux_constraint(Frame_List, flux_List):
+ # This is the mean squared *fractional* difference in image total flux density
+ # Negative means ignore
+ norm = float(np.sum([f > 0.0 for f in flux_List]))
+ return np.sum([(np.sum(Frame_List[j]) - flux_List[j])**2/flux_List[j]**2/norm*(flux_List[j] >= 0.0) for j in range(len(Frame_List))])
+
+def movie_flux_constraint_grad(Frame_List, flux_List): #Includes Jacobian factor to account for the frames being written as log(frame)
+ norm = float(np.sum([f > 0.0 for f in flux_List]))
+ return np.concatenate([2.0*(np.sum(Frame_List[j]) - flux_List[j])/flux_List[j]**2/norm*Frame_List[j].ravel()*(flux_List[j] >= 0.0) for j in range(len(Frame_List))])
+
+##################################################################################################
+# chi^2 estimation routines
+##################################################################################################
+
+
+def get_chisq(i, imvec_embed, d1, d2, d3, ttype, mask):
+ global A1_List, A2_List, A3_List, data1_List, data2_List, data3_List, sigma1_List, sigma2_List, sigma3_List
+ chisq1 = chisq2 = chisq3 = 1.0
+
+ if d1 != False and len(data1_List[i])>0:
+
+ chisq1 = chisq(imvec_embed, A1_List[i], data1_List[i], sigma1_List[i], d1, ttype=ttype, mask=mask)
+
+ if d2 != False and len(data2_List[i])>0:
+ chisq2 = chisq(imvec_embed, A2_List[i], data2_List[i], sigma2_List[i], d2, ttype=ttype, mask=mask)
+
+ if d3 != False and len(data3_List[i])>0:
+ chisq3 = chisq(imvec_embed, A3_List[i], data3_List[i], sigma3_List[i], d3, ttype=ttype, mask=mask)
+
+ return [chisq1, chisq2, chisq3]
+
+def get_chisq_wrap(args):
+ return get_chisq(*args)
+
+
+def get_chisqgrad(i, imvec_embed, d1, d2, d3, ttype, mask):
+ global A1_List, A2_List, A3_List, data1_List, data2_List, data3_List, sigma1_List, sigma2_List, sigma3_List
+ chisqgrad1 = chisqgrad2 = chisqgrad3 = 0.0*imvec_embed
+
+ if d1 != False and len(data1_List[i])>0:
+
+ chisqgrad1 = chisqgrad(imvec_embed, A1_List[i], data1_List[i], sigma1_List[i], d1, ttype=ttype, mask=mask) #This *does not* include the Jacobian factor
+
+ if d2 != False and len(data2_List[i])>0:
+ chisqgrad2 = chisqgrad(imvec_embed, A2_List[i], data2_List[i], sigma2_List[i], d2, ttype=ttype, mask=mask) #This *does not* include the Jacobian factor
+
+ if d3 != False and len(data3_List[i])>0:
+ chisqgrad3 = chisqgrad(imvec_embed, A3_List[i], data3_List[i], sigma3_List[i], d3, ttype=ttype, mask=mask) #This *does not* include the Jacobian factor
+
+ return [chisqgrad1, chisqgrad2, chisqgrad3]
+
+def get_chisqgrad_wrap(args):
+ return get_chisqgrad(*args)
+
+##################################################################################################
+# Imagers
+##################################################################################################
+
+
+def dynamical_imaging_minimal(Obsdata_List, InitIm_List, Prior, flux_List = [],
+d1='vis', d2=False, d3=False,
+alpha_d1=10, alpha_d2=10, alpha_d3=10,
+systematic_noise1=0.0, systematic_noise2=0.0, systematic_noise3=0.0,
+entropy1="tv2", entropy2="l1",
+alpha_s1=1.0, alpha_s2=1.0, norm_reg=True, alpha_A=1.0,
+R_dt ={'alpha':0.0, 'metric':'SymKL', 'p':2.0},
+maxit=200, J_factor = 0.001, stop=1.0e-10, ipynb=False, refresh_interval = 1000,
+minimizer_method = 'L-BFGS-B', NHIST = 25, update_interval = 1, clipfloor=0.,
+ttype = 'nfft', fft_pad_factor=2):
+
+ global A1_List, A2_List, A3_List, data1_List, data2_List, data3_List, sigma1_List, sigma2_List, sigma3_List
+
+ N_frame = len(Obsdata_List)
+ N_pixel = Prior.xdim #pixel dimension
+
+ # Determine the appropriate final resolution
+ all_res = []
+ for obs in Obsdata_List:
+ if len(obs.data) > 0:
+ all_res.append(obs.res())
+
+ beam_size = np.min(np.array(all_res))
+ print("Maximal Resolution:",beam_size)
+
+ # Find an observation with data
+ for j in range(N_frame):
+ if len(Obsdata_List[j].data) > 0:
+ first_obs = Obsdata_List[j]
+ break
+
+ # Catch problem if uvrange < largest baseline
+ if 1./Prior.psize < np.max(first_obs.unpack(['uvdist'])['uvdist']):
+ raise Exception("pixel spacing is larger than smallest spatial wavelength!")
+
+ if alpha_flux > 0.0 and len(flux_List) != N_frame:
+ raise Exception("Number of elements in the list of total flux densities does not match the number of frames!")
+
+ # Make the blurring kernel for R_dt
+ # Note: There are odd problems when sigma_dt is too small. I can't figure out why it causes the convolution to crash.
+ # However, having sigma_dt -> 0 is not a problem in theory. So we'll just set the kernel to be zero in that case and then ignore it later in the convolution.
+
+ B_dt = np.zeros((Prior.ydim,Prior.xdim))
+
+ embed_mask_List = [Prior.imvec > clipfloor for j in range(N_frame)]
+ embed_mask_All = np.array(embed_mask_List).flatten()
+
+ embed_totals = [np.sum(embed_mask) for embed_mask in embed_mask_List]
+
+ logprior_List = [None,] * N_frame
+ loginit_List = [None,] * N_frame
+
+ nprior_embed_List = [None,] * N_frame
+ nprior_List = [None,] * N_frame
+
+ ninit_embed_List = [InitIm_List[i].imvec for i in range(N_frame)]
+ ninit_List = [ninit_embed_List[i][embed_mask_List[i]] for i in range(N_frame)]
+
+ print ("Calculating lists/matrices for chi-squared terms...")
+ A1_List = [None for _ in range(N_frame)]
+ A2_List = [None for _ in range(N_frame)]
+ A3_List = [None for _ in range(N_frame)]
+ data1_List = [[] for _ in range(N_frame)]
+ data2_List = [[] for _ in range(N_frame)]
+ data3_List = [[] for _ in range(N_frame)]
+ sigma1_List = [None for _ in range(N_frame)]
+ sigma2_List = [None for _ in range(N_frame)]
+ sigma3_List = [None for _ in range(N_frame)]
+
+ # Get data and Fourier matrices for the data terms
+ for i in range(N_frame):
+ pixel_max = np.max(InitIm_List[i].imvec)
+ prior_flux_rescale = 1.0
+ if len(flux_List) > 0:
+ prior_flux_rescale = flux_List[i]/Prior.total_flux()
+
+ nprior_embed_List[i] = Prior.imvec * prior_flux_rescale
+
+ nprior_List[i] = nprior_embed_List[i][embed_mask_List[i]]
+ logprior_List[i] = np.log(nprior_List[i])
+ loginit_List[i] = np.log(ninit_List[i] + pixel_max/Target_Dynamic_Range/1.e6) #add the dynamic range floor here
+
+ if len(Obsdata_List[i].data) == 0: #This allows the algorithm to create frames for periods with no data
+ continue
+
+ (data1_List[i], sigma1_List[i], A1_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d1, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise1)
+ (data2_List[i], sigma2_List[i], A2_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d2, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise2)
+ (data3_List[i], sigma3_List[i], A3_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d3, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise3)
+
+ # Define the objective function and gradient
+ def objfunc(x):
+ # Frames is a list of the *unscattered* frames
+ Frames = np.zeros((N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ s1 = s2 = 0.0
+
+ if alpha_s1 != 0.0:
+ s1 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s1
+ if alpha_s2 != 0.0:
+ s2 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s2
+
+ s_dynamic = 0.0
+
+ if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames, B_dt, **R_dt)*R_dt['alpha']
+
+ chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)])
+
+ chisq = ((np.sum(chisq[:,0])/N_frame - 1.0)*alpha_d1 +
+ (np.sum(chisq[:,1])/N_frame - 1.0)*alpha_d2 +
+ (np.sum(chisq[:,2])/N_frame - 1.0)*alpha_d3)
+
+ return (s1 + s2 + s_dynamic + chisq)*J_factor
+
+ def objgrad(x):
+ Frames = np.zeros((N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ s1 = s2 = 0.0
+
+ if alpha_s1 != 0.0:
+ s1 = static_regularizer_gradient(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s1
+ if alpha_s2 != 0.0:
+ s2 = static_regularizer_gradient(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s2
+
+ s_dynamic_grad = 0.0
+ if R_dt['alpha'] != 0.0: s_dynamic_grad += Rdt_gradient(Frames, B_dt, **R_dt)*R_dt['alpha']
+
+
+ chisq_grad = np.array([get_chisqgrad(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)])
+
+ # Now add the Jacobian factor and concatenate
+ for j in range(N_frame):
+ chisq_grad[j,0] = chisq_grad[j,0]*Frames[j].ravel()[embed_mask_List[j]]
+ chisq_grad[j,1] = chisq_grad[j,1]*Frames[j].ravel()[embed_mask_List[j]]
+ chisq_grad[j,2] = chisq_grad[j,2]*Frames[j].ravel()[embed_mask_List[j]]
+
+ chisq_grad = (np.concatenate([embed(chisq_grad[i,0], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d1
+ + np.concatenate([embed(chisq_grad[i,1], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d2
+ + np.concatenate([embed(chisq_grad[i,2], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d3)
+
+ return (np.concatenate((s1 + s2 + (s_dynamic_grad + chisq_grad)[embed_mask_All]))*J_factor)
+
+ # Plotting function for each iteration
+ global nit
+ nit = 0
+ def plotcur(x, final=False):
+ global nit
+ nit += 1
+
+ if nit%update_interval == 0 or final == True:
+ print ("iteration %d" % nit)
+
+ Frames = np.zeros((N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ s1 = s2 = 0.0
+
+ if alpha_s1 != 0.0:
+
+ s1 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(),
+ Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s1
+ if alpha_s2 != 0.0:
+ s2 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(),
+ Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s2
+
+ s_dynamic = 0.0
+
+ if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames, B_dt, **R_dt)*R_dt['alpha']
+
+ chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]],
+ d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)])
+
+ chisq1_List = chisq[:,0]
+ chisq2_List = chisq[:,1]
+ chisq3_List = chisq[:,2]
+ chisq1 = np.sum(chisq1_List)/N_frame
+ chisq2 = np.sum(chisq2_List)/N_frame
+ chisq3 = np.sum(chisq3_List)/N_frame
+ chisq1_max = np.max(chisq1_List)
+ chisq2_max = np.max(chisq2_List)
+ chisq3_max = np.max(chisq3_List)
+ if d1 != False: print ("chi2_1: %f" % chisq1)
+ if d2 != False: print ("chi2_2: %f" % chisq2)
+ if d3 != False: print ("chi2_3: %f" % chisq3)
+ if d1 != False: print ("weighted chi2_1: %f" % (chisq1 * alpha_d1))
+ if d2 != False: print ("weighted chi2_2: %f" % (chisq2 * alpha_d2))
+ if d3 != False: print ("weighted chi2_3: %f" % (chisq3 * alpha_d3))
+ if d1 != False: print ("Max Frame chi2_1: %f" % chisq1_max)
+ if d2 != False: print ("Max Frame chi2_2: %f" % chisq2_max)
+ if d3 != False: print ("Max Frame chi2_3: %f" % chisq3_max)
+
+ if final == True:
+ if d1 != False: print ("All chisq1:",chisq1_List)
+ if d2 != False: print ("All chisq2:",chisq2_List)
+ if d3 != False: print ("All chisq3:",chisq3_List)
+
+ if s1 != 0.0: print ("weighted s1: %f" % (s1))
+ if s2 != 0.0: print ("weighted s2: %f" % (s2))
+ print ("weighted s_dynamic: %f" % (s_dynamic))
+
+ if nit%refresh_interval == 0:
+ print ("Plotting Functionality Temporarily Disabled...")
+
+ loginit = np.hstack(loginit_List).flatten()
+ x0 = loginit
+
+ print ("Total Pixel #: ",(N_pixel*N_pixel*N_frame))
+ print ("Clipped Pixel #: ",(len(loginit)))
+
+ print ("Initial Values:")
+ plotcur(x0)
+
+ # Minimize
+ optdict = {'maxiter':maxit, 'ftol':stop, 'maxcor':NHIST, 'gtol': 1e-10} # minimizer params
+ tstart = time.time()
+ res = opt.minimize(objfunc, x0, method=minimizer_method, jac=objgrad, options=optdict, callback=plotcur)
+ tstop = time.time()
+
+ Frames = np.zeros((N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(res.x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ #Impose the prior mask in linear space for the output
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ plotcur(res.x, final=True)
+
+ # Print stats
+ print ("time: %f s" % (tstop - tstart))
+ print ("J: %f" % res.fun)
+ print (res.message)
+
+ outim = [image.Image(Frames[i].reshape(Prior.ydim, Prior.xdim), Prior.psize,
+ Prior.ra, Prior.dec, rf=Obsdata_List[i].rf, source=Prior.source,
+ mjd=Prior.mjd, pulse=Prior.pulse) for i in range(N_frame)]
+
+ return outim
+
+
+def dynamical_imaging(obs_input, init_ims, Prior, Flow_Init = None, flux_List = None,
+d1='vis', d2=False, d3=False,
+alpha_d1=10, alpha_d2=10, alpha_d3=10,
+systematic_noise1=0.0, systematic_noise2=0.0, systematic_noise3=0.0,
+entropy1="tv2", entropy2="l1",
+alpha_s1=1.0, alpha_s2=1.0, norm_reg=True, alpha_A=1.0,
+R_dI ={'alpha':0.0, 'metric':'SymKL', 'p':2.0},
+R_dt ={'alpha':0.0, 'metric':'SymKL', 'sigma_dt':0.0, 'p':2.0},
+R_flow={'alpha':0.0, 'metric':'SymKL', 'p':2.0, 'alpha_flow_tv':50.0},
+alpha_centroid=0.0, alpha_flux=0.0, alpha_dF=0.0, alpha_dS1=0.0, alpha_dS2=0.0, #other regularizers
+stochastic_optics=False, scattering_model=False, alpha_phi = 1.e4, #options for scattering
+Target_Dynamic_Range = 10000.0,
+maxit=200, J_factor = 0.001, stop=1.0e-10, ipynb=False, refresh_interval = 1000,
+minimizer_method = 'L-BFGS-B', NHIST = 25, update_interval = 1, clipfloor=0., processes = -1,
+recalculate_chisqdata = True, ttype = 'nfft', fft_pad_factor=2, **kwargs):
+
+ """Run dynamical imaging.
+
+ Args:
+ obs_input (List or Obsdata): Observation. Form can be either:
+ 1. List of Obsdata objects, one per reconstructed frame. Some can have empty data arrays.
+ 2. Single Obsdata object.
+ init_ims (List or Movie): List of initial images. List can be either:
+ 1. Each an Image object, one per reconstructed frame.
+ 2. A Movie object, where the frames will be used
+ Prior (Image): The Image object with the prior image
+ Flow_Init: Optional initialization for imaging with R_flow
+ flux_List (List): Optional specification of the total flux density for each frame
+ d1 (str): The first data term; options are 'vis', 'bs', 'amp', 'cphase', 'camp', 'logcamp'
+ d2 (str): The second data term; options are 'vis', 'bs', 'amp', 'cphase', 'camp', 'logcamp'
+ d3 (str): The third data term; options are 'vis', 'bs', 'amp', 'cphase', 'camp', 'logcamp'
+ systematic_noise1 (float): Systematic noise on the first data term, as a fraction of the visibility amplitude
+ systematic_noise2 (float): Systematic noise on the second data term, as a fraction of the visibility amplitude
+ systematic_noise3 (float): Systematic noise on the third data term, as a fraction of the visibility amplitude
+ entropy1 (str): The first regularizer; options are 'simple', 'gs', 'tv', 'tv2', 'l1', 'patch','compact','compact2','rgauss'
+ entropy2 (str): The second regularizer; options are 'simple', 'gs', 'tv', 'tv2','l1', 'patch','compact','compact2','rgauss'
+ alpha_d1 (float): The first data term weighting
+ alpha_d2 (float): The second data term weighting
+ alpha_s1 (float): The first regularizer term weighting
+ alpha_s2 (float): The second regularizer term weighting
+ alpha_flux (float): The weighting for the total flux constraint
+ alpha_centroid (float): The weighting for the center of mass constraint
+ alpha_dF (float): The weighting for temporal continuity of the total flux density.
+ alpha_dS1 (float): The weighting for temporal continuity of entropy1.
+ alpha_dS2 (float): The weighting for temporal continuity of entropy2.
+
+ maxit (int): Maximum number of minimizer iterations
+ stop (float): The convergence criterion
+ minimizer_method (str): Minimizer method (e.g., 'L-BFGS-B' or 'CG')
+ update_interval (int): Print convergence status every update_interval steps
+ norm_reg (bool): If True, normalizes regularizer terms
+ ttype (str): The Fourier transform type; options are 'fast', 'direct', 'nfft'
+
+ stochastic_optics (bool): If True, stochastic optics imaging is used.
+ scattering_model (ScatteringModel): Optional specification of the ScatteringModel object.
+ alpha_phi (float): Weighting for screen phase regularization in stochastic optics.
+minimizer_method = 'L-BFGS-B', update_interval = 1
+
+ Returns:
+ List or Dictionary: A list of Image objects, one per frame, unless a flow or stochastic optics is used in which case it returns a dictionary {'Frames', 'Flow', 'EpsilonList' }.
+ """
+
+ global A1_List, A2_List, A3_List, data1_List, data2_List, data3_List, sigma1_List, sigma2_List, sigma3_List
+
+ # Make a list of frames if a movie is passed
+ if type(init_ims) == list:
+ InitIm_List = init_ims
+ else:
+ InitIm_List = init_ims.im_list()
+
+ # Make a list of observations if a single Obsdata object is passed
+ if type(obs_input) == list:
+ Obsdata_List = obs_input
+ else:
+ # Create one obsdata object for every frame
+ Obsdata_List = [obs_input.copy() for _ in InitIm_List]
+ # Populate each frame; for now just gather observations by nearest frame
+ tlist = obs_input.tlist()
+ obs_mjds = [obs_input.mjd + o['time'][0]/24.0 for o in tlist]
+ frame_mjds = [im.mjd + im.time/24.0 for im in InitIm_List]
+ idx_list = [np.argmin(np.abs(frame_mjds - obs_mjds[j])) for j in range(len(obs_mjds))]
+
+ c = 0
+ for j in range(len(frame_mjds)):
+ Obsdata_List[j].mjd = InitIm_List[j].mjd
+ try:
+ Obsdata_List[j].data = np.concatenate(tlist[[x == j for x in idx_list]])
+ except:
+ Obsdata_List[j].data = []
+ c = c + 1
+ pass
+ if c > 0:
+ print("%d/%d frames have no data"%(c,len(frame_mjds)))
+
+ N_frame = len(Obsdata_List)
+ N_pixel = Prior.xdim #pixel dimension
+
+ # Determine the appropriate final resolution
+ all_res = []
+ for obs in Obsdata_List:
+ if len(obs.data) > 0:
+ all_res.append(obs.res())
+
+ beam_size = np.min(np.array(all_res))
+ print("Maximal Resolution:",beam_size)
+
+ # Find an observation with data
+ for j in range(N_frame):
+ if len(Obsdata_List[j].data) > 0:
+ first_obs = Obsdata_List[j]
+ break
+
+ # Catch problem if uvrange < largest baseline
+ if 1./Prior.psize < np.max(first_obs.unpack(['uvdist'])['uvdist']):
+ raise Exception("pixel spacing is larger than smallest spatial wavelength!")
+
+ if alpha_flux > 0.0 and len(flux_List) != N_frame:
+ raise Exception("Number of elements in the list of total flux densities does not match the number of frames!")
+
+ # If using stochastic optics, do some preliminary calculations
+ if stochastic_optics == True:
+ # Doesn't yet work with clipping
+ clipfloor = -1.0
+
+ if scattering_model == False:
+ print("No scattering model specified. Assuming the default scattering for Sgr A*.")
+ scattering_model = so.ScatteringModel()
+
+ # First some preliminary definitions
+ N = InitIm_List[0].xdim
+ FOV = InitIm_List[0].psize * N * scattering_model.observer_screen_distance #Field of view, in cm, at the scattering screen
+
+ # The ensemble-average convolution kernel and its gradients
+ wavelength_List = np.array([C/obs.rf*100.0 for obs in Obsdata_List]) #Observing wavelength for each frame [cm]
+ wavelengthbar_List = wavelength_List/(2.0*np.pi) #lambda/(2pi) [cm]
+ rF_List = [scattering_model.rF(wavelength) for wavelength in wavelength_List]
+
+ print("Computing the Ensemble-Average Kernel for Each Frame...")
+ ea_ker = [scattering_model.Ensemble_Average_Kernel(InitIm_List[0], wavelength_cm = wavelength_List[j]) for j in range(N_frame)]
+ ea_ker_gradient = [so.Wrapped_Gradient(ea_ker[j]/(FOV/N)) for j in range(N_frame)]
+ ea_ker_gradient_x = [-ea_ker_gradient[j][1] for j in range(N_frame)]
+ ea_ker_gradient_y = [-ea_ker_gradient[j][0] for j in range(N_frame)]
+
+ # The power spectrum (note: rotation is not currently implemented; the gradients would need to be modified slightly)
+ sqrtQ = np.real(scattering_model.sqrtQ_Matrix(InitIm_List[0],t_hr=0.0))
+
+ # Make the blurring kernel for R_dt
+ # Note: There are odd problems when sigma_dt is too small. I can't figure out why it causes the convolution to crash.
+ # However, having sigma_dt -> 0 is not a problem in theory. So we'll just set the kernel to be zero in that case and then ignore it later in the convolution.
+
+ if R_dt['sigma_dt'] > 0.0:
+ B_dt = np.abs(np.array([[np.exp(-1.0*(float(i)**2+float(j)**2)/(2.*R_dt['sigma_dt']**2))
+ for i in np.linspace((Prior.xdim-1)/2., -(Prior.xdim-1)/2., num=Prior.xdim)]
+ for j in np.linspace((Prior.ydim-1)/2., -(Prior.ydim-1)/2., num=Prior.ydim)]))
+ if np.max(B_dt) == 0.0 or np.sum(B_dt) == 0.0:
+ raise Exception("Error with the blurring kernel!")
+ B_dt = B_dt / np.sum(B_dt) # normalize to be flux preserving
+ else:
+ B_dt = np.zeros((Prior.ydim,Prior.xdim))
+
+ embed_mask_List = [Prior.imvec > clipfloor for j in range(N_frame)]
+ embed_mask_All = np.array(embed_mask_List).flatten()
+
+ embed_totals = [np.sum(embed_mask) for embed_mask in embed_mask_List]
+ if len(set(embed_totals)) > 1 and R_flow['alpha'] > 0.0:
+ print ("If a flow is used, then each frame must have the same prior!")
+ return
+
+ logprior_List = [None for _ in range(N_frame)]
+ loginit_List = [None for _ in range(N_frame)]
+
+ nprior_embed_List = [None for _ in range(N_frame)]
+ nprior_List = [None for _ in range(N_frame)]
+
+ ninit_embed_List = [InitIm_List[i].imvec for i in range(N_frame)]
+ ninit_List = [ninit_embed_List[i][embed_mask_List[i]] for i in range(N_frame)]
+
+ if (recalculate_chisqdata == True and ttype == 'direct') or ttype != 'direct':
+ print ("Calculating lists/matrices for chi-squared terms...")
+ A1_List = [None for _ in range(N_frame)]
+ A2_List = [None for _ in range(N_frame)]
+ A3_List = [None for _ in range(N_frame)]
+ data1_List = [[] for _ in range(N_frame)]
+ data2_List = [[] for _ in range(N_frame)]
+ data3_List = [[] for _ in range(N_frame)]
+ sigma1_List = [None for _ in range(N_frame)]
+ sigma2_List = [None for _ in range(N_frame)]
+ sigma3_List = [None for _ in range(N_frame)]
+
+ # Get data and Fourier matrices for the data terms
+ for i in range(N_frame):
+ pixel_max = np.max(InitIm_List[i].imvec)
+ prior_flux_rescale = 1.0
+ if not flux_List:
+ pass
+ else:
+ prior_flux_rescale = flux_List[i]/Prior.total_flux()
+
+ nprior_embed_List[i] = Prior.imvec * prior_flux_rescale
+
+ nprior_List[i] = nprior_embed_List[i][embed_mask_List[i]]
+ logprior_List[i] = np.log(nprior_List[i])
+ loginit_List[i] = np.log(ninit_List[i] + pixel_max/Target_Dynamic_Range/1.e6) #add the dynamic range floor here
+
+ if len(Obsdata_List[i].data) == 0: #This allows the algorithm to create frames for periods with no data
+ continue
+
+ if (recalculate_chisqdata == True and ttype == 'direct') or ttype != 'direct':
+ # Try to create the chisqdata. These can throw errors, for instance when no closure quantities exist in a frame with data
+ try:
+ (data1_List[i], sigma1_List[i], A1_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d1, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise1)
+ except:
+ pass
+ try:
+ (data2_List[i], sigma2_List[i], A2_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d2, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise2)
+ except:
+ pass
+ try:
+ (data3_List[i], sigma3_List[i], A3_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d3, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise3)
+ except:
+ pass
+
+ # Coordinate matrix for COM constraint
+ coord = np.array([[[x,y] for x in np.linspace(Prior.xdim/2,-Prior.xdim/2,Prior.xdim)]
+ for y in np.linspace(Prior.ydim/2,-Prior.ydim/2,Prior.ydim)])
+ coord = coord.reshape(Prior.ydim*Prior.xdim, 2)
+
+ # Make the pool for parallel processing
+ if processes > 0:
+ print("Using Multiprocessing")
+ pool = Pool(processes=processes)
+ elif processes == 0:
+ processes = int(cpu_count())
+ print("Using Multiprocessing with %d Processes" % processes)
+ pool = Pool(processes=processes)
+ else:
+ print("Not Using Multiprocessing")
+
+ # Define the objective function and gradient
+ def objfunc(x):
+ # Frames is a list of the *unscattered* frames
+ Frames = np.zeros((N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ if R_flow['alpha'] != 0.0:
+ cur_len = np.sum(embed_mask_List[0]) #assumes all the priors have the same embedding
+ Flow_x = embed(x[init_i:(init_i+2*cur_len-1):2], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Flow_y = embed(x[(init_i+1):(init_i+2*cur_len):2], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Flow = np.transpose([Flow_x.ravel(),Flow_y.ravel()]).reshape((N_pixel, N_pixel,2))
+ init_i += 2*cur_len
+
+ if stochastic_optics == True:
+ EpsilonList = x[init_i:(init_i + N**2-1)]
+ im_List = [image.Image(Frames[j], Prior.psize, Prior.ra, Prior.dec, rf=Obsdata_List[j].rf, source=Prior.source, mjd=Prior.mjd) for j in range(N_frame)]
+ #the list of scattered image vectors
+ scatt_im_List = [scattering_model.Scatter(im_List[j], Epsilon_Screen=so.MakeEpsilonScreenFromList(EpsilonList, N), ea_ker = ea_ker[j], sqrtQ=sqrtQ, Linearized_Approximation=True).imvec for j in range(N_frame)]
+ init_i += len(EpsilonList)
+
+ s1 = s2 = 0.0
+
+ if alpha_s1 != 0.0:
+ s1 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s1
+ if alpha_s2 != 0.0:
+ s2 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s2
+
+ s_dynamic = cm = flux = s_dS = s_dF = 0.0
+
+ if R_dI['alpha'] != 0.0: s_dynamic += RdI(Frames, **R_dI)*R_dI['alpha']
+ if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames, B_dt, **R_dt)*R_dt['alpha']
+
+ if alpha_dS1 != 0.0: s_dS += RdS(Frames, nprior_embed_List, embed_mask_List, entropy1, norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_dS1
+ if alpha_dS2 != 0.0: s_dS += RdS(Frames, nprior_embed_List, embed_mask_List, entropy2, norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_dS2
+
+ if alpha_dF != 0.0: s_dF += RdF_clip(Frames, embed_mask_List)*alpha_dF
+
+ if alpha_centroid != 0.0: cm = centroid(Frames, coord) * alpha_centroid
+
+ if alpha_flux > 0.0:
+ flux = alpha_flux * movie_flux_constraint(Frames, flux_List)
+
+ if stochastic_optics == False:
+ if processes > 0:
+ chisq = np.array(pool.map(get_chisq_wrap, [[j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)]))
+ else:
+ chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)])
+ else:
+ if processes > 0:
+ chisq = np.array(pool.map(get_chisq_wrap, [[j, scatt_im_List[j][embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)]))
+ else:
+ chisq = np.array([get_chisq(j, scatt_im_List[j][embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)])
+
+ chisq = ((np.sum(chisq[:,0])/N_frame - 1.0)*alpha_d1 +
+ (np.sum(chisq[:,1])/N_frame - 1.0)*alpha_d2 +
+ (np.sum(chisq[:,2])/N_frame - 1.0)*alpha_d3)
+
+ if R_flow['alpha'] != 0.0:
+ flow_tv = squared_gradient_flow(Flow)
+ s_dynamic += flow_tv*R_flow['alpha_flow_tv']
+ s_dynamic += Rflow(Frames, Flow, **R_flow)*R_flow['alpha']
+
+ # Scattering screen regularization term
+ regterm_scattering = 0.0
+ if stochastic_optics == True:
+ chisq_epsilon = sum(EpsilonList*EpsilonList)/((N*N-1.0)/2.0)
+ regterm_scattering = alpha_phi * (chisq_epsilon - 1.0)
+
+ return (s1 + s2 + s_dF + s_dS + s_dynamic + chisq + cm + flux + regterm_scattering)*J_factor
+
+ def objgrad(x):
+ Frames = np.zeros((N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ if R_flow['alpha'] != 0.0:
+ cur_len = np.sum(embed_mask_List[0]) #assumes all the priors have the same embedding
+ Flow_x = embed(x[init_i:(init_i+2*cur_len-1):2], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Flow_y = embed(x[(init_i+1):(init_i+2*cur_len):2], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Flow = np.transpose([Flow_x.ravel(),Flow_y.ravel()]).reshape((N_pixel, N_pixel,2))
+ init_i += 2*cur_len
+
+ if stochastic_optics == True:
+ EpsilonList = x[init_i:(init_i + N**2-1)]
+ Epsilon_Screen = so.MakeEpsilonScreenFromList(EpsilonList, N)
+ im_List = [image.Image(Frames[j], Prior.psize, Prior.ra, Prior.dec, rf=Obsdata_List[j].rf, source=Prior.source, mjd=Prior.mjd) for j in range(N_frame)]
+ scatt_im_List = [scattering_model.Scatter(im_List[j], Epsilon_Screen=so.MakeEpsilonScreenFromList(EpsilonList, N), ea_ker = ea_ker[j], sqrtQ=sqrtQ, Linearized_Approximation=True).imvec for j in range(N_frame)] #the list of scattered image vectors
+ Epsilon_Screen = so.MakeEpsilonScreenFromList(EpsilonList, N)
+ init_i += len(EpsilonList)
+
+ s1 = s2 = 0.0
+
+ if alpha_s1 != 0.0:
+ s1 = static_regularizer_gradient(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s1
+ if alpha_s2 != 0.0:
+ s2 = static_regularizer_gradient(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s2
+
+ s_dynamic_grad = cm_grad = flux_grad = s_dS = s_dF = 0.0
+ if R_dI['alpha'] != 0.0: s_dynamic_grad += RdI_gradient(Frames,**R_dI)*R_dI['alpha']
+ if R_dt['alpha'] != 0.0: s_dynamic_grad += Rdt_gradient(Frames, B_dt, **R_dt)*R_dt['alpha']
+
+ if alpha_dS1 != 0.0: s_dS += RdS_gradient(Frames, nprior_embed_List, embed_mask_List, entropy1, norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_dS1
+ if alpha_dS2 != 0.0: s_dS += RdS_gradient(Frames, nprior_embed_List, embed_mask_List, entropy2, norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_dS2
+
+ if alpha_dF != 0.0: s_dF += RdF_gradient_clip(Frames, embed_mask_List)*alpha_dF
+
+ if alpha_centroid != 0.0: cm_grad = centroid_gradient(Frames, coord) * alpha_centroid
+
+ if alpha_flux > 0.0:
+ flux_grad = alpha_flux * movie_flux_constraint_grad(Frames, flux_List)
+
+ dchisq_dIa_List = []
+
+
+ # Michael -- can we do something about this
+ if stochastic_optics == False:
+ if processes > 0:
+ chisq_grad = np.array(pool.map(get_chisqgrad_wrap, [[j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)]))
+ else:
+ chisq_grad = np.array([get_chisqgrad(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)])
+ else:
+ if processes > 0:
+ chisq_grad = np.array(pool.map(get_chisqgrad_wrap, [[j, scatt_im_List[j][embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)]))
+ else:
+ chisq_grad = np.array([get_chisqgrad(j, scatt_im_List[j][embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)])
+
+ # Now, the chi^2 gradient must be modified so that it corresponds to the gradient wrt the unscattered image
+ for j in range(N_frame):
+ rF = rF_List[j]
+ phi = scattering_model.MakePhaseScreen(Epsilon_Screen, im_List[0], obs_frequency_Hz=im_List[j].rf,sqrtQ_init=sqrtQ).imvec.reshape((N, N))
+ phi_Gradient = so.Wrapped_Gradient(phi/(FOV/N))
+ phi_Gradient_x = -phi_Gradient[1]
+ phi_Gradient_y = -phi_Gradient[0]
+ dchisq_dIa_List.append( ((chisq_grad[j,0]*alpha_d1 + chisq_grad[j,1]*alpha_d2 + chisq_grad[j,2]*alpha_d3)/N_frame).reshape((N,N)) )
+
+ dchisq_dIa = chisq_grad[j,0].reshape((N,N))
+ gx = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_x[j][::-1,::-1], phi_Gradient_x * (dchisq_dIa))).flatten()
+ gy = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_y[j][::-1,::-1], phi_Gradient_y * (dchisq_dIa))).flatten()
+ chisq_grad[j,0] = so.Wrapped_Convolve(ea_ker[j][::-1,::-1], (dchisq_dIa)).flatten() + gx + gy
+
+ dchisq_dIa = chisq_grad[j,1].reshape((N,N))
+ gx = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_x[j][::-1,::-1], phi_Gradient_x * (dchisq_dIa))).flatten()
+ gy = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_y[j][::-1,::-1], phi_Gradient_y * (dchisq_dIa))).flatten()
+ chisq_grad[j,1] = so.Wrapped_Convolve(ea_ker[j][::-1,::-1], (dchisq_dIa)).flatten() + gx + gy
+
+ dchisq_dIa = chisq_grad[j,2].reshape((N,N))
+ gx = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_x[j][::-1,::-1], phi_Gradient_x * (dchisq_dIa))).flatten()
+ gy = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_y[j][::-1,::-1], phi_Gradient_y * (dchisq_dIa))).flatten()
+ chisq_grad[j,2] = so.Wrapped_Convolve(ea_ker[j][::-1,::-1], (dchisq_dIa)).flatten() + gx + gy
+
+ # Now add the Jacobian factor and concatenate
+ for j in range(N_frame):
+ chisq_grad[j,0] = chisq_grad[j,0]*Frames[j].ravel()[embed_mask_List[j]]
+ chisq_grad[j,1] = chisq_grad[j,1]*Frames[j].ravel()[embed_mask_List[j]]
+ chisq_grad[j,2] = chisq_grad[j,2]*Frames[j].ravel()[embed_mask_List[j]]
+
+ chisq_grad = (np.concatenate([embed(chisq_grad[i,0], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d1
+ + np.concatenate([embed(chisq_grad[i,1], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d2
+ + np.concatenate([embed(chisq_grad[i,2], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d3)
+
+ # Gradient of the data chi^2 wrt to the epsilon screen -- this is the really difficult one
+ chisq_grad_epsilon = np.array([])
+ if stochastic_optics == True:
+ #Preliminary Definitions
+ chisq_grad_epsilon = np.zeros(N**2-1)
+ ell_mat = np.zeros((N,N))
+ m_mat = np.zeros((N,N))
+ for ell in range(0, N):
+ for m in range(0, N):
+ ell_mat[ell,m] = ell
+ m_mat[ell,m] = m
+
+ for j in range(N_frame):
+ rF = rF_List[j]
+ dchisq_dIa = dchisq_dIa_List[j]
+ EA_Image = scattering_model.Ensemble_Average_Blur(im_List[j], ker = ea_ker[j])
+ EA_Gradient = so.Wrapped_Gradient((EA_Image.imvec/(FOV/N)).reshape(N, N))
+ #The gradient signs don't actually matter, but let's make them match intuition (i.e., right to left, bottom to top)
+ EA_Gradient_x = -EA_Gradient[1]
+ EA_Gradient_y = -EA_Gradient[0]
+
+ i_grad = 0
+ #Real part; top row
+ for t in range(1, (N+1)//2):
+ s=0
+ grad_term = so.Wrapped_Gradient(wavelengthbar_List[j]/FOV*sqrtQ[s][t]*2.0*np.cos(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
+ grad_term_x = -grad_term[1]
+ grad_term_y = -grad_term[0]
+ chisq_grad_epsilon[i_grad] += np.sum( dchisq_dIa * rF**2 * ( EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y ) )
+ i_grad = i_grad + 1
+
+ #Real part; remainder
+ for s in range(1,(N+1)//2):
+ for t in range(N):
+ grad_term = so.Wrapped_Gradient(wavelengthbar_List[j]/FOV*sqrtQ[s][t]*2.0*np.cos(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
+ grad_term_x = -grad_term[1]
+ grad_term_y = -grad_term[0]
+ chisq_grad_epsilon[i_grad] += np.sum( dchisq_dIa * rF**2 * ( EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y ) )
+ i_grad = i_grad + 1
+
+ #Imaginary part; top row
+ for t in range(1, (N+1)//2):
+ s=0
+ grad_term = so.Wrapped_Gradient(-wavelengthbar_List[j]/FOV*sqrtQ[s][t]*2.0*np.sin(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
+ grad_term_x = -grad_term[1]
+ grad_term_y = -grad_term[0]
+ chisq_grad_epsilon[i_grad] += np.sum( dchisq_dIa * rF**2 * ( EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y ) )
+ i_grad = i_grad + 1
+
+ #Imaginary part; remainder
+ for s in range(1,(N+1)//2):
+ for t in range(N):
+ grad_term = so.Wrapped_Gradient(-wavelengthbar_List[j]/FOV*sqrtQ[s][t]*2.0*np.sin(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N))
+ grad_term_x = -grad_term[1]
+ grad_term_y = -grad_term[0]
+ chisq_grad_epsilon[i_grad] += np.sum( dchisq_dIa * rF**2 * ( EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y ) )
+ i_grad = i_grad + 1
+
+ # Gradients related to the flow
+ flow_grad = np.array([])
+ if R_flow['alpha'] != 0.0:
+ cur_len = np.sum(embed_mask_List[0])
+ Flow_x = embed(x[init_i:(init_i+2*cur_len-1):2], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Flow_y = embed(x[(init_i+1):(init_i+2*cur_len):2], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Flow = np.transpose([Flow_x.ravel(),Flow_y.ravel()]).reshape((N_pixel, N_pixel,2))
+ flow_tv_grad = squared_gradient_flow_grad(Flow)
+ s_dynamic_grad_Frames = s_dynamic_grad_Flow = 0.0
+ s_dynamic_grad_Frames = Rflow_gradient_I(Frames, Flow, R_flow)
+ s_dynamic_grad_Flow = Rflow_gradient_m(Frames, Flow, R_flow)
+
+ s_dynamic_grad += s_dynamic_grad_Frames*R_flow['alpha']
+ flow_grad = s_dynamic_grad_Flow*R_flow['alpha'] + flow_tv_grad*R_flow['alpha_flow_tv']
+ # now handle the embedding
+ flow_grad = np.transpose([flow_grad[::2][embed_mask_List[0]], flow_grad[1::2][embed_mask_List[0]]]).ravel()
+
+ # Gradient of the chi^2 regularization term for the epsilon screen
+ chisq_epsilon_grad = np.array([])
+ if stochastic_optics == True:
+ chisq_epsilon_grad = alpha_phi * 2.0*EpsilonList/((N*N-1)/2.0)
+
+ return (np.concatenate((s1 + s2 + s_dF + s_dS + (s_dynamic_grad + chisq_grad + cm_grad + cm_grad + flux_grad)[embed_mask_All], flow_grad, chisq_grad_epsilon + chisq_epsilon_grad))*J_factor)
+
+ # Plotting function for each iteration
+ global nit
+ nit = 0
+ def plotcur(x, final=False):
+ global nit
+ nit += 1
+
+ if nit%update_interval == 0 or final == True:
+ print ("iteration %d" % nit)
+
+ Frames = np.zeros((N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ if R_flow['alpha'] != 0.0:
+ cur_len = np.sum(embed_mask_List[0]) #assumes all the priors have the same embedding
+ Flow_x = embed(x[init_i:(init_i+2*cur_len-1):2], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Flow_y = embed(x[(init_i+1):(init_i+2*cur_len):2], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Flow = np.transpose([Flow_x.ravel(),Flow_y.ravel()]).reshape((N_pixel, N_pixel,2))
+ init_i += 2*cur_len
+
+ if stochastic_optics == True:
+ EpsilonList = x[init_i:(init_i + N**2-1)]
+ im_List = [image.Image(Frames[j], Prior.psize, Prior.ra, Prior.dec, rf=Obsdata_List[j].rf, source=Prior.source, mjd=Prior.mjd) for j in range(N_frame)]
+
+ scatt_im_List = [scattering_model.Scatter(im_List[j], Epsilon_Screen=so.MakeEpsilonScreenFromList(EpsilonList, N), ea_ker = ea_ker[j], sqrtQ=sqrtQ, Linearized_Approximation=True).imvec
+ for j in range(N_frame)] #the list of scattered image vectors
+
+ s1 = s2 = 0.0
+
+ if alpha_s1 != 0.0:
+
+ s1 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(),
+ Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s1
+ if alpha_s2 != 0.0:
+ s2 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(),
+ Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s2
+
+ s_dynamic = cm = s_dS = s_dF = 0.0
+
+ if R_dI['alpha'] != 0.0: s_dynamic += RdI(Frames, **R_dI)*R_dI['alpha']
+ if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames, B_dt, **R_dt)*R_dt['alpha']
+
+ if alpha_dS1 != 0.0: s_dS += RdS(Frames, nprior_embed_List, embed_mask_List, entropy1, norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_dS1
+ if alpha_dS2 != 0.0: s_dS += RdS(Frames, nprior_embed_List, embed_mask_List, entropy2, norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_dS2
+
+ if alpha_dF != 0.0: s_dF += RdF_clip(Frames, embed_mask_List)*alpha_dF
+
+ if alpha_centroid != 0.0: cm = centroid(Frames, coord) * alpha_centroid
+
+ if stochastic_optics == False:
+ if processes > 0:
+
+ chisq = np.array(pool.map(get_chisq_wrap, [[j, Frames[j].ravel()[embed_mask_List[j]],
+ d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)]))
+ else:
+ chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]],
+ d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)])
+ else:
+ if processes > 0:
+ chisq = np.array(pool.map(get_chisq_wrap, [[j, scatt_im_List[j][embed_mask_List[j]],
+ d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)]))
+ else:
+ chisq = np.array([get_chisq(j, scatt_im_List[j][embed_mask_List[j]],
+ d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)])
+
+ chisq1_List = chisq[:,0]
+ chisq2_List = chisq[:,1]
+ chisq3_List = chisq[:,2]
+ chisq1 = np.sum(chisq1_List)/N_frame
+ chisq2 = np.sum(chisq2_List)/N_frame
+ chisq3 = np.sum(chisq3_List)/N_frame
+ chisq1_max = np.max(chisq1_List)
+ chisq2_max = np.max(chisq2_List)
+ chisq3_max = np.max(chisq3_List)
+ if d1 != False: print ("chi2_1: %f" % chisq1)
+ if d2 != False: print ("chi2_2: %f" % chisq2)
+ if d3 != False: print ("chi2_3: %f" % chisq3)
+ if d1 != False: print ("weighted chi2_1: %f" % (chisq1 * alpha_d1))
+ if d2 != False: print ("weighted chi2_2: %f" % (chisq2 * alpha_d2))
+ if d3 != False: print ("weighted chi2_3: %f" % (chisq3 * alpha_d3))
+ if d1 != False: print ("Max Frame chi2_1: %f" % chisq1_max)
+ if d2 != False: print ("Max Frame chi2_2: %f" % chisq2_max)
+ if d3 != False: print ("Max Frame chi2_3: %f" % chisq3_max)
+
+ if final == True:
+ if d1 != False: print ("All chisq1:",chisq1_List)
+ if d2 != False: print ("All chisq2:",chisq2_List)
+ if d3 != False: print ("All chisq3:",chisq3_List)
+
+ # Now deal with the a flow, if necessary
+ if R_flow['alpha'] != 0.0:
+ flow_tv = squared_gradient_flow(Flow)
+ s_dynamic += flow_tv*R_flow['alpha_flow_tv']
+ print ("Weighted Flow TV: %f" % (flow_tv*R_flow['alpha_flow_tv']))
+ s_dynamic += Rflow(Frames, Flow, **R_flow)*R_flow['alpha']
+ print ("Weighted R_Flow: %f" % (Rflow(Frames, Flow, **R_flow)*R_flow['alpha']))
+
+ if s1 != 0.0: print ("weighted s1: %f" % (s1))
+ if s2 != 0.0: print ("weighted s2: %f" % (s2))
+ if s_dF != 0.0: print ("weighted s_dF: %f" % (s_dF))
+ if s_dS != 0.0: print ("weighted s_dS: %f" % (s_dS))
+ print ("weighted s_dynamic: %f" % (s_dynamic))
+ if alpha_centroid > 0.0: print ("weighted COM: %f" % cm)
+
+ if alpha_flux > 0.0:
+ print ("weighted flux constraint: %f" % (alpha_flux * movie_flux_constraint(Frames, flux_List)))
+
+ if stochastic_optics == True:
+ chisq_epsilon = sum(EpsilonList*EpsilonList)/((N*N-1.0)/2.0)
+ regterm_scattering = alpha_phi * (chisq_epsilon - 1.0)
+ print("Epsilon chi^2 : %0.2f " % (chisq_epsilon))
+ print("Weighted Epsilon chi^2 : %0.2f " % (regterm_scattering))
+ print("Max |Epsilon| : %0.2f " % (max(abs(EpsilonList))))
+
+ if nit%refresh_interval == 0:
+ print ("Plotting Functionality Temporarily Disabled...")
+
+ loginit = np.hstack(loginit_List).flatten()
+ if R_flow['alpha'] == 0.0:
+ x0 = loginit
+ else:
+ Flow_Init_embed = np.transpose([Flow_Init.ravel()[::2][embed_mask_List[0]],Flow_Init.ravel()[1::2][embed_mask_List[0]]]).ravel()
+ x0 = np.concatenate( (loginit, Flow_Init_embed) )
+
+ if stochastic_optics == True:
+ x0 = np.concatenate((x0,np.zeros(N**2-1)))
+
+
+ print ("Total Pixel #: ",(N_pixel*N_pixel*N_frame))
+ print ("Clipped Pixel #: ",(len(loginit)))
+
+ print ("Initial Values:")
+ plotcur(x0)
+
+ # Minimize
+ optdict = {'maxiter':maxit, 'ftol':stop, 'maxcor':NHIST, 'gtol': 1e-10} # minimizer params
+ tstart = time.time()
+ res = opt.minimize(objfunc, x0, method=minimizer_method, jac=objgrad, options=optdict, callback=plotcur)
+ tstop = time.time()
+
+ Frames = np.zeros((N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(res.x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ #Impose the prior mask in linear space for the output
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ Flow = EpsilonList = False
+
+ if R_flow['alpha'] != 0.0:
+ print ("Collecting Flow...")
+ cur_len = np.sum(embed_mask_List[0])
+ Flow_x = embed(res.x[init_i:(init_i+2*cur_len-1):2], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Flow_y = embed(res.x[(init_i+1):(init_i+2*cur_len):2], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Flow = np.transpose([Flow_x.ravel(),Flow_y.ravel()]).reshape((N_pixel, N_pixel,2))
+ init_i += 2*cur_len
+
+ if stochastic_optics == True:
+ EpsilonList = res.x[init_i:(init_i + N**2-1)]
+ init_i += len(EpsilonList)
+
+
+ plotcur(res.x, final=True)
+
+ # Print stats
+ print ("time: %f s" % (tstop - tstart))
+ print ("J: %f" % res.fun)
+ print (res.message)
+
+ #Note: the global variables are *not* released to avoid recalculation
+ if processes != -1:
+ pool.close()
+
+ #Return Frames
+ outim = [image.Image(Frames[i].reshape(Prior.ydim, Prior.xdim), Prior.psize,
+ Prior.ra, Prior.dec, rf=Obsdata_List[i].rf, source=Prior.source,
+ mjd=InitIm_List[i].mjd, time=InitIm_List[i].time, pulse=Prior.pulse) for i in range(N_frame)]
+
+ if type(init_ims) == list:
+ pass
+ else:
+ outim = movie.merge_im_list(outim)
+
+ if R_flow['alpha'] == 0.0 and stochastic_optics == False:
+ return outim
+ else:
+ return {'Movie':outim, 'Frames':outim, 'Flow':Flow, 'EpsilonList':EpsilonList }
+
+
+def multifreq_dynamical_imaging(Obsdata_Multifreq_List, InitIm_Multifreq_List, Prior, flux_Multifreq_List = [],
+d1='vis', d2=False, d3=False,
+alpha_d1=10, alpha_d2=10, alpha_d3=10,
+systematic_noise1=0.0, systematic_noise2=0.0, systematic_noise3=0.0,
+entropy1="tv2", entropy2="l1",
+alpha_s1=1.0, alpha_s2=1.0, norm_reg=True, alpha_A=1.0,
+R_dI ={'alpha':0.0, 'metric':'SymKL', 'p':2.0},
+R_dt ={'alpha':0.0, 'metric':'SymKL', 'sigma_dt':0.0, 'p':2.0},
+R_dt_multifreq ={'alpha':0.0, 'metric':'SymKL', 'sigma_dt':0.0, 'p':2.0},
+alpha_centroid=0.0, alpha_flux=0.0, alpha_dF=0.0, alpha_dS1=0.0, alpha_dS2=0.0, #other regularizers
+Target_Dynamic_Range = 10000.0,
+maxit=200, J_factor = 0.001, stop=1.0e-10, ipynb=False, refresh_interval = 1000, minimizer_method = 'L-BFGS-B', NHIST = 25, update_interval = 1, clipfloor=0., recalculate_chisqdata = True, ttype = 'nfft', fft_pad_factor=2):
+ """Run dynamic imager
+ Uses I = exp(I') change of variables.
+ Obsdata_List is a list of Obsdata objects, InitIm_List is a list of Image objects, and Prior is an Image object.
+ Returns list of Image objects, one per frame (unless a flow or stochastic optics is used)
+ ttype = 'direct' or 'fast' or 'nfft'
+ """
+
+ global A1_List, A2_List, A3_List, data1_List, data2_List, data3_List, sigma1_List, sigma2_List, sigma3_List
+
+ N_freq = len(Obsdata_Multifreq_List)
+ N_frame = len(Obsdata_Multifreq_List[0])
+ N_pixel = Prior.xdim #pixel dimension
+
+ # Flatten the input lists
+ flux_List = [x for y in flux_Multifreq_List for x in y]
+ InitIm_List = [x for y in InitIm_Multifreq_List for x in y]
+ Obsdata_List = [x for y in Obsdata_Multifreq_List for x in y]
+
+ # Determine the appropriate final resolution
+ all_res = [[] for j in range(N_freq)]
+ for j in range(len(Obsdata_Multifreq_List)):
+ for obs in Obsdata_Multifreq_List[j]:
+ if len(obs.data) > 0:
+ all_res[j].append(obs.res())
+
+ # Determine the beam size for each frequency
+ beam_size = [np.min(all_res[j]) for j in range(N_freq)]
+ print("Maximal Resolutions:",beam_size)
+
+ if alpha_flux > 0.0 and len(flux_Multifreq_List[0]) != N_frame:
+ raise Exception("Number of elements in the list of total flux densities does not match the number of frames!")
+
+ # Make the blurring kernel for R_dt
+ # Note: There are odd problems when sigma_dt is too small. I can't figure out why it causes the convolution to crash.
+ # However, having sigma_dt -> 0 is not a problem in theory. So we'll just set the kernel to be zero in that case and then ignore it later in the convolution.
+
+ if R_dt['sigma_dt'] > 0.0:
+ B_dt = np.abs(np.array([[np.exp(-1.0*(float(i)**2+float(j)**2)/(2.*R_dt['sigma_dt']**2))
+ for i in np.linspace((Prior.xdim-1)/2., -(Prior.xdim-1)/2., num=Prior.xdim)]
+ for j in np.linspace((Prior.ydim-1)/2., -(Prior.ydim-1)/2., num=Prior.ydim)]))
+ if np.max(B_dt) == 0.0 or np.sum(B_dt) == 0.0:
+ raise Exception("Error with the blurring kernel!")
+ B_dt = B_dt / np.sum(B_dt) # normalize to be flux preserving
+ else:
+ B_dt = np.zeros((Prior.ydim,Prior.xdim))
+
+ if R_dt_multifreq['sigma_dt'] > 0.0:
+ B_dt_multifreq = np.abs(np.array([[np.exp(-1.0*(float(i)**2+float(j)**2)/(2.*R_dt_multifreq['sigma_dt']**2))
+ for i in np.linspace((Prior.xdim-1)/2., -(Prior.xdim-1)/2., num=Prior.xdim)]
+ for j in np.linspace((Prior.ydim-1)/2., -(Prior.ydim-1)/2., num=Prior.ydim)]))
+ if np.max(B_dt_multifreq) == 0.0 or np.sum(B_dt_multifreq) == 0.0:
+ raise Exception("Error with the blurring kernel!")
+ B_dt_multifreq = B_dt_multifreq / np.sum(B_dt_multifreq) # normalize to be flux preserving
+ else:
+ B_dt_multifreq = np.zeros((Prior.ydim,Prior.xdim))
+
+ embed_mask_List = [Prior.imvec > clipfloor for j in range(N_freq*N_frame)]
+ embed_mask_All = np.array(embed_mask_List).flatten()
+
+ embed_totals = [np.sum(embed_mask) for embed_mask in embed_mask_List]
+
+ logprior_List = [None,] * (N_freq * N_frame)
+ loginit_List = [None,] * (N_freq * N_frame)
+
+ nprior_embed_List = [None,] * (N_freq * N_frame)
+ nprior_List = [None,] * (N_freq * N_frame)
+
+ ninit_embed_List = [InitIm_List[i].imvec for i in range(N_freq * N_frame)]
+ ninit_List = [ninit_embed_List[i][embed_mask_List[i]] for i in range(N_freq * N_frame)]
+
+ if (recalculate_chisqdata == True and ttype == 'direct') or ttype != 'direct':
+ print ("Calculating lists/matrices for chi-squared terms...")
+ A1_List = [None,] * N_freq * N_frame
+ A2_List = [None,] * N_freq * N_frame
+ A3_List = [None,] * N_freq * N_frame
+ data1_List = [[],] * N_freq * N_frame
+ data2_List = [[],] * N_freq * N_frame
+ data3_List = [[],] * N_freq * N_frame
+ sigma1_List = [None,] * N_freq * N_frame
+ sigma2_List = [None,] * N_freq * N_frame
+ sigma3_List = [None,] * N_freq * N_frame
+
+ # Get data and Fourier matrices for the data terms
+ for i in range(N_frame*N_freq):
+ pixel_max = np.max(InitIm_List[i].imvec)
+ prior_flux_rescale = 1.0
+ if len(flux_List) > 0:
+ prior_flux_rescale = flux_List[i]/Prior.total_flux()
+
+ nprior_embed_List[i] = Prior.imvec * prior_flux_rescale
+
+ nprior_List[i] = nprior_embed_List[i][embed_mask_List[i]]
+ logprior_List[i] = np.log(nprior_List[i])
+ loginit_List[i] = np.log(ninit_List[i] + pixel_max/Target_Dynamic_Range/1.e6) #add the dynamic range floor here
+
+ if len(Obsdata_List[i].data) == 0: #This allows the algorithm to create frames for periods with no data
+ continue
+
+ if (recalculate_chisqdata == True and ttype == 'direct') or ttype != 'direct':
+ (data1_List[i], sigma1_List[i], A1_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d1, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise1)
+ (data2_List[i], sigma2_List[i], A2_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d2, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise2)
+ (data3_List[i], sigma3_List[i], A3_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d3, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise3)
+
+ # Coordinate matrix for COM constraint
+ coord = np.array([[[x,y] for x in np.linspace(Prior.xdim/2,-Prior.xdim/2,Prior.xdim)]
+ for y in np.linspace(Prior.ydim/2,-Prior.ydim/2,Prior.ydim)])
+ coord = coord.reshape(Prior.ydim*Prior.xdim, 2)
+
+ # Define the objective function and gradient
+ def objfunc(x):
+ Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_freq*N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ s1 = s2 = s_multifreq = s_dynamic = cm = flux = s_dS = s_dF = 0.0
+
+ # Multifrequency part
+ if R_dt_multifreq['alpha'] != 0.0:
+ for j in range(N_frame):
+ s_multifreq += Rdt(Frames[j::N_frame], B_dt_multifreq, **R_dt_multifreq)*R_dt_multifreq['alpha']
+
+ # Individual frequencies
+ for j in range(N_freq):
+ i1 = j*N_frame
+ i2 = (j+1)*N_frame
+
+ if alpha_s1 != 0.0:
+ s1 += static_regularizer(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s1
+ if alpha_s2 != 0.0:
+ s2 += static_regularizer(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s2
+
+ if R_dI['alpha'] != 0.0: s_dynamic += RdI(Frames[i1:i2], **R_dI)*R_dI['alpha']
+ if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames[i1:i2], B_dt, **R_dt)*R_dt['alpha']
+
+ if alpha_dS1 != 0.0: s_dS += RdS(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy1, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS1
+ if alpha_dS2 != 0.0: s_dS += RdS(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy2, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS2
+
+ if alpha_dF != 0.0: s_dF += RdF_clip(Frames[i1:i2], embed_mask_List[i1:i2])*alpha_dF
+
+ if alpha_centroid != 0.0: cm = centroid(Frames, coord) * alpha_centroid
+
+ if alpha_flux > 0.0:
+ flux = alpha_flux * movie_flux_constraint(Frames, flux_List)
+
+ chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame*N_freq)])
+
+ chisq = ((np.sum(chisq[:,0])/(N_freq*N_frame) - 1.0)*alpha_d1 +
+ (np.sum(chisq[:,1])/(N_freq*N_frame) - 1.0)*alpha_d2 +
+ (np.sum(chisq[:,2])/(N_freq*N_frame) - 1.0)*alpha_d3)
+
+ return (s1 + s2 + s_dF + s_dS + s_multifreq + s_dynamic + chisq + cm + flux)*J_factor
+
+ def objgrad(x):
+ Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_freq*N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ s1 = s2 = s_dS = s_dF = np.zeros((N_freq*N_frame*cur_len))
+ s_dynamic_grad = cm_grad = flux_grad = np.zeros((N_freq*N_frame*N_pixel*N_pixel))
+ s_multifreq = 0.0
+
+ # Multifrequency part
+ if R_dt_multifreq['alpha'] != 0.0:
+ s_multifreq = np.zeros((N_freq*N_frame, N_pixel*N_pixel))
+ for j in range(N_frame):
+ s_multifreq[j::N_frame] += Rdt_gradient(Frames[j::N_frame], B_dt_multifreq, **R_dt_multifreq).reshape((N_freq,N_pixel*N_pixel))*R_dt_multifreq['alpha']
+ s_multifreq = s_multifreq.reshape(N_freq*N_frame*N_pixel*N_pixel)
+
+ # Individual frequencies
+ for j in range(N_freq):
+ i1 = j*N_frame
+ i2 = (j+1)*N_frame
+ f1 = j*N_frame*N_pixel*N_pixel
+ f2 = (j+1)*N_frame*N_pixel*N_pixel
+ mf1 = j*N_frame*cur_len # Note: This assumes that all priors have the same number of masked pixels!
+ mf2 = (j+1)*N_frame*cur_len
+
+
+ if alpha_s1 != 0.0:
+ s1[mf1:mf2] = static_regularizer_gradient(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s1
+ if alpha_s2 != 0.0:
+ s2[mf1:mf2] = static_regularizer_gradient(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s2
+
+ if R_dI['alpha'] != 0.0: s_dynamic_grad[f1:f2] += RdI_gradient(Frames[i1:i2],**R_dI)*R_dI['alpha']
+ if R_dt['alpha'] != 0.0: s_dynamic_grad[f1:f2] += Rdt_gradient(Frames[i1:i2], B_dt, **R_dt)*R_dt['alpha']
+
+ if alpha_dS1 != 0.0: s_dS[mf1:mf2] += RdS_gradient(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy1, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS1
+ if alpha_dS2 != 0.0: s_dS[mf1:mf2] += RdS_gradient(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy2, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS2
+
+ if alpha_dF != 0.0: s_dF[mf1:mf2] += RdF_gradient_clip(Frames[i1:i2], embed_mask_List[i1:i2])*alpha_dF
+
+ if alpha_centroid != 0.0: cm_grad = centroid_gradient(Frames, coord) * alpha_centroid
+
+ if alpha_flux > 0.0:
+ flux_grad = alpha_flux * movie_flux_constraint_grad(Frames, flux_List)
+
+ chisq_grad = np.array([get_chisqgrad(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_freq*N_frame)])
+
+ # Now add the Jacobian factor and concatenate
+ for j in range(N_freq*N_frame):
+ chisq_grad[j,0] = chisq_grad[j,0]*Frames[j].ravel()[embed_mask_List[j]]
+ chisq_grad[j,1] = chisq_grad[j,1]*Frames[j].ravel()[embed_mask_List[j]]
+ chisq_grad[j,2] = chisq_grad[j,2]*Frames[j].ravel()[embed_mask_List[j]]
+
+ chisq_grad = (np.concatenate([embed(chisq_grad[i,0], embed_mask_List[i]) for i in range(N_freq*N_frame)])/(N_freq*N_frame)*alpha_d1
+ + np.concatenate([embed(chisq_grad[i,1], embed_mask_List[i]) for i in range(N_freq*N_frame)])/(N_freq*N_frame)*alpha_d2
+ + np.concatenate([embed(chisq_grad[i,2], embed_mask_List[i]) for i in range(N_freq*N_frame)])/(N_freq*N_frame)*alpha_d3)
+
+# print((s1.shape, s2.shape, s_dF.shape, s_dS.shape))
+# print((s_multifreq.shape, s_dynamic_grad.shape, chisq_grad.shape, cm_grad.shape, flux_grad.shape))
+# print((s_multifreq[embed_mask_All].shape))
+
+ return ((s1 + s2 + s_dF + s_dS + (s_multifreq + s_dynamic_grad + chisq_grad + cm_grad + flux_grad)[embed_mask_All])*J_factor)
+
+ # Plotting function for each iteration
+ global nit
+ nit = 0
+ def plotcur(x, final=False):
+ global nit
+ nit += 1
+
+ if nit%update_interval == 0 or final == True:
+ print ("iteration %d" % nit)
+
+ Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_freq*N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ s1 = s2 = s_multifreq = s_dynamic = cm = s_dS = s_dF = 0.0
+
+ # Multifrequency part
+ if R_dt_multifreq['alpha'] != 0.0:
+ for j in range(N_frame):
+ s_multifreq += Rdt(Frames[j::N_frame], B_dt_multifreq, **R_dt_multifreq)*R_dt_multifreq['alpha']
+
+ # Individual frequencies
+ for j in range(N_freq):
+ i1 = j*N_frame
+ i2 = (j+1)*N_frame
+
+ if alpha_s1 != 0.0:
+ s1 += static_regularizer(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s1
+ if alpha_s2 != 0.0:
+ s2 += static_regularizer(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s2
+
+ if R_dI['alpha'] != 0.0: s_dynamic += RdI(Frames[i1:i2], **R_dI)*R_dI['alpha']
+ if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames[i1:i2], B_dt, **R_dt)*R_dt['alpha']
+
+ if alpha_dS1 != 0.0: s_dS += RdS(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy1, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS1
+ if alpha_dS2 != 0.0: s_dS += RdS(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy2, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS2
+
+ if alpha_dF != 0.0: s_dF += RdF_clip(Frames[i1:i2], embed_mask_List[i1:i2])*alpha_dF
+
+ if alpha_centroid != 0.0: cm = centroid(Frames, coord) * alpha_centroid
+
+ chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]],
+ d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_freq*N_frame)])
+
+ chisq1_List = chisq[:,0]
+ chisq2_List = chisq[:,1]
+ chisq3_List = chisq[:,2]
+ chisq1 = np.sum(chisq1_List)/len(chisq1_List)
+ chisq2 = np.sum(chisq2_List)/len(chisq1_List)
+ chisq3 = np.sum(chisq3_List)/len(chisq1_List)
+ chisq1_max = np.max(chisq1_List)
+ chisq2_max = np.max(chisq2_List)
+ chisq3_max = np.max(chisq3_List)
+ if d1 != False: print ("chi2_1: %f" % chisq1)
+ if d2 != False: print ("chi2_2: %f" % chisq2)
+ if d3 != False: print ("chi2_3: %f" % chisq3)
+ if d1 != False: print ("weighted chi2_1: %f" % (chisq1 * alpha_d1))
+ if d2 != False: print ("weighted chi2_2: %f" % (chisq2 * alpha_d2))
+ if d3 != False: print ("weighted chi2_3: %f" % (chisq3 * alpha_d3))
+ if d1 != False: print ("Max Frame chi2_1: %f" % chisq1_max)
+ if d2 != False: print ("Max Frame chi2_2: %f" % chisq2_max)
+ if d3 != False: print ("Max Frame chi2_3: %f" % chisq3_max)
+
+ if final == True:
+ if d1 != False: print ("All chisq1:",chisq1_List)
+ if d2 != False: print ("All chisq2:",chisq2_List)
+ if d3 != False: print ("All chisq3:",chisq3_List)
+
+ if s1 != 0.0: print ("weighted s1: %f" % (s1))
+ if s2 != 0.0: print ("weighted s2: %f" % (s2))
+ if s_dF != 0.0: print ("weighted s_dF: %f" % (s_dF))
+ if s_dS != 0.0: print ("weighted s_dS: %f" % (s_dS))
+ print ("weighted s_dynamic: %f" % (s_dynamic))
+ print ("weighted s_multifreq: %f" % (s_multifreq))
+ if alpha_centroid > 0.0: print ("weighted COM: %f" % cm)
+
+ if alpha_flux > 0.0:
+ print ("weighted flux constraint: %f" % (alpha_flux * movie_flux_constraint(Frames, flux_List)))
+
+ if nit%refresh_interval == 0:
+ print ("Plotting Functionality Temporarily Disabled...")
+
+ loginit = np.hstack(loginit_List).flatten()
+
+ x0 = loginit
+
+ print ("Total Pixel #: ",(N_pixel*N_pixel*N_frame*N_freq))
+ print ("Clipped Pixel #: ",(len(loginit)))
+
+ print ("Initial Values:")
+ plotcur(x0)
+
+ # Minimize
+ optdict = {'maxiter':maxit, 'ftol':stop, 'maxcor':NHIST, 'gtol': 1e-10} # minimizer params
+ tstart = time.time()
+ res = opt.minimize(objfunc, x0, method=minimizer_method, jac=objgrad, options=optdict, callback=plotcur)
+ tstop = time.time()
+
+ Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel))
+ log_Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel))
+
+ init_i = 0
+ for i in range(N_freq*N_frame):
+ cur_len = np.sum(embed_mask_List[i])
+ log_Frames[i] = embed(res.x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel))
+ #Impose the prior mask in linear space for the output
+ Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel)))
+ init_i += cur_len
+
+ plotcur(res.x, final=True)
+
+ # Print stats
+ print ("time: %f s" % (tstop - tstart))
+ print ("J: %f" % res.fun)
+ print (res.message)
+
+ #Return Frames
+ outim = [[image.Image(Frames[i + j*N_frame].reshape(Prior.ydim, Prior.xdim), Prior.psize,
+ Prior.ra, Prior.dec, rf=Obsdata_Multifreq_List[j][i].rf, source=Prior.source,
+ mjd=Prior.mjd, pulse=Prior.pulse) for i in range(N_frame)] for j in range(N_freq)]
+
+ return outim
+
+
+
+
+
+
+##################################################################################################
+# Plotting Functions
+##################################################################################################
+
+def plot_im_List_Set(im_List_List, plot_log_amplitude=False, ipynb=False):
+ plt.ion()
+ plt.clf()
+
+ Prior = im_List_List[0][0]
+
+ xnum = len(im_List_List[0])
+ ynum = len(im_List_List)
+
+ for i in range(xnum*ynum):
+ plt.subplot(ynum, xnum, i+1)
+ im = im_List_List[(i-i%xnum)//xnum][i%xnum]
+ if plot_log_amplitude == False:
+ plt.imshow(im.imvec.reshape(im.ydim,im.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+ else:
+ plt.imshow(np.log(im.imvec.reshape(im.ydim,im.xdim)), cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+ xticks = ticks(im.xdim, im.psize/RADPERAS/1e-6)
+ yticks = ticks(im.ydim, im.psize/RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ if i == 0:
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ else:
+ plt.xlabel('')
+ plt.ylabel('')
+ plt.title('')
+
+ plt.draw()
+
+def plot_im_List(im_List, plot_log_amplitude=False, ipynb=False):
+
+ plt.ion()
+ plt.clf()
+
+ Prior = im_List[0]
+
+ for i in range(len(im_List)):
+ plt.subplot(1, len(im_List), i+1)
+ if plot_log_amplitude == False:
+ plt.imshow(im_List[i].imvec.reshape(Prior.ydim,Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+ else:
+ plt.imshow(np.log(im_List[i].imvec.reshape(Prior.ydim,Prior.xdim)), cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+ xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6)
+ yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ if i == 0:
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ else:
+ plt.xlabel('')
+ plt.ylabel('')
+ plt.title('')
+
+
+ plt.draw()
+
+def plot_i_dynamic(im_List, Prior, nit, chi2, s, s_dynamic, ipynb=False):
+
+ plt.ion()
+ plt.clf()
+
+ for i in range(len(im_List)):
+ plt.subplot(1, len(im_List), i+1)
+ plt.imshow(im_List[i].reshape(Prior.ydim,Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+ xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6)
+ yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ if i == 0:
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ plt.title("step: %i $\chi^2$: %f $s$: %f $s_{t}$: %f" % (nit, chi2, s, s_dynamic), fontsize=20)
+ else:
+ plt.xlabel('')
+ plt.ylabel('')
+ plt.title('')
+
+
+ plt.draw()
+
+##################################################################################################
+#BU blazar CLEAN file loading functions
+##################################################################################################
+
+class MOJAVEHTMLParser(HTMLParser):
+ #standard overriding of the python HTMLParser to suit the format of the BU blazar library
+ dates = []
+ mod_files = []
+ uv_files = []
+
+ #checks for CLEAN files linked on the page by looking for file ending
+ def handle_starttag(self, tag, attrs):
+ for attr in attrs:
+ if len(attr) == 2:
+ attr1s = attr[1].strip().split('.')
+ if (len(attr1s) >1):
+ fileType = attr1s[1]
+ if fileType == 'icn':
+ self.mod_files.append(str(attr[1]).strip())
+ if fileType == 'uvf':
+ self.uv_files.append(str(attr[1]).strip())
+
+ #extracts date information and standardizes format to DD + month string + YYYY = DDMONYYYY
+ def handle_data(self,data):
+ ds = data.split(' ')
+ lds = len(ds)
+ if ds[lds-1].isdigit():
+ if lds == 3 or lds == 4:
+ day = str(ds[lds-3])
+ month = str(ds[lds-2])
+ year = str(ds[lds-1])
+ if len(day) == 1:
+ day = '0' + day
+ month = month[0:3]
+ monthNum = str(list(calendar.month_abbr).index(month))
+ if len(monthNum) == 1:
+ monthNum = '0' + monthNum
+ newDate = year +monthNum + day
+ self.dates.append(str(newDate))
+
+class BlazarHTMLParser(HTMLParser):
+ #standard overriding of the python HTMLParser to suit the format of the BU blazar library
+ dates = []
+ mod_files = []
+ uv_files = []
+
+ #checks for CLEAN files linked on the page by looking for file ending
+ def handle_starttag(self, tag, attrs):
+ for attr in attrs:
+ if len(attr) == 2:
+ attr1s = attr[1].strip().split('.')
+ if (len(attr1s) >1):
+ fileType = attr1s[1]
+ if fileType == 'MOD' or fileType=='mod':
+ self.mod_files.append(str(attr[1]).strip())
+ if fileType == 'UVP' or fileType=='uvp':
+ self.uv_files.append(str(attr[1]).strip())
+
+ #extracts date information and standardizes format to DD + month string + YYYY = DDMONYYYY
+ def handle_data(self,data):
+ ds = data.split(' ')
+ lds = len(ds)
+ if ds[lds-1].isdigit():
+ if lds == 3 or lds == 4:
+ day = str(ds[lds-3])
+ month = str(ds[lds-2])
+ year = str(ds[lds-1])
+ if len(day) == 1:
+ day = '0' + day
+ month = month[0:3]
+ monthNum = str(list(calendar.month_abbr).index(month))
+ if len(monthNum) == 1:
+ monthNum = '0' + monthNum
+ newDate = year +monthNum + day
+ self.dates.append(str(newDate))
+
+def generateMOJAVEdates(url, sourceName, path = './'):
+ #Creates a list of observation dates and .mod filenames for a particular source in the MOJAVE library
+ #Returns the filename of the output file
+ #url is the URL of the MOJAVE library page of files for a particular source
+ r = requests.get(url)
+ parser = MOJAVEHTMLParser()
+ parser.feed(r.text)
+ outputFileName = sourceName+'_dates_and_CLEAN_filenames.txt'
+ outputFile = open(outputFileName, 'w')
+ outputFile.write("#Observation dates, UV files, and CLEAN models obtained from " + url + '\n')
+ for i in range(len(parser.dates)):
+ outputFile.write(parser.dates[i]+','+parser.uv_files[i]+','+parser.mod_files[i]+'\n')
+ outputFile.close()
+ return outputFileName
+
+def generateCLEANdates(url, sourceName, path = './'):
+ #Creates a list of observation dates and .mod filenames for a particular source in the BU blazar library
+ #Returns the filename of the output file
+ #url is the URL of the BU blazar library page of files for a particular source
+ r = requests.get(url)
+ parser = BlazarHTMLParser()
+ parser.feed(r.text)
+ outputFileName = sourceName+'_dates_and_CLEAN_filenames.txt'
+ outputFile = open(outputFileName, 'w')
+ outputFile.write("#Observation dates, UV files, and CLEAN models obtained from " + url + '\n')
+ for i in range(len(parser.dates)):
+ outputFile.write(parser.dates[i]+','+parser.uv_files[i]+','+parser.mod_files[i]+'\n')
+ outputFile.close()
+ return outputFileName
+
+def sourceNameFromURL(url):
+ #returns a string containing the BU url designation for a particular source, i.e. "3c454" for 3c454.3
+ urls = url.split('/')
+ lurls = len(urls)
+ sourcehtmls = urls[lurls-1].split('.')
+ sn = sourcehtmls[0]
+ return sn
+
+def downloadMOJAVEfiles(url, path = './'):
+ #Downloads data and image from MOJAVE library
+ #url is the URL of the MOJAVE page of files for a particular source
+
+ sn = sourceNameFromURL(url)
+ #Find the base directory in the BU library from the url. This is a really hacky way to do this.
+ #source name in BU library
+ lsn = len(sn)
+ baseurl = url[:-(5+lsn)]
+
+ #Make a new directory for the downloaded files
+ CLEANpath = path + '/' + sn + '_CLEAN_images'
+ if not os.path.exists(CLEANpath):
+ os.mkdir(CLEANpath)
+
+ print ("Downloading CLEAN images to " + CLEANpath)
+
+ UVpath = path + '/' + sn+"_uvf_files"
+ if not os.path.exists(UVpath):
+ os.mkdir(UVpath)
+
+ print ("Downloading uvfits files to " + UVpath)
+
+ #Generate the bookkeeping file and iterate on it, downloading files and saving them in a better format
+ #Files are saved in a new directory named after the source i.e. "3c454_CLEAN_files"
+ guideFileName = generateCLEANdates(url, sn, path=path)
+ observations = np.loadtxt(guideFileName, dtype = str, delimiter = ',', skiprows = 1)
+ for obs in observations:
+ date = obs[0]
+ UVbuFileName = obs[1]
+ CLEANbuFileName = obs[2]
+ UVurl = baseurl+UVbuFileName
+ CLEANurl = baseurl + CLEANbuFileName
+ # If it doesn't already exist, download the UV file
+ if os.path.isfile(UVpath+'/'+date+'_'+ sn+".uvf") == False:
+ print ("Downloading " + (UVpath+'/'+date+'_'+ sn+".uvf"))
+ response = requests.get(UVurl, stream=True)
+ with open(UVpath+'/'+date+'_'+ sn+".uvf",'wb') as handle:
+ handle.write(response.raw.read())
+ else:
+ print ("Already Downloaded " + (UVpath+'/'+date+'_'+ sn+".uvf"))
+
+ # If it doesn't already exist, download the CLEAN file
+ if os.path.isfile(CLEANpath+'/'+date+'_'+sn+".icn.fits.gz") == False:
+ print ("Downloading " + (CLEANpath+'/'+date+'_'+sn+".icn.fits.gz"))
+ response = requests.get(CLEANurl, stream=True)
+ with open(CLEANpath+'/'+date+'_'+sn+".icn.fits.gz",'wb') as handle:
+ for chunk in response.iter_content(chunk_size = 128):
+ handle.write(chunk)
+ else:
+ print ("Already Downloaded " + (CLEANpath+'/'+date+'_'+sn+".icn.fits.gz"))
+
+def downloadCLEANfiles(url, path = './'):
+ #Downloads all CLEAN files from a single source in the BU blazar library
+ #url is the URL of the BU blazar library page of files for a particular source
+
+ sn = sourceNameFromURL(url)
+ #Find the base directory in the BU library from the url. This is a really hacky way to do this.
+ #source name in BU library
+ lsn = len(sn)
+ baseurl = url[:-(5+lsn)]
+
+ #Make a new directory for the downloaded files
+ CLEANpath = path + '/' + sn + '_CLEAN_files'
+ if not os.path.exists(CLEANpath):
+ os.mkdir(CLEANpath)
+
+ print ("Downloading CLEAN files to " + CLEANpath)
+
+ UVpath = path + '/' + sn+"_UVP.gz_files"
+ if not os.path.exists(UVpath):
+ os.mkdir(UVpath)
+
+ print ("Downloading uvfits files to " + UVpath)
+
+ #Generate the bookkeeping file and iterate on it, downloading files and saving them in a better format
+ #Files are saved in a new directory named after the source i.e. "3c454_CLEAN_files"
+ guideFileName = generateCLEANdates(url, sn, path=path)
+ observations = np.loadtxt(guideFileName, dtype = str, delimiter = ',', skiprows = 1)
+ for obs in observations:
+ date = obs[0]
+ UVbuFileName = obs[1]
+ CLEANbuFileName = obs[2]
+ UVurl = baseurl+UVbuFileName
+ CLEANurl = baseurl + CLEANbuFileName
+ # If it doesn't already exist, download the UV file
+ if os.path.isfile(UVpath+'/'+date+'_'+ sn+"_UV.UVP.gz") == False:
+ print ("Downloading " + (UVpath+'/'+date+'_'+ sn+"_UV.UVP.gz"))
+ response = requests.get(UVurl, stream=True)
+ with open(UVpath+'/'+date+'_'+ sn+"_UV.UVP.gz",'wb') as handle:
+ handle.write(response.raw.read())
+ else:
+ print ("Already Downloaded " + (UVpath+'/'+date+'_'+ sn+"_UV.UVP.gz"))
+
+ # If it doesn't already exist, download the CLEAN file
+ if os.path.isfile(CLEANpath+'/'+date+'_'+sn+"_CLEAN.mod") == False:
+ print ("Downloading " + (CLEANpath+'/'+date+'_'+sn+"_CLEAN.mod"))
+ response = requests.get(CLEANurl, stream=True)
+ with open(CLEANpath+'/'+date+'_'+sn+"_CLEAN.mod",'wb') as handle:
+ for chunk in response.iter_content(chunk_size = 128):
+ handle.write(chunk)
+ else:
+ print ("Already Downloaded " + (CLEANpath+'/'+date+'_'+sn+"_CLEAN.mod"))
+
+def minDeltaMJD(inputMJD, im_List):
+ #returns the image whose MJD most closely matches the inputMJD
+ index = 0
+ minDelta = 10000
+ for i in range(len(im_List)):
+ oldDelta = minDelta
+ minDelta = min(minDelta, abs(inputMJD - im_List[i].mjd))
+ if not(minDelta == oldDelta):
+ index = i
+ #return im_List[index].copy()
+ return index
+
+def ReadCLEAN(nameF, reference_obs, npix, fov=0, beamPar=(0,0,0.)):
+#This should be able to load CLEAN Model data
+#such as given here https://www.bu.edu/blazars/VLBA_GLAST/3c454.html
+#nameF - name of the CLEAN Model file to load (3columns: Flux in Jy, r in mas, theta in deg)
+#npix - number of pixels in one dimension
+#fov - field of view (radians)
+#beamPar - parameters of Gaussian beam, same as beamparams in image.blur_gauss
+
+ #read data
+ #first remove multiple models in a single file
+ linesF = open(nameF).readlines()
+ DelTrig = 0
+ TrigString = '!'
+ linesMax = 0
+ for cou in range(len(linesF)):
+ if (linesF[cou].find(TrigString) != -1)*(cou>8):
+ open(nameF, 'w').writelines(linesF[:cou])
+ break
+
+ #skip headline
+ TableMOD = np.genfromtxt(nameF, skip_header=4)
+ ScaleR = 1.
+ FluxConst = 1.
+ Flux = FluxConst*TableMOD[:,0]
+ xPS = ScaleR*TableMOD[:,1]*np.cos(np.pi/2.-(np.pi/180.)*TableMOD[:,2])*(1.e3)*RADPERUAS #to radians
+ yPS = ScaleR*TableMOD[:,1]*np.sin(np.pi/2.-(np.pi/180.)*TableMOD[:,2])*(1.e3)*RADPERUAS #to radians
+ NumbPoints = np.shape(yPS)[0]
+
+ #set image parameters
+ if fov==0:
+ MaxR = np.amax(TableMOD[:,1]) #in mas
+ fov = 1.*MaxR*(1.e3)*RADPERUAS
+
+ image0 = np.zeros((int(npix),int(npix)))
+ im = image.Image(image0, fov/npix, 0., 0., rf=86e9)
+
+ beamMaj = beamPar[0]
+ if beamMaj==0:
+ beamMaj = 4.*fov/npix
+
+ beamMin = beamPar[1]
+ if beamMin==0:
+ beamMin = 4.*fov/npix
+
+ beamTh = beamPar[2]
+
+ sigma_maj = beamMaj / (2. * np.sqrt(2. * np.log(2.)))
+ sigma_min = beamMin / (2. * np.sqrt(2. * np.log(2.)))
+ cth = np.cos(beamTh)
+ sth = np.sin(beamTh)
+ xfov = im.xdim * im.psize
+ yfov = im.ydim * im.psize
+ xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0
+ ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0
+
+ gauss = image0
+ for couP in range(NumbPoints):
+ x = xPS[couP]
+ y = yPS[couP]
+ xM, yM = np.meshgrid(xlist, ylist)
+ gaussNew = np.exp(-((yM-y)*cth + (xM-x)*sth)**2/(2*sigma_maj**2) - ((xM-x)*cth - (yM-y)*sth)**2/(2.*sigma_min**2))
+ gauss = gauss + gaussNew*Flux[couP]
+
+ gauss /= (2.0*np.pi*sigma_maj*sigma_min)/(fov/npix)**2 #Normalize the Gaussian
+ gauss = (gauss > 0.)*gauss + 1.e-10
+
+ imageCLEAN = image.Image(gauss, fov/npix, reference_obs.ra, reference_obs.dec, rf=reference_obs.rf)
+ imageCLEAN.mjd = reference_obs.mjd
+ return imageCLEAN
+
+def Cont(imG):
+#This is meant to create plots similar to the ones from
+#https://www.bu.edu/blazars/VLBA_GLAST/3c454.html
+#for the visual comparison
+
+ import matplotlib.pyplot as plt
+ plt.figure()
+ Z = np.reshape(imG.imvec,(imG.xdim,imG.ydim))
+ pov = imG.xdim*imG.psize
+ pov_mas = pov/(RADPERUAS*1.e3)
+ Zmax = np.amax(Z)
+ print(Zmax)
+
+ levels = np.array((-0.00125*Zmax,0.00125*Zmax,0.0025*Zmax, 0.005*Zmax, 0.01*Zmax,
+ 0.02*Zmax, 0.04*Zmax, 0.08*Zmax, 0.16*Zmax, 0.32*Zmax, 0.64*Zmax))
+ CS = plt.contour(Z, levels,
+ origin='lower',
+ linewidths=2,
+ extent=(-pov_mas/2., pov_mas/2., -pov_mas/2., pov_mas/2.))
+ plt.show()
+
+
+
+def ReadSeriesImages(pathCLEAN, Obs, npix,fov,beamPar, obsNumbList):
+
+ listCLEAN = os.listdir(pathCLEAN)
+ listCLEAN = sorted(listCLEAN)
+ listCLEAN = list( listCLEAN[i] for i in obsNumbList )
+
+ im_List = [None]*len(listCLEAN)
+ for cou in range(len(listCLEAN)):
+ nameF = pathCLEAN+listCLEAN[cou]
+ print(nameF)
+ im_List[cou] = ReadCLEAN(nameF,Obs[cou], npix, fov, beamPar)
+
+ return im_List
+
+def SaveSeriesImages(pathCLEAN, im_List, sourceName, outputDirectory='default'):
+ #saves a list of images returned by ReadSeriesImages according to the naming convention in pathCLEAN and the source name
+ if outputDirectory == 'default':
+ outputDirectory = sourceName+'_READ_CLEAN_files'
+ if not os.path.exists(outputDirectory):
+ os.mkdir(outputDirectory)
+ outputSuffix = '_'+sourceName+'_READ_CLEAN.txt'
+
+ listCLEAN = os.listdir(pathCLEAN)
+ listCLEAN = sorted(listCLEAN)
+ datesCLEAN = [filename.split('_')[0] for filename in listCLEAN]
+ for i in range(len(datesCLEAN)):
+ outputName = outputDirectory+'/'+datesCLEAN[i] + outputSuffix
+ im_List[i].save_txt(outputName)
+
+def LoadSeriesImages(sourceName):
+ #loads a list of images saved according to the BU Blazar library convention
+ dirName = sourceName+'_READ_CLEAN_files'
+ fileList = os.listdir(dirName)
+ filenameList = [dirName+'/'+filename for filename in fileList]
+ im_List = [image.load_txt(filename) for filename in filenameList]
+ return im_List
diff --git a/imaging/imager_utils.py b/imaging/imager_utils.py
new file mode 100644
index 00000000..5f565c07
--- /dev/null
+++ b/imaging/imager_utils.py
@@ -0,0 +1,4397 @@
+# imager_utils.py
+# General imager functions for total intensity VLBI data
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import time
+import numpy as np
+import scipy.optimize as opt
+import matplotlib.pyplot as plt
+
+import ehtim.image as image
+import ehtim.observing.obs_helpers as obsh
+import ehtim.const_def as ehc
+
+
+##################################################################################################
+# Constants & Definitions
+##################################################################################################
+
+NORM_REGULARIZER = False # ANDREW TODO change this default in the future
+
+MAXLS = 100 # maximum number of line searches in L-BFGS-B
+NHIST = 100 # number of steps to store for hessian approx
+MAXIT = 100 # maximum number of iterations
+STOP = 1.e-8 # convergence criterion
+
+DATATERMS = ['vis', 'bs', 'amp', 'cphase', 'cphase_diag',
+ 'camp', 'logcamp', 'logcamp_diag', 'logamp']
+REGULARIZERS = ['gs', 'tv', 'tvlog','tv2', 'tv2log','l1w', 'lA', 'patch', 'simple', 'compact', 'compact2', 'rgauss']
+
+nit = 0 # global variable to track the iteration number in the plotting callback
+
+##################################################################################################
+# Total Intensity Imager
+##################################################################################################
+
+
+def imager_func(Obsdata, InitIm, Prior, flux,
+ d1='vis', d2=False, d3=False,
+ alpha_d1=100, alpha_d2=100, alpha_d3=100,
+ s1='simple', s2=False, s3=False,
+ alpha_s1=1, alpha_s2=1, alpha_s3=1,
+ alpha_flux=500, alpha_cm=500,
+ **kwargs):
+ """Run a general interferometric imager. Only works directly on the image's primary polarization.
+
+ Args:
+ Obsdata (Obsdata): The Obsdata object with VLBI data
+ InitIm (Image): The Image object with the initial image for the minimization
+ Prior (Image): The Image object with the prior image
+ flux (float): The total flux of the output image in Jy
+
+ d1 (str): The first data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag'
+ d2 (str): The second data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag'
+ d3 (str): The third data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag'
+
+ s1 (str): The first regularizer; options are 'simple', 'gs', 'tv', 'tv2', 'l1', 'patch','compact','compact2','rgauss'
+ s2 (str): The second regularizer; options are 'simple', 'gs', 'tv', 'tv2','l1', 'patch','compact','compact2','rgauss'
+ s3 (str): The third regularizer; options are 'simple', 'gs', 'tv', 'tv2','l1', 'patch','compact','compact2','rgauss'
+
+ alpha_d1 (float): The first data term weighting
+ alpha_d2 (float): The second data term weighting
+ alpha_d2 (float): The third data term weighting
+
+ alpha_s1 (float): The first regularizer term weighting
+ alpha_s2 (float): The second regularizer term weighting
+ alpha_s3 (float): The third regularizer term weighting
+
+ alpha_flux (float): The weighting for the total flux constraint
+ alpha_cm (float): The weighting for the center of mass constraint
+
+ maxit (int): Maximum number of minimizer iterations
+ stop (float): The convergence criterion
+ clipfloor (float): The Jy/pixel level above which prior image pixels are varied
+
+ grads (bool): If True, analytic gradients are used
+ logim (bool): If True, uses I = exp(I') change of variables
+ norm_reg (bool): If True, normalizes regularizer terms
+ norm_init (bool): If True, normalizes initial image to given total flux
+ show_updates (bool): If True, displays the progress of the minimizer
+
+ weighting (str): 'natural' or 'uniform'
+ debias (bool): if True then apply debiasing to amplitudes/closure amplitudes
+ systematic_noise (float): a fractional systematic noise tolerance to add to thermal sigmas
+ snrcut (float): a snr cutoff for including data in the chi^2 sum
+ beam_size (float): beam size in radians for normalizing the regularizers
+
+ maxset (bool): if True, use maximal set instead of minimal for closure quantities
+ systematic_cphase_noise (float): a value in degrees to add to the closure phase sigmas
+ cp_uv_min (float): flag baselines shorter than this before forming closure quantities
+
+ ttype (str): The Fourier transform type; options are 'fast', 'direct', 'nfft'
+ fft_pad_factor (float): The FFT will pre-pad the image by this factor x the original size
+ order (int): Interpolation order for sampling the FFT
+ conv_func (str): The convolving function for gridding; options are 'gaussian', 'pill', and 'cubic'
+ p_rad (int): The pixel radius for the convolving function in gridding for FFTs
+
+ Returns:
+ Image: Image object with result
+ """
+
+ # some kwarg default values
+ maxit = kwargs.get('maxit', MAXIT)
+ stop = kwargs.get('stop', STOP)
+ clipfloor = kwargs.get('clipfloor', 0)
+ ttype = kwargs.get('ttype', 'direct')
+
+ grads = kwargs.get('grads', True)
+ logim = kwargs.get('logim', True)
+ norm_init = kwargs.get('norm_init', False)
+ show_updates = kwargs.get('show_updates', True)
+
+ beam_size = kwargs.get('beam_size', Obsdata.res())
+ kwargs['beam_size'] = beam_size
+
+ # Make sure data and regularizer options are ok
+ if not d1 and not d2:
+ raise Exception("Must have at least one data term!")
+ if not s1 and not s2:
+ raise Exception("Must have at least one regularizer term!")
+ if (not ((d1 in DATATERMS) or d1 is False)) or (not ((d2 in DATATERMS) or d2 is False)):
+ raise Exception("Invalid data term: valid data terms are: " + ' '.join(DATATERMS))
+ if (not ((s1 in REGULARIZERS) or s1 is False)) or (not ((s2 in REGULARIZERS) or s2 is False)):
+ raise Exception("Invalid regularizer: valid regularizers are: " + ' '.join(REGULARIZERS))
+ if (Prior.psize != InitIm.psize) or (Prior.xdim != InitIm.xdim) or (Prior.ydim != InitIm.ydim):
+ raise Exception("Initial image does not match dimensions of the prior image!")
+ if (InitIm.polrep != Prior.polrep):
+ raise Exception(
+ "Initial image pol. representation does not match pol. representation of the prior image!")
+ if (logim and Prior.pol_prim in ['Q', 'U', 'V']):
+ raise Exception(
+ "Cannot image Stokes Q,U,or V with log image transformation! Set logim=False in imager_func")
+
+ pol = Prior.pol_prim
+ print("Generating %s image..." % pol)
+
+ # Catch scale and dimension problems
+ imsize = np.max([Prior.xdim, Prior.ydim]) * Prior.psize
+ uvmax = 1.0/Prior.psize
+ uvmin = 1.0/imsize
+ uvdists = Obsdata.unpack('uvdist')['uvdist']
+ maxbl = np.max(uvdists)
+ minbl = np.max(uvdists[uvdists > 0])
+ maxamp = np.max(np.abs(Obsdata.unpack('amp')['amp']))
+
+ if uvmax < maxbl:
+ print("Warning! Pixel Spacing is larger than smallest spatial wavelength!")
+ if uvmin > minbl:
+ print("Warning! Field of View is smaller than largest nonzero spatial wavelength!")
+ if flux > 1.2*maxamp:
+ print("Warning! Specified flux is > 120% of maximum visibility amplitude!")
+ if flux < .8*maxamp:
+ print("Warning! Specified flux is < 80% of maximum visibility amplitude!")
+
+ # Define embedding mask
+ embed_mask = Prior.imvec > clipfloor
+
+ # Normalize prior image to total flux and limit imager range to prior values > clipfloor
+ if (not norm_init):
+ nprior = Prior.imvec[embed_mask]
+ ninit = InitIm.imvec[embed_mask]
+ else:
+ nprior = (flux * Prior.imvec / np.sum((Prior.imvec)[embed_mask]))[embed_mask]
+ ninit = (flux * InitIm.imvec / np.sum((InitIm.imvec)[embed_mask]))[embed_mask]
+
+ if len(nprior) == 0:
+ raise Exception("clipfloor too large: all prior pixels have been clipped!")
+
+ # Get data and fourier matrices for the data terms
+ (data1, sigma1, A1) = chisqdata(Obsdata, Prior, embed_mask, d1, pol=pol, **kwargs)
+ (data2, sigma2, A2) = chisqdata(Obsdata, Prior, embed_mask, d2, pol=pol, **kwargs)
+ (data3, sigma3, A3) = chisqdata(Obsdata, Prior, embed_mask, d3, pol=pol, **kwargs)
+
+ # Define the chi^2 and chi^2 gradient
+ def chisq1(imvec):
+ return chisq(imvec, A1, data1, sigma1, d1, ttype=ttype, mask=embed_mask)
+
+ def chisq1grad(imvec):
+ c = chisqgrad(imvec, A1, data1, sigma1, d1, ttype=ttype, mask=embed_mask)
+ return c
+
+ def chisq2(imvec):
+ return chisq(imvec, A2, data2, sigma2, d2, ttype=ttype, mask=embed_mask)
+
+ def chisq2grad(imvec):
+ c = chisqgrad(imvec, A2, data2, sigma2, d2, ttype=ttype, mask=embed_mask)
+ return c
+
+ def chisq3(imvec):
+ return chisq(imvec, A3, data3, sigma3, d3, ttype=ttype, mask=embed_mask)
+
+ def chisq3grad(imvec):
+ c = chisqgrad(imvec, A3, data3, sigma3, d3, ttype=ttype, mask=embed_mask)
+ return c
+
+ # Define the regularizer and regularizer gradient
+ def reg1(imvec):
+ return regularizer(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s1, **kwargs)
+
+ def reg1grad(imvec):
+ return regularizergrad(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s1, **kwargs)
+
+ def reg2(imvec):
+ return regularizer(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s2, **kwargs)
+
+ def reg2grad(imvec):
+ return regularizergrad(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s2, **kwargs)
+
+ def reg3(imvec):
+ return regularizer(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s3, **kwargs)
+
+ def reg3grad(imvec):
+ return regularizergrad(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s3, **kwargs)
+
+ # Define constraint functions
+ def flux_constraint(imvec):
+ return regularizer(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, "flux", **kwargs)
+
+ def flux_constraint_grad(imvec):
+ return regularizergrad(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, "flux", **kwargs)
+
+ def cm_constraint(imvec):
+ return regularizer(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, "cm", **kwargs)
+
+ def cm_constraint_grad(imvec):
+ return regularizergrad(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, "cm", **kwargs)
+
+ # Define the objective function and gradient
+ def objfunc(imvec):
+ if logim:
+ imvec = np.exp(imvec)
+
+ datterm = alpha_d1 * (chisq1(imvec) - 1) + alpha_d2 * \
+ (chisq2(imvec) - 1) + alpha_d3 * (chisq3(imvec) - 1)
+ regterm = alpha_s1 * reg1(imvec) + alpha_s2 * reg2(imvec) + alpha_s3 * reg3(imvec)
+ conterm = alpha_flux * flux_constraint(imvec) + alpha_cm * cm_constraint(imvec)
+
+ return datterm + regterm + conterm
+
+ def objgrad(imvec):
+ if logim:
+ imvec = np.exp(imvec)
+
+ datterm = alpha_d1 * chisq1grad(imvec) + alpha_d2 * \
+ chisq2grad(imvec) + alpha_d3 * chisq3grad(imvec)
+ regterm = alpha_s1 * reg1grad(imvec) + alpha_s2 * \
+ reg2grad(imvec) + alpha_s3 * reg3grad(imvec)
+ conterm = alpha_flux * flux_constraint_grad(imvec) + alpha_cm * cm_constraint_grad(imvec)
+
+ grad = datterm + regterm + conterm
+
+ # chain rule term for change of variables
+ if logim:
+ grad *= imvec
+
+ return grad
+
+ # Define plotting function for each iteration
+ global nit
+ nit = 0
+
+ def plotcur(im_step):
+ global nit
+ if logim:
+ im_step = np.exp(im_step)
+ if show_updates:
+ chi2_1 = chisq1(im_step)
+ chi2_2 = chisq2(im_step)
+ chi2_3 = chisq3(im_step)
+ s_1 = reg1(im_step)
+ s_2 = reg2(im_step)
+ s_3 = reg3(im_step)
+ if np.any(np.invert(embed_mask)):
+ im_step = embed(im_step, embed_mask)
+ plot_i(im_step, Prior, nit, {d1: chi2_1, d2: chi2_2, d3: chi2_3}, pol=pol)
+ print("i: %d chi2_1: %0.2f chi2_2: %0.2f chi2_3: %0.2f s_1: %0.2f s_2: %0.2f s_3: %0.2f" % (
+ nit, chi2_1, chi2_2, chi2_3, s_1, s_2, s_3))
+ nit += 1
+
+ # Generate and the initial image
+ if logim:
+ xinit = np.log(ninit)
+ else:
+ xinit = ninit
+
+ # Print stats
+ print("Initial S_1: %f S_2: %f S_3: %f" % (reg1(ninit), reg2(ninit), reg3(ninit)))
+ print("Initial Chi^2_1: %f Chi^2_2: %f Chi^2_3: %f" %
+ (chisq1(ninit), chisq2(ninit), chisq3(ninit)))
+ print("Initial Objective Function: %f" % (objfunc(xinit)))
+
+ if d1 in DATATERMS:
+ print("Total Data 1: ", (len(data1)))
+ if d2 in DATATERMS:
+ print("Total Data 2: ", (len(data2)))
+ if d3 in DATATERMS:
+ print("Total Data 3: ", (len(data3)))
+
+ print("Total Pixel #: ", (len(Prior.imvec)))
+ print("Clipped Pixel #: ", (len(ninit)))
+ print()
+ plotcur(xinit)
+
+ # Minimize
+ optdict = {'maxiter': maxit, 'ftol': stop, 'maxcor': NHIST,
+ 'gtol': stop, 'maxls': MAXLS} # minimizer dict params
+ tstart = time.time()
+ if grads:
+ res = opt.minimize(objfunc, xinit, method='L-BFGS-B', jac=objgrad,
+ options=optdict, callback=plotcur)
+ else:
+ res = opt.minimize(objfunc, xinit, method='L-BFGS-B',
+ options=optdict, callback=plotcur)
+
+ tstop = time.time()
+
+ # Format output
+ out = res.x
+ if logim:
+ out = np.exp(res.x)
+ if np.any(np.invert(embed_mask)):
+ out = embed(out, embed_mask)
+
+ outim = image.Image(out.reshape(Prior.ydim, Prior.xdim),
+ Prior.psize, Prior.ra, Prior.dec,
+ rf=Prior.rf, source=Prior.source,
+ polrep=Prior.polrep, pol_prim=pol,
+ mjd=Prior.mjd, time=Prior.time, pulse=Prior.pulse)
+
+ # copy over other polarizations
+ outim.copy_pol_images(InitIm)
+
+ # Print stats
+ print("time: %f s" % (tstop - tstart))
+ print("J: %f" % res.fun)
+ print("Final Chi^2_1: %f Chi^2_2: %f Chi^2_3: %f" %
+ (chisq1(out[embed_mask]), chisq2(out[embed_mask]), chisq3(out[embed_mask])))
+ print(res.message)
+
+ # Return Image object
+ return outim
+
+##################################################################################################
+# Wrapper Functions
+##################################################################################################
+
+
+def chisq(imvec, A, data, sigma, dtype, ttype='direct', mask=None):
+ """return the chi^2 for the appropriate dtype
+ """
+
+ if mask is None:
+ mask = []
+ chisq = 1
+ if dtype not in DATATERMS:
+ return chisq
+
+ if ttype not in ['fast', 'direct', 'nfft']:
+ raise Exception("Possible ttype values are 'fast', 'direct'!, 'nfft!'")
+
+ if ttype == 'direct':
+ if dtype == 'vis':
+ chisq = chisq_vis(imvec, A, data, sigma)
+ elif dtype == 'amp':
+ chisq = chisq_amp(imvec, A, data, sigma)
+ elif dtype == 'logamp':
+ chisq = chisq_logamp(imvec, A, data, sigma)
+ elif dtype == 'bs':
+ chisq = chisq_bs(imvec, A, data, sigma)
+ elif dtype == 'cphase':
+ chisq = chisq_cphase(imvec, A, data, sigma)
+ elif dtype == 'cphase_diag':
+ chisq = chisq_cphase_diag(imvec, A, data, sigma)
+ elif dtype == 'camp':
+ chisq = chisq_camp(imvec, A, data, sigma)
+ elif dtype == 'logcamp':
+ chisq = chisq_logcamp(imvec, A, data, sigma)
+ elif dtype == 'logcamp_diag':
+ chisq = chisq_logcamp_diag(imvec, A, data, sigma)
+
+ elif ttype == 'fast':
+ if len(mask) > 0 and np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+
+ if dtype not in ['cphase_diag', 'logcamp_diag']:
+ vis_arr = obsh.fft_imvec(imvec, A[0])
+
+ if dtype == 'vis':
+ chisq = chisq_vis_fft(vis_arr, A, data, sigma)
+ elif dtype == 'amp':
+ chisq = chisq_amp_fft(vis_arr, A, data, sigma)
+ elif dtype == 'logamp':
+ chisq = chisq_logamp_fft(vis_arr, A, data, sigma)
+ elif dtype == 'bs':
+ chisq = chisq_bs_fft(vis_arr, A, data, sigma)
+ elif dtype == 'cphase':
+ chisq = chisq_cphase_fft(vis_arr, A, data, sigma)
+ elif dtype == 'cphase_diag':
+ chisq = chisq_cphase_diag_fft(imvec, A, data, sigma)
+ elif dtype == 'camp':
+ chisq = chisq_camp_fft(vis_arr, A, data, sigma)
+ elif dtype == 'logcamp':
+ chisq = chisq_logcamp_fft(vis_arr, A, data, sigma)
+ elif dtype == 'logcamp_diag':
+ chisq = chisq_logcamp_diag_fft(imvec, A, data, sigma)
+
+ elif ttype == 'nfft':
+ if len(mask) > 0 and np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+
+ if dtype == 'vis':
+ chisq = chisq_vis_nfft(imvec, A, data, sigma)
+ elif dtype == 'amp':
+ chisq = chisq_amp_nfft(imvec, A, data, sigma)
+ elif dtype == 'logamp':
+ chisq = chisq_logamp_nfft(imvec, A, data, sigma)
+ elif dtype == 'bs':
+ chisq = chisq_bs_nfft(imvec, A, data, sigma)
+ elif dtype == 'cphase':
+ chisq = chisq_cphase_nfft(imvec, A, data, sigma)
+ elif dtype == 'cphase_diag':
+ chisq = chisq_cphase_diag_nfft(imvec, A, data, sigma)
+ elif dtype == 'camp':
+ chisq = chisq_camp_nfft(imvec, A, data, sigma)
+ elif dtype == 'logcamp':
+ chisq = chisq_logcamp_nfft(imvec, A, data, sigma)
+ elif dtype == 'logcamp_diag':
+ chisq = chisq_logcamp_diag_nfft(imvec, A, data, sigma)
+
+ return chisq
+
+
+def chisqgrad(imvec, A, data, sigma, dtype, ttype='direct', mask=None):
+ """return the chi^2 gradient for the appropriate dtype
+ """
+
+ if mask is None:
+ mask = []
+ chisqgrad = np.zeros(len(imvec))
+ if dtype not in DATATERMS:
+ return chisqgrad
+
+ if ttype not in ['fast', 'direct', 'nfft']:
+ raise Exception("Possible ttype values are 'fast', 'direct', 'nfft'!")
+
+ if ttype == 'direct':
+ if dtype == 'vis':
+ chisqgrad = chisqgrad_vis(imvec, A, data, sigma)
+ elif dtype == 'amp':
+ chisqgrad = chisqgrad_amp(imvec, A, data, sigma)
+ elif dtype == 'logamp':
+ chisqgrad = chisqgrad_logamp(imvec, A, data, sigma)
+ elif dtype == 'bs':
+ chisqgrad = chisqgrad_bs(imvec, A, data, sigma)
+ elif dtype == 'cphase':
+ chisqgrad = chisqgrad_cphase(imvec, A, data, sigma)
+ elif dtype == 'cphase_diag':
+ chisqgrad = chisqgrad_cphase_diag(imvec, A, data, sigma)
+ elif dtype == 'camp':
+ chisqgrad = chisqgrad_camp(imvec, A, data, sigma)
+ elif dtype == 'logcamp':
+ chisqgrad = chisqgrad_logcamp(imvec, A, data, sigma)
+ elif dtype == 'logcamp_diag':
+ chisqgrad = chisqgrad_logcamp_diag(imvec, A, data, sigma)
+
+ elif ttype == 'fast':
+ if len(mask) > 0 and np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+
+ if dtype not in ['cphase_diag', 'logcamp_diag']:
+ vis_arr = obsh.fft_imvec(imvec, A[0])
+
+ if dtype == 'vis':
+ chisqgrad = chisqgrad_vis_fft(vis_arr, A, data, sigma)
+ elif dtype == 'amp':
+ chisqgrad = chisqgrad_amp_fft(vis_arr, A, data, sigma)
+ elif dtype == 'logamp':
+ chisqgrad = chisqgrad_logamp_fft(vis_arr, A, data, sigma)
+ elif dtype == 'bs':
+ chisqgrad = chisqgrad_bs_fft(vis_arr, A, data, sigma)
+ elif dtype == 'cphase':
+ chisqgrad = chisqgrad_cphase_fft(vis_arr, A, data, sigma)
+ elif dtype == 'cphase_diag':
+ chisqgrad = chisqgrad_cphase_diag_fft(imvec, A, data, sigma)
+ elif dtype == 'camp':
+ chisqgrad = chisqgrad_camp_fft(vis_arr, A, data, sigma)
+ elif dtype == 'logcamp':
+ chisqgrad = chisqgrad_logcamp_fft(vis_arr, A, data, sigma)
+ elif dtype == 'logcamp_diag':
+ chisqgrad = chisqgrad_logcamp_diag_fft(imvec, A, data, sigma)
+
+ if len(mask) > 0 and np.any(np.invert(mask)):
+ chisqgrad = chisqgrad[mask]
+
+ elif ttype == 'nfft':
+ if len(mask) > 0 and np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+
+ if dtype == 'vis':
+ chisqgrad = chisqgrad_vis_nfft(imvec, A, data, sigma)
+ elif dtype == 'amp':
+ chisqgrad = chisqgrad_amp_nfft(imvec, A, data, sigma)
+ elif dtype == 'logamp':
+ chisqgrad = chisqgrad_logamp_nfft(imvec, A, data, sigma)
+ elif dtype == 'bs':
+ chisqgrad = chisqgrad_bs_nfft(imvec, A, data, sigma)
+ elif dtype == 'cphase':
+ chisqgrad = chisqgrad_cphase_nfft(imvec, A, data, sigma)
+ elif dtype == 'cphase_diag':
+ chisqgrad = chisqgrad_cphase_diag_nfft(imvec, A, data, sigma)
+ elif dtype == 'camp':
+ chisqgrad = chisqgrad_camp_nfft(imvec, A, data, sigma)
+ elif dtype == 'logcamp':
+ chisqgrad = chisqgrad_logcamp_nfft(imvec, A, data, sigma)
+ elif dtype == 'logcamp_diag':
+ chisqgrad = chisqgrad_logcamp_diag_nfft(imvec, A, data, sigma)
+
+ if len(mask) > 0 and np.any(np.invert(mask)):
+ chisqgrad = chisqgrad[mask]
+
+ return chisqgrad
+
+
+def regularizer(imvec, nprior, mask, flux, xdim, ydim, psize, stype, **kwargs):
+ """return the regularizer value
+ """
+
+ norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER)
+ beam_size = kwargs.get('beam_size', psize)
+ alpha_A = kwargs.get('alpha_A', 1.0)
+ epsilon = kwargs.get('epsilon_tv', 0.)
+
+ if stype == "flux":
+ s = -sflux(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "cm":
+ s = -scm(imvec, xdim, ydim, psize, flux, mask, norm_reg=norm_reg, beam_size=beam_size)
+ elif stype == "simple":
+ s = -ssimple(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "l1":
+ s = -sl1(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "l1w":
+ s = -sl1w(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "lA":
+ s = -slA(imvec, nprior, psize, flux, beam_size, alpha_A, norm_reg)
+ elif stype == "gs":
+ s = -sgs(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "patch":
+ s = -spatch(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "tv":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ s = -stv(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg,beam_size=beam_size, epsilon=epsilon)
+ elif stype == "tvlog":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, clipfloor=epsilon, randomfloor=True)
+ npix = xdim*ydim
+ logvec = np.log(imvec)
+ logflux = npix*np.abs(np.log(flux/npix))
+ s = -stv(logvec, xdim, ydim, psize, logflux, norm_reg=norm_reg,beam_size=beam_size, epsilon=epsilon)
+ elif stype == "tv2":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ s = -stv2(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg, beam_size=beam_size)
+ elif stype == "tv2log":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ npix = xdim*ydim
+ logvec = np.log(imvec)
+ logflux = npix*np.abs(np.log(flux/npix))
+ s = -stv2(logvec, xdim, ydim, psize, logflux, norm_reg=norm_reg, beam_size=beam_size)
+ elif stype == "compact":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ s = -scompact(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg)
+ elif stype == "compact2":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ s = -scompact2(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg)
+ elif stype == "rgauss":
+ # additional key words for gaussian regularizer
+ major = kwargs.get('major', 1.0)
+ minor = kwargs.get('minor', 1.0)
+ PA = kwargs.get('PA', 1.0)
+
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ s = -sgauss(imvec, xdim, ydim, psize, major=major, minor=minor, PA=PA)
+ else:
+ s = 0
+
+ return s
+
+
+def regularizergrad(imvec, nprior, mask, flux, xdim, ydim, psize, stype, **kwargs):
+ """return the regularizer gradient
+ """
+
+ norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER)
+ beam_size = kwargs.get('beam_size', psize)
+ alpha_A = kwargs.get('alpha_A', 1.0)
+ epsilon = kwargs.get('epsilon_tv', 0.)
+
+ if stype == "flux":
+ s = -sfluxgrad(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "cm":
+ s = -scmgrad(imvec, xdim, ydim, psize, flux, mask, norm_reg=norm_reg, beam_size=beam_size)
+ elif stype == "simple":
+ s = -ssimplegrad(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "l1":
+ s = -sl1grad(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "l1w":
+ s = -sl1wgrad(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "lA":
+ s = -slAgrad(imvec, nprior, psize, flux, beam_size, alpha_A, norm_reg)
+ elif stype == "gs":
+ s = -sgsgrad(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "patch":
+ s = -spatchgrad(imvec, nprior, flux, norm_reg=norm_reg)
+ elif stype == "tv":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ s = -stvgrad(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg,
+ beam_size=beam_size, epsilon=epsilon)
+ s = s[mask]
+ elif stype == "tvlog":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, clipfloor=epsilon, randomfloor=True)
+ npix = xdim*ydim
+ logvec = np.log(imvec)
+ logflux = npix*np.abs(np.log(flux/npix))
+ s = -stvgrad(logvec, xdim, ydim, psize, logflux, norm_reg=norm_reg,beam_size=beam_size, epsilon=epsilon)
+ s = s / imvec
+ s = s[mask]
+ elif stype == "tv2":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ s = -stv2grad(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg, beam_size=beam_size)
+ s = s[mask]
+ elif stype == "tv2log":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ npix = xdim*ydim
+ logvec = np.log(imvec)
+ logflux = npix*np.abs(np.log(flux/npix))
+ s = -stv2grad(logvec, xdim, ydim, psize, logflux, norm_reg=norm_reg, beam_size=beam_size)
+ s = s / imvec
+ s = s[mask]
+ elif stype == "compact":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ s = -scompactgrad(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg)
+ s = s[mask]
+ elif stype == "compact2":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ s = -scompact2grad(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg)
+ s = s[mask]
+ elif stype == "rgauss":
+ # additional key words for gaussian regularizer
+ major = kwargs.get('major', 1.0)
+ minor = kwargs.get('minor', 1.0)
+ PA = kwargs.get('PA', 1.0)
+
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=True)
+ s = -sgauss_grad(imvec, xdim, ydim, psize, major, minor, PA)
+ s = s[mask]
+ else:
+ s = np.zeros(len(imvec))
+
+ return s
+
+
+def chisqdata(Obsdata, Prior, mask, dtype, pol='I', **kwargs):
+ """Return the data, sigma, and matrices for the appropriate dtype
+ """
+
+ ttype = kwargs.get('ttype', 'direct')
+ (data, sigma, A) = (False, False, False)
+ if ttype not in ['fast', 'direct', 'nfft']:
+ raise Exception("Possible ttype values are 'fast', 'direct', 'nfft'!")
+
+ if ttype == 'direct':
+ if dtype == 'vis':
+ (data, sigma, A) = chisqdata_vis(Obsdata, Prior, mask, pol=pol, **kwargs)
+ elif dtype == 'amp' or dtype == 'logamp':
+ (data, sigma, A) = chisqdata_amp(Obsdata, Prior, mask, pol=pol, **kwargs)
+ elif dtype == 'bs':
+ (data, sigma, A) = chisqdata_bs(Obsdata, Prior, mask, pol=pol, **kwargs)
+ elif dtype == 'cphase':
+ (data, sigma, A) = chisqdata_cphase(Obsdata, Prior, mask, pol=pol, **kwargs)
+ elif dtype == 'cphase_diag':
+ (data, sigma, A) = chisqdata_cphase_diag(Obsdata, Prior, mask, pol=pol, **kwargs)
+ elif dtype == 'camp':
+ (data, sigma, A) = chisqdata_camp(Obsdata, Prior, mask, pol=pol, **kwargs)
+ elif dtype == 'logcamp':
+ (data, sigma, A) = chisqdata_logcamp(Obsdata, Prior, mask, pol=pol, **kwargs)
+ elif dtype == 'logcamp_diag':
+ (data, sigma, A) = chisqdata_logcamp_diag(Obsdata, Prior, mask, pol=pol, **kwargs)
+
+ elif ttype == 'fast':
+ if dtype == 'vis':
+ (data, sigma, A) = chisqdata_vis_fft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'amp' or dtype == 'logamp':
+ (data, sigma, A) = chisqdata_amp_fft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'bs':
+ (data, sigma, A) = chisqdata_bs_fft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'cphase':
+ (data, sigma, A) = chisqdata_cphase_fft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'cphase_diag':
+ (data, sigma, A) = chisqdata_cphase_diag_fft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'camp':
+ (data, sigma, A) = chisqdata_camp_fft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'logcamp':
+ (data, sigma, A) = chisqdata_logcamp_fft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'logcamp_diag':
+ (data, sigma, A) = chisqdata_logcamp_diag_fft(Obsdata, Prior, pol=pol, **kwargs)
+
+ elif ttype == 'nfft':
+ if dtype == 'vis':
+ (data, sigma, A) = chisqdata_vis_nfft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'amp' or dtype == 'logamp':
+ (data, sigma, A) = chisqdata_amp_nfft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'bs':
+ (data, sigma, A) = chisqdata_bs_nfft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'cphase':
+ (data, sigma, A) = chisqdata_cphase_nfft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'cphase_diag':
+ (data, sigma, A) = chisqdata_cphase_diag_nfft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'camp':
+ (data, sigma, A) = chisqdata_camp_nfft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'logcamp':
+ (data, sigma, A) = chisqdata_logcamp_nfft(Obsdata, Prior, pol=pol, **kwargs)
+ elif dtype == 'logcamp_diag':
+ (data, sigma, A) = chisqdata_logcamp_diag_nfft(Obsdata, Prior, pol=pol, **kwargs)
+
+ return (data, sigma, A)
+
+
+##################################################################################################
+# DFT Chi-squared and Gradient Functions
+##################################################################################################
+
+def chisq_vis(imvec, Amatrix, vis, sigma):
+ """Visibility chi-squared"""
+
+ samples = np.dot(Amatrix, imvec)
+ chisq = np.sum(np.abs((samples-vis)/sigma)**2)/(2*len(vis))
+ return chisq
+
+def chisqgrad_vis(imvec, Amatrix, vis, sigma):
+ """The gradient of the visibility chi-squared"""
+
+ samples = np.dot(Amatrix, imvec)
+ wdiff = (vis - samples)/(sigma**2)
+
+ out = -np.real(np.dot(Amatrix.conj().T, wdiff))/len(vis)
+ return out
+
+
+def chisq_amp(imvec, A, amp, sigma):
+ """Visibility Amplitudes (normalized) chi-squared"""
+
+ amp_samples = np.abs(np.dot(A, imvec))
+ return np.sum(np.abs((amp - amp_samples)/sigma)**2)/len(amp)
+
+
+def chisqgrad_amp(imvec, A, amp, sigma):
+ """The gradient of the amplitude chi-squared"""
+
+ i1 = np.dot(A, imvec)
+ amp_samples = np.abs(i1)
+
+ pp = ((amp - amp_samples) * amp_samples) / (sigma**2) / i1
+ out = (-2.0/len(amp)) * np.real(np.dot(pp, A))
+ return out
+
+
+def chisq_bs(imvec, Amatrices, bis, sigma):
+ """Bispectrum chi-squared"""
+
+ bisamples = (np.dot(Amatrices[0], imvec) *
+ np.dot(Amatrices[1], imvec) *
+ np.dot(Amatrices[2], imvec))
+ chisq = np.sum(np.abs(((bis - bisamples)/sigma))**2)/(2.*len(bis))
+ return chisq
+
+
+def chisqgrad_bs(imvec, Amatrices, bis, sigma):
+ """The gradient of the bispectrum chi-squared"""
+
+ bisamples = (np.dot(Amatrices[0], imvec) *
+ np.dot(Amatrices[1], imvec) *
+ np.dot(Amatrices[2], imvec))
+
+ wdiff = ((bis - bisamples).conj())/(sigma**2)
+ pt1 = wdiff * np.dot(Amatrices[1], imvec) * np.dot(Amatrices[2], imvec)
+ pt2 = wdiff * np.dot(Amatrices[0], imvec) * np.dot(Amatrices[2], imvec)
+ pt3 = wdiff * np.dot(Amatrices[0], imvec) * np.dot(Amatrices[1], imvec)
+ out = (np.dot(pt1, Amatrices[0]) +
+ np.dot(pt2, Amatrices[1]) +
+ np.dot(pt3, Amatrices[2]))
+
+ out = -np.real(out) / len(bis)
+ return out
+
+
+def chisq_cphase(imvec, Amatrices, clphase, sigma):
+ """Closure Phases (normalized) chi-squared"""
+ clphase = clphase * ehc.DEGREE
+ sigma = sigma * ehc.DEGREE
+
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ clphase_samples = np.angle(i1 * i2 * i3)
+
+ chisq = (2.0/len(clphase)) * np.sum((1.0 - np.cos(clphase-clphase_samples))/(sigma**2))
+ return chisq
+
+
+def chisqgrad_cphase(imvec, Amatrices, clphase, sigma):
+ """The gradient of the closure phase chi-squared"""
+ clphase = clphase * ehc.DEGREE
+ sigma = sigma * ehc.DEGREE
+
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ clphase_samples = np.angle(i1 * i2 * i3)
+
+ pref = np.sin(clphase - clphase_samples)/(sigma**2)
+ pt1 = pref/i1
+ pt2 = pref/i2
+ pt3 = pref/i3
+ out = np.dot(pt1, Amatrices[0]) + np.dot(pt2, Amatrices[1]) + np.dot(pt3, Amatrices[2])
+ out = (-2.0/len(clphase)) * np.imag(out)
+ return out
+
+
+def chisq_cphase_diag(imvec, Amatrices, clphase_diag, sigma):
+ """Diagonalized closure phases (normalized) chi-squared"""
+ clphase_diag = np.concatenate(clphase_diag) * ehc.DEGREE
+ sigma = np.concatenate(sigma) * ehc.DEGREE
+
+ A3_diag = Amatrices[0]
+ tform_mats = Amatrices[1]
+
+ clphase_diag_samples = []
+ for iA, A3 in enumerate(A3_diag):
+ clphase_samples = np.angle(np.dot(A3[0], imvec) *
+ np.dot(A3[1], imvec) *
+ np.dot(A3[2], imvec))
+ clphase_diag_samples.append(np.dot(tform_mats[iA], clphase_samples))
+ clphase_diag_samples = np.concatenate(clphase_diag_samples)
+
+ chisq = np.sum((1.0 - np.cos(clphase_diag-clphase_diag_samples))/(sigma**2))
+ chisq *= (2.0/len(clphase_diag))
+ return chisq
+
+
+def chisqgrad_cphase_diag(imvec, Amatrices, clphase_diag, sigma):
+ """The gradient of the diagonalized closure phase chi-squared"""
+ clphase_diag = clphase_diag * ehc.DEGREE
+ sigma = sigma * ehc.DEGREE
+
+ A3_diag = Amatrices[0]
+ tform_mats = Amatrices[1]
+
+ deriv = np.zeros_like(imvec)
+ for iA, A3 in enumerate(A3_diag):
+
+ i1 = np.dot(A3[0], imvec)
+ i2 = np.dot(A3[1], imvec)
+ i3 = np.dot(A3[2], imvec)
+ clphase_samples = np.angle(i1 * i2 * i3)
+ clphase_diag_samples = np.dot(tform_mats[iA], clphase_samples)
+
+ clphase_diag_measured = clphase_diag[iA]
+ clphase_diag_sigma = sigma[iA]
+
+ term1 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples) /
+ (clphase_diag_sigma**2.0)), (tform_mats[iA]/i1)), A3[0])
+ term2 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples) /
+ (clphase_diag_sigma**2.0)), (tform_mats[iA]/i2)), A3[1])
+ term3 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples) /
+ (clphase_diag_sigma**2.0)), (tform_mats[iA]/i3)), A3[2])
+ deriv += -2.0*np.imag(term1 + term2 + term3)
+
+ deriv *= 1.0/np.float(len(np.concatenate(clphase_diag)))
+
+ return deriv
+
+
+def chisq_camp(imvec, Amatrices, clamp, sigma):
+ """Closure Amplitudes (normalized) chi-squared"""
+
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ i4 = np.dot(Amatrices[3], imvec)
+ clamp_samples = np.abs((i1 * i2)/(i3 * i4))
+
+ chisq = np.sum(np.abs((clamp - clamp_samples)/sigma)**2)/len(clamp)
+ return chisq
+
+
+def chisqgrad_camp(imvec, Amatrices, clamp, sigma):
+ """The gradient of the closure amplitude chi-squared"""
+
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ i4 = np.dot(Amatrices[3], imvec)
+ clamp_samples = np.abs((i1 * i2)/(i3 * i4))
+
+ pp = ((clamp - clamp_samples) * clamp_samples)/(sigma**2)
+ pt1 = pp/i1
+ pt2 = pp/i2
+ pt3 = -pp/i3
+ pt4 = -pp/i4
+ out = (np.dot(pt1, Amatrices[0]) +
+ np.dot(pt2, Amatrices[1]) +
+ np.dot(pt3, Amatrices[2]) +
+ np.dot(pt4, Amatrices[3]))
+ out *= (-2.0/len(clamp)) * np.real(out)
+ return out
+
+
+def chisq_logcamp(imvec, Amatrices, log_clamp, sigma):
+ """Log Closure Amplitudes (normalized) chi-squared"""
+
+ a1 = np.abs(np.dot(Amatrices[0], imvec))
+ a2 = np.abs(np.dot(Amatrices[1], imvec))
+ a3 = np.abs(np.dot(Amatrices[2], imvec))
+ a4 = np.abs(np.dot(Amatrices[3], imvec))
+
+ samples = np.log(a1) + np.log(a2) - np.log(a3) - np.log(a4)
+ chisq = np.sum(np.abs((log_clamp - samples)/sigma)**2) / (len(log_clamp))
+ return chisq
+
+
+def chisqgrad_logcamp(imvec, Amatrices, log_clamp, sigma):
+ """The gradient of the Log closure amplitude chi-squared"""
+
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ i4 = np.dot(Amatrices[3], imvec)
+ log_clamp_samples = (np.log(np.abs(i1)) +
+ np.log(np.abs(i2)) -
+ np.log(np.abs(i3)) -
+ np.log(np.abs(i4)))
+
+ pp = (log_clamp - log_clamp_samples) / (sigma**2)
+ pt1 = pp / i1
+ pt2 = pp / i2
+ pt3 = -pp / i3
+ pt4 = -pp / i4
+ out = (np.dot(pt1, Amatrices[0]) +
+ np.dot(pt2, Amatrices[1]) +
+ np.dot(pt3, Amatrices[2]) +
+ np.dot(pt4, Amatrices[3]))
+ out = (-2.0/len(log_clamp)) * np.real(out)
+ return out
+
+
+def chisq_logcamp_diag(imvec, Amatrices, log_clamp_diag, sigma):
+ """Diagonalized log closure amplitudes (normalized) chi-squared"""
+
+ log_clamp_diag = np.concatenate(log_clamp_diag)
+ sigma = np.concatenate(sigma)
+
+ A4_diag = Amatrices[0]
+ tform_mats = Amatrices[1]
+
+ log_clamp_diag_samples = []
+ for iA, A4 in enumerate(A4_diag):
+
+ a1 = np.abs(np.dot(A4[0], imvec))
+ a2 = np.abs(np.dot(A4[1], imvec))
+ a3 = np.abs(np.dot(A4[2], imvec))
+ a4 = np.abs(np.dot(A4[3], imvec))
+
+ log_clamp_samples = np.log(a1) + np.log(a2) - np.log(a3) - np.log(a4)
+ log_clamp_diag_samples.append(np.dot(tform_mats[iA], log_clamp_samples))
+
+ log_clamp_diag_samples = np.concatenate(log_clamp_diag_samples)
+
+ chisq = np.sum(np.abs((log_clamp_diag - log_clamp_diag_samples)/sigma)**2)
+ chisq /= (len(log_clamp_diag))
+
+ return chisq
+
+
+def chisqgrad_logcamp_diag(imvec, Amatrices, log_clamp_diag, sigma):
+ """The gradient of the diagonalized log closure amplitude chi-squared"""
+
+ A4_diag = Amatrices[0]
+ tform_mats = Amatrices[1]
+
+ deriv = np.zeros_like(imvec)
+ for iA, A4 in enumerate(A4_diag):
+
+ i1 = np.dot(A4[0], imvec)
+ i2 = np.dot(A4[1], imvec)
+ i3 = np.dot(A4[2], imvec)
+ i4 = np.dot(A4[3], imvec)
+ log_clamp_samples = np.log(np.abs(i1)) + np.log(np.abs(i2)) - \
+ np.log(np.abs(i3)) - np.log(np.abs(i4))
+ log_clamp_diag_samples = np.dot(tform_mats[iA], log_clamp_samples)
+
+ log_clamp_diag_measured = log_clamp_diag[iA]
+ log_clamp_diag_sigma = sigma[iA]
+
+ term1 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples) /
+ (log_clamp_diag_sigma**2.0)), (tform_mats[iA]/i1)), A4[0])
+ term2 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples) /
+ (log_clamp_diag_sigma**2.0)), (tform_mats[iA]/i2)), A4[1])
+ term3 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples) /
+ (log_clamp_diag_sigma**2.0)), (tform_mats[iA]/i3)), A4[2])
+ term4 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples) /
+ (log_clamp_diag_sigma**2.0)), (tform_mats[iA]/i4)), A4[3])
+ deriv += -2.0*np.real(term1 + term2 - term3 - term4)
+
+ deriv *= 1.0/np.float(len(np.concatenate(log_clamp_diag)))
+
+ return deriv
+
+
+def chisq_logamp(imvec, A, amp, sigma):
+ """Log Visibility Amplitudes (normalized) chi-squared"""
+
+ # to lowest order the variance on the logarithm of a quantity x is
+ # sigma^2_log(x) = sigma^2/x^2
+ logsigma = sigma / amp
+
+ amp_samples = np.abs(np.dot(A, imvec))
+ chisq = np.sum(np.abs((np.log(amp) - np.log(amp_samples))/logsigma)**2)/len(amp)
+ return chisq
+
+def chisqgrad_logamp(imvec, A, amp, sigma):
+ """The gradient of the Log amplitude chi-squared"""
+
+ # to lowest order the variance on the logarithm of a quantity x is
+ # sigma^2_log(x) = sigma^2/x^2
+ logsigma = sigma / amp
+
+ i1 = np.dot(A, imvec)
+ amp_samples = np.abs(i1)
+
+ pp = ((np.log(amp) - np.log(amp_samples))) / (logsigma**2) / i1
+ out = (-2.0/len(amp)) * np.real(np.dot(pp, A))
+ return out
+
+##################################################################################################
+# FFT Chi-squared and Gradient Functions
+##################################################################################################
+
+
+def chisq_vis_fft(vis_arr, A, vis, sigma):
+ """Visibility chi-squared from fft
+ """
+
+ im_info, sampler_info_list, gridder_info_list = A
+ samples = obsh.sampler(vis_arr, sampler_info_list, sample_type="vis")
+
+ chisq = np.sum(np.abs((samples-vis)/sigma)**2)/(2*len(vis))
+
+ return chisq
+
+
+def chisqgrad_vis_fft(vis_arr, A, vis, sigma):
+ """The gradient of the visibility chi-squared from fft
+ """
+
+ im_info, sampler_info_list, gridder_info_list = A
+
+ # samples and gradient FT
+ pulsefac = sampler_info_list[0].pulsefac
+ samples = obsh.sampler(vis_arr, sampler_info_list, sample_type="vis")
+ wdiff_vec = (-1.0/len(vis)*(vis - samples)/(sigma**2)) * pulsefac.conj()
+
+ # Setup and perform the inverse FFT
+ wdiff_arr = obsh.gridder([wdiff_vec], gridder_info_list)
+ grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff_arr)))
+ grad_arr = grad_arr * (im_info.npad * im_info.npad)
+
+ # extract relevant cells and flatten
+ # TODO or is x<-->y??
+ out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2,
+ im_info.padvaly1:-im_info.padvaly2].flatten())
+
+ return out
+
+
+def chisq_amp_fft(vis_arr, A, amp, sigma):
+ """Visibility amplitude chi-squared from fft
+ """
+
+ im_info, sampler_info_list, gridder_info_list = A
+ amp_samples = np.abs(obsh.sampler(vis_arr, sampler_info_list, sample_type="vis"))
+ chisq = np.sum(np.abs((amp_samples-amp)/sigma)**2)/(len(amp))
+ return chisq
+
+
+def chisqgrad_amp_fft(vis_arr, A, amp, sigma):
+ """The gradient of the amplitude chi-kernesquared
+ """
+
+ im_info, sampler_info_list, gridder_info_list = A
+
+ # samples
+ samples = obsh.sampler(vis_arr, sampler_info_list, sample_type="vis")
+ amp_samples = np.abs(samples)
+
+ # gradient FT
+ pulsefac = sampler_info_list[0].pulsefac
+ wdiff_vec = (-2.0/len(amp)*((amp - amp_samples) * amp_samples) /
+ (sigma**2) / samples.conj()) * pulsefac.conj()
+
+ # Setup and perform the inverse FFT
+ wdiff_arr = obsh.gridder([wdiff_vec], gridder_info_list)
+ grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff_arr)))
+ grad_arr = grad_arr * (im_info.npad * im_info.npad)
+
+ # extract relevent cells and flatten
+ # TODO or is x<-->y??
+ out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2,
+ im_info.padvaly1:-im_info.padvaly2].flatten())
+
+ return out
+
+
+def chisq_bs_fft(vis_arr, A, bis, sigma):
+ """Bispectrum chi-squared from fft"""
+
+ im_info, sampler_info_list, gridder_info_list = A
+ bisamples = obsh.sampler(vis_arr, sampler_info_list, sample_type="bs")
+
+ return np.sum(np.abs(((bis - bisamples)/sigma))**2)/(2.*len(bis))
+
+
+def chisqgrad_bs_fft(vis_arr, A, bis, sigma):
+ """The gradient of the amplitude chi-squared
+ """
+ im_info, sampler_info_list, gridder_info_list = A
+
+ v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis")
+ v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis")
+ v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis")
+ bisamples = v1*v2*v3
+
+ wdiff = -1.0/len(bis)*(bis - bisamples)/(sigma**2)
+
+ pt1 = wdiff * (v2 * v3).conj() * sampler_info_list[0].pulsefac.conj()
+ pt2 = wdiff * (v1 * v3).conj() * sampler_info_list[1].pulsefac.conj()
+ pt3 = wdiff * (v1 * v2).conj() * sampler_info_list[2].pulsefac.conj()
+
+ # Setup and perform the inverse FFT
+ wdiff = obsh.gridder([pt1, pt2, pt3], gridder_info_list)
+ grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff)))
+ grad_arr = grad_arr * (im_info.npad * im_info.npad)
+
+ # extract relevant cells and flatten
+ # TODO or is x<-->y??
+ out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2,
+ im_info.padvaly1:-im_info.padvaly2].flatten())
+ return out
+
+
+def chisq_cphase_fft(vis_arr, A, clphase, sigma):
+ """Closure Phases (normalized) chi-squared from fft
+ """
+
+ clphase = clphase * ehc.DEGREE
+ sigma = sigma * ehc.DEGREE
+
+ im_info, sampler_info_list, gridder_info_list = A
+ clphase_samples = np.angle(obsh.sampler(vis_arr, sampler_info_list, sample_type="bs"))
+
+ chisq = (2.0/len(clphase)) * np.sum((1.0 - np.cos(clphase-clphase_samples))/(sigma**2))
+ return chisq
+
+
+def chisqgrad_cphase_fft(vis_arr, A, clphase, sigma):
+ """The gradient of the closure phase chi-squared from fft"""
+
+ clphase = clphase * ehc.DEGREE
+ sigma = sigma * ehc.DEGREE
+ im_info, sampler_info_list, gridder_info_list = A
+
+ # sample visibilities and closure phases
+ v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis")
+ v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis")
+ v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis")
+ clphase_samples = np.angle(v1*v2*v3)
+
+ pref = (2.0/len(clphase)) * np.sin(clphase - clphase_samples)/(sigma**2)
+ pt1 = pref/v1.conj() * sampler_info_list[0].pulsefac.conj()
+ pt2 = pref/v2.conj() * sampler_info_list[1].pulsefac.conj()
+ pt3 = pref/v3.conj() * sampler_info_list[2].pulsefac.conj()
+
+ # Setup and perform the inverse FFT
+ wdiff = obsh.gridder([pt1, pt2, pt3], gridder_info_list)
+ grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff)))
+ grad_arr = grad_arr * (im_info.npad * im_info.npad)
+
+ # extract relevant cells and flatten
+ # TODO or is x<-->y??
+ out = np.imag(grad_arr[im_info.padvalx1:-im_info.padvalx2,
+ im_info.padvaly1:-im_info.padvaly2].flatten())
+
+ return out
+
+
+def chisq_cphase_diag_fft(imvec, A, clphase_diag, sigma):
+ """Diagonalized closure phases (normalized) chi-squared from fft
+ """
+
+ clphase_diag = np.concatenate(clphase_diag) * ehc.DEGREE
+ sigma = np.concatenate(sigma) * ehc.DEGREE
+
+ A3 = A[0]
+ tform_mats = A[1]
+
+ im_info, sampler_info_list, gridder_info_list = A3
+ vis_arr = obsh.fft_imvec(imvec, A3[0])
+ clphase_samples = np.angle(obsh.sampler(vis_arr, sampler_info_list, sample_type="bs"))
+
+ count = 0
+ clphase_diag_samples = []
+ for tform_mat in tform_mats:
+ clphase_samples_here = clphase_samples[count:count+len(tform_mat)]
+ clphase_diag_samples.append(np.dot(tform_mat, clphase_samples_here))
+ count += len(tform_mat)
+
+ clphase_diag_samples = np.concatenate(clphase_diag_samples)
+
+ chisq = np.sum((1.0 - np.cos(clphase_diag-clphase_diag_samples))/(sigma**2))
+ chisq *= (2.0/len(clphase_diag))
+ return chisq
+
+
+def chisqgrad_cphase_diag_fft(imvec, A, clphase_diag, sigma):
+ """The gradient of the closure phase chi-squared from fft"""
+
+ clphase_diag = np.concatenate(clphase_diag) * ehc.DEGREE
+ sigma = np.concatenate(sigma) * ehc.DEGREE
+
+ A3 = A[0]
+ tform_mats = A[1]
+
+ im_info, sampler_info_list, gridder_info_list = A3
+ vis_arr = obsh.fft_imvec(imvec, A3[0])
+
+ # sample visibilities and closure phases
+ v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis")
+ v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis")
+ v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis")
+ clphase_samples = np.angle(v1*v2*v3)
+
+ # gradient vec stuff
+ count = 0
+ pref = np.zeros_like(clphase_samples)
+ for tform_mat in tform_mats:
+
+ clphase_diag_samples = np.dot(tform_mat, clphase_samples[count:count+len(tform_mat)])
+ clphase_diag_measured = clphase_diag[count:count+len(tform_mat)]
+ clphase_diag_sigma = sigma[count:count+len(tform_mat)]
+
+ for j in range(len(clphase_diag_measured)):
+ pref[count:count+len(tform_mat)] += 2.0 * tform_mat[j, :] * np.sin(
+ clphase_diag_measured[j] - clphase_diag_samples[j])/(clphase_diag_sigma[j]**2)
+
+ count += len(tform_mat)
+
+ pt1 = pref/v1.conj() * sampler_info_list[0].pulsefac.conj()
+ pt2 = pref/v2.conj() * sampler_info_list[1].pulsefac.conj()
+ pt3 = pref/v3.conj() * sampler_info_list[2].pulsefac.conj()
+
+ # Setup and perform the inverse FFT
+ wdiff = obsh.gridder([pt1, pt2, pt3], gridder_info_list)
+ grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff)))
+ grad_arr = grad_arr * (im_info.npad * im_info.npad)
+
+ # extract relevant cells and flatten
+ deriv = np.imag(grad_arr[im_info.padvalx1:-im_info.padvalx2,
+ im_info.padvaly1:-im_info.padvaly2].flatten())
+ deriv *= 1.0/np.float(len(clphase_diag))
+
+ return deriv
+
+
+def chisq_camp_fft(vis_arr, A, clamp, sigma):
+ """Closure Amplitudes (normalized) chi-squared from fft
+ """
+
+ im_info, sampler_info_list, gridder_info_list = A
+ clamp_samples = obsh.sampler(vis_arr, sampler_info_list, sample_type="camp")
+ chisq = np.sum(np.abs((clamp - clamp_samples)/sigma)**2)/len(clamp)
+ return chisq
+
+
+def chisqgrad_camp_fft(vis_arr, A, clamp, sigma):
+ """The gradient of the closure amplitude chi-squared from fft
+ """
+
+ im_info, sampler_info_list, gridder_info_list = A
+
+ # sampled visibility and closure amplitudes
+ v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis")
+ v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis")
+ v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis")
+ v4 = obsh.sampler(vis_arr, [sampler_info_list[3]], sample_type="vis")
+ clamp_samples = np.abs((v1 * v2)/(v3 * v4))
+
+ # gradient components
+ pp = (-2.0/len(clamp)) * ((clamp - clamp_samples) * clamp_samples)/(sigma**2)
+ pt1 = pp/v1.conj() * sampler_info_list[0].pulsefac.conj()
+ pt2 = pp/v2.conj() * sampler_info_list[1].pulsefac.conj()
+ pt3 = -pp/v3.conj() * sampler_info_list[2].pulsefac.conj()
+ pt4 = -pp/v4.conj() * sampler_info_list[3].pulsefac.conj()
+
+ # Setup and perform the inverse FFT
+ wdiff = obsh.gridder([pt1, pt2, pt3, pt4], gridder_info_list)
+ grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff)))
+ grad_arr = grad_arr * (im_info.npad * im_info.npad)
+
+ # extract relevant cells and flatten
+ # TODO or is x<-->y??
+ out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2,
+ im_info.padvaly1:-im_info.padvaly2].flatten())
+
+ return out
+
+
+def chisq_logcamp_fft(vis_arr, A, log_clamp, sigma):
+ """Closure Amplitudes (normalized) chi-squared from fft
+ """
+
+ im_info, sampler_info_list, gridder_info_list = A
+ log_clamp_samples = np.log(obsh.sampler(vis_arr, sampler_info_list, sample_type='camp'))
+
+ chisq = np.sum(np.abs((log_clamp - log_clamp_samples)/sigma)**2) / (len(log_clamp))
+
+ return chisq
+
+
+def chisqgrad_logcamp_fft(vis_arr, A, log_clamp, sigma):
+ """The gradient of the closure amplitude chi-squared from fft
+ """
+
+ im_info, sampler_info_list, gridder_info_list = A
+
+ # sampled visibility and closure amplitudes
+ v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis")
+ v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis")
+ v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis")
+ v4 = obsh.sampler(vis_arr, [sampler_info_list[3]], sample_type="vis")
+
+ log_clamp_samples = np.log(np.abs((v1 * v2)/(v3 * v4)))
+
+ # gradient components
+ pp = (-2.0/len(log_clamp)) * (log_clamp - log_clamp_samples) / (sigma**2)
+ pt1 = pp / v1.conj() * sampler_info_list[0].pulsefac.conj()
+ pt2 = pp / v2.conj() * sampler_info_list[1].pulsefac.conj()
+ pt3 = -pp / v3.conj() * sampler_info_list[2].pulsefac.conj()
+ pt4 = -pp / v4.conj() * sampler_info_list[3].pulsefac.conj()
+
+ # Setup and perform the inverse FFT
+ wdiff = obsh.gridder([pt1, pt2, pt3, pt4], gridder_info_list)
+ grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff)))
+ grad_arr = grad_arr * (im_info.npad * im_info.npad)
+
+ # extract relevant cells and flatten
+ # TODO or is x<-->y??
+ out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2,
+ im_info.padvaly1:-im_info.padvaly2].flatten())
+
+ return out
+
+
+def chisq_logcamp_diag_fft(imvec, A, log_clamp_diag, sigma):
+ """Diagonalized log closure amplitudes (normalized) chi-squared from fft
+ """
+
+ log_clamp_diag = np.concatenate(log_clamp_diag)
+ sigma = np.concatenate(sigma)
+
+ A4 = A[0]
+ tform_mats = A[1]
+
+ im_info, sampler_info_list, gridder_info_list = A4
+ vis_arr = obsh.fft_imvec(imvec, A4[0])
+ log_clamp_samples = np.log(obsh.sampler(vis_arr, sampler_info_list, sample_type='camp'))
+
+ count = 0
+ log_clamp_diag_samples = []
+ for tform_mat in tform_mats:
+ log_clamp_samples_here = log_clamp_samples[count:count+len(tform_mat)]
+ log_clamp_diag_samples.append(np.dot(tform_mat, log_clamp_samples_here))
+ count += len(tform_mat)
+ log_clamp_diag_samples = np.concatenate(log_clamp_diag_samples)
+
+ chisq = np.sum(np.abs((log_clamp_diag - log_clamp_diag_samples)/sigma)**2)
+ chisq /= (len(log_clamp_diag))
+ return chisq
+
+
+def chisqgrad_logcamp_diag_fft(imvec, A, log_clamp_diag, sigma):
+ """The gradient of the diagonalized log closure amplitude chi-squared from fft
+ """
+
+ log_clamp_diag = np.concatenate(log_clamp_diag)
+ sigma = np.concatenate(sigma)
+
+ A4 = A[0]
+ tform_mats = A[1]
+
+ im_info, sampler_info_list, gridder_info_list = A4
+ vis_arr = obsh.fft_imvec(imvec, A4[0])
+
+ # sampled visibility and closure amplitudes
+ v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis")
+ v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis")
+ v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis")
+ v4 = obsh.sampler(vis_arr, [sampler_info_list[3]], sample_type="vis")
+ log_clamp_samples = np.log(np.abs((v1 * v2)/(v3 * v4)))
+
+ # gradient vec stuff
+ count = 0
+ pref = np.zeros_like(log_clamp_samples)
+ for tform_mat in tform_mats:
+
+ log_clamp_diag_samples = np.dot(tform_mat, log_clamp_samples[count:count+len(tform_mat)])
+ log_clamp_diag_measured = log_clamp_diag[count:count+len(tform_mat)]
+ log_clamp_diag_sigma = sigma[count:count+len(tform_mat)]
+
+ for j in range(len(log_clamp_diag_measured)):
+ pref[count:count+len(tform_mat)] += -2.0 * tform_mat[j, :] * \
+ (log_clamp_diag_measured[j] - log_clamp_diag_samples[j]) / \
+ (log_clamp_diag_sigma[j]**2)
+
+ count += len(tform_mat)
+
+ pt1 = pref / v1.conj() * sampler_info_list[0].pulsefac.conj()
+ pt2 = pref / v2.conj() * sampler_info_list[1].pulsefac.conj()
+ pt3 = -pref / v3.conj() * sampler_info_list[2].pulsefac.conj()
+ pt4 = -pref / v4.conj() * sampler_info_list[3].pulsefac.conj()
+
+ # Setup and perform the inverse FFT
+ wdiff = obsh.gridder([pt1, pt2, pt3, pt4], gridder_info_list)
+ grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff)))
+ grad_arr = grad_arr * (im_info.npad * im_info.npad)
+
+ # extract relevant cells and flatten
+ deriv = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2,
+ im_info.padvaly1:-im_info.padvaly2].flatten())
+ deriv *= 1.0/np.float(len(log_clamp_diag))
+
+ return deriv
+
+
+def chisq_logamp_fft(vis_arr, A, amp, sigma):
+ """Visibility amplitude chi-squared from fft
+ """
+
+ # to lowest order the variance on the logarithm of a quantity x is
+ # sigma^2_log(x) = sigma^2/x^2
+ logsigma = sigma / amp
+
+ im_info, sampler_info_list, gridder_info_list = A
+ amp_samples = np.abs(obsh.sampler(vis_arr, sampler_info_list, sample_type="vis"))
+ chisq = np.sum(np.abs((np.log(amp_samples)-np.log(amp))/logsigma)**2)/(len(amp))
+ return chisq
+
+
+def chisqgrad_logamp_fft(vis_arr, A, amp, sigma):
+ """The gradient of the amplitude chi-kernesquared
+ """
+
+ im_info, sampler_info_list, gridder_info_list = A
+
+ # samples
+ samples = obsh.sampler(vis_arr, sampler_info_list, sample_type="vis")
+ amp_samples = np.abs(samples)
+
+ # gradient FT
+ logsigma = sigma / amp
+ pulsefac = sampler_info_list[0].pulsefac
+ wdiff_vec = (-2.0/len(amp)*((np.log(amp) - np.log(amp_samples))) /
+ (logsigma**2) / samples.conj()) * pulsefac.conj()
+
+ # Setup and perform the inverse FFT
+ wdiff_arr = obsh.gridder([wdiff_vec], gridder_info_list)
+ grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff_arr)))
+ grad_arr = grad_arr * (im_info.npad * im_info.npad)
+
+ # extract relevent cells and flatten
+ # TODO or is x<-->y??
+ out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2,
+ im_info.padvaly1:-im_info.padvaly2].flatten())
+
+ return out
+
+##################################################################################################
+# NFFT Chi-squared and Gradient Functions
+##################################################################################################
+
+
+def chisq_vis_nfft(imvec, A, vis, sigma):
+ """Visibility chi-squared from nfft
+ """
+
+ # get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ # compute uniform --> nonuniform transform
+ plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T
+ plan.trafo()
+ samples = plan.f.copy()*pulsefac
+
+ # compute chi^2
+ chisq = np.sum(np.abs((samples-vis)/sigma)**2)/(2*len(vis))
+
+ return chisq
+
+
+def chisqgrad_vis_nfft(imvec, A, vis, sigma):
+ """The gradient of the visibility chi-squared from nfft
+ """
+
+ # get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ # compute uniform --> nonuniform transform
+ plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T
+ plan.trafo()
+ samples = plan.f.copy()*pulsefac
+
+ # gradient vec for adjoint FT
+ wdiff_vec = (-1.0/len(vis)*(vis - samples)/(sigma**2)) * pulsefac.conj()
+ plan.f = wdiff_vec
+ plan.adjoint()
+ grad = np.real((plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim))
+
+ return grad
+
+
+def chisq_amp_nfft(imvec, A, amp, sigma):
+ """Visibility amplitude chi-squared from nfft
+ """
+ # get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ # compute uniform --> nonuniform transform
+ plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T
+ plan.trafo()
+ samples = plan.f.copy()*pulsefac
+
+ # compute chi^2
+ amp_samples = np.abs(samples)
+ chisq = np.sum(np.abs((amp_samples-amp)/sigma)**2)/(len(amp))
+
+ return chisq
+
+
+def chisqgrad_amp_nfft(imvec, A, amp, sigma):
+ """The gradient of the amplitude chi-squared from nfft
+ """
+
+ # get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ # compute uniform --> nonuniform transform
+ plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T
+ plan.trafo()
+ samples = plan.f.copy()*pulsefac
+ amp_samples = np.abs(samples)
+
+ # gradient vec for adjoint FT
+ wdiff_vec = (-2.0/len(amp)*((amp - amp_samples) * samples) /
+ (sigma**2) / amp_samples) * pulsefac.conj()
+ plan.f = wdiff_vec
+ plan.adjoint()
+ out = np.real((plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim))
+
+ return out
+
+
+def chisq_bs_nfft(imvec, A, bis, sigma):
+ """Bispectrum chi-squared from fft"""
+
+ # get nfft objects
+ nfft_info1 = A[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ samples1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ samples2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ samples3 = plan3.f.copy()*pulsefac3
+
+ # compute chi^2
+ bisamples = samples1*samples2*samples3
+ chisq = np.sum(np.abs(((bis - bisamples)/sigma))**2)/(2.*len(bis))
+ return chisq
+
+
+def chisqgrad_bs_nfft(imvec, A, bis, sigma):
+ """The gradient of the amplitude chi-squared from the nfft
+ """
+ # get nfft objects
+ nfft_info1 = A[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ v1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ v2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ v3 = plan3.f.copy()*pulsefac3
+
+ # gradient vec for adjoint FT
+ bisamples = v1*v2*v3
+ wdiff = -1.0/len(bis)*(bis - bisamples)/(sigma**2)
+ pt1 = wdiff * (v2 * v3).conj() * pulsefac1.conj()
+ pt2 = wdiff * (v1 * v3).conj() * pulsefac2.conj()
+ pt3 = wdiff * (v1 * v2).conj() * pulsefac3.conj()
+
+ # Setup and perform the inverse FFT
+ plan1.f = pt1
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+
+ plan2.f = pt2
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+
+ plan3.f = pt3
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ out = out1 + out2 + out3
+ return out
+
+
+def chisq_cphase_nfft(imvec, A, clphase, sigma):
+ """Closure Phases (normalized) chi-squared from nfft
+ """
+
+ clphase = clphase * ehc.DEGREE
+ sigma = sigma * ehc.DEGREE
+
+ # get nfft objects
+ nfft_info1 = A[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ samples1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ samples2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ samples3 = plan3.f.copy()*pulsefac3
+
+ # compute chi^2
+ clphase_samples = np.angle(samples1*samples2*samples3)
+ chisq = (2.0/len(clphase)) * np.sum((1.0 - np.cos(clphase-clphase_samples))/(sigma**2))
+
+ return chisq
+
+
+def chisqgrad_cphase_nfft(imvec, A, clphase, sigma):
+ """The gradient of the closure phase chi-squared from nfft"""
+
+ clphase = clphase * ehc.DEGREE
+ sigma = sigma * ehc.DEGREE
+
+ # get nfft objects
+ nfft_info1 = A[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ v1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ v2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ v3 = plan3.f.copy()*pulsefac3
+
+ # gradient vec for adjoint FT
+ clphase_samples = np.angle(v1*v2*v3)
+ pref = (2.0/len(clphase)) * np.sin(clphase - clphase_samples)/(sigma**2)
+ pt1 = pref/v1.conj() * pulsefac1.conj()
+ pt2 = pref/v2.conj() * pulsefac2.conj()
+ pt3 = pref/v3.conj() * pulsefac3.conj()
+
+ # Setup and perform the inverse FFT
+ plan1.f = pt1
+ plan1.adjoint()
+ out1 = np.imag((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+
+ plan2.f = pt2
+ plan2.adjoint()
+ out2 = np.imag((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+
+ plan3.f = pt3
+ plan3.adjoint()
+ out3 = np.imag((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ out = out1 + out2 + out3
+ return out
+
+
+def chisq_cphase_diag_nfft(imvec, A, clphase_diag, sigma):
+ """Diagonalized closure phases (normalized) chi-squared from nfft
+ """
+
+ clphase_diag = np.concatenate(clphase_diag) * ehc.DEGREE
+ sigma = np.concatenate(sigma) * ehc.DEGREE
+
+ A3 = A[0]
+ tform_mats = A[1]
+
+ # get nfft objects
+ nfft_info1 = A3[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A3[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A3[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ samples1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ samples2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ samples3 = plan3.f.copy()*pulsefac3
+
+ clphase_samples = np.angle(samples1*samples2*samples3)
+
+ count = 0
+ clphase_diag_samples = []
+ for tform_mat in tform_mats:
+ clphase_samples_here = clphase_samples[count:count+len(tform_mat)]
+ clphase_diag_samples.append(np.dot(tform_mat, clphase_samples_here))
+ count += len(tform_mat)
+
+ clphase_diag_samples = np.concatenate(clphase_diag_samples)
+
+ # compute chi^2
+ chisq = (2.0/len(clphase_diag)) * \
+ np.sum((1.0 - np.cos(clphase_diag-clphase_diag_samples))/(sigma**2))
+
+ return chisq
+
+
+def chisqgrad_cphase_diag_nfft(imvec, A, clphase_diag, sigma):
+ """The gradient of the diagonalized closure phase chi-squared from nfft"""
+
+ clphase_diag = np.concatenate(clphase_diag) * ehc.DEGREE
+ sigma = np.concatenate(sigma) * ehc.DEGREE
+
+ A3 = A[0]
+ tform_mats = A[1]
+
+ # get nfft objects
+ nfft_info1 = A3[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A3[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A3[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ v1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ v2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ v3 = plan3.f.copy()*pulsefac3
+
+ clphase_samples = np.angle(v1*v2*v3)
+
+ # gradient vec for adjoint FT
+ count = 0
+ pref = np.zeros_like(clphase_samples)
+ for tform_mat in tform_mats:
+
+ clphase_diag_samples = np.dot(tform_mat, clphase_samples[count:count+len(tform_mat)])
+ clphase_diag_measured = clphase_diag[count:count+len(tform_mat)]
+ clphase_diag_sigma = sigma[count:count+len(tform_mat)]
+
+ for j in range(len(clphase_diag_measured)):
+ pref[count:count+len(tform_mat)] += 2.0 * tform_mat[j, :] * np.sin(
+ clphase_diag_measured[j] - clphase_diag_samples[j])/(clphase_diag_sigma[j]**2)
+
+ count += len(tform_mat)
+
+ pt1 = pref/v1.conj() * pulsefac1.conj()
+ pt2 = pref/v2.conj() * pulsefac2.conj()
+ pt3 = pref/v3.conj() * pulsefac3.conj()
+
+ # Setup and perform the inverse FFT
+ plan1.f = pt1
+ plan1.adjoint()
+ out1 = np.imag((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+
+ plan2.f = pt2
+ plan2.adjoint()
+ out2 = np.imag((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+
+ plan3.f = pt3
+ plan3.adjoint()
+ out3 = np.imag((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ deriv = out1 + out2 + out3
+ deriv *= 1.0/np.float(len(clphase_diag))
+
+ return deriv
+
+
+def chisq_camp_nfft(imvec, A, clamp, sigma):
+ """Closure Amplitudes (normalized) chi-squared from fft
+ """
+
+ # get nfft objects
+ nfft_info1 = A[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ nfft_info4 = A[3]
+ plan4 = nfft_info4.plan
+ pulsefac4 = nfft_info4.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ samples1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ samples2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ samples3 = plan3.f.copy()*pulsefac3
+
+ plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T
+ plan4.trafo()
+ samples4 = plan4.f.copy()*pulsefac4
+
+ # compute chi^2
+ clamp_samples = np.abs((samples1*samples2)/(samples3*samples4))
+ chisq = np.sum(np.abs((clamp - clamp_samples)/sigma)**2)/len(clamp)
+ return chisq
+
+
+def chisqgrad_camp_nfft(imvec, A, clamp, sigma):
+ """The gradient of the closure amplitude chi-squared from fft
+ """
+
+ # get nfft objects
+ nfft_info1 = A[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ nfft_info4 = A[3]
+ plan4 = nfft_info4.plan
+ pulsefac4 = nfft_info4.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ v1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ v2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ v3 = plan3.f.copy()*pulsefac3
+
+ plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T
+ plan4.trafo()
+ v4 = plan4.f.copy()*pulsefac4
+
+ # gradient vec for adjoint FT
+ clamp_samples = np.abs((v1 * v2)/(v3 * v4))
+
+ pp = (-2.0/len(clamp)) * ((clamp - clamp_samples) * clamp_samples)/(sigma**2)
+ pt1 = pp/v1.conj() * pulsefac1.conj()
+ pt2 = pp/v2.conj() * pulsefac2.conj()
+ pt3 = -pp/v3.conj() * pulsefac3.conj()
+ pt4 = -pp/v4.conj() * pulsefac4.conj()
+
+ # Setup and perform the inverse FFT
+ plan1.f = pt1
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+
+ plan2.f = pt2
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+
+ plan3.f = pt3
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ plan4.f = pt4
+ plan4.adjoint()
+ out4 = np.real((plan4.f_hat.copy().T).reshape(nfft_info4.xdim*nfft_info4.ydim))
+
+ out = out1 + out2 + out3 + out4
+ return out
+
+
+def chisq_logcamp_nfft(imvec, A, log_clamp, sigma):
+ """Log Closure Amplitudes (normalized) chi-squared from fft
+ """
+
+ # get nfft objects
+ nfft_info1 = A[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ nfft_info4 = A[3]
+ plan4 = nfft_info4.plan
+ pulsefac4 = nfft_info4.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ samples1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ samples2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ samples3 = plan3.f.copy()*pulsefac3
+
+ plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T
+ plan4.trafo()
+ samples4 = plan4.f.copy()*pulsefac4
+
+ # compute chi^2
+ log_clamp_samples = (np.log(np.abs(samples1)) + np.log(np.abs(samples2)) -
+ np.log(np.abs(samples3)) - np.log(np.abs(samples4)))
+ chisq = np.sum(np.abs((log_clamp - log_clamp_samples)/sigma)**2) / (len(log_clamp))
+ return chisq
+
+
+def chisqgrad_logcamp_nfft(imvec, A, log_clamp, sigma):
+ """The gradient of the log closure amplitude chi-squared from fft
+ """
+
+ # get nfft objects
+ nfft_info1 = A[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ nfft_info4 = A[3]
+ plan4 = nfft_info4.plan
+ pulsefac4 = nfft_info4.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ v1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ v2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ v3 = plan3.f.copy()*pulsefac3
+
+ plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T
+ plan4.trafo()
+ v4 = plan4.f.copy()*pulsefac4
+
+ # gradient vec for adjoint FT
+ log_clamp_samples = np.log(np.abs((v1 * v2)/(v3 * v4)))
+
+ pp = (-2.0/len(log_clamp)) * (log_clamp - log_clamp_samples) / (sigma**2)
+ pt1 = pp / v1.conj() * pulsefac1.conj()
+ pt2 = pp / v2.conj() * pulsefac2.conj()
+ pt3 = -pp / v3.conj() * pulsefac3.conj()
+ pt4 = -pp / v4.conj() * pulsefac4.conj()
+
+ # Setup and perform the inverse FFT
+ plan1.f = pt1
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+
+ plan2.f = pt2
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+
+ plan3.f = pt3
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ plan4.f = pt4
+ plan4.adjoint()
+ out4 = np.real((plan4.f_hat.copy().T).reshape(nfft_info4.xdim*nfft_info4.ydim))
+
+ out = out1 + out2 + out3 + out4
+ return out
+
+
+def chisq_logcamp_diag_nfft(imvec, A, log_clamp_diag, sigma):
+ """Diagonalized log closure amplitudes (normalized) chi-squared from nfft
+ """
+
+ log_clamp_diag = np.concatenate(log_clamp_diag)
+ sigma = np.concatenate(sigma)
+
+ A4 = A[0]
+ tform_mats = A[1]
+
+ # get nfft objects
+ nfft_info1 = A4[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A4[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A4[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ nfft_info4 = A4[3]
+ plan4 = nfft_info4.plan
+ pulsefac4 = nfft_info4.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ samples1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ samples2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ samples3 = plan3.f.copy()*pulsefac3
+
+ plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T
+ plan4.trafo()
+ samples4 = plan4.f.copy()*pulsefac4
+
+ log_clamp_samples = (np.log(np.abs(samples1)) + np.log(np.abs(samples2)) -
+ np.log(np.abs(samples3)) - np.log(np.abs(samples4)))
+
+ count = 0
+ log_clamp_diag_samples = []
+ for tform_mat in tform_mats:
+ log_clamp_samples_here = log_clamp_samples[count:count+len(tform_mat)]
+ log_clamp_diag_samples.append(np.dot(tform_mat, log_clamp_samples_here))
+ count += len(tform_mat)
+
+ log_clamp_diag_samples = np.concatenate(log_clamp_diag_samples)
+
+ # compute chi^2
+ chisq = np.sum(np.abs((log_clamp_diag - log_clamp_diag_samples)/sigma)**2) / \
+ (len(log_clamp_diag))
+
+ return chisq
+
+
+def chisqgrad_logcamp_diag_nfft(imvec, A, log_clamp_diag, sigma):
+ """The gradient of the diagonalized log closure amplitude chi-squared from fft
+ """
+
+ log_clamp_diag = np.concatenate(log_clamp_diag)
+ sigma = np.concatenate(sigma)
+
+ A4 = A[0]
+ tform_mats = A[1]
+
+ # get nfft objects
+ nfft_info1 = A4[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A4[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A4[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ nfft_info4 = A4[3]
+ plan4 = nfft_info4.plan
+ pulsefac4 = nfft_info4.pulsefac
+
+ # compute uniform --> nonuniform transforms
+ plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T
+ plan1.trafo()
+ v1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T
+ plan2.trafo()
+ v2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T
+ plan3.trafo()
+ v3 = plan3.f.copy()*pulsefac3
+
+ plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T
+ plan4.trafo()
+ v4 = plan4.f.copy()*pulsefac4
+
+ log_clamp_samples = np.log(np.abs((v1 * v2)/(v3 * v4)))
+
+ # gradient vec for adjoint FT
+ count = 0
+ pp = np.zeros_like(log_clamp_samples)
+ for tform_mat in tform_mats:
+
+ log_clamp_diag_samples = np.dot(tform_mat, log_clamp_samples[count:count+len(tform_mat)])
+ log_clamp_diag_measured = log_clamp_diag[count:count+len(tform_mat)]
+ log_clamp_diag_sigma = sigma[count:count+len(tform_mat)]
+
+ for j in range(len(log_clamp_diag_measured)):
+ pp[count:count+len(tform_mat)] += -2.0 * tform_mat[j, :] * \
+ (log_clamp_diag_measured[j] - log_clamp_diag_samples[j]) / \
+ (log_clamp_diag_sigma[j]**2)
+
+ count += len(tform_mat)
+
+ pt1 = pp / v1.conj() * pulsefac1.conj()
+ pt2 = pp / v2.conj() * pulsefac2.conj()
+ pt3 = -pp / v3.conj() * pulsefac3.conj()
+ pt4 = -pp / v4.conj() * pulsefac4.conj()
+
+ # Setup and perform the inverse FFT
+ plan1.f = pt1
+ plan1.adjoint()
+ out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim))
+
+ plan2.f = pt2
+ plan2.adjoint()
+ out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim))
+
+ plan3.f = pt3
+ plan3.adjoint()
+ out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim))
+
+ plan4.f = pt4
+ plan4.adjoint()
+ out4 = np.real((plan4.f_hat.copy().T).reshape(nfft_info4.xdim*nfft_info4.ydim))
+
+ deriv = out1 + out2 + out3 + out4
+ deriv *= 1.0/np.float(len(log_clamp_diag))
+
+ return deriv
+
+
+def chisq_logamp_nfft(imvec, A, amp, sigma):
+ """Visibility log amplitude chi-squared from nfft
+ """
+
+ # get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ # compute uniform --> nonuniform transform
+ plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T
+ plan.trafo()
+ samples = plan.f.copy()*pulsefac
+
+ # to lowest order the variance on the logarithm of a quantity x is
+ # sigma^2_log(x) = sigma^2/x^2
+ logsigma = sigma / amp
+
+ # compute chi^2
+ amp_samples = np.abs(samples)
+ chisq = np.sum(np.abs((np.log(amp_samples)-np.log(amp))/logsigma)**2)/(len(amp))
+
+ return chisq
+
+
+def chisqgrad_logamp_nfft(imvec, A, amp, sigma):
+ """The gradient of the log amplitude chi-squared from nfft
+ """
+
+ # get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ # compute uniform --> nonuniform transform
+ plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T
+ plan.trafo()
+ samples = plan.f.copy()*pulsefac
+ amp_samples = np.abs(samples)
+
+ # gradient vec for adjoint FT
+ logsigma = sigma / amp
+ wdiff_vec = (-2.0/len(amp)*((np.log(amp) - np.log(amp_samples))) /
+ (logsigma**2) / samples.conj()) * pulsefac.conj()
+ plan.f = wdiff_vec
+ plan.adjoint()
+ out = np.real((plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim))
+
+ return out
+
+
+##################################################################################################
+# Regularizer and Gradient Functions
+##################################################################################################
+
+def sflux(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER):
+ """Total flux constraint
+ """
+ if norm_reg:
+ norm = flux**2
+ else:
+ norm = 1
+
+ out = -(np.sum(imvec) - flux)**2
+ return out/norm
+
+
+def sfluxgrad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER):
+ """Total flux constraint gradient
+ """
+ if norm_reg:
+ norm = flux**2
+ else:
+ norm = 1
+
+ out = -2*(np.sum(imvec) - flux)*np.ones(len(imvec))
+ return out / norm
+
+
+def scm(imvec, nx, ny, psize, flux, embed_mask, norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Center-of-mass constraint
+ """
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = beam_size**2 * flux**2
+ else:
+ norm = 1
+
+ xx, yy = np.meshgrid(range(nx//2, -nx//2, -1), range(ny//2, -ny//2, -1))
+ xx = psize*xx.flatten()[embed_mask]
+ yy = psize*yy.flatten()[embed_mask]
+
+ out = -(np.sum(imvec*xx)**2 + np.sum(imvec*yy)**2)
+ return out/norm
+
+
+def scmgrad(imvec, nx, ny, psize, flux, embed_mask, norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Center-of-mass constraint gradient
+ """
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = beam_size**2 * flux**2
+ else:
+ norm = 1
+
+ xx, yy = np.meshgrid(range(nx//2, -nx//2, -1), range(ny//2, -ny//2, -1))
+ xx = psize*xx.flatten()[embed_mask]
+ yy = psize*yy.flatten()[embed_mask]
+
+ out = -2*(np.sum(imvec*xx)*xx + np.sum(imvec*yy)*yy)
+ return out/norm
+
+
+def ssimple(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER):
+ """Simple entropy
+ """
+ if norm_reg:
+ norm = flux
+ else:
+ norm = 1
+
+ entropy = -np.sum(imvec*np.log(imvec/priorvec))
+ return entropy/norm
+
+
+def ssimplegrad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER):
+ """Simple entropy gradient
+ """
+ if norm_reg:
+ norm = flux
+ else:
+ norm = 1
+
+ entropygrad = -np.log(imvec/priorvec) - 1
+ return entropygrad/norm
+
+
+def sl1(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER):
+ """L1 norm regularizer
+ """
+ if norm_reg:
+ norm = flux
+ else:
+ norm = 1
+
+ # l1 = -np.sum(np.abs(imvec - priorvec))
+ l1 = -np.sum(np.abs(imvec))
+ return l1/norm
+
+
+def sl1grad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER):
+ """L1 norm gradient
+ """
+ if norm_reg:
+ norm = flux
+ else:
+ norm = 1
+
+ # l1grad = -np.sign(imvec - priorvec)
+ l1grad = -np.sign(imvec)
+ return l1grad/norm
+
+
+def sl1w(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER, epsilon=ehc.EP):
+ """Weighted L1 norm regularizer a la SMILI
+ """
+
+ if norm_reg:
+ norm = 1 # should be ok?
+ # This is SMILI normalization
+ # norm = np.sum((np.sqrt(priorvec**2 + epsilon) + epsilon)/np.sqrt(priorvec**2 + epsilon))
+ else:
+ norm = 1
+
+ num = np.sqrt(imvec**2 + epsilon)
+ denom = np.sqrt(priorvec**2 + epsilon) + epsilon
+
+ l1w = -np.sum(num/denom)
+ return l1w/norm
+
+
+def sl1wgrad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER, epsilon=ehc.EP):
+ """Weighted L1 norm gradient
+ """
+ if norm_reg:
+ norm = 1 # should be ok?
+ # This is SMILI normalization
+ # norm = np.sum((np.sqrt(priorvec**2 + epsilon) + epsilon)/np.sqrt(priorvec**2 + epsilon))
+ else:
+ norm = 1
+
+ num = imvec / np.sqrt(imvec**2 + epsilon)
+ denom = np.sqrt(priorvec**2 + epsilon) + epsilon
+
+ l1wgrad = - num / denom
+ return l1wgrad/norm
+
+
+def fA(imvec, I_ref=1.0, alpha_A=1.0):
+ """Function to take imvec to itself in the limit alpha_A -> 0
+ and to a binary representation in the limit alpha_A -> infinity
+ """
+ return 2.0/np.pi * (1.0 + alpha_A)/alpha_A * np.arctan(np.pi*alpha_A/2.0*np.abs(imvec)/I_ref)
+
+
+def fAgrad(imvec, I_ref=1.0, alpha_A=1.0):
+ """Function to take imvec to itself in the limit alpha_A -> 0
+ and to a binary representation in the limit alpha_A -> infinity
+ """
+ return (1.0 + alpha_A) / (I_ref * (1.0 + (np.pi*alpha_A/2.0*imvec/I_ref)**2))
+
+
+def slA(imvec, priorvec, psize, flux, beam_size=None, alpha_A=1.0, norm_reg=NORM_REGULARIZER):
+ """l_A regularizer
+ """
+
+ # The appropriate I_ref is something like the total flux divided by the # of pixels per beam
+ if beam_size is None:
+ beam_size = psize
+ I_ref = flux
+
+ if norm_reg:
+ norm_l1 = 1.0 # as alpha_A ->0
+ norm_l0 = (beam_size/psize)**2 # as alpha_A ->\infty
+ weight_l1 = 1.0/(1.0 + alpha_A)
+ weight_l0 = alpha_A
+ norm = (norm_l1 * weight_l1 + norm_l0 * weight_l0)/(weight_l0 + weight_l1)
+ else:
+ norm = 1
+
+ return -np.sum(fA(imvec, I_ref, alpha_A))/norm
+
+
+def slAgrad(imvec, priorvec, psize, flux, beam_size=None, alpha_A=1.0, norm_reg=NORM_REGULARIZER):
+ """l_A gradient
+ """
+
+ # The appropriate I_ref is something like the total flux divided by the # of pixels per beam
+ if beam_size is None:
+ beam_size = psize
+ I_ref = flux
+
+ if norm_reg:
+ norm_l1 = 1.0 # as alpha_A ->0
+ norm_l0 = (beam_size/psize)**2 # as alpha_A ->\infty
+ weight_l1 = 1.0/(1.0 + alpha_A)
+ weight_l0 = alpha_A
+ norm = (norm_l1 * weight_l1 + norm_l0 * weight_l0)/(weight_l0 + weight_l1)
+ else:
+ norm = 1
+
+ return -fAgrad(imvec, I_ref, alpha_A)/norm
+
+
+def sgs(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER):
+ """Gull-skilling entropy
+ """
+ if norm_reg:
+ norm = flux
+ else:
+ norm = 1
+
+ entropy = np.sum(imvec - priorvec - imvec*np.log(imvec/priorvec))
+ return entropy/norm
+
+
+def sgsgrad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER):
+ """Gull-Skilling gradient
+ """
+ if norm_reg:
+ norm = flux
+ else:
+ norm = 1
+
+ entropygrad = -np.log(imvec/priorvec)
+ return entropygrad/norm
+
+# TODO: epsilon is 0 by default for backwards compatibilitys
+def stv(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None, epsilon=0.):
+ """Total variation regularizer
+ """
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = flux*psize / beam_size
+ else:
+ norm = 1
+
+ im = imvec.reshape(ny, nx)
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ out = -np.sum(np.sqrt(np.abs(im_l1 - im)**2 + np.abs(im_l2 - im)**2 + epsilon))
+ return out/norm
+
+# TODO: epsilon is 0 by default for backwards compatibility
+def stvgrad(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None, epsilon=0.):
+ """Total variation gradient
+ """
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = flux*psize / beam_size
+ else:
+ norm = 1
+
+ im = imvec.reshape(ny, nx)
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1]
+ im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1]
+
+ # rotate images
+ im_r1l2 = np.roll(np.roll(impad, 1, axis=0), -1, axis=1)[1:ny+1, 1:nx+1]
+ im_l1r2 = np.roll(np.roll(impad, -1, axis=0), 1, axis=1)[1:ny+1, 1:nx+1]
+
+ # add together terms and return
+ g1 = (2*im - im_l1 - im_l2) / np.sqrt((im - im_l1)**2 + (im - im_l2)**2 + epsilon)
+ g2 = (im - im_r1) / np.sqrt((im - im_r1)**2 + (im_r1l2 - im_r1)**2 + epsilon)
+ g3 = (im - im_r2) / np.sqrt((im - im_r2)**2 + (im_l1r2 - im_r2)**2 + epsilon)
+
+ # mask the first row column gradient terms that don't exist
+ mask1 = np.zeros(im.shape)
+ mask2 = np.zeros(im.shape)
+ mask1[0, :] = 1
+ mask2[:, 0] = 1
+ g2[mask1.astype(bool)] = 0
+ g3[mask2.astype(bool)] = 0
+
+ # add terms together and return
+ out = -(g1 + g2 + g3).flatten()
+ return out/norm
+
+
+def stv2(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Squared Total variation regularizer
+ """
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = psize**4 * flux**2 / beam_size**4
+ else:
+ norm = 1
+
+ im = imvec.reshape(ny, nx)
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ out = -np.sum((im_l1 - im)**2 + (im_l2 - im)**2)
+ return out/norm
+
+
+def stv2grad(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Squared Total variation gradient
+ """
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = psize**4 * flux**2 / beam_size**4
+ else:
+ norm = 1
+
+ im = imvec.reshape(ny, nx)
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1]
+ im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1]
+
+ g1 = (2*im - im_l1 - im_l2)
+ g2 = (im - im_r1)
+ g3 = (im - im_r2)
+
+ # mask the first row column gradient terms that don't exist
+ mask1 = np.zeros(im.shape)
+ mask2 = np.zeros(im.shape)
+ mask1[0, :] = 1
+ mask2[:, 0] = 1
+ g2[mask1.astype(bool)] = 0
+ g3[mask2.astype(bool)] = 0
+
+ # add together terms and return
+ out = -2*(g1 + g2 + g3).flatten()
+ return out/norm
+
+
+def spatch(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER):
+ """Patch prior regularizer
+ """
+ if norm_reg:
+ norm = flux**2
+ else:
+ norm = 1
+
+ out = -0.5*np.sum((imvec - priorvec) ** 2)
+ return out/norm
+
+
+def spatchgrad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER):
+ """Patch prior gradient
+ """
+ if norm_reg:
+ norm = flux**2
+ else:
+ norm = 1
+
+ out = -(imvec - priorvec)
+ return out/norm
+
+# TODO FIGURE OUT NORMALIZATIONS FOR COMPACT 1 & 2 REGULARIZERS
+
+
+def scompact(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None):
+ """I r^2 source size regularizer
+ """
+
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = flux * (beam_size**2)
+ else:
+ norm = 1
+
+ im = imvec.reshape(ny, nx)
+
+ xx, yy = np.meshgrid(range(nx), range(ny))
+ xx = xx - (nx-1)/2.0
+ yy = yy - (ny-1)/2.0
+ xxpsize = xx * psize
+ yypsize = yy * psize
+
+ x0 = np.sum(np.sum(im * xxpsize))/flux
+ y0 = np.sum(np.sum(im * yypsize))/flux
+
+ out = -np.sum(np.sum(im * ((xxpsize - x0)**2 + (yypsize - y0)**2)))
+ return out/norm
+
+
+def scompactgrad(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Gradient for I r^2 source size regularizer
+ """
+
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = flux * beam_size**2
+ else:
+ norm = 1
+
+ im = imvec.reshape(ny, nx)
+
+ xx, yy = np.meshgrid(range(nx), range(ny))
+ xx = xx - (nx-1)/2.0
+ yy = yy - (ny-1)/2.0
+ xxpsize = xx * psize
+ yypsize = yy * psize
+
+ x0 = np.sum(np.sum(im * xxpsize))/flux
+ y0 = np.sum(np.sum(im * yypsize))/flux
+
+ term1 = np.sum(np.sum(im * ((xxpsize - x0))))
+ term2 = np.sum(np.sum(im * ((yypsize - y0))))
+
+ grad = -2*xxpsize*term1 - 2*yypsize*term2 + (xxpsize - x0)**2 + (yypsize - y0)**2
+
+ return -grad.reshape(-1)/norm
+
+
+def scompact2(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None):
+ """I^2r^2 source size regularizer
+ """
+
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = flux**2 * beam_size**2
+ else:
+ norm = 1
+
+ im = imvec.reshape(ny, nx)
+
+ xx, yy = np.meshgrid(range(nx), range(ny))
+ xx = xx - (nx-1)/2.0
+ yy = yy - (ny-1)/2.0
+ xxpsize = xx * psize
+ yypsize = yy * psize
+
+ out = -np.sum(np.sum(im**2 * (xxpsize**2 + yypsize**2)))
+ return out/norm
+
+
+def scompact2grad(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Gradient for I^2r^2 source size regularizer
+ """
+
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = flux**2 * beam_size**2
+ else:
+ norm = 1
+
+ im = imvec.reshape(ny, nx)
+
+ xx, yy = np.meshgrid(range(nx), range(ny))
+ xx = xx - (nx-1)/2.0
+ yy = yy - (ny-1)/2.0
+ xxpsize = xx * psize
+ yypsize = yy * psize
+
+ grad = -2*im*(xxpsize**2 + yypsize**2)
+
+ return grad.reshape(-1)/norm
+
+
+def sgauss(imvec, xdim, ydim, psize, major, minor, PA):
+ """Gaussian source size regularizer
+ """
+
+ # major, minor and PA are all in radians
+ phi = PA
+
+ # eigenvalues of covariance matrix
+ lambda1 = minor**2./(8.*np.log(2.))
+ lambda2 = major**2./(8.*np.log(2.))
+
+ # now compute covariance matrix elements from user inputs
+ sigxx_prime = lambda1*(np.cos(phi)**2.) + lambda2*(np.sin(phi)**2.)
+ sigyy_prime = lambda1*(np.sin(phi)**2.) + lambda2*(np.cos(phi)**2.)
+ sigxy_prime = (lambda2 - lambda1)*np.cos(phi)*np.sin(phi)
+
+ # we get the dimensions and image vector
+ im = imvec.reshape(xdim, ydim)
+ xlist, ylist = np.meshgrid(range(xdim), range(ydim))
+ xlist = xlist - (xdim-1)/2.0
+ ylist = ylist - (ydim-1)/2.0
+
+ xx = xlist * psize
+ yy = ylist * psize
+
+ # the centroid parameters
+ x0 = np.sum(xx*im) / np.sum(im)
+ y0 = np.sum(yy*im) / np.sum(im)
+
+ # we calculate the elements of the covariance matrix
+ sigxx = (np.sum((xx - x0)**2.*im)/np.sum(im))
+ sigyy = (np.sum((yy - y0)**2.*im)/np.sum(im))
+ sigxy = (np.sum((xx - x0)*(yy - y0)*im)/np.sum(im))
+
+ # We calculate the regularizer #this line was CHANGED
+ rgauss = -((sigxx - sigxx_prime)**2. + (sigyy - sigyy_prime)**2. + 2*(sigxy - sigxy_prime)**2.)
+ # normalization will need to be redone, right now requires alpha~1000
+ rgauss = rgauss/(major**2. * minor**2.)
+ return rgauss
+
+
+def sgauss_grad(imvec, xdim, ydim, psize, major, minor, PA):
+ """Gradient for Gaussian source size regularizer
+ """
+
+ # major, minor and PA are all in radians
+ phi = PA
+
+ # computing eigenvalues of the covariance matrix
+ lambda1 = (minor**2.)/(8.*np.log(2.))
+ lambda2 = (major**2.)/(8.*np.log(2.))
+
+ # now compute covariance matrix elements from user inputs
+
+ sigxx_prime = lambda1*(np.cos(phi)**2.) + lambda2*(np.sin(phi)**2.)
+ sigyy_prime = lambda1*(np.sin(phi)**2.) + lambda2*(np.cos(phi)**2.)
+ sigxy_prime = (lambda2 - lambda1)*np.cos(phi)*np.sin(phi)
+
+ # we get the dimensions and image vector
+ im = imvec.reshape(xdim, ydim)
+ xlist, ylist = np.meshgrid(range(xdim), range(ydim))
+ xlist = xlist - (xdim-1)/2.0
+ ylist = ylist - (ydim-1)/2.0
+
+ xx = xlist * psize
+ yy = ylist * psize
+
+ # the centroid parameters
+ x0 = np.sum(xx*im) / np.sum(im)
+ y0 = np.sum(yy*im) / np.sum(im)
+
+ # we calculate the elements of the covariance matrix of the image
+ sigxx = (np.sum((xx - x0)**2.*im)/np.sum(im))
+ sigyy = (np.sum((yy - y0)**2.*im)/np.sum(im))
+ sigxy = (np.sum((xx - x0)*(yy - y0)*im)/np.sum(im))
+
+ # now we compute the gradients of all quantities
+ # gradient of centroid
+ dx0 = (xx - x0) / np.sum(im)
+ dy0 = (yy - y0) / np.sum(im)
+
+ # gradients of covariance matrix elements
+ dxx = (((xx - x0)**2. - 2.*(xx - x0)*dx0*im) - sigxx) / np.sum(im)
+
+ dyy = (((yy - y0)**2. - 2.*(yy - y0)*dx0*im) - sigyy) / np.sum(im)
+
+ dxy = (((xx - x0)*(yy - y0) - (yy - y0)*dx0*im - (xx - x0)*dy0*im) - sigxy) / np.sum(im)
+
+ # gradient of the regularizer #this line was CHANGED
+ drgauss = (2.*(sigxx - sigxx_prime)*dxx +
+ 2.*(sigyy - sigyy_prime)*dyy +
+ 4.*(sigxy - sigxy_prime)*dxy)
+
+ # normalization will need to be redone, right now requires alpha~1000
+ drgauss = drgauss/(major**2. * minor**2.)
+
+ return -drgauss.reshape(-1)
+
+
+##################################################################################################
+# Chi^2 Data functions
+##################################################################################################
+def apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol):
+ """apply systematic noise to VISIBILITIES or AMPLITUDES
+ data_arr should have fields 't1','t2','u','v','vis','amp','sigma'
+
+ returns: (uv, vis, amp, sigma)
+ """
+
+ vtype = ehc.vis_poldict[pol]
+ atype = ehc.amp_poldict[pol]
+ etype = ehc.sig_poldict[pol]
+
+ t1 = data_arr['t1']
+ t2 = data_arr['t2']
+
+ sigma = data_arr[etype]
+ amp = data_arr[atype]
+ try:
+ vis = data_arr[vtype]
+ except ValueError:
+ vis = amp.astype('c16')
+
+ snrmask = np.abs(amp/sigma) >= snrcut
+
+ if type(systematic_noise) is dict:
+ sys_level = np.zeros(len(t1))
+ for i in range(len(t1)):
+ if t1[i] in systematic_noise.keys():
+ t1sys = systematic_noise[t1[i]]
+ else:
+ t1sys = 0.
+ if t2[i] in systematic_noise.keys():
+ t2sys = systematic_noise[t2[i]]
+ else:
+ t2sys = 0.
+
+ if t1sys < 0 or t2sys < 0:
+ sys_level[i] = -1
+ else:
+ sys_level[i] = np.sqrt(t1sys**2 + t2sys**2)
+ else:
+ sys_level = np.sqrt(2)*systematic_noise*np.ones(len(t1))
+
+ mask = sys_level >= 0.
+ mask = snrmask * mask
+
+ sigma = np.linalg.norm([sigma, sys_level*np.abs(amp)], axis=0)[mask]
+ vis = vis[mask]
+ amp = amp[mask]
+ uv = np.hstack((data_arr['u'].reshape(-1, 1), data_arr['v'].reshape(-1, 1)))[mask]
+ return (uv, vis, amp, sigma)
+
+
+def chisqdata_vis(Obsdata, Prior, mask, pol='I', **kwargs):
+ """Return the data, sigmas, and fourier matrix for visibilities
+ """
+
+ # unpack keyword args
+ systematic_noise = kwargs.get('systematic_noise', 0.)
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ atype = ehc.amp_poldict[pol]
+ etype = ehc.sig_poldict[pol]
+ data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias)
+ (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # make fourier matrix
+ A = obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv, pulse=Prior.pulse, mask=mask)
+
+ return (vis, sigma, A)
+
+
+def chisqdata_amp(Obsdata, Prior, mask, pol='I', **kwargs):
+ """Return the data, sigmas, and fourier matrix for visibility amplitudes
+ """
+
+ # unpack keyword args
+ systematic_noise = kwargs.get('systematic_noise', 0.)
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ atype = ehc.amp_poldict[pol]
+ etype = ehc.sig_poldict[pol]
+ if (Obsdata.amp is None) or (len(Obsdata.amp) == 0) or pol != 'I':
+ data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias)
+
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed amplitude table in amplitude chi^2!")
+ if not type(Obsdata.amp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed amplitude table is not a numpy rec array!")
+ data_arr = Obsdata.amp
+
+ # apply systematic noise and SNR cut
+ # TODO -- after pre-computed??
+ (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # make fourier matrix
+ A = obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv, pulse=Prior.pulse, mask=mask)
+
+ return (amp, sigma, A)
+
+
+def chisqdata_bs(Obsdata, Prior, mask, pol='I', **kwargs):
+ """return the data, sigmas, and fourier matrices for bispectra
+ """
+
+ # unpack keyword args
+ # systematic_noise = kwargs.get('systematic_noise',0.)
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ weighting = kwargs.get('weighting', 'natural')
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.bispec is None) or (len(Obsdata.bispec) == 0) or pol != 'I':
+ biarr = Obsdata.bispectra(mode="all", vtype=vtype, count=count, snrcut=snrcut)
+
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed bispectrum table in cphase chi^2!")
+ if not type(Obsdata.bispec) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed bispectrum table is not a numpy rec array!")
+ biarr = Obsdata.bispec
+ # reduce to a minimal set
+ if count != 'max':
+ biarr = obsh.reduce_tri_minimal(Obsdata, biarr)
+
+ uv1 = np.hstack((biarr['u1'].reshape(-1, 1), biarr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((biarr['u2'].reshape(-1, 1), biarr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((biarr['u3'].reshape(-1, 1), biarr['v3'].reshape(-1, 1)))
+ bi = biarr['bispec']
+ sigma = biarr['sigmab']
+
+ # add systematic noise
+ # sigma = np.linalg.norm([biarr['sigmab'], systematic_noise*np.abs(biarr['bispec'])], axis=0)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # make fourier matrices
+ A3 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask)
+ )
+
+ return (bi, sigma, A3)
+
+
+def chisqdata_cphase(Obsdata, Prior, mask, pol='I', **kwargs):
+ """Return the data, sigmas, and fourier matrices for closure phases
+ """
+
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ uv_min = kwargs.get('cp_uv_min', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ systematic_cphase_noise = kwargs.get('systematic_cphase_noise', 0.)
+ weighting = kwargs.get('weighting', 'natural')
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.cphase is None) or (len(Obsdata.cphase) == 0) or pol != 'I':
+ clphasearr = Obsdata.c_phases(mode="all", vtype=vtype,
+ count=count, uv_min=uv_min, snrcut=snrcut)
+ else: # TODO precomputed with not Stokes I
+ print("Using pre-computed cphase table in cphase chi^2!")
+ if not type(Obsdata.cphase) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed closure phase table is not a numpy rec array!")
+ clphasearr = Obsdata.cphase
+ # reduce to a minimal set
+ if count != 'max':
+ clphasearr = obsh.reduce_tri_minimal(Obsdata, clphasearr)
+
+ uv1 = np.hstack((clphasearr['u1'].reshape(-1, 1), clphasearr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((clphasearr['u2'].reshape(-1, 1), clphasearr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((clphasearr['u3'].reshape(-1, 1), clphasearr['v3'].reshape(-1, 1)))
+ clphase = clphasearr['cphase']
+ sigma = clphasearr['sigmacp']
+
+ # add systematic cphase noise (in DEGREES)
+ sigma = np.linalg.norm([sigma, systematic_cphase_noise*np.ones(len(sigma))], axis=0)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # make fourier matrices
+ A3 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask)
+ )
+ return (clphase, sigma, A3)
+
+
+def chisqdata_cphase_diag(Obsdata, Prior, mask, pol='I', **kwargs):
+ """Return the data, sigmas, and fourier matrices for diagonalized closure phases
+ """
+
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ uv_min = kwargs.get('cp_uv_min', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ clphasearr = Obsdata.c_phases_diag(vtype=vtype, count=count, snrcut=snrcut, uv_min=uv_min)
+
+ # loop over timestamps
+ clphase_diag = []
+ sigma_diag = []
+ A3_diag = []
+ tform_mats = []
+ for ic, cl in enumerate(clphasearr):
+
+ # get diagonalized closure phases and errors
+ clphase_diag.append(cl[0]['cphase'])
+ sigma_diag.append(cl[0]['sigmacp'])
+
+ # get uv arrays
+ u1 = cl[2][:, 0].astype('float')
+ v1 = cl[3][:, 0].astype('float')
+ uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1)))
+
+ u2 = cl[2][:, 1].astype('float')
+ v2 = cl[3][:, 1].astype('float')
+ uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1)))
+
+ u3 = cl[2][:, 2].astype('float')
+ v3 = cl[3][:, 2].astype('float')
+ uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1)))
+
+ # compute Fourier matrices
+ A3 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask)
+ )
+ A3_diag.append(A3)
+
+ # get transformation matrix for this timestamp
+ tform_mats.append(cl[4].astype('float'))
+
+ # combine Fourier and transformation matrices into tuple for outputting
+ Amatrices = (np.array(A3_diag), np.array(tform_mats))
+
+ return (np.array(clphase_diag), np.array(sigma_diag), Amatrices)
+
+
+def chisqdata_camp(Obsdata, Prior, mask, pol='I', **kwargs):
+ """Return the data, sigmas, and fourier matrices for closure amplitudes
+ """
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+
+ # unpack data & mask low snr points
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.camp is None) or (len(Obsdata.camp) == 0) or pol != 'I':
+ clamparr = Obsdata.c_amplitudes(mode='all', count=count,
+ vtype=vtype, ctype='camp', debias=debias, snrcut=snrcut)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed closure amplitude table in closure amplitude chi^2!")
+ if not type(Obsdata.camp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed closure amplitude table is not a numpy rec array!")
+ clamparr = Obsdata.camp
+ # reduce to a minimal set
+ if count != 'max':
+ clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='camp')
+
+ uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1)))
+ uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1)))
+ clamp = clamparr['camp']
+ sigma = clamparr['sigmaca']
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # make fourier matrices
+ A4 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv4, pulse=Prior.pulse, mask=mask)
+ )
+
+ return (clamp, sigma, A4)
+
+
+def chisqdata_logcamp(Obsdata, Prior, mask, pol='I', **kwargs):
+ """Return the data, sigmas, and fourier matrices for log closure amplitudes
+ """
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+
+ # unpack data & mask low snr points
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.logcamp is None) or (len(Obsdata.logcamp) == 0) or pol != 'I':
+ clamparr = Obsdata.c_amplitudes(mode='all', count=count,
+ vtype=vtype, ctype='logcamp', debias=debias, snrcut=snrcut)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed log closure amplitude table in log closure amplitude chi^2!")
+ if not type(Obsdata.logcamp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed log closure amplitude table is not a numpy rec array!")
+ clamparr = Obsdata.logcamp
+ # reduce to a minimal set
+ if count != 'max':
+ clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='logcamp')
+
+ uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1)))
+ uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1)))
+ clamp = clamparr['camp']
+ sigma = clamparr['sigmaca']
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # make fourier matrices
+ A4 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv4, pulse=Prior.pulse, mask=mask)
+ )
+
+ return (clamp, sigma, A4)
+
+
+def chisqdata_logcamp_diag(Obsdata, Prior, mask, pol='I', **kwargs):
+ """Return the data, sigmas, and fourier matrices for diagonalized log closure amplitudes
+ """
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+
+ # unpack data & mask low snr points
+ vtype = ehc.vis_poldict[pol]
+ clamparr = Obsdata.c_log_amplitudes_diag(vtype=vtype, count=count, debias=debias, snrcut=snrcut)
+
+ # loop over timestamps
+ clamp_diag = []
+ sigma_diag = []
+ A4_diag = []
+ tform_mats = []
+ for ic, cl in enumerate(clamparr):
+
+ # get diagonalized log closure amplitudes and errors
+ clamp_diag.append(cl[0]['camp'])
+ sigma_diag.append(cl[0]['sigmaca'])
+
+ # get uv arrays
+ u1 = cl[2][:, 0].astype('float')
+ v1 = cl[3][:, 0].astype('float')
+ uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1)))
+
+ u2 = cl[2][:, 1].astype('float')
+ v2 = cl[3][:, 1].astype('float')
+ uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1)))
+
+ u3 = cl[2][:, 2].astype('float')
+ v3 = cl[3][:, 2].astype('float')
+ uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1)))
+
+ u4 = cl[2][:, 3].astype('float')
+ v4 = cl[3][:, 3].astype('float')
+ uv4 = np.hstack((u4.reshape(-1, 1), v4.reshape(-1, 1)))
+
+ # compute Fourier matrices
+ A4 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask),
+ obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv4, pulse=Prior.pulse, mask=mask)
+ )
+ A4_diag.append(A4)
+
+ # get transformation matrix for this timestamp
+ tform_mats.append(cl[4].astype('float'))
+
+ # combine Fourier and transformation matrices into tuple for outputting
+ Amatrices = (np.array(A4_diag), np.array(tform_mats))
+
+ return (np.array(clamp_diag), np.array(sigma_diag), Amatrices)
+
+##################################################################################################
+# FFT Chi^2 Data functions
+##################################################################################################
+
+
+def chisqdata_vis_fft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the data, sigmas, uv points, and FFT info for visibilities
+ """
+
+ # unpack keyword args
+ systematic_noise = kwargs.get('systematic_noise', 0.)
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+ order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ atype = ehc.amp_poldict[pol]
+ etype = ehc.sig_poldict[pol]
+ data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias)
+ (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # prepare image and fft info objects
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse)
+ gs_info = obsh.make_gridder_and_sampler_info(
+ im_info, uv, conv_func=conv_func, p_rad=p_rad, order=order)
+
+ sampler_info_list = [gs_info[0]]
+ gridder_info_list = [gs_info[1]]
+ A = (im_info, sampler_info_list, gridder_info_list)
+
+ return (vis, sigma, A)
+
+
+def chisqdata_amp_fft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the data, sigmas, uv points, and FFT info for visibility amplitudes
+ """
+
+ # unpack keyword args
+ systematic_noise = kwargs.get('systematic_noise', 0.)
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+ order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ atype = ehc.amp_poldict[pol]
+ etype = ehc.sig_poldict[pol]
+ if (Obsdata.amp is None) or (len(Obsdata.amp) == 0) or pol != 'I':
+ data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed amplitude table in amplitude chi^2!")
+ if not type(Obsdata.amp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed amplitude table is not a numpy rec array!")
+ data_arr = Obsdata.amp
+
+ # apply systematic noise
+ (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # prepare image and fft info objects
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse)
+ gs_info = obsh.make_gridder_and_sampler_info(im_info, uv,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+
+ sampler_info_list = [gs_info[0]]
+ gridder_info_list = [gs_info[1]]
+ A = (im_info, sampler_info_list, gridder_info_list)
+
+ return (amp, sigma, A)
+
+
+def chisqdata_bs_fft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the data, sigmas, uv points, and FFT info for bispectra
+ """
+
+ # unpack keyword args
+ # systematic_noise = kwargs.get('systematic_noise',0.)
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ weighting = kwargs.get('weighting', 'natural')
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+ order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.bispec is None) or (len(Obsdata.bispec) == 0) or pol != 'I':
+ biarr = Obsdata.bispectra(mode="all", vtype=vtype, count=count, snrcut=snrcut)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed bispectrum table in cphase chi^2!")
+ if not type(Obsdata.bispec) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed bispectrum table is not a numpy rec array!")
+ biarr = Obsdata.bispec
+ # reduce to a minimal set
+ if count != 'max':
+ biarr = obsh.reduce_tri_minimal(Obsdata, biarr)
+
+ uv1 = np.hstack((biarr['u1'].reshape(-1, 1), biarr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((biarr['u2'].reshape(-1, 1), biarr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((biarr['u3'].reshape(-1, 1), biarr['v3'].reshape(-1, 1)))
+ bi = biarr['bispec']
+ sigma = biarr['sigmab']
+
+ # add systematic noise
+ # sigma = np.linalg.norm([biarr['sigmab'], systematic_noise*np.abs(biarr['bispec'])], axis=0)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # prepare image and fft info objects
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse)
+ gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+
+ sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0]]
+ gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1]]
+ A = (im_info, sampler_info_list, gridder_info_list)
+
+ return (bi, sigma, A)
+
+
+def chisqdata_cphase_fft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the data, sigmas, uv points, and FFT info for closure phases
+ """
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ uv_min = kwargs.get('cp_uv_min', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ weighting = kwargs.get('weighting', 'natural')
+ systematic_cphase_noise = kwargs.get('systematic_cphase_noise', 0.)
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+ order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.cphase is None) or (len(Obsdata.cphase) == 0) or pol != 'I':
+ clphasearr = Obsdata.c_phases(mode="all", vtype=vtype,
+ count=count, uv_min=uv_min, snrcut=snrcut)
+ else: # TODO precomputed with not Stokes I
+ print("Using pre-computed cphase table in cphase chi^2!")
+ if not type(Obsdata.cphase) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed closure phase table is not a numpy rec array!")
+ clphasearr = Obsdata.cphase
+ # reduce to a minimal set
+ if count != 'max':
+ clphasearr = obsh.reduce_tri_minimal(Obsdata, clphasearr)
+
+ uv1 = np.hstack((clphasearr['u1'].reshape(-1, 1), clphasearr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((clphasearr['u2'].reshape(-1, 1), clphasearr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((clphasearr['u3'].reshape(-1, 1), clphasearr['v3'].reshape(-1, 1)))
+ clphase = clphasearr['cphase']
+ sigma = clphasearr['sigmacp']
+
+ # add systematic cphase noise (in DEGREES)
+ sigma = np.linalg.norm([sigma, systematic_cphase_noise*np.ones(len(sigma))], axis=0)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # prepare image and fft info objects
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse)
+ gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+
+ sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0]]
+ gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1]]
+ A = (im_info, sampler_info_list, gridder_info_list)
+
+ return (clphase, sigma, A)
+
+
+def chisqdata_cphase_diag_fft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the data, sigmas, uv points, and FFT info for diagonalized closure phases
+ """
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ uv_min = kwargs.get('cp_uv_min', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+ order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ clphasearr = Obsdata.c_phases_diag(vtype=vtype, count=count, snrcut=snrcut, uv_min=uv_min)
+
+ # loop over timestamps
+ clphase_diag = []
+ sigma_diag = []
+ tform_mats = []
+ u1 = []
+ v1 = []
+ u2 = []
+ v2 = []
+ u3 = []
+ v3 = []
+ for ic, cl in enumerate(clphasearr):
+
+ # get diagonalized closure phases and errors
+ clphase_diag.append(cl[0]['cphase'])
+ sigma_diag.append(cl[0]['sigmacp'])
+
+ # get u and v values
+ u1.append(cl[2][:, 0].astype('float'))
+ v1.append(cl[3][:, 0].astype('float'))
+ u2.append(cl[2][:, 1].astype('float'))
+ v2.append(cl[3][:, 1].astype('float'))
+ u3.append(cl[2][:, 2].astype('float'))
+ v3.append(cl[3][:, 2].astype('float'))
+
+ # get transformation matrix for this timestamp
+ tform_mats.append(cl[4].astype('float'))
+
+ # fix formatting of arrays
+ u1 = np.concatenate(u1)
+ v1 = np.concatenate(v1)
+ u2 = np.concatenate(u2)
+ v2 = np.concatenate(v2)
+ u3 = np.concatenate(u3)
+ v3 = np.concatenate(v3)
+
+ # get uv arrays
+ uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1)))
+ uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1)))
+ uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1)))
+
+ # prepare image and fft info objects
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse)
+ gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+
+ sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0]]
+ gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1]]
+ A3 = (im_info, sampler_info_list, gridder_info_list)
+
+ # combine Fourier and transformation matrices into tuple for outputting
+ Amatrices = (A3, np.array(tform_mats))
+
+ return (np.array(clphase_diag), np.array(sigma_diag), Amatrices)
+
+
+def chisqdata_camp_fft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the data, sigmas, uv points, and FFT info for closure amplitudes
+ """
+
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+ order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.camp is None) or (len(Obsdata.camp) == 0) or pol != 'I':
+ clamparr = Obsdata.c_amplitudes(mode='all', count=count,
+ vtype=vtype, ctype='camp', debias=debias, snrcut=snrcut)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed closure amplitude table in closure amplitude chi^2!")
+ if not type(Obsdata.camp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed closure amplitude table is not a numpy rec array!")
+ clamparr = Obsdata.camp
+ # reduce to a minimal set
+ if count != 'max':
+ clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='camp')
+
+ uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1)))
+ uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1)))
+ clamp = clamparr['camp']
+ sigma = clamparr['sigmaca']
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # prepare image and fft info objects
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse)
+ gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info4 = obsh.make_gridder_and_sampler_info(im_info, uv4,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+
+ sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0], gs_info4[0]]
+ gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1], gs_info4[1]]
+ A = (im_info, sampler_info_list, gridder_info_list)
+
+ return (clamp, sigma, A)
+
+
+def chisqdata_logcamp_fft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the data, sigmas, uv points, and FFT info for log closure amplitudes
+ """
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+ order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.logcamp is None) or (len(Obsdata.logcamp) == 0) or pol != 'I':
+ clamparr = Obsdata.c_amplitudes(mode='all', count=count,
+ vtype=vtype, ctype='logcamp', debias=debias, snrcut=snrcut)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed log closure amplitude table in log closure amplitude chi^2!")
+ if not type(Obsdata.logcamp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed log closure amplitude table is not a numpy rec array!")
+ clamparr = Obsdata.logcamp
+ # reduce to a minimal set
+ if count != 'max':
+ clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='logcamp')
+
+ uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1)))
+ uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1)))
+ clamp = clamparr['camp']
+ sigma = clamparr['sigmaca']
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # prepare image and fft info objects
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse)
+ gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info4 = obsh.make_gridder_and_sampler_info(im_info, uv4,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+
+ sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0], gs_info4[0]]
+ gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1], gs_info4[1]]
+ A = (im_info, sampler_info_list, gridder_info_list)
+
+ return (clamp, sigma, A)
+
+
+def chisqdata_logcamp_diag_fft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the data, sigmas, uv points, and FFT info for diagonalized log closure amplitudes
+ """
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+ order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT)
+
+ # unpack data & mask low snr points
+ vtype = ehc.vis_poldict[pol]
+ clamparr = Obsdata.c_log_amplitudes_diag(vtype=vtype, count=count,
+ debias=debias, snrcut=snrcut)
+
+ # loop over timestamps
+ clamp_diag = []
+ sigma_diag = []
+ tform_mats = []
+ u1 = []
+ v1 = []
+ u2 = []
+ v2 = []
+ u3 = []
+ v3 = []
+ u4 = []
+ v4 = []
+ for ic, cl in enumerate(clamparr):
+
+ # get diagonalized log closure amplitudes and errors
+ clamp_diag.append(cl[0]['camp'])
+ sigma_diag.append(cl[0]['sigmaca'])
+
+ # get u and v values
+ u1.append(cl[2][:, 0].astype('float'))
+ v1.append(cl[3][:, 0].astype('float'))
+ u2.append(cl[2][:, 1].astype('float'))
+ v2.append(cl[3][:, 1].astype('float'))
+ u3.append(cl[2][:, 2].astype('float'))
+ v3.append(cl[3][:, 2].astype('float'))
+ u4.append(cl[2][:, 3].astype('float'))
+ v4.append(cl[3][:, 3].astype('float'))
+
+ # get transformation matrix for this timestamp
+ tform_mats.append(cl[4].astype('float'))
+
+ # fix formatting of arrays
+ u1 = np.concatenate(u1)
+ v1 = np.concatenate(v1)
+ u2 = np.concatenate(u2)
+ v2 = np.concatenate(v2)
+ u3 = np.concatenate(u3)
+ v3 = np.concatenate(v3)
+ u4 = np.concatenate(u4)
+ v4 = np.concatenate(v4)
+
+ # get uv arrays
+ uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1)))
+ uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1)))
+ uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1)))
+ uv4 = np.hstack((u4.reshape(-1, 1), v4.reshape(-1, 1)))
+
+ # prepare image and fft info objects
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse)
+ gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+ gs_info4 = obsh.make_gridder_and_sampler_info(im_info, uv4,
+ conv_func=conv_func, p_rad=p_rad, order=order)
+
+ sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0], gs_info4[0]]
+ gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1], gs_info4[1]]
+ A = (im_info, sampler_info_list, gridder_info_list)
+
+ # combine Fourier and transformation matrices into tuple for outputting
+ Amatrices = (A, np.array(tform_mats))
+
+ return (np.array(clamp_diag), np.array(sigma_diag), Amatrices)
+
+##################################################################################################
+# NFFT Chi^2 Data functions
+##################################################################################################
+
+
+def chisqdata_vis_nfft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the visibilities, sigmas, uv points, and nfft info
+ """
+ if (Prior.xdim % 2 or Prior.ydim % 2):
+ raise Exception("NFFT doesn't work with odd image dimensions!")
+
+ # unpack keyword args
+ systematic_noise = kwargs.get('systematic_noise', 0.)
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ atype = ehc.amp_poldict[pol]
+ etype = ehc.sig_poldict[pol]
+ data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias)
+ (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv)
+ A = [A1]
+
+ return (vis, sigma, A)
+
+
+def chisqdata_amp_nfft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the amplitudes, sigmas, uv points, and nfft info
+ """
+ if (Prior.xdim % 2 or Prior.ydim % 2):
+ raise Exception("NFFT doesn't work with odd image dimensions!")
+
+ # unpack keyword args
+ systematic_noise = kwargs.get('systematic_noise', 0.)
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ atype = ehc.amp_poldict[pol]
+ etype = ehc.sig_poldict[pol]
+ if (Obsdata.amp is None) or (len(Obsdata.amp) == 0) or pol != 'I':
+ data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed amplitude table in amplitude chi^2!")
+ if not type(Obsdata.amp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed amplitude table is not a numpy rec array!")
+ data_arr = Obsdata.amp
+
+ # apply systematic noise
+ (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv)
+ A = [A1]
+
+ return (amp, sigma, A)
+
+
+def chisqdata_bs_nfft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the bispectra, sigmas, uv points, and nfft info
+ """
+ if (Prior.xdim % 2 or Prior.ydim % 2):
+ raise Exception("NFFT doesn't work with odd image dimensions!")
+
+ # unpack keyword args
+ # systematic_noise = kwargs.get('systematic_noise',0.)
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ weighting = kwargs.get('weighting', 'natural')
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.bispec is None) or (len(Obsdata.bispec) == 0) or pol != 'I':
+ biarr = Obsdata.bispectra(mode="all", vtype=vtype, count=count, snrcut=snrcut)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed bispectrum table in cphase chi^2!")
+ if not type(Obsdata.bispec) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed bispectrum table is not a numpy rec array!")
+ biarr = Obsdata.bispec
+ # reduce to a minimal set
+ if count != 'max':
+ biarr = obsh.reduce_tri_minimal(Obsdata, biarr)
+
+ uv1 = np.hstack((biarr['u1'].reshape(-1, 1), biarr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((biarr['u2'].reshape(-1, 1), biarr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((biarr['u3'].reshape(-1, 1), biarr['v3'].reshape(-1, 1)))
+ bi = biarr['bispec']
+ sigma = biarr['sigmab']
+
+ # add systematic noise
+ # sigma = np.linalg.norm([biarr['sigmab'], systematic_noise*np.abs(biarr['bispec'])], axis=0)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1)
+ A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2)
+ A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3)
+ A = [A1, A2, A3]
+
+ return (bi, sigma, A)
+
+
+def chisqdata_cphase_nfft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the closure phases, sigmas, uv points, and nfft info
+ """
+ if (Prior.xdim % 2 or Prior.ydim % 2):
+ raise Exception("NFFT doesn't work with odd image dimensions!")
+
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ uv_min = kwargs.get('cp_uv_min', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ weighting = kwargs.get('weighting', 'natural')
+ systematic_cphase_noise = kwargs.get('systematic_cphase_noise', 0.)
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.cphase is None) or (len(Obsdata.cphase) == 0) or pol != 'I':
+ clphasearr = Obsdata.c_phases(mode="all", vtype=vtype,
+ count=count, uv_min=uv_min, snrcut=snrcut)
+ else: # TODO precomputed with not Stokes I
+ print("Using pre-computed cphase table in cphase chi^2!")
+ if not type(Obsdata.cphase) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed closure phase table is not a numpy rec array!")
+ clphasearr = Obsdata.cphase
+ # reduce to a minimal set
+ if count != 'max':
+ clphasearr = obsh.reduce_tri_minimal(Obsdata, clphasearr)
+
+ uv1 = np.hstack((clphasearr['u1'].reshape(-1, 1), clphasearr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((clphasearr['u2'].reshape(-1, 1), clphasearr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((clphasearr['u3'].reshape(-1, 1), clphasearr['v3'].reshape(-1, 1)))
+ clphase = clphasearr['cphase']
+ sigma = clphasearr['sigmacp']
+
+ # add systematic cphase noise (in DEGREES)
+ sigma = np.linalg.norm([sigma, systematic_cphase_noise*np.ones(len(sigma))], axis=0)
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1)
+ A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2)
+ A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3)
+ A = [A1, A2, A3]
+
+ return (clphase, sigma, A)
+
+
+def chisqdata_cphase_diag_nfft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the diagonalized closure phases, sigmas, uv points, and nfft info
+ """
+ if (Prior.xdim % 2 or Prior.ydim % 2):
+ raise Exception("NFFT doesn't work with odd image dimensions!")
+
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ uv_min = kwargs.get('cp_uv_min', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ clphasearr = Obsdata.c_phases_diag(vtype=vtype, count=count, snrcut=snrcut, uv_min=uv_min)
+
+ # loop over timestamps
+ clphase_diag = []
+ sigma_diag = []
+ tform_mats = []
+ u1 = []
+ v1 = []
+ u2 = []
+ v2 = []
+ u3 = []
+ v3 = []
+ for ic, cl in enumerate(clphasearr):
+
+ # get diagonalized closure phases and errors
+ clphase_diag.append(cl[0]['cphase'])
+ sigma_diag.append(cl[0]['sigmacp'])
+
+ # get u and v values
+ u1.append(cl[2][:, 0].astype('float'))
+ v1.append(cl[3][:, 0].astype('float'))
+ u2.append(cl[2][:, 1].astype('float'))
+ v2.append(cl[3][:, 1].astype('float'))
+ u3.append(cl[2][:, 2].astype('float'))
+ v3.append(cl[3][:, 2].astype('float'))
+
+ # get transformation matrix for this timestamp
+ tform_mats.append(cl[4].astype('float'))
+
+ # fix formatting of arrays
+ u1 = np.concatenate(u1)
+ v1 = np.concatenate(v1)
+ u2 = np.concatenate(u2)
+ v2 = np.concatenate(v2)
+ u3 = np.concatenate(u3)
+ v3 = np.concatenate(v3)
+
+ # get uv arrays
+ uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1)))
+ uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1)))
+ uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1)))
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1)
+ A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2)
+ A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3)
+ A = [A1, A2, A3]
+
+ # combine Fourier and transformation matrices into tuple for outputting
+ Amatrices = (A, np.array(tform_mats))
+
+ return (np.array(clphase_diag), np.array(sigma_diag), Amatrices)
+
+
+def chisqdata_camp_nfft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the closure amplitudes, sigmas, uv points, and nfft info
+ """
+ if (Prior.xdim % 2 or Prior.ydim % 2):
+ raise Exception("NFFT doesn't work with odd image dimensions!")
+
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.camp is None) or (len(Obsdata.camp) == 0) or pol != 'I':
+ clamparr = Obsdata.c_amplitudes(mode='all', count=count,
+ vtype=vtype, ctype='camp', debias=debias, snrcut=snrcut)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed closure amplitude table in closure amplitude chi^2!")
+ if not type(Obsdata.camp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed closure amplitude table is not a numpy rec array!")
+ clamparr = Obsdata.camp
+ # reduce to a minimal set
+ if count != 'max':
+ clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='camp')
+
+ uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1)))
+ uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1)))
+ clamp = clamparr['camp']
+ sigma = clamparr['sigmaca']
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1)
+ A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2)
+ A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3)
+ A4 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv4)
+ A = [A1, A2, A3, A4]
+
+ return (clamp, sigma, A)
+
+
+def chisqdata_logcamp_nfft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the log closure amplitudes, sigmas, uv points, and nfft info
+ """
+ if (Prior.xdim % 2 or Prior.ydim % 2):
+ raise Exception("NFFT doesn't work with odd image dimensions!")
+
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ weighting = kwargs.get('weighting', 'natural')
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ vtype = ehc.vis_poldict[pol]
+ if (Obsdata.logcamp is None) or (len(Obsdata.logcamp) == 0) or pol != 'I':
+ clamparr = Obsdata.c_amplitudes(mode='all', count=count,
+ vtype=vtype, ctype='logcamp', debias=debias, snrcut=snrcut)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed log closure amplitude table in log closure amplitude chi^2!")
+ if not type(Obsdata.logcamp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed log closure amplitude table is not a numpy rec array!")
+ clamparr = Obsdata.logcamp
+ # reduce to a minimal set
+ if count != 'max':
+ clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='logcamp')
+
+ uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1)))
+ uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1)))
+ uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1)))
+ uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1)))
+ clamp = clamparr['camp']
+ sigma = clamparr['sigmaca']
+
+ # data weighting
+ if weighting == 'uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1)
+ A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2)
+ A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3)
+ A4 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv4)
+ A = [A1, A2, A3, A4]
+
+ return (clamp, sigma, A)
+
+
+def chisqdata_logcamp_diag_nfft(Obsdata, Prior, pol='I', **kwargs):
+ """Return the diagonalized log closure amplitudes, sigmas, uv points, and nfft info
+ """
+ if (Prior.xdim % 2 or Prior.ydim % 2):
+ raise Exception("NFFT doesn't work with odd image dimensions!")
+
+ # unpack keyword args
+ maxset = kwargs.get('maxset', False)
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut = kwargs.get('snrcut', 0.)
+ debias = kwargs.get('debias', True)
+ fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data & mask low snr points
+ vtype = ehc.vis_poldict[pol]
+ clamparr = Obsdata.c_log_amplitudes_diag(vtype=vtype, count=count, debias=debias, snrcut=snrcut)
+
+ # loop over timestamps
+ clamp_diag = []
+ sigma_diag = []
+ tform_mats = []
+ u1 = []
+ v1 = []
+ u2 = []
+ v2 = []
+ u3 = []
+ v3 = []
+ u4 = []
+ v4 = []
+ for ic, cl in enumerate(clamparr):
+
+ # get diagonalized log closure amplitudes and errors
+ clamp_diag.append(cl[0]['camp'])
+ sigma_diag.append(cl[0]['sigmaca'])
+
+ # get u and v values
+ u1.append(cl[2][:, 0].astype('float'))
+ v1.append(cl[3][:, 0].astype('float'))
+ u2.append(cl[2][:, 1].astype('float'))
+ v2.append(cl[3][:, 1].astype('float'))
+ u3.append(cl[2][:, 2].astype('float'))
+ v3.append(cl[3][:, 2].astype('float'))
+ u4.append(cl[2][:, 3].astype('float'))
+ v4.append(cl[3][:, 3].astype('float'))
+
+ # get transformation matrix for this timestamp
+ tform_mats.append(cl[4].astype('float'))
+
+ # fix formatting of arrays
+ u1 = np.concatenate(u1)
+ v1 = np.concatenate(v1)
+ u2 = np.concatenate(u2)
+ v2 = np.concatenate(v2)
+ u3 = np.concatenate(u3)
+ v3 = np.concatenate(v3)
+ u4 = np.concatenate(u4)
+ v4 = np.concatenate(v4)
+
+ # get uv arrays
+ uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1)))
+ uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1)))
+ uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1)))
+ uv4 = np.hstack((u4.reshape(-1, 1), v4.reshape(-1, 1)))
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1)
+ A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2)
+ A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3)
+ A4 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv4)
+ A = [A1, A2, A3, A4]
+
+ # combine Fourier and transformation matrices into tuple for outputting
+ Amatrices = (A, np.array(tform_mats))
+
+ return (np.array(clamp_diag), np.array(sigma_diag), Amatrices)
+
+##################################################################################################
+# Restoring ,Embedding, and Plotting Functions
+##################################################################################################
+
+
+def plot_i(im, Prior, nit, chi2_dict, **kwargs):
+ """Plot the total intensity image at each iteration
+ """
+ cmap = kwargs.get('cmap', 'afmhot')
+ interpolation = kwargs.get('interpolation', 'gaussian')
+ pol = kwargs.get('pol', '')
+ scale = kwargs.get('scale', None)
+ dynamic_range = kwargs.get('dynamic_range', 1.e5)
+ gamma = kwargs.get('dynamic_range', .5)
+
+ plt.ion()
+ plt.pause(1.e-6)
+ plt.clf()
+
+ imarr = im.reshape(Prior.ydim, Prior.xdim)
+
+ if scale == 'log':
+ if (imarr < 0.0).any():
+ print('clipping values less than 0')
+ imarr[imarr < 0.0] = 0.0
+ imarr = np.log(imarr + np.max(imarr)/dynamic_range)
+
+ if scale == 'gamma':
+ if (imarr < 0.0).any():
+ print('clipping values less than 0')
+ imarr[imarr < 0.0] = 0.0
+ imarr = (imarr + np.max(imarr)/dynamic_range)**(gamma)
+
+ plt.imshow(imarr, cmap=plt.get_cmap(cmap), interpolation=interpolation)
+ xticks = obsh.ticks(Prior.xdim, Prior.psize/ehc.RADPERAS/1e-6)
+ yticks = obsh.ticks(Prior.ydim, Prior.psize/ehc.RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel(r'Relative RA ($\mu$as)')
+ plt.ylabel(r'Relative Dec ($\mu$as)')
+ plotstr = str(pol) + " : step: %i " % nit
+ for key in chi2_dict.keys():
+ plotstr += r"$\chi^2_{%s}$: %0.2f " % (key, chi2_dict[key])
+ plt.title(plotstr, fontsize=18)
+
+
+def embed(im, mask, clipfloor=0., randomfloor=False):
+ """Embeds a 1d image array into the size of boolean embed mask
+ """
+
+ out = np.zeros(len(mask))
+
+ # Here's a much faster version than before
+ out[mask.nonzero()] = im
+
+ #if clipfloor != 0.0:
+ if randomfloor: # prevent total variation gradient singularities
+ out[(mask-1).nonzero()] = clipfloor * \
+ np.abs(np.random.normal(size=len((mask-1).nonzero()[0])))
+ else:
+ out[(mask-1).nonzero()] = clipfloor
+
+ return out
diff --git a/imaging/linearize_energy.py b/imaging/linearize_energy.py
new file mode 100644
index 00000000..1ebb1ef6
--- /dev/null
+++ b/imaging/linearize_energy.py
@@ -0,0 +1,111 @@
+from __future__ import division
+
+import numpy as np
+import ehtim.image as image
+
+from ehtim.const_def import *
+from ehtim.observing.obs_helpers import *
+
+
+def linearized_bi(x0, A3, bispec, sigs, nPixels, alpha=100, reg="patch"):
+
+ Alin, blin = computeLinTerms_bi(x0, A3, bispec, sigs, nPixels, alpha=alpha, reg=reg)
+ gradient = 2*np.dot(Alin.T, np.dot(Alin, x0) - blin)
+
+ return -gradient
+
+def linearizedSol_bs(Obsdata, currImage, Prior, alpha=100, beta=100, reg="patch"):
+# note what beta is
+
+ # normalize the prior
+ # TODO: SHOULD THIS BE DONE??
+ zbl = np.nanmax(np.abs(Obsdata.unpack(['vis'])['vis']))
+ nprior = zbl * Prior.imvec / np.sum(Prior.imvec)
+
+ if reg == "patch":
+ linRegTerm, constRegTerm = spatchlingrad(currImage.imvec, nprior)
+
+ # Get bispectra data
+ biarr = Obsdata.bispectra(mode="all", count="max")
+
+ bispec = biarr['bispec']
+ sigs = biarr['sigmab']
+
+ nans = np.isnan(sigs)
+ bispec = bispec[nans==False]
+ sigs = sigs[nans==False]
+ biarr = biarr[nans==False]
+
+
+ uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1)))
+ uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1)))
+ uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1)))
+
+
+ # Compute the fourier matrices
+ A3 = (ftmatrix(currImage.psize, currImage.xdim, currImage.ydim, uv1, pulse=currImage.pulse),
+ ftmatrix(currImage.psize, currImage.xdim, currImage.ydim, uv2, pulse=currImage.pulse),
+ ftmatrix(currImage.psize, currImage.xdim, currImage.ydim, uv3, pulse=currImage.pulse)
+ )
+
+ Alin, blin = computeLinTerms_bi(currImage.imvec, A3, bispec, sigs, currImage.xdim*currImage.ydim, alpha=alpha, reg=reg)
+
+ out = np.linalg.solve(Alin + beta*linRegTerm, blin + beta*constRegTerm);
+
+ return image.Image(out.reshape((currImage.ydim, currImage.xdim)), currImage.psize, currImage.ra, currImage.dec, rf=currImage.rf, source=currImage.source, mjd=currImage.mjd, pulse=currImage.pulse)
+
+
+def computeLinTerms_bi(x0, A3, bispec, sigs, nPixels, alpha=100, reg="patch"):
+
+
+ sigmaR = sigmaI = 1.0/(sigs**2)
+
+ rA = np.real(A3);
+ iA = np.imag(A3);
+
+ rA1 = rA[0];
+ rA2 = rA[1];
+ rA3 = rA[2];
+ iA1 = iA[0];
+ iA2 = iA[1];
+ iA3 = iA[2];
+
+ rA1x0 = np.dot(rA1, x0);
+ rA2x0 = np.dot(rA2, x0);
+ rA3x0 = np.dot(rA3, x0);
+ iA1x0 = np.dot(iA1, x0);
+ iA2x0 = np.dot(iA2, x0);
+ iA3x0 = np.dot(iA3, x0);
+
+ fR = (rA1x0*rA2x0*rA3x0 - rA1x0*iA2x0*iA3x0 - iA1x0*rA2x0*iA3x0 - iA1x0*iA2x0*rA3x0);
+ fI = (rA1x0*rA2x0*iA3x0 + rA1x0*iA2x0*rA3x0 + iA1x0*rA2x0*rA3x0 - iA1x0*iA2x0*iA3x0);
+
+ yR = np.real(bispec);
+ yI = np.imag(bispec);
+
+ #f = np.sum( sigmaR*(fR - yR)**2 + sigmaI*(fI - yI)**2 + 2.0*sigmaRI*(fR - yR)*(fI - yI)) / 2.0;
+
+ # size number of bispectrum x nPixels
+ dxReal = (rA1 *np.tile(rA2x0*rA3x0, [nPixels, 1]).T + rA2*np.tile(rA1x0*rA3x0, [nPixels, 1]).T + rA3*np.tile(rA2x0*rA1x0, [nPixels, 1]).T \
+ -( rA1*np.tile(iA2x0*iA3x0, [nPixels, 1]).T + iA2*np.tile(rA1x0*iA3x0, [nPixels, 1]).T + iA3*np.tile(iA2x0*rA1x0, [nPixels, 1]).T) \
+ -( iA1*np.tile(rA2x0*iA3x0, [nPixels, 1]).T + rA2*np.tile(iA1x0*iA3x0, [nPixels, 1]).T + iA3*np.tile(rA2x0*iA1x0, [nPixels, 1]).T ) \
+ -( iA1*np.tile(iA2x0*rA3x0, [nPixels, 1]).T + iA2*np.tile(iA1x0*rA3x0, [nPixels, 1]).T + rA3*np.tile(iA2x0*iA1x0, [nPixels, 1]).T ) );
+
+ # size number of bispectrum x nPixels
+ dxImag = (rA1*np.tile(rA2x0*iA3x0, [nPixels, 1]).T + rA2*np.tile(rA1x0*iA3x0, [nPixels, 1]).T + iA3*np.tile(rA2x0*rA1x0, [nPixels, 1]).T \
+ +( rA1*np.tile(iA2x0*rA3x0, [nPixels, 1]).T + iA2*np.tile(rA1x0*rA3x0, [nPixels, 1]).T + rA3*np.tile(iA2x0*rA1x0, [nPixels, 1]).T) \
+ +( iA1*np.tile(rA2x0*rA3x0, [nPixels, 1]).T + rA2*np.tile(iA1x0*rA3x0, [nPixels, 1]).T + rA3*np.tile(rA2x0*iA1x0, [nPixels, 1]).T ) \
+ -( iA1*np.tile(iA2x0*iA3x0, [nPixels, 1]).T + iA2*np.tile(iA1x0*iA3x0, [nPixels, 1]).T + iA3*np.tile(iA2x0*iA1x0, [nPixels, 1]).T ) );
+
+ #size number of bixpectrum x 1
+ betaR = fR - np.dot(dxReal,x0);
+ betaI = fI - np.dot(dxImag,x0);
+
+ blin = np.dot( np.transpose( np.dot( np.diag( alpha*sigmaR ) , dxReal)) , (yR - betaR) ) + np.dot( np.transpose( np.dot( np.diag( alpha*sigmaI ), dxImag )), (yI - betaI) ) ;
+ Alin = np.dot( np.transpose( np.dot( np.diag( alpha*sigmaR ) , dxReal)) , dxReal ) + np.dot( np.transpose( np.dot( np.diag( alpha*sigmaI ), dxImag )) , dxImag );
+
+ return (Alin, blin)
+
+def spatchlingrad(imvec, priorvec):
+ return (np.diag(np.ones(len(priorvec))), priorvec)
+
diff --git a/imaging/multifreq_imager_utils.py b/imaging/multifreq_imager_utils.py
new file mode 100644
index 00000000..f2ce28b7
--- /dev/null
+++ b/imaging/multifreq_imager_utils.py
@@ -0,0 +1,278 @@
+# multifreq_imager_utils.py
+# imager functions for multifrequency VLBI data
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+from builtins import range
+
+import string
+import time
+import numpy as np
+import scipy.optimize as opt
+import scipy.ndimage as nd
+import scipy.ndimage.filters as filt
+import matplotlib.pyplot as plt
+try:
+ from pynfft.nfft import NFFT
+except ImportError:
+ pass
+ #print("Warning: No NFFT installed! Cannot use nfft functions")
+
+import ehtim.image as image
+from . import linearize_energy as le
+
+from ehtim.const_def import *
+from ehtim.observing.obs_helpers import *
+from ehtim.statistics.dataframes import *
+
+NORM_REGULARIZER = True
+EPSILON = 1.e-12
+
+##################################################################################################
+# Mulitfrequency regularizers
+##################################################################################################
+
+def regularizer_mf(imvec, nprior, mask, flux, xdim, ydim, psize, stype, **kwargs):
+ """return the regularizer value on spectral index or curvature
+ """
+
+ norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER)
+ beam_size = kwargs.get('beam_size', psize)
+
+ if stype == "l2_alpha" or stype=="l2_beta":
+ s = -l2_alpha(imvec, nprior, norm_reg=norm_reg)
+ elif stype == "tv_alpha" or stype=="tv_beta":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=False)
+ s = -tv_alpha(imvec, xdim, ydim, psize, norm_reg=norm_reg, beam_size=beam_size)
+ else:
+ s = 0
+
+ return s
+
+def regularizergrad_mf(imvec, nprior, mask, flux, xdim, ydim, psize, stype, **kwargs):
+ """return the regularizer gradient
+ """
+
+ norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER)
+ beam_size = kwargs.get('beam_size', psize)
+
+ if stype == "l2_alpha" or stype=="l2_beta":
+ s = -l2_alpha_grad(imvec, nprior)
+ elif stype == "tv_alpha" or stype=="tv_beta":
+ if np.any(np.invert(mask)):
+ imvec = embed(imvec, mask, randomfloor=False)
+ s = -tv_alpha_grad(imvec, xdim, ydim, psize, norm_reg=norm_reg, beam_size=beam_size)
+ s = s[mask]
+ else:
+ s = 0
+
+ return s
+
+
+def l2_alpha(imvec, priorvec, norm_reg=NORM_REGULARIZER):
+ """L2 norm on spectral index w/r/t prior
+ """
+
+ if norm_reg:
+ norm = float(len(imvec))
+ else:
+ norm = 1
+
+ out = -(np.sum((imvec - priorvec)**2))
+ return out/norm
+
+def l2_alpha_grad(imvec, priorvec, norm_reg=NORM_REGULARIZER):
+ """L2 norm on spectral index w/r/t prior
+ """
+
+ if norm_reg:
+ norm = float(len(imvec))
+ else:
+ norm = 1
+
+ out = -2*(np.sum(imvec - priorvec))*np.ones(len(imvec))
+ return out/norm
+
+
+def tv_alpha(imvec, nx, ny, psize, norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Total variation regularizer
+ """
+ if beam_size is None: beam_size = psize
+ if norm_reg:
+ norm = len(imvec)*psize / beam_size
+ else:
+ norm = 1
+
+ im = imvec.reshape(ny, nx)
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ out = -np.sum(np.sqrt(np.abs(im_l1 - im)**2 + np.abs(im_l2 - im)**2 + EPSILON))
+
+ return out/norm
+
+def tv_alpha_grad(imvec, nx, ny, psize, norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Total variation gradient
+ """
+ if beam_size is None: beam_size = psize
+ if norm_reg:
+ norm = len(imvec)*psize / beam_size
+ else:
+ norm = 1
+
+ im = imvec.reshape(ny,nx)
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1]
+ im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1]
+
+ #rotate images
+ im_r1l2 = np.roll(np.roll(impad, 1, axis=0),-1, axis=1)[1:ny+1, 1:nx+1]
+ im_l1r2 = np.roll(np.roll(impad, -1, axis=0), 1, axis=1)[1:ny+1, 1:nx+1]
+
+ #add together terms and return
+ g1 = (2*im - im_l1 - im_l2) / np.sqrt((im - im_l1)**2 + (im - im_l2)**2 + EPSILON)
+ g2 = (im - im_r1) / np.sqrt((im - im_r1)**2 + (im_r1l2 - im_r1)**2 + EPSILON)
+ g3 = (im - im_r2) / np.sqrt((im - im_r2)**2 + (im_l1r2 - im_r2)**2 + EPSILON)
+
+ #mask the first row column gradient terms that don't exist
+ mask1 = np.zeros(im.shape)
+ mask2 = np.zeros(im.shape)
+ mask1[0,:] = 1
+ mask2[:,0] = 1
+ g2[mask1.astype(bool)] = 0
+ g3[mask2.astype(bool)] = 0
+
+ # add terms together and return
+ out= -(g1 + g2 + g3).flatten()
+ return out/norm
+
+##################################################################################################
+
+##################################################################################################
+def unpack_mftuple(imvec, inittuple, nimage, mf_solve = (1,1,0)):
+ """unpack imvec into tuple,
+ replaces quantities not iterated with their initial values
+ """
+ init0 = inittuple[0]
+ init1 = inittuple[1]
+ init2 = inittuple[2]
+
+ imct = 0
+ if mf_solve[0] == 0:
+ im0 = init0
+ else:
+ im0 = imvec[imct*nimage:(imct+1)*nimage]
+ imct += 1
+
+ if mf_solve[1] == 0:
+ im1 = init1
+ else:
+ im1 = imvec[imct*nimage:(imct+1)*nimage]
+ imct += 1
+
+ if mf_solve[2] == 0:
+ im2 = init2
+ else:
+ im2 = imvec[imct*nimage:(imct+1)*nimage]
+ imct += 1
+ return np.array((im0, im1, im2))
+
+def pack_mftuple(mftuple, mf_solve = (1,1,0)):
+ """pack multifreq data into image vector,
+ ignore quantities not iterated
+ """
+
+ vec = np.array([])
+ if mf_solve[0] != 0:
+ vec = np.hstack((vec,mftuple[0]))
+ if mf_solve[1] != 0:
+ vec = np.hstack((vec,mftuple[1]))
+ if mf_solve[2] != 0:
+ vec = np.hstack((vec,mftuple[2]))
+
+ return vec
+
+def embed(im, mask, clipfloor=0., randomfloor=False):
+ """Embeds a 1d image array into the size of boolean embed mask
+ """
+
+ out = np.zeros(len(mask))
+
+ # Here's a much faster version than before
+ out[mask.nonzero()] = im
+
+ if clipfloor != 0.0:
+ if randomfloor: # prevent total variation gradient singularities
+ out[(mask-1).nonzero()] = clipfloor * \
+ np.abs(np.random.normal(size=len((mask-1).nonzero())))
+ else:
+ out[(mask-1).nonzero()] = clipfloor
+
+ return out
+
+def embed_mf(imtuple, mask, clipfloor=0., randomfloor=False):
+ """Embeds a multifrequency image tuple into the size of boolean embed mask
+ """
+ out0=np.zeros(len(mask))
+ out1=np.zeros(len(mask))
+ out2=np.zeros(len(mask))
+
+ # Here's a much faster version than before
+ out0[mask.nonzero()] = imtuple[0]
+ out1[mask.nonzero()] = imtuple[1]
+ out2[mask.nonzero()] = imtuple[2]
+
+ if clipfloor != 0.0:
+ if randomfloor: # prevent total variation gradient singularities
+ out0[(mask-1).nonzero()] = clipfloor * np.abs(np.random.normal(size=len((mask-1).nonzero())))
+ out1[(mask-1).nonzero()] = clipfloor * np.abs(np.random.normal(size=len((mask-1).nonzero())))
+ out2[(mask-1).nonzero()] = clipfloor * np.abs(np.random.normal(size=len((mask-1).nonzero())))
+ else:
+ out0[(mask-1).nonzero()] = clipfloor
+ out1[(mask-1).nonzero()] = clipfloor
+ out2[(mask-1).nonzero()] = clipfloor
+
+ return (out0, out1, out2)
+
+def imvec_at_freq(mftuple, log_freqratio):
+ """Get the image at a frequency given by ref_freq*e(log_freqratio)
+ Remember spectral index is defined with a + sign!
+ """
+ imvec_ref_log = np.log(mftuple[0])
+ spectral_index = mftuple[1]
+ curvature = mftuple[2]
+
+ logimvec = imvec_ref_log + spectral_index*log_freqratio + curvature*log_freqratio*log_freqratio
+ imvec = np.exp(logimvec)
+ return imvec
+
+def mf_all_grads_chain(funcgrad, imvec_cur, imvec_ref, log_freqratio):
+ """Get the gradients of the reference image, spectral index, and curvature
+ w/r/t the gradient of a function funcgrad to the image given frequency ref_freq*e(log_freqratio)
+ """
+
+ dfunc_dI0 = funcgrad * imvec_cur / imvec_ref
+ dfunc_dalpha = funcgrad * imvec_cur * log_freqratio
+ dfunc_dbeta = funcgrad * imvec_cur * log_freqratio * log_freqratio
+
+ return np.array((dfunc_dI0, dfunc_dalpha, dfunc_dbeta))
diff --git a/imaging/patch_prior.py b/imaging/patch_prior.py
new file mode 100644
index 00000000..d2dd6b1d
--- /dev/null
+++ b/imaging/patch_prior.py
@@ -0,0 +1,150 @@
+# patch_prior.py
+#
+# Create a "prior" image for vlbi imaging by cleaning the input image.
+# Image cleaning is done by breaking the image into patches and cleaning each one
+# individually by assigning it a cluster in the input Gaussian mixture model and
+# using a weiner filter to denoise each patch.
+# These ideas are based on Expected Patch Log Likelihood patch prior work
+#
+# Code Author: Katie Bouman
+# Date: June 1, 2016
+
+from __future__ import division
+from builtins import map
+from builtins import range
+
+from matplotlib import pyplot as plt
+import ehtim.image as image
+import scipy.io
+import numpy as np
+
+def patchPrior(im, beta, patchPriorFile='naturalPrior.mat', patchSize=8 ):
+
+ # load data
+ ldata = scipy.io.loadmat(patchPriorFile)
+
+ # reassign and reshape data
+ nmodels = ldata['nmodels'].ravel()
+ nmodels = nmodels[0]
+ mixweights = ldata['mixweights'].ravel()
+ covs = np.array(ldata['covs'])
+ means = np.array(ldata['means'])
+
+ # reshape image
+ img = np.reshape(im.imvec, (im.ydim, im.xdim) )
+
+ I1, counts = cleanImage(img, beta, nmodels, covs, mixweights, means, patchSize)
+
+ if not all(counts[0][0] == item for item in np.reshape(counts, (-1)) ):
+ raise TypeError("The counts are not the same for every pixel in the image")
+
+ I1 = I1/counts[0][0]
+ out = image.Image(I1, im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse)
+
+ return (out, counts[0][0])
+
+def cleanImage(img, beta, nmodels, covs, mixweights, means, patchSize=8):
+
+ # pad images with 0's
+ validRegion = np.lib.pad( np.ones(img.shape), (patchSize-1, patchSize-1), 'constant', constant_values=(0, 0) )
+ cleanIPad = np.lib.pad( img , (patchSize-1, patchSize-1), 'constant', constant_values=(0, 0) )
+
+ # adjust the dynamic range of image to be in the range 0 to 1
+ minCleanI = min( np.reshape(cleanIPad, (-1)) )
+ cleanIPad = cleanIPad - minCleanI
+ maxCleanI = max( np.reshape(cleanIPad, (-1)) )
+ cleanIPad = cleanIPad / maxCleanI;
+
+ # extract all overlapping patches from the image
+ Z = im2col(np.transpose(cleanIPad), patchSize)
+
+ # clean each patch by weiner filtering
+ meanZ = np.mean(Z,0)
+ Z = Z - np.tile( meanZ, [patchSize**2, 1] );
+ cleanZ = cleanPatches( Z,patchSize,(beta)**(-0.5), nmodels, covs, mixweights, means);
+ cleanZ = cleanZ + np.tile( meanZ, [patchSize**2, 1] );
+
+ # join all patches together
+ mm = validRegion.shape[0]
+ nn = validRegion.shape[1]
+ t = np.reshape(list(range(0,mm*nn,1)), (mm, nn) )
+ temp = im2col(t, patchSize)
+ I1 = np.transpose( np.bincount( np.array(list(map(int, np.reshape(temp, (-1)) ))), weights=np.reshape(cleanZ, (-1))) )
+ counts = np.transpose( np.bincount( np.array(list(map(int, np.reshape(temp, (-1)) ))), weights=np.reshape(np.ones(cleanZ.shape), (-1))) )
+
+ # normalize and put back in the original scale
+ I1 = I1/counts;
+ I1 = (I1*maxCleanI) + minCleanI;
+ I1 = I1*counts;
+
+ # set all negative entries to 0 (hacky)
+ I1[I1<0] = 0;
+
+ # crop out the center valid region
+ I1 = np.extract(np.reshape(validRegion, (-1)), I1)
+ counts = np.extract(np.reshape(validRegion, (-1)), counts)
+
+ # reshape
+ I1 = np.transpose(np.reshape( I1, (img.shape[1], img.shape[0])));
+ counts = np.transpose(np.reshape( counts, (img.shape[1], img.shape[0])));
+
+ return I1, counts
+
+def im2col(im, patchSize):
+
+ # extract all overlapping patches from the image
+ M,N = im.shape
+ col_extent = N - patchSize + 1
+ row_extent = M - patchSize + 1
+ # Get Starting block indices
+ start_idx = np.arange(patchSize)[:,None]*N + np.arange(patchSize)
+ # Get offsetted indices across the height and width of input array
+ offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent)
+ # Get all actual indices & index into input array for final output
+ Z = np.take (im,start_idx.ravel()[:,None] + offset_idx.ravel())
+ return Z
+
+
+def cleanPatches(Y, patchSize, noiseSD, nmodels, covs, mixweights, means):
+
+ SigmaNoise = noiseSD**2 * np.eye(patchSize**2);
+
+ #remove DC component
+ meanY = np.mean(Y,0);
+ Y = Y - np.tile(meanY, [Y.shape[0], 1] );
+
+ #calculate assignment probabilities for each mixture component for all patches
+ PYZ = np.zeros((nmodels,Y.shape[1]));
+ for i in range (0,nmodels):
+ PYZ[i,:] = np.log(mixweights[i]) + loggausspdf2(Y, covs[:,:,i] + SigmaNoise);
+
+
+ #find the most likely component for each patch
+ ks = PYZ.argmax(axis = 0)
+
+ # and now perform weiner filtering
+ Xhat = np.zeros(Y.shape);
+ for i in range (0,nmodels):
+ inds = np.array(np.where(ks==i)).ravel()
+ Xhat[:,inds] = np.dot( covs[:,:,i], np.dot( np.linalg.inv( covs[:,:,i]+SigmaNoise ), Y[:,inds] ) ) + np.dot( SigmaNoise, np.dot( np.linalg.inv(covs[:,:,i]+SigmaNoise), np.transpose(np.tile( np.transpose(means[:,i]), [inds.shape[0], 1] )) ));
+
+
+ Xhat = Xhat + np.tile(meanY, [Xhat.shape[0], 1] )
+ return Xhat
+
+
+def loggausspdf2(X, sigma):
+#log pdf of Gaussian with zero mena
+#Based on code written by Mo Chen (mochen@ie.cuhk.edu.hk). March 2009.
+ d = X.shape[0]
+
+ R = np.linalg.cholesky(sigma).T;
+ # todo check that sigma is psd
+
+ q = np.sum( ( np.dot( np.linalg.inv(np.transpose(R)) , X ) )**2 , 0); # quadratic term (M distance)
+ c = d*np.log(2*np.pi)+2*np.sum(np.log( np.diagonal(R) ), 0); # normalization constant
+ y = -(c+q)/2.0;
+
+ return y
+
+
diff --git a/imaging/pol_imager_utils.py b/imaging/pol_imager_utils.py
new file mode 100644
index 00000000..e34f4d6e
--- /dev/null
+++ b/imaging/pol_imager_utils.py
@@ -0,0 +1,2098 @@
+# pol_imager_utils.py
+# General imager functions for polarimetric VLBI data
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+#TODO
+# FIX NFFTS for m and for p -- offsets? imag?
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+from builtins import range
+
+import string
+import time
+import copy
+import numpy as np
+import scipy.optimize as opt
+import scipy.ndimage as ndF
+import scipy.ndimage.filters as filt
+import matplotlib.pyplot as plt
+try:
+ from pynfft.nfft import NFFT
+except ImportError:
+ pass
+ #print("Warning: No NFFT installed! Cannot use nfft functions")
+from scipy.special import jv
+
+import ehtim.image as image
+from . import linearize_energy as le
+
+from ehtim.const_def import *
+from ehtim.observing.obs_helpers import *
+
+TANWIDTH_M = 0.5
+TANWIDTH_V = 1
+
+##################################################################################################
+# Constants & Definitions
+##################################################################################################
+
+NORM_REGULARIZER = False #ANDREW TODO change this default in the future
+
+MAXLS = 100 # maximum number of line searches in L-BFGS-B
+NHIST = 100 # number of steps to store for hessian approx
+MAXIT = 100 # maximum number of iterations
+STOP = 1.e-100 # convergence criterion
+
+DATATERMS_POL = ['pvis', 'm', 'pbs','vvis']
+REGULARIZERS_POL = ['msimple', 'hw', 'ptv','l1v','l2v','vtv','v2tv2','vflux']
+
+nit = 0 # global variable to track the iteration number in the plotting callback
+
+def qimage(iimage, mimage, chiimage):
+ """Return the Q image from m and chi"""
+ return iimage * mimage * np.cos(2*chiimage)
+
+def uimage(iimage, mimage, chiimage):
+ """Return the U image from m and chi"""
+ return iimage * mimage * np.sin(2*chiimage)
+
+
+##################################################################################################
+# Polarimetric Imager
+##################################################################################################
+def pol_imager_func(Obsdata, InitIm, Prior,
+ pol_trans=True, pol_solve = (0,1,1),
+ d1='pvis', d2=False,
+ s1='msimple', s2=False,
+ alpha_d1=100, alpha_d2=100,
+ alpha_s1=1, alpha_s2=1,
+ **kwargs):
+
+ """Run a polarimetric imager.
+
+ Args:
+ Obsdata (Obsdata): The Obsdata object with polarimetric VLBI data
+ Prior (Image): The Image object with the prior image
+ InitIm (Image): The Image object with the initial image for the minimization
+
+ pol_trans (bool): polarization scaled to physical values or not
+ pol_solve (tuple): len 3 tuple, solve for the corresponding pol image when not zero
+
+ d1 (str): The first data term; options are 'pvis', 'm', 'pbs'
+ d2 (str): The second data term; options are 'pvis', 'm', 'pbs'
+ s1 (str): The first regularizer; options are 'msimple', 'hw', 'ptv'
+ s2 (str): The second regularizer; options are 'msimple', 'hw', 'ptv'
+ alpha_d1 (float): The first data term weighting
+ alpha_d2 (float): The second data term weighting
+ alpha_s1 (float): The first regularizer term weighting
+ alpha_s2 (float): The second regularizer term weighting
+
+ ttype (str): The Fourier transform type; options are 'fast' and 'direct'
+ fft_pad_factor (float): The FFT will pre-pad the image by this factor x the original size
+ p_rad (float): The pixel radius for the convolving function in gridding for FFTs
+
+ clipfloor (float): The Stokes I Jy/pixel level above which prior image pixels are varied
+ grads (bool): If True, analytic gradients are used
+ norm_reg (bool): If True, normalizes regularizer terms
+ beam_size (float): beam size in radians for normalizing the regularizers
+ flux (float): The total flux of the output image in Jy
+
+ maxit (int): Maximum number of minimizer iterations
+ stop (float): The convergence criterion
+ show_updates (bool): If True, displays the progress of the minimizer
+
+ Returns:
+ Image: Image object with result
+ """
+
+ # some kwarg default values
+ maxit = kwargs.get('maxit', MAXIT)
+ stop = kwargs.get('stop', STOP)
+ clipfloor = kwargs.get('clipfloor', -1)
+ ttype = kwargs.get('ttype','direct')
+ grads = kwargs.get('grads',True)
+ norm_init = kwargs.get('norm_init',False)
+ show_updates = kwargs.get('show_updates',True)
+
+ beam_size = kwargs.get('beam_size',Obsdata.res())
+ flux = kwargs.get('flux',InitIm.total_flux())
+ kwargs['beam_size'] = beam_size
+
+ # Make sure data and regularizer options are ok
+ if not d1 and not d2:
+ raise Exception("Must have at least one data term!")
+ if not s1 and not s2:
+ raise Exception("Must have at least one regularizer term!")
+ if (not ((d1 in DATATERMS_POL) or d1==False)) or (not ((d2 in DATATERMS_POL) or d2==False)):
+ raise Exception("Invalid data term: valid data terms are: " + ' '.join(DATATERMS_POL))
+ if (not ((s1 in REGULARIZERS_POL) or s1==False)) or (not ((s2 in REGULARIZERS_POL) or s2==False)):
+ raise Exception("Invalid regularizer: valid regularizers are: " + ' '.join(REGULARIZERS_POL))
+ if (Prior.psize != InitIm.psize) or (Prior.xdim != InitIm.xdim) or (Prior.ydim != InitIm.ydim):
+ raise Exception("Initial image does not match dimensions of the prior image!")
+ if (ttype not in ["direct","nfft"]):
+ raise Exception("FFT no yet implemented in polarimetric imaging -- use NFFT!")
+ if (not pol_trans):
+ raise Exception("Only pol_trans==True supported!")
+ if (len(pol_solve)!=3):
+ raise Exception("pol_solve tuple must have 3 entries!")
+
+ # Catch scale and dimension problems
+ imsize = np.max([Prior.xdim, Prior.ydim]) * Prior.psize
+ uvmax = 1.0/Prior.psize
+ uvmin = 1.0/imsize
+ uvdists = Obsdata.unpack('uvdist')['uvdist']
+ maxbl = np.max(uvdists)
+ minbl = np.max(uvdists[uvdists > 0])
+ maxamp = np.max(np.abs(Obsdata.unpack('amp')['amp']))
+
+ if uvmax < maxbl:
+ print("Warning! Pixel Spacing is larger than smallest spatial wavelength!")
+ if uvmin > minbl:
+ print("Warning! Field of View is smaller than largest nonzero spatial wavelength!")
+
+ # convert polrep to stokes
+ # ANDREW todo -- make more general??
+ Prior = Prior.switch_polrep(polrep_out='stokes', pol_prim_out='I')
+ InitIm = InitIm.switch_polrep(polrep_out='stokes', pol_prim_out='I')
+
+ # embedding mask
+ embed_mask = Prior.imvec > clipfloor
+
+ # initial Stokes I image
+ iimage = InitIm.imvec[embed_mask]
+ nimage = len(iimage)
+
+ # initial pol image
+ if pol_trans:
+ if len(InitIm.qvec) and (np.any(InitIm.qvec!=0) or np.any(InitIm.uvec!=0)): #TODO right? or should it be if any=0
+ init1 = (np.abs(InitIm.qvec + 1j*InitIm.uvec) / InitIm.imvec)[embed_mask]
+ init2 = (np.arctan2(InitIm.uvec, InitIm.qvec) / 2.0)[embed_mask]
+ else:
+ # !AC TODO get the actual zero baseline pol. frac from the data!??
+ print("No polarimetric image in the initial image!")
+ init1 = 0.2 * (np.ones(len(iimage)) + 1e-2 * np.random.rand(len(iimage)))
+ init2 = np.zeros(len(iimage)) + 1e-2 * np.random.rand(len(iimage))
+
+ # Change of variables
+ inittuple = np.array((iimage, init1, init2))
+ xtuple = mcv_r(inittuple)
+
+ # Get data and fourier matrices for the data terms
+ (data1, sigma1, A1) = polchisqdata(Obsdata, Prior, embed_mask, d1, **kwargs)
+ (data2, sigma2, A2) = polchisqdata(Obsdata, Prior, embed_mask, d2, **kwargs)
+
+ # Define the chi^2 and chi^2 gradient
+ def chisq1(imtuple):
+ return polchisq(imtuple, A1, data1, sigma1, d1, ttype=ttype,
+ mask=embed_mask, pol_trans=pol_trans)
+
+ def chisq1grad(imtuple):
+ return polchisqgrad(imtuple, A1, data1, sigma1, d1, ttype=ttype, mask=embed_mask,
+ pol_trans=pol_trans, pol_solve=pol_solve)
+
+ def chisq2(imtuple):
+ return polchisq(imtuple, A2, data2, sigma2, d2, ttype=ttype, mask=embed_mask,
+ pol_trans=pol_trans)
+
+ def chisq2grad(imtuple):
+ return polchisqgrad(imtuple, A2, data2, sigma2, d2, ttype=ttype, mask=embed_mask,
+ pol_trans=pol_trans,pol_solve=pol_solve)
+
+ # Define the regularizer and regularizer gradient
+ def reg1(imtuple):
+ return polregularizer(imtuple, embed_mask, flux, flux, flux,
+ Prior.xdim, Prior.ydim, Prior.psize, s1, **kwargs)
+
+ def reg1grad(imtuple):
+ return polregularizergrad(imtuple, embed_mask, flux, flux, flux,
+ Prior.xdim, Prior.ydim, Prior.psize, s1, **kwargs)
+
+ def reg2(imtuple):
+ return polregularizer(imtuple, embed_mask, flux, flux, flux,
+ Prior.xdim, Prior.ydim, Prior.psize, s2, **kwargs)
+
+ def reg2grad(imtuple):
+ return polregularizergrad(imtuple, embed_mask, flux, flux, flux,
+ Prior.xdim, Prior.ydim, Prior.psize, s2, **kwargs)
+
+
+ # Define the objective function and gradient
+ def objfunc(allvec):
+ # unpack allvec into image tuple
+ cvtuple = unpack_poltuple(allvec, xtuple, nimage, pol_solve)
+
+ # change of variables
+ if pol_trans:
+ imtuple = mcv(cvtuple)
+ else:
+ raise Exception()
+
+ datterm = alpha_d1 * (chisq1(imtuple) - 1) + alpha_d2 * (chisq2(imtuple) - 1)
+ regterm = alpha_s1 * reg1(imtuple) + alpha_s2 * reg2(imtuple)
+
+ return datterm + regterm
+
+ def objgrad(allvec):
+ # unpack allvec into image tuple
+ cvtuple = unpack_poltuple(allvec, xtuple, nimage, pol_solve)
+
+ # change of variables
+ if pol_trans:
+ imtuple = mcv(cvtuple)
+ else:
+ raise Exception()
+
+ datterm = alpha_d1 * chisq1grad(imtuple) + alpha_d2 * chisq2grad(imtuple)
+ regterm = alpha_s1 * reg1grad(imtuple) + alpha_s2 * reg2grad(imtuple)
+ gradarr = datterm + regterm
+
+ # chain rule
+ if pol_trans:
+ chainarr = mchain(cvtuple)
+ gradarr = gradarr*chainarr
+
+ # repack grad into single vector
+ grad = pack_poltuple(gradarr, pol_solve)
+
+ return grad
+
+ # Define plotting function for each iteration
+ global nit
+ nit = 0
+ def plotcur(im_step):
+ global nit
+ cvtuple = unpack_poltuple(im_step, xtuple, nimage, pol_solve)
+ if pol_trans:
+ imtuple = mcv(cvtuple) #change of variables
+ else:
+ raise Exception()
+
+ if show_updates:
+ #print( np.max(np.abs((imtuple[2]-xx_static)/xx_static)))
+ chi2_1 = chisq1(imtuple)
+ chi2_2 = chisq2(imtuple)
+ s_1 = reg1(imtuple)
+ s_2 = reg2(imtuple)
+ if np.any(np.invert(embed_mask)):
+ imtuple = embed_pol(imtuple, embed_mask)
+ plot_m(imtuple, Prior, nit, {d1:chi2_1, d2:chi2_2})
+ print("i: %d chi2_1: %0.2f chi2_2: %0.2f s_1: %0.2f s_2: %0.2f" % (nit, chi2_1, chi2_2,s_1,s_2))
+ nit += 1
+
+ # Print stats
+ print("Initial S_1: %f S_2: %f" % (reg1(inittuple), reg2(inittuple)))
+ print("Initial Chi^2_1: %f Chi^2_2: %f" % (chisq1(inittuple), chisq2(inittuple)))
+ if d1 in DATATERMS_POL:
+ print("Total Data 1: ", (len(data1)))
+ if d2 in DATATERMS_POL:
+ print("Total Data 2: ", (len(data2)))
+ print("Total Pixel #: ", (len(Prior.imvec)))
+ print("Clipped Pixel #: ", nimage)
+ print()
+
+ # Plot Initial
+ xinit = pack_poltuple(xtuple.copy(), pol_solve)
+ plotcur(xinit)
+
+ # Minimize
+ optdict = {'maxiter':maxit, 'ftol':stop, 'maxcor':NHIST,'maxls':MAXLS,'gtol':stop,'maxfun':1.e100} # minimizer dict params
+ tstart = time.time()
+ if grads:
+ res = opt.minimize(objfunc, xinit, method='L-BFGS-B', jac=objgrad,callback=plotcur,
+ options=optdict)
+ else:
+ res = opt.minimize(objfunc, xinit, method='L-BFGS-B',
+ options=optdict, callback=plotcur)
+
+
+
+ tstop = time.time()
+
+ # Format output
+ outcv = unpack_poltuple(res.x, xtuple, nimage, pol_solve)
+ if pol_trans:
+ outcut = mcv(outcv) #change of variables
+ else:
+ outcut = outcv
+
+ if np.any(np.invert(embed_mask)):
+ out = embed_pol(out, embed_mask) #embed
+ else:
+ out = outcut
+
+ iimage = out[0]
+ qimage = make_q_image(out, pol_trans)
+ uimage = make_u_image(out, pol_trans)
+
+ outim = image.Image(iimage.reshape(Prior.ydim, Prior.xdim), Prior.psize,
+ Prior.ra, Prior.dec, rf=Prior.rf, source=Prior.source,
+ mjd=Prior.mjd, pulse=Prior.pulse)
+ outim.add_qu(qimage.reshape(Prior.ydim, Prior.xdim), uimage.reshape(Prior.ydim, Prior.xdim))
+
+ # Print stats
+ print("time: %f s" % (tstop - tstart))
+ print("J: %f" % res.fun)
+ print("Final Chi^2_1: %f Chi^2_2: %f" % (chisq1(outcut), chisq2(outcut)))
+ print(res.message)
+
+ # Return Image object
+ return outim
+
+##################################################################################################
+# Linear Polarimetric image representations and Change of Variables
+##################################################################################################
+def pack_poltuple(poltuple, pol_solve = (0,1,1)):
+ """pack polvec into image vector,
+ ignore quantities not iterated
+ """
+
+ vec = np.array([])
+ if pol_solve[0] != 0:
+ vec = np.hstack((vec,poltuple[0]))
+ if pol_solve[1] != 0:
+ vec = np.hstack((vec,poltuple[1]))
+ if pol_solve[2] != 0:
+ vec = np.hstack((vec,poltuple[2]))
+ if len(pol_solve)==4 and pol_solve[3] != 0:
+ vec = np.hstack((vec,poltuple[3]))
+
+ return vec
+
+
+def unpack_poltuple(polvec, inittuple, nimage, pol_solve = (0,1,1)):
+ """unpack polvec into image tuple,
+ replaces quantities not iterated with initial values
+ """
+ init0 = inittuple[0]
+ init1 = inittuple[1]
+ init2 = inittuple[2]
+ if len(pol_solve)==4:
+ init3 = inittuple[3]
+
+ imct = 0
+ if pol_solve[0] == 0:
+ im0 = init0
+ else:
+ im0 = polvec[imct*nimage:(imct+1)*nimage]
+ imct += 1
+
+ if pol_solve[1] == 0:
+ im1 = init1
+ else:
+ im1 = polvec[imct*nimage:(imct+1)*nimage]
+ imct += 1
+
+ if pol_solve[2] == 0:
+ im2 = init2
+ else:
+ im2 = polvec[imct*nimage:(imct+1)*nimage]
+ imct += 1
+
+ if len(pol_solve)==4:
+ if pol_solve[3] == 0:
+ im3 = init3
+ else:
+ im3 = polvec[imct*nimage:(imct+1)*nimage]
+ imct += 1
+ out = np.array((im0, im1, im2, im3))
+ else:
+ out = np.array((im0, im1, im2))
+
+ return out
+
+def make_p_image(imtuple, pol_trans=True):
+ """construct a polarimetric image P = Q + iU
+ """
+
+ if pol_trans:
+ pimage = imtuple[0] * imtuple[1] * np.exp(2j*imtuple[2])
+ else:
+ pimage = imtuple[1] + 1j*imtuple[2]
+
+ return pimage
+
+def make_m_image(imtuple, pol_trans=True):
+ """construct a polarimetric ratrio image abs(P/I) = abs(Q + iU)/I
+ """
+
+ if pol_trans:
+ mimage = imtuple[1]
+ else:
+ mimage = np.abs((imtuple[1] + 1j*imtuple[2])/imtuple[0])
+ return mimage
+
+def make_chi_image(imtuple, pol_trans=True):
+ """construct a polarimetric angle image
+ """
+
+ if pol_trans:
+ chiimage = imtuple[2]
+ else:
+ chiimage = 0.5*np.angle((imtuple[1] + 1j*imtuple[2])/imtuple[0])
+
+ return chiimage
+
+def make_q_image(imtuple, pol_trans=True):
+ """construct an image of stokes Q
+ """
+
+ if pol_trans:
+ qimage = imtuple[0] * imtuple[1] * np.cos(2*imtuple[2])
+ else:
+ qimage = imtuple[1]
+
+ return qimage
+
+def make_u_image(imtuple, pol_trans=True):
+ """construct an image of stokes U
+ """
+
+ if pol_trans:
+ uimage = imtuple[0] * imtuple[1] * np.sin(2*imtuple[2])
+ else:
+ uimage = imtuple[2]
+
+ return uimage
+
+def make_v_image(imtuple, pol_trans=True):
+ """construct an image of stokes V
+ """
+
+ if len(imtuple)==4:
+ if pol_trans:
+ vimage = imtuple[0] * imtuple[3]
+ else:
+ vimage = imtuple[3]
+ else:
+ vimage = np.zeros(imtuple[0].shape)
+
+ return vimage
+
+def make_vfrac_image(imtuple, pol_trans=True):
+ """construct an image of stokes V
+ """
+ if len(imtuple)==4:
+ if pol_trans:
+ vfimage = imtuple[3]
+ else:
+ vfimage = imtuple[3]/imtuple[0]
+ else:
+ vfimage = np.zeros(imtuple[0].shape)
+
+ return vfimage
+
+
+# these change of variables only apply to polarimetric ratios
+# !AC In these pol. changes of variables, might be useful to
+# take m -> m/100 by adjusting B (function becomes less steep around m' = 0)
+
+def mcv(imtuple):
+ """Change of pol. ratio from range (-inf, inf) to (0,1)
+ """
+
+ iimage = imtuple[0]
+ mimage = imtuple[1]
+ chiimage = imtuple[2]
+
+ mtrans = 0.5 + np.arctan(mimage/TANWIDTH_M)/np.pi
+ if len(imtuple)==4:
+ vfimage = imtuple[3]
+ vtrans = 2*np.arctan(vfimage/TANWIDTH_V)/np.pi
+ out = np.array((iimage, mtrans, chiimage, vtrans))
+ else:
+ out = np.array((iimage, mtrans, chiimage))
+
+
+ return out
+
+def mcv_r(imtuple):
+ """Change of pol. ratio from range (0,1) to (-inf,inf)
+ """
+ iimage = imtuple[0]
+ mimage = imtuple[1]
+ chiimage = imtuple[2]
+
+ mtrans = TANWIDTH_M*np.tan(np.pi*(mimage - 0.5))
+ if len(imtuple)==4:
+ vfimage = imtuple[3]
+ vtrans = TANWIDTH_V*np.tan(0.5*np.pi*(vfimage))
+ out = np.array((iimage, mtrans, chiimage, vtrans))
+ else:
+ out = np.array((iimage, mtrans, chiimage))
+
+ return out
+
+def mchain(imtuple):
+ """The gradient change of variables, dm/dm'
+ """
+ iimage = imtuple[0]
+ mimage = imtuple[1]
+ chiimage = imtuple[2]
+
+ ichain = np.ones(len(iimage))
+ mmchain = 1 / (TANWIDTH_M*np.pi*(1 + (mimage/TANWIDTH_M)**2))
+ chichain = np.ones(len(chiimage))
+
+ if len(imtuple)==4:
+ vfimage = imtuple[3]
+ vchain = 2. / (TANWIDTH_V*np.pi*(1 + (vfimage/TANWIDTH_V)**2))
+ out = np.array((ichain, mmchain, chichain, vchain))
+ else:
+ out = np.array((ichain, mmchain, chichain))
+
+ return np.array(out)
+
+
+##################################################################################################
+# Wrapper Functions
+##################################################################################################
+
+def polchisq(imtuple, A, data, sigma, dtype, ttype='direct', mask=[], pol_trans=True):
+ """return the chi^2 for the appropriate dtype
+ """
+
+ chisq = 1
+ if not dtype in DATATERMS_POL:
+ return chisq
+ if ttype not in ['fast','direct','nfft']:
+ raise Exception("Possible ttype values are 'fast' and 'direct'!")
+
+ if ttype == 'direct':
+ # linear
+ if dtype == 'pvis':
+ chisq = chisq_p(imtuple, A, data, sigma, pol_trans)
+ elif dtype == 'm':
+ chisq = chisq_m(imtuple, A, data, sigma, pol_trans)
+ elif dtype == 'pbs':
+ chisq = chisq_pbs(imtuple, A, data, sigma, pol_trans)
+
+ # circular
+ elif dtype == 'vvis':
+ chisq = chisq_vvis(imtuple, A, data, sigma, pol_trans)
+
+ elif ttype== 'fast':
+ raise Exception("FFT not yet implemented in polchisq!")
+
+ elif ttype== 'nfft':
+ if len(mask)>0 and np.any(np.invert(mask)):
+ imtuple = embed_pol(imtuple, mask, randomfloor=True)
+
+ # linear
+ if dtype == 'pvis':
+ chisq = chisq_p_nfft(imtuple, A, data, sigma, pol_trans)
+ elif dtype == 'm':
+ chisq = chisq_m_nfft(imtuple, A, data, sigma, pol_trans)
+ elif dtype == 'pbs':
+ chisq = chisq_pbs_nfft(imtuple, A, data, sigma, pol_trans)
+
+ # circular
+ elif dtype == 'vvis':
+ chisq = chisq_vvis_nfft(imtuple, A, data, sigma, pol_trans)
+
+ return chisq
+
+def polchisqgrad(imtuple, A, data, sigma, dtype, ttype='direct',
+ mask=[], pol_trans=True,pol_solve=(0,1,1)):
+
+ """return the chi^2 gradient for the appropriate dtype
+ """
+
+ chisqgrad = np.zeros((3,len(imtuple[0])))
+ if not dtype in DATATERMS_POL:
+ return chisqgrad
+ if ttype not in ['fast','direct','nfft']:
+ raise Exception("Possible ttype values are 'fast' and 'direct'!")
+
+ if ttype == 'direct':
+ # linear
+ if dtype == 'pvis':
+ chisqgrad = chisqgrad_p(imtuple, A, data, sigma, pol_trans,pol_solve)
+ elif dtype == 'm':
+ chisqgrad = chisqgrad_m(imtuple, A, data, sigma, pol_trans,pol_solve)
+ elif dtype == 'pbs':
+ chisqgrad = chisqgrad_pbs(imtuple, A, data, sigma, pol_trans,pol_solve)
+
+ # circular
+ elif dtype == 'vvis':
+ chisqgrad = chisqgrad_vvis(imtuple, A, data, sigma, pol_trans,pol_solve)
+
+ elif ttype== 'fast':
+ raise Exception("FFT not yet implemented in polchisqgrad!")
+
+ elif ttype== 'nfft':
+ if len(mask)>0 and np.any(np.invert(mask)):
+ imtuple = embed_pol(imtuple, mask, randomfloor=True)
+
+ # linear
+ if dtype == 'pvis':
+ chisqgrad = chisqgrad_p_nfft(imtuple, A, data, sigma, pol_trans,pol_solve)
+ elif dtype == 'm':
+ chisqgrad = chisqgrad_m_nfft(imtuple, A, data, sigma, pol_trans,pol_solve)
+ elif dtype == 'pbs':
+ chisqgrad = chisqgrad_pbs_nfft(imtuple, A, data, sigma, pol_trans,pol_solve)
+
+ # circular
+ elif dtype == 'vvis':
+ chisqgrad = chisqgrad_vvis_nfft(imtuple, A, data, sigma, pol_trans,pol_solve)
+
+ if len(mask)>0 and np.any(np.invert(mask)):
+ if len(chisqgrad)==4:
+ chisqgrad = np.array((chisqgrad[0][mask],chisqgrad[1][mask],chisqgrad[2][mask],chisqgrad[3][mask]))
+ else:
+ chisqgrad = np.array((chisqgrad[0][mask],chisqgrad[1][mask],chisqgrad[2][mask]))
+
+ return chisqgrad
+
+
+def polregularizer(imtuple, mask, flux, pflux, vflux, xdim, ydim, psize, stype, **kwargs):
+
+ norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER)
+ pol_trans = kwargs.get('pol_trans', True)
+ beam_size = kwargs.get('beam_size',1)
+
+ # linear
+ if stype == "msimple":
+ reg = -sm(imtuple, flux, pol_trans, norm_reg=norm_reg)
+ elif stype == "hw":
+ reg = -shw(imtuple, flux, pol_trans, norm_reg=norm_reg)
+ elif stype == "ptv":
+ if np.any(np.invert(mask)):
+ imtuple = embed_pol(imtuple, mask, randomfloor=True)
+ reg = -stv_pol(imtuple, flux, xdim, ydim, psize, pol_trans,
+ norm_reg=norm_reg, beam_size=beam_size)
+ # circular
+ elif stype == 'vflux':
+ reg = -svflux(imtuple, vflux, norm_reg=norm_reg)
+ elif stype == "l1v":
+ reg = -sl1v(imtuple, vflux, pol_trans, norm_reg=norm_reg)
+ elif stype == "l2v":
+ reg = -sl2v(imtuple, vflux, pol_trans, norm_reg=norm_reg)
+ elif stype == "vtv":
+ if np.any(np.invert(mask)):
+ imtuple = embed_pol(imtuple, mask, randomfloor=True)
+ reg = -stv_v(imtuple, vflux, xdim, ydim, psize, pol_trans,
+ norm_reg=norm_reg, beam_size=beam_size)
+ elif stype == "vtv2":
+ if np.any(np.invert(mask)):
+ imtuple = embed_pol(imtuple, mask, randomfloor=True)
+ reg = -stv2_v(imtuple, vflux, xdim, ydim, psize, pol_trans,
+ norm_reg=norm_reg, beam_size=beam_size)
+ else:
+ reg = 0
+
+ return reg
+
+def polregularizergrad(imtuple, mask, flux, pflux, vflux, xdim, ydim, psize, stype, **kwargs):
+
+ norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER)
+ pol_trans = kwargs.get('pol_trans', True)
+ pol_solve = kwargs.get('pol_solve', (0,1,1))
+ beam_size = kwargs.get('beam_size',1)
+
+ # linear
+ if stype == "msimple":
+ reggrad = -smgrad(imtuple, flux, pol_trans, pol_solve, norm_reg=norm_reg)
+ elif stype == "hw":
+ reggrad = -shwgrad(imtuple, flux, pol_trans, pol_solve, norm_reg=norm_reg)
+ elif stype == "ptv":
+ if np.any(np.invert(mask)):
+ imtuple = embed_pol(imtuple, mask, randomfloor=True)
+ reggrad = -stv_pol_grad(imtuple, flux, xdim, ydim, psize, pol_trans, pol_solve,
+ norm_reg=norm_reg, beam_size=beam_size)
+ if np.any(np.invert(mask)):
+ if len(reggrad)==4:
+ reggrad = np.array((reggrad[0][mask],reggrad[1][mask],reggrad[2][mask],reggrad[3][mask]))
+ else:
+ reggrad = np.array((reggrad[0][mask],reggrad[1][mask],reggrad[2][mask]))
+
+ # circular
+ elif stype == 'vflux':
+ reggrad = -svfluxgrad(imtuple, vflux, pol_trans, pol_solve=pol_solve, norm_reg=norm_reg)
+ elif stype == "l1v":
+ reggrad = -sl1vgrad(imtuple, vflux, pol_trans, pol_solve=pol_solve, norm_reg=norm_reg)
+ elif stype == "l2v":
+ reggrad = -sl2vgrad(imtuple, vflux, pol_trans, pol_solve=pol_solve, norm_reg=norm_reg)
+ elif stype == "vtv":
+ if np.any(np.invert(mask)):
+ imtuple = embed_pol(imtuple, mask, randomfloor=True)
+ reggrad = -stv_v_grad(imtuple, vflux, xdim, ydim, psize, pol_trans,
+ pol_solve=pol_solve, norm_reg=norm_reg, beam_size=beam_size)
+ if np.any(np.invert(mask)):
+ reggrad = np.array((reggrad[0][mask],reggrad[1][mask],reggrad[2][mask],reggrad[3][mask]))
+ elif stype == "vtv2":
+ if np.any(np.invert(mask)):
+ imtuple = embed_pol(imtuple, mask, randomfloor=True)
+ reggrad = -stv2_v_grad(imtuple, vflux, xdim, ydim, psize, pol_trans,
+ pol_solve=pol_solve, norm_reg=norm_reg, beam_size=beam_size)
+ if np.any(np.invert(mask)):
+ reggrad = np.array((reggrad[0][mask],reggrad[1][mask],reggrad[2][mask],reggrad[3][mask]))
+
+ else:
+ reggrad = np.zeros((len(imtuple),len(imtuple[0])))
+
+ return reggrad
+
+
+def polchisqdata(Obsdata, Prior, mask, dtype, **kwargs):
+
+ """Return the data, sigma, and matrices for the appropriate dtype
+ """
+
+ ttype=kwargs.get('ttype','direct')
+
+ (data, sigma, A) = (False, False, False)
+ if ttype not in ['fast','direct','nfft']:
+ raise Exception("Possible ttype values are 'fast' and 'direct' and 'nfft'!")
+ if ttype=='direct':
+ if dtype == 'pvis':
+ (data, sigma, A) = chisqdata_pvis(Obsdata, Prior, mask)
+ elif dtype == 'm':
+ (data, sigma, A) = chisqdata_m(Obsdata, Prior, mask)
+ elif dtype == 'pbs':
+ (data, sigma, A) = chisqdata_pbs(Obsdata, Prior, mask)
+ elif dtype == 'vvis':
+ (data, sigma, A) = chisqdata_vvis(Obsdata, Prior, mask)
+ elif ttype=='fast':
+ raise Exception("FFT not yet implemented in polchisqdata!")
+
+ elif ttype=='nfft':
+ if dtype == 'pvis':
+ (data, sigma, A) = chisqdata_pvis_nfft(Obsdata, Prior, mask, **kwargs)
+ elif dtype == 'm':
+ (data, sigma, A) = chisqdata_m_nfft(Obsdata, Prior, mask, **kwargs)
+ elif dtype == 'pbs':
+ (data, sigma, A) = chisqdata_pbs_nfft(Obsdata, Prior, mask, **kwargs)
+ elif dtype == 'vvis':
+ (data, sigma, A) = chisqdata_vvis_nfft(Obsdata, Prior, mask, **kwargs)
+
+ return (data, sigma, A)
+
+
+##################################################################################################
+# DFT Chi-squared and Gradient Functions
+##################################################################################################
+
+def chisq_p(imtuple, Amatrix, p, sigmap, pol_trans=True):
+ """Polarimetric visibility chi-squared
+ """
+
+ pimage = make_p_image(imtuple, pol_trans)
+ psamples = np.dot(Amatrix, pimage)
+ chisq = np.sum(np.abs((p - psamples))**2/(sigmap**2)) / (2*len(p))
+ return chisq
+
+def chisqgrad_p(imtuple, Amatrix, p, sigmap, pol_trans=True,pol_solve=(0,1,1)):
+ """Polarimetric visibility chi-squared gradient
+ """
+
+
+ iimage = imtuple[0]
+ pimage = make_p_image(imtuple, pol_trans)
+ psamples = np.dot(Amatrix, pimage)
+ pdiff = (p - psamples) / (sigmap**2)
+ zeros = np.zeros(len(iimage))
+
+ if pol_trans:
+
+ mimage = imtuple[1]
+ chiimage = imtuple[2]
+
+ if pol_solve[0]!=0:
+ gradi = -np.real(mimage * np.exp(-2j*chiimage) * np.dot(Amatrix.conj().T, pdiff)) / len(p)
+ else:
+ gradi = zeros
+
+ if pol_solve[1]!=0:
+ gradm = -np.real(iimage * np.exp(-2j*chiimage) * np.dot(Amatrix.conj().T, pdiff)) / len(p)
+ else:
+ gradm = zeros
+
+ if pol_solve[2]!=0:
+ gradchi = -2 * np.imag(pimage.conj() * np.dot(Amatrix.conj().T, pdiff)) / len(p)
+ else:
+ gradchi = zeros
+
+ # output tuple can be length 3 or 4 depending on V
+ if len(imtuple)==4:
+ gradout = np.array((gradi, gradm, gradchi, zeros))
+ else:
+ gradout = np.array((gradi, gradm, gradchi))
+
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return gradout
+
+def chisq_m(imtuple, Amatrix, m, sigmam, pol_trans=True):
+ """Polarimetric ratio chi-squared
+ """
+
+ iimage = imtuple[0]
+ pimage = make_p_image(imtuple, pol_trans)
+ msamples = np.dot(Amatrix, pimage) / np.dot(Amatrix, iimage)
+ return np.sum(np.abs((m - msamples))**2/(sigmam**2)) / (2*len(m))
+
+def chisqgrad_m(imtuple, Amatrix, m, sigmam, pol_trans=True,pol_solve=(0,1,1)):
+ """The gradient of the polarimetric ratio chisq
+ """
+ iimage = imtuple[0]
+ pimage = make_p_image(imtuple, pol_trans)
+
+ isamples = np.dot(Amatrix, iimage)
+ psamples = np.dot(Amatrix, pimage)
+ zeros = np.zeros(len(iimage))
+
+ if pol_trans:
+
+ mimage = imtuple[1]
+ chiimage = imtuple[2]
+
+ msamples = psamples/isamples
+ mdiff = (m - msamples) / (isamples.conj() * sigmam**2)
+
+ if pol_solve[0]!=0:
+ gradi = (-np.real(mimage * np.exp(-2j*chiimage) * np.dot(Amatrix.conj().T, mdiff)) / len(m) +
+ np.real(np.dot(Amatrix.conj().T, msamples.conj() * mdiff)) / len(m))
+ else:
+ gradi = zeros
+
+ if pol_solve[1]!=0:
+ gradm = -np.real(iimage*np.exp(-2j*chiimage) * np.dot(Amatrix.conj().T, mdiff)) / len(m)
+ else:
+ gradm = zeros
+
+ if pol_solve[2]!=0:
+ gradchi = -2 * np.imag(pimage.conj() * np.dot(Amatrix.conj().T, mdiff)) / len(m)
+ else:
+ gradchi = zeros
+
+ # output tuple can be length 3 or 4 depending on V
+ if len(imtuple)==4:
+ gradout = np.array((gradi, gradm, gradchi, zeros))
+ else:
+ gradout = np.array((gradi, gradm, gradchi))
+
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return gradout
+
+def chisq_pbs(imtuple, Amatrices, bis_p, sigma, pol_trans=True):
+ """Polarimetric bispectrum chi-squared
+ """
+
+ iimage = imtuple[0]
+ pimage = make_p_image(imtuple, pol_trans)
+ bisamples_p = np.dot(Amatrices[0], pimage) * np.dot(Amatrices[1], pimage) * np.dot(Amatrices[2], pimage)
+ chisq = np.sum(np.abs((bis_p - bisamples_p)/sigma)**2) / (2.*len(bis_p))
+ return chisq
+
+def chisqgrad_pbs(imtuple, Amatrices, bis_p, sigma, pol_trans=True,pol_solve=(0,1,1)):
+ """Polarimetric bispectrum chi-squared gradient
+ """
+ pimage = make_p_image(imtuple, pol_trans)
+ bisamples_p = np.dot(Amatrices[0], pimage) * np.dot(Amatrices[1], pimage) * np.dot(Amatrices[2], pimage)
+
+ wdiff = ((bis_p - bisamples_p).conj()) / (sigma**2)
+ pt1 = wdiff * np.dot(Amatrices[1],pimage) * np.dot(Amatrices[2],pimage)
+ pt2 = wdiff * np.dot(Amatrices[0],pimage) * np.dot(Amatrices[2],pimage)
+ pt3 = wdiff * np.dot(Amatrices[0],pimage) * np.dot(Amatrices[1],pimage)
+ ptsum = np.dot(pt1, Amatrices[0]) + np.dot(pt2, Amatrices[1]) + np.dot(pt3, Amatrices[2])
+
+ if pol_trans:
+ iimage = imtuple[0]
+ mimage = imtuple[1]
+ chiimage = imtuple[2]
+
+ if pol_solve[0]!=0:
+ gradi = -np.real(ptsum * mimage * np.exp(2j*chiimage)) / len(bis_p)
+ else:
+ gradi = zeros
+ if pol_solve[1]!=0:
+ gradm = -np.real(ptsum * iimage * np.exp(2j*chiimage)) / len(bis_p)
+ else:
+ gradm = zeros
+ if pol_solve[2]!=0:
+ gradchi = 2 * np.imag(ptsum * pimage) / len(bis_p)
+ else:
+ gradchi = zeros
+
+ # output tuple can be length 3 or 4 depending on V
+ if len(imtuple)==4:
+ gradout = np.array((gradi, gradm, gradchi, zeros))
+ else:
+ gradout = np.array((gradi, gradm, gradchi))
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return gradout
+
+# stokes v
+def chisq_vvis(imtuple, Amatrix, v, sigmav, pol_trans=True):
+ """V visibility chi-squared
+ """
+
+ vimage = make_v_image(imtuple, pol_trans)
+ vsamples = np.dot(Amatrix, vimage)
+ chisq = np.sum(np.abs((v - vsamples))**2/(sigmav**2)) / (2*len(v))
+ return chisq
+
+def chisqgrad_vvis(imtuple, Amatrix, v, sigmap, pol_trans=True,pol_solve=(0,1,1)):
+ """V visibility chi-squared gradient
+ """
+
+
+ iimage = imtuple[0]
+ vimage = make_v_image(imtuple, pol_trans)
+ vsamples = np.dot(Amatrix, vimage)
+ vdiff = (v - vsamples) / (sigmav**2)
+ zeros = np.zeros(len(iimage))
+
+ if pol_trans:
+ vfimage = imtuple[3] # fractional
+ if pol_solve[0]!=0:
+ gradi = -np.real(vfimage * np.dot(Amatrix.conj().T, vdiff)) / len(v)
+ else:
+ gradi = zeros
+
+ if pol_solve[3]!=0:
+ gradv = -np.real(iimage * np.dot(Amatrix.conj().T, vdiff)) / len(v)
+ else:
+ gradm = zeros
+
+
+ gradout = np.array((gradi, zeros, zeros, gradv))
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return gradout
+
+##################################################################################################
+# NFFT Chi-squared and Gradient Functions
+##################################################################################################
+def chisq_p_nfft(imtuple, A, p, sigmap, pol_trans=True):
+ """P visibility chi-squared
+ """
+
+ #get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ #compute uniform --> nonuniform transform
+ pimage = make_p_image(imtuple, pol_trans)
+ plan.f_hat = pimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T
+ plan.trafo()
+ psamples = plan.f.copy()*pulsefac
+
+ #compute chi^2
+ chisq = np.sum(np.abs((p - psamples))**2/(sigmap**2)) / (2*len(p))
+
+ return chisq
+
+def chisqgrad_p_nfft(imtuple, A, p, sigmap, pol_trans=True,pol_solve=(0,1,1)):
+ """P visibility chi-squared gradient
+ """
+
+ #get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ #compute uniform --> nonuniform transform
+ iimage = imtuple[0]
+ pimage = make_p_image(imtuple, pol_trans)
+ plan.f_hat = pimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T
+ plan.trafo()
+ psamples = plan.f.copy()*pulsefac
+
+
+ pdiff_vec = (-1.0/len(p)) * (p - psamples) / (sigmap**2) * pulsefac.conj()
+ plan.f = pdiff_vec
+ plan.adjoint()
+ ppart = (plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim)
+
+ zeros = np.zeros(len(iimage))
+
+ if pol_trans:
+
+ mimage = imtuple[1]
+ chiimage = imtuple[2]
+
+ if pol_solve[0]!=0:
+
+ gradi = np.real(mimage * np.exp(-2j*chiimage) * ppart)
+ else:
+ gradi = zeros
+
+ if pol_solve[1]!=0:
+ gradm = np.real(iimage*np.exp(-2j*chiimage) *ppart)
+ else:
+ gradm = zeros
+
+ if pol_solve[2]!=0:
+ gradchi = 2 * np.imag(pimage.conj() * ppart)
+ else:
+ gradchi = zeros
+
+ if len(imtuple)==4:
+ gradout = np.array((gradi, gradm, gradchi, zeros))
+ else:
+ gradout = np.array((gradi, gradm, gradchi))
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return gradout
+
+
+def chisq_m_nfft(imtuple, A, m, sigmam, pol_trans=True):
+ """Polarimetric ratio chi-squared
+ """
+ #get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ #compute uniform --> nonuniform transform
+ iimage = imtuple[0]
+ plan.f_hat = iimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T
+ plan.trafo()
+ isamples = plan.f.copy()*pulsefac
+
+ pimage = make_p_image(imtuple, pol_trans)
+ plan.f_hat = pimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T
+ plan.trafo()
+ psamples = plan.f.copy()*pulsefac
+
+ #compute chi^2
+ msamples = psamples/isamples
+ chisq = np.sum(np.abs((m - msamples))**2/(sigmam**2)) / (2*len(m))
+ return chisq
+
+def chisqgrad_m_nfft(imtuple, A, m, sigmam, pol_trans=True,pol_solve=(0,1,1)):
+ """Polarimetric ratio chi-squared gradient
+ """
+ iimage = imtuple[0]
+ pimage = make_p_image(imtuple, pol_trans)
+
+
+ #get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ #compute uniform --> nonuniform transform
+ iimage = imtuple[0]
+ plan.f_hat = iimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T
+ plan.trafo()
+ isamples = plan.f.copy()*pulsefac
+
+ pimage = make_p_image(imtuple, pol_trans)
+ plan.f_hat = pimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T
+ plan.trafo()
+ psamples = plan.f.copy()*pulsefac
+
+ zeros = np.zeros(len(iimage))
+
+ if pol_trans:
+
+ mimage = imtuple[1]
+ chiimage = imtuple[2]
+
+ msamples = psamples/isamples
+ mdiff_vec = (-1./len(m))*(m - msamples) / (isamples.conj() * sigmam**2) * pulsefac.conj()
+ plan.f = mdiff_vec
+ plan.adjoint()
+ mpart = (plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim)
+
+ if pol_solve[0]!=0: #TODO -- not right??
+ plan.f = mdiff_vec * msamples.conj()
+ plan.adjoint()
+ mpart2 = (plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim)
+
+ gradi = (np.real(mimage * np.exp(-2j*chiimage) * mpart) - np.real(mpart2))
+
+ else:
+ gradi = zeros
+
+ if pol_solve[1]!=0:
+ gradm = np.real(iimage*np.exp(-2j*chiimage) * mpart)
+ else:
+ gradm = zeros
+
+ if pol_solve[2]!=0:
+ gradchi = 2 * np.imag(pimage.conj() * mpart)
+ else:
+ gradchi = zeros
+
+ # output tuple can be length 3 or 4 depending on V
+ if len(imtuple)==4:
+ gradout = np.array((gradi, gradm, gradchi, zeros))
+ else:
+ gradout = np.array((gradi, gradm, gradchi))
+
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return gradout
+
+
+def chisq_pbs_nfft(imtuple, A, bis_p, sigma, pol_trans=True):
+ """Polarimetric bispectrum chi-squared
+ """
+
+ #get nfft objects
+ nfft_info1 = A[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ #compute uniform --> nonuniform transforms
+ pimage = make_p_image(imtuple, pol_trans)
+
+ plan1.f_hat = pimage.copy().reshape((nfft_info1.ydim,nfft_info1.xdim)).T
+ plan1.trafo()
+ samples1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = pimage.copy().reshape((nfft_info2.ydim,nfft_info2.xdim)).T
+ plan2.trafo()
+ samples2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = pimage.copy().reshape((nfft_info3.ydim,nfft_info3.xdim)).T
+ plan3.trafo()
+ samples3 = plan3.f.copy()*pulsefac3
+
+ #compute chi^2
+ bisamples_p = samples1*samples2*samples3
+ chisq = np.sum(np.abs((bis_p - bisamples_p)/sigma)**2) / (2.*len(bis_p))
+
+ return chisq
+
+def chisqgrad_pbs(imtuple, Amatrices, bis_p, sigma, pol_trans=True,pol_solve=(0,1,1)):
+ """Polarimetric bispectrum chi-squared gradient
+ """
+
+ #get nfft objects
+ nfft_info1 = A[0]
+ plan1 = nfft_info1.plan
+ pulsefac1 = nfft_info1.pulsefac
+
+ nfft_info2 = A[1]
+ plan2 = nfft_info2.plan
+ pulsefac2 = nfft_info2.pulsefac
+
+ nfft_info3 = A[2]
+ plan3 = nfft_info3.plan
+ pulsefac3 = nfft_info3.pulsefac
+
+ #compute uniform --> nonuniform transforms
+ pimage = make_p_image(imtuple, pol_trans)
+
+ plan1.f_hat = pimage.copy().reshape((nfft_info1.ydim,nfft_info1.xdim)).T
+ plan1.trafo()
+ v1 = plan1.f.copy()*pulsefac1
+
+ plan2.f_hat = pimage.copy().reshape((nfft_info2.ydim,nfft_info2.xdim)).T
+ plan2.trafo()
+ v2 = plan2.f.copy()*pulsefac2
+
+ plan3.f_hat = pimage.copy().reshape((nfft_info3.ydim,nfft_info3.xdim)).T
+ plan3.trafo()
+ v3 = plan3.f.copy()*pulsefac3
+
+ # gradient vec for adjoint fft
+ bisamples_p = v1*v2*v3
+ wdiff = (-1./len(bis_p)) * ((bis_p - bisamples_p).conj()) / (sigma**2)
+ pt1 = wdiff * (v2 * v3).conj() * pulsefac1.conj()
+ pt2 = wdiff * (v1 * v3).conj() * pulsefac2.conj()
+ pt3 = wdiff * (v1 * v2).conj() * pulsefac3.conj()
+
+ plan1.f = pt1
+ plan1.adjoint()
+ out1 = (plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)
+
+ plan2.f = pt2
+ plan2.adjoint()
+ out2 = (plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)
+
+ plan3.f = pt3
+ plan3.adjoint()
+ out3 = (plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)
+
+ ptsum = out1 + out2 + out3
+
+ if pol_trans:
+ iimage = imtuple[0]
+ mimage = imtuple[1]
+ chiimage = imtuple[2]
+
+ if pol_solve[0]!=0:
+ gradi = np.real(ptsum * mimage * np.exp(2j*chiimage))
+ else:
+ gradi = zeros
+ if pol_solve[1]!=0:
+ gradm = np.real(ptsum * iimage * np.exp(2j*chiimage))
+ else:
+ gradm = zeros
+ if pol_solve[2]!=0:
+ gradchi = -2 * np.imag(ptsum * pimage)
+ else:
+ gradchi = zeros
+
+ # output tuple can be length 3 or 4 depending on V
+ if len(imtuple)==4:
+ gradout = np.array((gradi, gradm, gradchi, zeros))
+ else:
+ gradout = np.array((gradi, gradm, gradchi))
+
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return gradout
+
+# stokes v
+def chisq_vvis_nfft(imtuple, A, v, sigmav, pol_trans=True):
+ """V visibility chi-squared
+ """
+
+ #get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ #compute uniform --> nonuniform transform
+ vimage = make_v_image(imtuple, pol_trans)
+ plan.f_hat = vimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T
+ plan.trafo()
+ vsamples = plan.f.copy()*pulsefac
+
+ #compute chi^2
+ chisq = np.sum(np.abs((v - vsamples))**2/(sigmav**2)) / (2*len(v))
+
+ return chisq
+
+def chisqgrad_vvis_nfft(imtuple, A, v, sigmav, pol_trans=True,pol_solve=(0,1,1)):
+ """V visibility chi-squared gradient
+ """
+
+ #get nfft object
+ nfft_info = A[0]
+ plan = nfft_info.plan
+ pulsefac = nfft_info.pulsefac
+
+ #compute uniform --> nonuniform transform
+ iimage = imtuple[0]
+ vimage = make_v_image(imtuple, pol_trans)
+ plan.f_hat = vimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T
+ plan.trafo()
+ vsamples = plan.f.copy()*pulsefac
+
+
+ vdiff_vec = (-1.0/len(v)) * (v - vsamples) / (sigmav**2) * pulsefac.conj()
+ plan.f = vdiff_vec
+ plan.adjoint()
+ vpart = (plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim)
+
+ zeros = np.zeros(len(iimage))
+
+ if pol_trans:
+
+ vfimage = imtuple[3] #fractional
+
+ if pol_solve[0]!=0:
+ gradi = np.real(vfimage*vpart)
+ else:
+ gradi = zeros
+
+ if pol_solve[3]!=0:
+ gradv = np.real(iimage*vpart)
+ else:
+ gradv = zeros
+
+
+ gradout = np.array((gradi, zeros, zeros, gradv))
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return gradout
+
+##################################################################################################
+# Polarimetric Entropy and Gradient Functions
+##################################################################################################
+
+def sm(imtuple, flux, pol_trans=True,
+ norm_reg=NORM_REGULARIZER):
+ """I log m entropy
+ """
+ if norm_reg: norm = flux
+ else: norm = 1
+
+ iimage = imtuple[0]
+ mimage = make_m_image(imtuple, pol_trans)
+ S = -np.sum(iimage * np.log(mimage))
+ return S/norm
+
+def smgrad(imtuple, flux, pol_trans=True,pol_solve=(0,1,1),
+ norm_reg=NORM_REGULARIZER):
+ """I log m entropy gradient
+ """
+
+ if norm_reg: norm = flux
+ else: norm = 1
+
+ iimage = imtuple[0]
+ zeros = np.zeros(len(iimage))
+ mimage = make_m_image(imtuple, pol_trans)
+
+ if pol_trans:
+
+
+ if pol_solve[0]!=0:
+ gradi = -np.log(mimage)
+ else:
+ gradi = zeros
+
+ if pol_solve[1]!=0:
+ gradm = -iimage / mimage
+ else:
+ gradm = zeros
+
+ gradchi = zeros
+
+ if len(imtuple)==4:
+ out = np.array((gradi, gradm, gradchi,zeros))
+ else:
+ out = np.array((gradi, gradm, gradchi))
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return np.array(out)/norm
+
+def shw(imtuple, flux, pol_trans=True, norm_reg=NORM_REGULARIZER):
+ """Holdaway-Wardle polarimetric entropy
+ """
+
+ if norm_reg: norm = flux
+ else: norm = 1
+
+ iimage = imtuple[0]
+ mimage = make_m_image(imtuple, pol_trans)
+ S = -np.sum(iimage * (((1+mimage)/2) * np.log((1+mimage)/2) + ((1-mimage)/2) * np.log((1-mimage)/2)))
+ return S/norm
+
+def shwgrad(imtuple, flux, pol_trans=True,pol_solve=(0,1,1),
+ norm_reg=NORM_REGULARIZER):
+ """Gradient of the Holdaway-Wardle polarimetric entropy
+ """
+ if norm_reg: norm = flux
+ else: norm = 1
+
+ iimage = imtuple[0]
+ zeros = np.zeros(len(iimage))
+ mimage = make_m_image(imtuple, pol_trans)
+ if pol_trans:
+
+ if pol_solve[0]!=0:
+ gradi = -(((1+mimage)/2) * np.log((1+mimage)/2) + ((1-mimage)/2) * np.log((1-mimage)/2))
+ else:
+ gradi = zeros
+
+ if pol_solve[1]!=0:
+ gradm = -iimage * np.arctanh(mimage)
+ else:
+ gradm = zeros
+
+ gradchi = zeros
+
+ if len(imtuple)==4:
+ out = np.array((gradi, gradm, gradchi,zeros))
+ else:
+ out = np.array((gradi, gradm, gradchi))
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return np.array(out)/norm
+
+def stv_pol(imtuple, flux, nx, ny, psize, pol_trans=True,
+ norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Total variation of I*m*exp(2Ichi)"""
+
+ if beam_size is None: beam_size = psize
+ if norm_reg: norm = flux*psize / beam_size
+ else: norm = 1
+
+ pimage = make_p_image(imtuple, pol_trans)
+ im = pimage.reshape(ny, nx)
+
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ S = -np.sum(np.sqrt(np.abs(im_l1 - im)**2 + np.abs(im_l2 - im)**2))
+ return S/norm
+
+def stv_pol_grad(imtuple, flux, nx, ny, psize, pol_trans=True, pol_solve=(0,1,1),
+ norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Total variation entropy gradient"""
+
+ if beam_size is None: beam_size = psize
+ if norm_reg: norm = flux*psize / beam_size
+ else: norm = 1
+
+ iimage = imtuple[0]
+ zeros = np.zeros(len(iimage))
+ pimage = make_p_image(imtuple, pol_trans)
+
+ im = pimage.reshape(ny, nx)
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1]
+ im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1]
+ im_r1l2 = np.roll(np.roll(impad, 1, axis=0), -1, axis=1)[1:ny+1, 1:nx+1]
+ im_l1r2 = np.roll(np.roll(impad, 1, axis=0), -1, axis=1)[1:ny+1, 1:nx+1]
+
+ # Denominators
+ d1 = np.sqrt(np.abs(im_l1 - im)**2 + np.abs(im_l2 - im)**2)
+ d2 = np.sqrt(np.abs(im_r1 - im)**2 + np.abs(im_r1l2 - im_r1)**2)
+ d3 = np.sqrt(np.abs(im_r2 - im)**2 + np.abs(im_l1r2 - im_r2)**2)
+
+ if pol_trans:
+
+ # dS/dI Numerators
+ if pol_solve[0]!=0:
+ m1 = 2*np.abs(im*im) - np.abs(im*im_l1)*np.cos(2*(np.angle(im_l1) - np.angle(im))) - np.abs(im*im_l2)*np.cos(2*(np.angle(im_l2) - np.angle(im)))
+ m2 = np.abs(im*im) - np.abs(im*im_r1)*np.cos(2*(np.angle(im) - np.angle(im_r1)))
+ m3 = np.abs(im*im) - np.abs(im*im_r2)*np.cos(2*(np.angle(im) - np.angle(im_r2)))
+ igrad = -(1./iimage)*(m1/d1 + m2/d2 + m3/d3).flatten()
+ else:
+ igrad = zeros
+
+ # dS/dm numerators
+ if pol_solve[1]!=0:
+ m1 = 2*np.abs(im) - np.abs(im_l1)*np.cos(2*(np.angle(im_l1) - np.angle(im))) - np.abs(im_l2)*np.cos(2*(np.angle(im_l2) - np.angle(im)))
+ m2 = np.abs(im) - np.abs(im_r1)*np.cos(2*(np.angle(im) - np.angle(im_r1)))
+ m3 = np.abs(im) - np.abs(im_r2)*np.cos(2*(np.angle(im) - np.angle(im_r2)))
+ mgrad = -iimage*(m1/d1 + m2/d2 + m3/d3).flatten()
+ else:
+ mgrad=zeros
+
+ # dS/dchi numerators
+ if pol_solve[2]!=0:
+ c1 = -2*np.abs(im*im_l1)*np.sin(2*(np.angle(im_l1) - np.angle(im))) - 2*np.abs(im*im_l2)*np.sin(2*(np.angle(im_l2) - np.angle(im)))
+ c2 = 2*np.abs(im*im_r1)*np.sin(2*(np.angle(im) - np.angle(im_r1)))
+ c3 = 2*np.abs(im*im_r2)*np.sin(2*(np.angle(im) - np.angle(im_r2)))
+ chigrad = -(c1/d1 + c2/d2 + c3/d3).flatten()
+ else:
+ chigrad = zeros
+
+ if len(imtuple)==4:
+ out = np.array((igrad, mgrad, chigrad,zeros))
+ else:
+ out = np.array((igrad, mgrad, chigrad))
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return out/norm
+
+# circular polarization
+# TODO check!!
+def svflux(imtuple, vflux, pol_trans=True, norm_reg=NORM_REGULARIZER):
+ """Total flux constraint
+ """
+ if norm_reg: norm = np.abs(vflux)**2
+ else: norm = 1
+
+ vimage = make_v_image(imtuple, pol_trans)
+
+ out = -(np.sum(vimage) - vflux)**2
+ return out/norm
+
+
+def svfluxgrad(imtuple, vflux, pol_trans=True, pol_solve=(0,0,0,1), norm_reg=NORM_REGULARIZER):
+ """Total flux constraint gradient
+ """
+ if norm_reg: norm = np.abs(vflux)**2
+ else: norm = 1
+
+
+ iimage = imtuple[0]
+ zeros = np.zeros(len(iimage))
+ vimage = make_v_image(imtuple, pol_trans)
+ grad = -2*(np.sum(vimage) - vflux)*np.ones(len(vimage))
+
+
+ if pol_trans:
+
+ # dS/dI Numerators
+ if pol_solve[0]!=0:
+ igrad = (vimage/iimage)*grad
+ else:
+ igrad = zeros
+
+ # dS/dv numerators
+ if pol_solve[3]!=0:
+ vgrad = iimage*grad
+ else:
+ vgrad=zeros
+
+
+ out = np.array((igrad, zeros, zeros, vgrad))
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return out/norm
+
+def sl1v(imtuple, vflux, pol_trans=True, norm_reg=NORM_REGULARIZER):
+ """L1 norm regularizer on V
+ """
+ if norm_reg: norm = np.abs(vflux)
+ else: norm = 1
+
+ vimage = make_v_image(imtuple, pol_trans)
+ l1 = -np.sum(np.abs(vimage))
+ return l1/norm
+
+
+def sl1vgrad(imtuple, vflux, pol_trans=True, pol_solve=(0,0,0,1), norm_reg=NORM_REGULARIZER):
+ """L1 norm gradient
+ """
+ if norm_reg: norm = np.abs(vflux)
+ else: norm = 1
+
+ iimage = imtuple[0]
+ zeros = np.zeros(len(iimage))
+ vimage = make_v_image(imtuple, pol_trans)
+ grad = -np.sign(vimage)
+
+ if pol_trans:
+
+ # dS/dI Numerators
+ if pol_solve[0]!=0:
+ igrad = (vimage/iimage)*grad
+ else:
+ igrad = zeros
+
+ # dS/dv numerators
+ if pol_solve[3]!=0:
+ vgrad = iimage*grad
+ else:
+ vgrad=zeros
+
+
+ out = np.array((igrad, zeros, zeros, vgrad))
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return out/norm
+
+def sl2v(imtuple, vflux, pol_trans=True, norm_reg=NORM_REGULARIZER):
+ """L1 norm regularizer on V
+ """
+ if norm_reg: norm = np.abs(vflux**2)
+ else: norm = 1
+
+ iimage = imtuple[0]
+ vimage = make_v_image(imtuple, pol_trans)
+ l2 = -np.sum((vimage)**2)
+ return l2/norm
+
+
+def sl2vgrad(imtuple, vflux, pol_trans=True, pol_solve=(0,0,0,1), norm_reg=NORM_REGULARIZER):
+ """L2 norm gradient
+ """
+ if norm_reg: norm = np.abs(vflux**2)
+ else: norm = 1
+
+ iimage = imtuple[0]
+ zeros = np.zeros(len(iimage))
+ vimage = make_v_image(imtuple, pol_trans)
+ grad = -2*vimage
+
+ if pol_trans:
+
+ # dS/dI Numerators
+ if pol_solve[0]!=0:
+ igrad = (vimage/iimage)*grad
+ else:
+ igrad = zeros
+
+ # dS/dv numerators
+ if pol_solve[3]!=0:
+ vgrad = iimage*grad
+ else:
+ vgrad=zeros
+
+
+ out = np.array((igrad, zeros, zeros, vgrad))
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return out/norm
+
+def stv_v(imtuple, vflux, nx, ny, psize, pol_trans=True,
+ norm_reg=NORM_REGULARIZER, beam_size=None, epsilon=0.):
+ """Total variation of I*vfrac"""
+
+ if beam_size is None: beam_size = psize
+ if norm_reg: norm = np.abs(vflux)*psize / beam_size
+ else: norm = 1
+
+ vimage = make_v_image(imtuple, pol_trans)
+ im = vimage.reshape(ny, nx)
+
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ S = -np.sum(np.sqrt(np.abs(im_l1 - im)**2 + np.abs(im_l2 - im)**2+epsilon))
+ return S/norm
+
+def stv_v_grad(imtuple, vflux, nx, ny, psize, pol_trans=True, pol_solve=(0,0,0,1),
+ norm_reg=NORM_REGULARIZER, beam_size=None, epsilon=0.):
+ """Total variation gradient"""
+
+ if beam_size is None: beam_size = psize
+ if norm_reg: norm = np.abs(vflux)*psize / beam_size
+ else: norm = 1
+
+ iimage = imtuple[0]
+ zeros = np.zeros(len(iimage))
+ vimage = make_v_image(imtuple, pol_trans)
+
+ im = vimage.reshape(ny, nx)
+
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1]
+ im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1]
+
+ # rotate images
+ im_r1l2 = np.roll(np.roll(impad, 1, axis=0), -1, axis=1)[1:ny+1, 1:nx+1]
+ im_l1r2 = np.roll(np.roll(impad, -1, axis=0), 1, axis=1)[1:ny+1, 1:nx+1]
+
+ # add together terms and return
+ g1 = (2*im - im_l1 - im_l2) / np.sqrt((im - im_l1)**2 + (im - im_l2)**2 + epsilon)
+ g2 = (im - im_r1) / np.sqrt((im - im_r1)**2 + (im_r1l2 - im_r1)**2 + epsilon)
+ g3 = (im - im_r2) / np.sqrt((im - im_r2)**2 + (im_l1r2 - im_r2)**2 + epsilon)
+
+ # mask the first row column gradient terms that don't exist
+ mask1 = np.zeros(im.shape)
+ mask2 = np.zeros(im.shape)
+ mask1[0, :] = 1
+ mask2[:, 0] = 1
+ g2[mask1.astype(bool)] = 0
+ g3[mask2.astype(bool)] = 0
+
+ # add terms together and return
+ grad = -(g1 + g2 + g3).flatten()
+
+ if pol_trans:
+
+ # dS/dI Numerators
+ if pol_solve[0]!=0:
+ igrad = (vimage/iimage)*grad
+ else:
+ igrad = zeros
+
+ # dS/dv numerators
+ if pol_solve[3]!=0:
+ vgrad = iimage*grad
+ else:
+ vgrad=zeros
+
+
+ out = np.array((igrad, zeros, zeros, vgrad))
+
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return out/norm
+
+def stv2_v(imtuple, vflux, nx, ny, psize, pol_trans=True,
+ norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Squared Total variation of I*vfrac
+ """
+
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = psize**4 * np.abs(vflux**2) / beam_size**4
+ else:
+ norm = 1
+
+ vimage = make_v_image(imtuple, pol_trans)
+ im = vimage.reshape(ny, nx)
+
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ out = -np.sum((im_l1 - im)**2 + (im_l2 - im)**2)
+ return out/norm
+
+def stv2_v_grad(imtuple, vflux, nx, ny, psize, pol_trans=True, pol_solve=(0,0,0,1),
+ norm_reg=NORM_REGULARIZER, beam_size=None):
+ """Squared Total variation gradient
+ """
+ if beam_size is None:
+ beam_size = psize
+ if norm_reg:
+ norm = psize**4 * np.abs(vflux**2) / beam_size**4
+ else:
+ norm = 1
+
+ iimage = imtuple[0]
+ zeros = np.zeros(len(iimage))
+ vimage = make_v_image(imtuple, pol_trans)
+ im = vimage.reshape(ny, nx)
+
+ impad = np.pad(im, 1, mode='constant', constant_values=0)
+ im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1]
+ im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1]
+ im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1]
+ im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1]
+
+ g1 = (2*im - im_l1 - im_l2)
+ g2 = (im - im_r1)
+ g3 = (im - im_r2)
+
+ # mask the first row column gradient terms that don't exist
+ mask1 = np.zeros(im.shape)
+ mask2 = np.zeros(im.shape)
+ mask1[0, :] = 1
+ mask2[:, 0] = 1
+ g2[mask1.astype(bool)] = 0
+ g3[mask2.astype(bool)] = 0
+
+ # add together terms and return
+ grad = -2*(g1 + g2 + g3).flatten()
+
+ if pol_trans:
+
+ # dS/dI Numerators
+ if pol_solve[0]!=0:
+ igrad = (vimage/iimage)*grad
+ else:
+ igrad = zeros
+
+ # dS/dv numerators
+ if pol_solve[3]!=0:
+ vgrad = iimage*grad
+ else:
+ vgrad=zeros
+
+
+ out = np.array((igrad, zeros, zeros, vgrad))
+
+
+ else:
+ raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans)
+
+ return out/norm
+
+##################################################################################################
+# Embedding and Chi^2 Data functions
+##################################################################################################
+def embed_pol(imtuple, mask, clipfloor=0., randomfloor=False):
+ """Embeds a polarimetric image tuple into the size of boolean embed mask
+ """
+
+ out0=np.zeros(len(mask))
+ out1=np.zeros(len(mask))
+ out2=np.zeros(len(mask))
+ if len(imtuple)==4: out3=np.zeros(len(mask))
+
+ # Here's a much faster version than before
+ out0[mask.nonzero()] = imtuple[0]
+ out1[mask.nonzero()] = imtuple[1]
+ out2[mask.nonzero()] = imtuple[2]
+ if len(imtuple)==4:
+ out3[mask.nonzero()] = imtuple[3]
+
+ if clipfloor != 0.0:
+ out0[(mask-1).nonzero()] = clipfloor
+ out1[(mask-1).nonzero()] = 0
+ out2[(mask-1).nonzero()] = 0
+ if len(imtuple)==4: out3[(mask-1).nonzero()] = 0
+ if randomfloor: # prevent total variation gradient singularities
+ out0[(mask-1).nonzero()] *= np.abs(np.random.normal(size=len((mask-1).nonzero())))
+
+ if len(imtuple)==4:
+ out = (out0, out1, out2, out3)
+ else:
+ out = (out0, out1, out2)
+
+ return out
+
+def chisqdata_pvis(Obsdata, Prior, mask):
+ """Return the visibilities, sigmas, and fourier matrix for an observation, prior, mask
+ """
+
+ data_arr = Obsdata.unpack(['u','v','pvis','psigma'], conj=True)
+ uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1)))
+ vis = data_arr['pvis']
+ sigma = data_arr['psigma']
+ A = ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv, pulse=Prior.pulse, mask=mask)
+
+ return (vis, sigma, A)
+
+def chisqdata_pvis_nfft(Obsdata, Prior, mask, **kwargs):
+ """Return the visibilities, sigmas, and fourier matrix for an observation, prior, mask
+ """
+
+ # unpack keyword args
+ fft_pad_factor = kwargs.get('fft_pad_factor',FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ data_arr = Obsdata.unpack(['u','v','pvis','psigma'], conj=True)
+ uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1)))
+ vis = data_arr['pvis']
+ sigma = data_arr['psigma']
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv)
+ A = [A1]
+
+ return (vis, sigma, A)
+
+
+def chisqdata_m(Obsdata, Prior, mask):
+ """Return the pol ratios, sigmas, and fourier matrix for and observation, prior, mask
+ """
+
+ mdata = Obsdata.unpack(['u','v','m','msigma'], conj=True)
+ uv = np.hstack((mdata['u'].reshape(-1,1), mdata['v'].reshape(-1,1)))
+ m = mdata['m']
+ sigmam = mdata['msigma']
+ A = ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv, pulse=Prior.pulse, mask=mask)
+
+ return (m, sigmam, A)
+
+def chisqdata_m_nfft(Obsdata, Prior, mask, **kwargs):
+ """Return the pol ratios, sigmas, and fourier matrix for an observation, prior, mask
+ """
+
+ # unpack keyword args
+ fft_pad_factor = kwargs.get('fft_pad_factor',FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ mdata = Obsdata.unpack(['u','v','m','msigma'], conj=True)
+ uv = np.hstack((mdata['u'].reshape(-1,1), mdata['v'].reshape(-1,1)))
+ m = mdata['m']
+ sigmam = mdata['msigma']
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv)
+ A = [A1]
+
+ return (m, sigmam, A)
+
+
+def chisqdata_pbs(Obsdata, Prior, mask):
+ """return the bispectra, sigmas, and fourier matrices for and observation, prior, mask
+ """
+
+ biarr = Obsdata.bispectra(mode="all", vtype='rlvis', count="min") #TODO CONJ??
+ uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1)))
+ uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1)))
+ uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1)))
+ bi = biarr['bispec']
+ sigma = biarr['sigmab']
+
+ A3 = (ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask),
+ ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask),
+ ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask)
+ )
+
+ return (bi, sigma, A3)
+
+def chisqdata_pbs_nfft(Obsdata, Prior, mask):
+ """return the bispectra, sigmas, and fourier matrices for and observation, prior, mask
+ """
+
+ # unpack keyword args
+ fft_pad_factor = kwargs.get('fft_pad_factor',FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ biarr = Obsdata.bispectra(mode="all", vtype='rlvis', count="min")
+ uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1)))
+ uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1)))
+ uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1)))
+ bi = biarr['bispec']
+ sigma = biarr['sigmab']
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1)
+ A2 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2)
+ A3 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3)
+ A = [A1,A2,A3]
+
+ return (bi, sigma, A)
+
+def chisqdata_vvis(Obsdata, Prior, mask):
+ """Return the visibilities, sigmas, and fourier matrix for an observation, prior, mask
+ """
+
+ data_arr = Obsdata.unpack(['u','v','vvis','vsigma'], conj=False)
+ uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1)))
+ vis = data_arr['vvis']
+ sigma = data_arr['vsigma']
+ A = ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv, pulse=Prior.pulse, mask=mask)
+
+ return (vis, sigma, A)
+
+def chisqdata_vvis_nfft(Obsdata, Prior, mask, **kwargs):
+ """Return the visibilities, sigmas, and fourier matrix for an observation, prior, mask
+ """
+
+ # unpack keyword args
+ fft_pad_factor = kwargs.get('fft_pad_factor',FFT_PAD_DEFAULT)
+ p_rad = kwargs.get('p_rad', GRIDDER_P_RAD_DEFAULT)
+
+ # unpack data
+ data_arr = Obsdata.unpack(['u','v','vvis','vsigma'], conj=False)
+ uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1)))
+ vis = data_arr['vvis']
+ sigma = data_arr['vsigma']
+
+ # get NFFT info
+ npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim)))
+ A1 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv)
+ A = [A1]
+
+ return (vis, sigma, A)
+
+##################################################################################################
+# Plotting
+##################################################################################################
+
+#TODO this only works for pol_trans == "amp_cphase"
+def plot_m(imtuple, Prior, nit, chi2_dict, **kwargs):
+
+ cmap = kwargs.get('cmap','afmhot')
+ interpolation = kwargs.get('interpolation', 'gaussian')
+ pcut = kwargs.get('pcut', 0.05)
+ nvec = kwargs.get('nvec', 15)
+ scale = kwargs.get('scale',None)
+ dynamic_range = kwargs.get('dynamic_range',1.e5)
+ gamma = kwargs.get('dynamic_range',.5)
+
+ plt.ion()
+ plt.pause(1.e-6)
+ plt.clf()
+
+ # unpack
+ im = imtuple[0]
+ mim = imtuple[1]
+ chiim = imtuple[2]
+ imarr = im.reshape(Prior.ydim,Prior.xdim)
+
+ if scale=='log':
+ if (imarr < 0.0).any():
+ print('clipping values less than 0')
+ imarr[imarr<0.0] = 0.0
+ imarr = np.log(imarr + np.max(imarr)/dynamic_range)
+ #unit = 'log(' + cbar_unit[0] + ' per ' + cbar_unit[1] + ')'
+
+ if scale=='gamma':
+ if (imarr < 0.0).any():
+ print('clipping values less than 0')
+ imarr[imarr<0.0] = 0.0
+ imarr = (imarr + np.max(imarr)/dynamic_range)**(gamma)
+ #unit = '(' + cbar_unit[0] + ' per ' + cbar_unit[1] + ')^gamma'
+
+ # Mask for low flux points
+ thin = int(round(Prior.xdim/nvec))
+ mask = imarr > pcut * np.max(im)
+ mask2 = mask[::thin, ::thin]
+
+ # Get vectors and ratio from current image
+ x = np.array([[i for i in range(Prior.xdim)] for j in range(Prior.ydim)])[::thin, ::thin][mask2]
+ y = np.array([[j for i in range(Prior.xdim)] for j in range(Prior.ydim)])[::thin, ::thin][mask2]
+ q = qimage(im, mim, chiim)
+ u = uimage(im, mim, chiim)
+ a = -np.sin(np.angle(q+1j*u)/2).reshape(Prior.ydim, Prior.xdim)[::thin, ::thin][mask2]
+ b = np.cos(np.angle(q+1j*u)/2).reshape(Prior.ydim, Prior.xdim)[::thin, ::thin][mask2]
+ m = (np.abs(q + 1j*u)/im).reshape(Prior.ydim, Prior.xdim)
+ m[~mask] = 0
+
+ # Stokes I plot
+ plt.subplot(121)
+ plt.imshow(imarr, cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.01*Prior.xdim, units='x', pivot='mid', color='k', angles='uv', scale=1.0/thin)
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.005*Prior.xdim, units='x', pivot='mid', color='w', angles='uv', scale=1.1/thin)
+
+ xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6)
+ yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ plt.title('Stokes I')
+
+ # Ratio plot
+ plt.subplot(122)
+ plt.imshow(m, cmap=plt.get_cmap('winter'), interpolation='gaussian', vmin=0, vmax=1)
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.01*Prior.xdim, units='x', pivot='mid', color='k', angles='uv', scale=1.0/thin)
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.005*Prior.xdim, units='x', pivot='mid', color='w', angles='uv', scale=1.1/thin)
+
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ plt.title('m (above %i %% max flux)' % int(pcut*100))
+
+ # Create title
+ plotstr = "step: %i " % nit
+ for key in chi2_dict.keys():
+ plotstr += "$\chi^2_{%s}$: %0.2f " % (key, chi2_dict[key])
+ plt.suptitle(plotstr, fontsize=18)
+
+
diff --git a/imaging/starwarps.py b/imaging/starwarps.py
new file mode 100644
index 00000000..16eb3d45
--- /dev/null
+++ b/imaging/starwarps.py
@@ -0,0 +1,1713 @@
+# See example_starwarps.py for an example of how to use these methods
+# Contact Katie Bouman (klbouman@caltech.edu) for any questions
+#
+# The methods/techniques used in this, referred to as StarWars, are described in
+# "Reconstructing Video from Interferometric Measurements of Time-Varying Sources"
+# by Katherine L. Bouman, Michael D. Johnson, Adrian V. Dalca,
+# Andrew Chael, Freek Roelofs, Sheperd S. Doeleman, and William T. Freeman
+
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+#import ehtim as eh
+import ehtim.image as image
+import ehtim.observing.pulses
+from ehtim.observing.obs_helpers import *
+from ehtim.imaging.imager_utils import chisqdata
+
+import scipy.stats as st
+import scipy
+import copy
+import sys
+
+import matplotlib.pyplot as plt
+
+PROPERROR = True
+
+##################################################################################################
+
+
+def solve_singleImage(mu, Lambda_orig, obs, measurement={'vis':1}, numLinIters=5, mask=[], normalize=False):
+
+ if len(mask):
+ Lambda = Lambda_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1])
+ else:
+ Lambda = Lambda_orig
+ mask = np.ones(mu.imvec.shape)>0
+
+ if list(measurement.keys())==1 and measurement.keys()[0]=='vis':
+ numLinIters = 1
+
+ z_List_t_t = mu.copy()
+ z_List_lin = mu.copy()
+
+ for k in range(0,numLinIters):
+ meas, idealmeas, F, measCov, valid = getMeasurementTerms(obs, z_List_lin, measurement=measurement, mask=mask, normalize=normalize)
+ if valid:
+ z_List_t_t.imvec[mask], P_List_t_t = prodGaussiansLem2(F, measCov, meas, mu.imvec[mask], Lambda)
+
+ if k < numLinIters-1:
+ z_List_lin = z_List_t_t.copy()
+ else:
+ z_List_t_t = mu.copy()
+ P_List_t_t = copy.deepcopy(Lambda)
+
+ return (z_List_t_t, P_List_t_t, z_List_lin)
+
+
+##################################################################################################
+
+def forwardUpdates_apxImgs(mu, Lambda_orig, obs_List, A_orig, Q_orig, init_images, measurement={'vis':1}, lightcurve=None, numLinIters=5, interiorPriors=False, mask=[], normalize=False):
+ '''
+ Gaussian image prior:
+ :param mu: (list - len(mu)=num_time_steps or 1): every element is an image object which contains the mean image
+ at given timestep. If list length is one mean image is duplicated for all time steps
+ :param Lambda_orig: (list - len(Lambda_orig)=num_time_steps or 1): original unmasked covariance matrix.
+ Every element is a 2D numpy array which contains the covariance at a given timestep.
+ If list length is one, the cov image is duplicated for all time steps.
+
+ Observations:
+ :param obs_List: list of observations, for each time step
+
+ Dynamical Evolution Model
+ :param A_orig: original unmasked A matrix - time-invariant mean of warp field for dynamical evolution
+ :param Q_orig: original unmasked Q matrix - time-invariant covariance matrix of dynamical evolution model,
+ describing the amount of allowed intensity deviation
+
+ Other Parameters:
+ :param init_images: option to provide initialization for the forward updates.
+ If none provided, then use the initialization from StarWarps paper
+ :param measurement: data products used
+ :param lightcurve: light curve time seres, needed if imposing a flux constraint
+ :param numLinIters: number of linearized iterations. We have non-linear measurement function f_{t}(x_{t}),
+ and we linearize the solution around \tilde{x_{t}} by taking the first order Taylor series expansion of f.
+ To improve the solution of the forward and backward terms, each step in the forward pass can be
+ iteratively re-solved and \tilde{x}_{t} can be updated at each iteration.
+ The values of \tilde{x}_{t} are fixed for the backward pass.
+ Note that if f is linear, only a single iteration will be enough to converge to the optimal solution.
+ Thus, if the only measurement is visibility, numLinIters = 1 should be set.
+ :param interiorPriors: flag for whether to use interior priors
+ :param mask: to select parts of the image to utilize. default is True for all pixels.
+ This is because getMeasurementTerms doesn't work with a mask yet.
+ :param normalize: flag for whether to normalize sigma in getMeasurementTerms
+ '''
+
+ # linear case: measurement function is linear and problem is convex
+ if list(measurement.keys())==1 and measurement.keys()[0]=='vis':
+ numLinIters = 1
+
+ # apply mask
+ if len(mask):
+ A = A_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1])
+ Q = Q_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1])
+ Lambda = []
+ for t in range(0, len(Lambda_orig)):
+ Lambda.append( Lambda_orig[t][mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1]) )
+ else:
+ Lambda = Lambda_orig
+ A = A_orig
+ Q = Q_orig
+
+ #if measurement=='bispectrum':
+ # print 'WARNING: check the loglikelihood for non-linear functions'
+
+ # create an image of 0's
+ zero_im = mu[0].copy()
+ zero_im.imvec = 0.0*zero_im.imvec
+
+ # intitilize the prediction and update mean and covariances
+ z_List_t_t = [] # Mean of the hidden (state) image at t given data up to time t
+ P_List_t_t = [] # Covariance of the hidden (state) image at t given data up to time t
+
+ # initialize z and P lists with to be all zeros
+ for t in range(0,len(obs_List)):
+ z_List_t_t.append(zero_im.copy())
+ P_List_t_t.append(np.zeros(Lambda[0].shape))
+
+ # prediction 1 list: prediction at time t given all information up to t-1
+ z_List_t_tm1 = copy.deepcopy(z_List_t_t)
+ P_List_t_tm1 = copy.deepcopy(P_List_t_t)
+
+ # prediction 2 list: deep copy of prediction 1 list (possibly intermediate state variable)
+ z_star_List_t_tm1 = copy.deepcopy(z_List_t_t)
+ P_star_List_t_tm1 = copy.deepcopy(P_List_t_t)
+ # initialize linearization z's
+ z_List_lin = copy.deepcopy(z_List_t_t)
+
+ loglikelihood_prior = 0.0
+ loglikelihood_data = 0.0
+
+ # for each forward timestep...
+ for t in range(0,len(obs_List)):
+ sys.stdout.write('\rForward timestep %i of %i total timesteps...' % (t,len(obs_List)))
+ sys.stdout.flush()
+
+ print('forward timestep: ' + str(t))
+
+ # Duplicate mean and covariance if needed
+ if len(mu) == 1:
+ mu_t = mu[0]
+ Lambda_t = Lambda[0]
+ # get current image's mean and covariance
+ else:
+ mu_t = mu[t]
+ Lambda_t = Lambda[t]
+
+ # use lightcurve data if provided
+ if lightcurve:
+ tot = np.sum(mu_t.imvec)
+ mu_t.imvec = lightcurve[t]/tot * mu_t.imvec
+ Lambda_t = (lightcurve[t]/tot)**2 * Lambda_t
+
+
+ # predict
+ # Initialization of hidden state mean and covariance for t = 0
+ if t==0:
+ z_star_List_t_tm1[t].imvec = copy.deepcopy(mu_t.imvec)
+ P_star_List_t_tm1[t] = copy.deepcopy(Lambda_t)
+ # StarWarps initialization of hidden state mean and covariance
+ else:
+ z_List_t_tm1[t].imvec[mask] = np.dot( A, z_List_t_t[t-1].imvec[mask] )
+ if PROPERROR:
+ P_List_t_tm1[t] = Q + np.dot( np.dot( A, P_List_t_t[t-1] ), np.transpose(A) )
+ else:
+ print('no prop error')
+ P_List_t_tm1[t] = Q
+
+ # main predict step, using Lemma 1 from StarWarps supplementary doc (also see eq 29-30), using interior priors
+ if interiorPriors:
+ z_star_List_t_tm1[t].imvec[mask], P_star_List_t_tm1[t] = prodGaussiansLem1( mu_t.imvec[mask], Lambda_t, z_List_t_tm1[t].imvec[mask], P_List_t_tm1[t] )
+ else:
+ z_star_List_t_tm1[t] = z_List_t_tm1[t].copy()
+ P_star_List_t_tm1[t] = copy.deepcopy(P_List_t_tm1[t])
+
+
+ # update
+ # either go with user-provided initialization (if given) or take z_star_List_t_tm1 as an initialization
+ if init_images is None:
+ init_images_t = z_star_List_t_tm1[t].copy()
+ elif len(init_images) == 1:
+ init_images_t = init_images[0]
+ else:
+ init_images_t = init_images[t]
+ z_List_lin[t] = init_images_t.copy()
+
+ # Do the linearized iterations
+ for k in range(0,numLinIters):
+ # F is the derivative of the Forward model with respect to the unknown parameters
+ if lightcurve:
+ meas, idealmeas, F, measCov, valid = getMeasurementTerms(obs_List[t], z_List_lin[t], measurement=measurement, tot_flux=lightcurve[t], mask=mask, normalize=normalize)
+ else:
+ meas, idealmeas, F, measCov, valid = getMeasurementTerms(obs_List[t], z_List_lin[t], measurement=measurement, tot_flux=None, mask=mask, normalize=normalize)
+
+ # main update step, using Lemma 2 from StarWarps supplementary doc (also see eq 30-31)
+ if valid:
+ z_List_t_t[t].imvec[mask], P_List_t_t[t] = prodGaussiansLem2(F, measCov, meas, z_star_List_t_tm1[t].imvec[mask], P_star_List_t_tm1[t])
+
+ if k < numLinIters-1:
+ z_List_lin[t] = z_List_t_t[t].copy()
+ else:
+ z_List_t_t[t] = z_star_List_t_tm1[t].copy()
+ P_List_t_t[t] = copy.deepcopy(P_star_List_t_tm1[t])
+
+ # update the prior log likelihood, using interior priors
+ if t>0 and interiorPriors:
+ loglikelihood_prior = loglikelihood_prior + evaluateGaussianDist_log( z_List_t_tm1[t].imvec[mask], mu_t.imvec[mask], Lambda_t + P_List_t_tm1[t] )
+
+ # update the data log likelihood
+ if valid:
+ loglikelihood_data = loglikelihood_data + evaluateGaussianDist_log( np.dot(F , z_star_List_t_tm1[t].imvec[mask]), meas, measCov + np.dot( F, np.dot(P_star_List_t_tm1[t], F.T)) )
+
+
+ # compute the log likelihood (equation 27 in StarWarps paper)
+ loglikelihood = loglikelihood_prior + loglikelihood_data
+ return ((loglikelihood_data, loglikelihood_prior, loglikelihood), z_List_t_tm1, P_List_t_tm1, z_List_t_t, P_List_t_t, z_List_lin)
+
+
+###################################### EXTENDED MESSAGE PASSING ########################################
+
+def backwardUpdates(mu, Lambda_orig, obs_List, A_orig, Q_orig, measurement={'vis':1}, lightcurve=None, apxImgs=False, mask=[], normalize=False):
+ '''
+ Gaussian image prior:
+ :param mu: (list - len(mu)=num_time_steps or 1): every element is an image object which contains the mean image
+ at given timestep. If list length is one mean image is duplicated for all time steps
+ :param Lambda_orig: (list - len(Lambda_orig)=num_time_steps or 1): original unmasked covariance matrix.
+ Every element is a 2D numpy array which contains the covariance at a given timestep.
+ If list length is one, the cov image is duplicated for all time steps.
+
+ Observations:
+ :param obs_List: list of observations, for each time step
+
+ Dynamical Evolution Model
+ :param A_orig: original unmasked A matrix - time-invariant mean of warp field for dynamical evolution
+ :param Q_orig: original unmasked Q matrix - time-invariant covariance matrix of dynamical evolution model,
+ describing the amount of allowed intensity deviation
+
+ Other Parameters
+ :param measurement: data products used
+ :param lightcurve: light curve time seres, needed if imposing a flux constraint
+ :param apxImgs: ???
+ :param mask: to select parts of the image to utilize. default is True for all pixels.
+ This is because getMeasurementTerms doesn't work with a mask yet.
+ :param normalize: flag for whether to normalize sigma in getMeasurementTerms
+ '''
+
+ if len(mask):
+ A = A_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1])
+ Q = Q_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1])
+ Lambda = []
+ for t in range(0, len(Lambda_orig)):
+ Lambda.append( Lambda_orig[t][mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1]) )
+ else:
+ Lambda = Lambda_orig
+ A = A_orig
+ Q = Q_orig
+
+ # create an image of 0's
+ zero_im = mu[0].copy()
+ zero_im.imvec = 0.0*zero_im.imvec
+
+ # intitilize the prediction and update mean and covariances
+ z_t_t = []
+ P_t_t = []
+ # update list
+ for t in range(0,len(obs_List)):
+ z_t_t.append(zero_im.copy())
+ P_t_t.append(np.zeros(Lambda[0].shape))
+ # prediction 1 list
+ z_star_t_tp1 = copy.deepcopy(z_t_t)
+ P_star_t_tp1 = copy.deepcopy(P_t_t)
+
+
+ lastidx = len(obs_List)-1
+ for t in range(lastidx,-1,-1):
+ sys.stdout.write('\rBackward timestep %i of %i total timesteps...' % (t,len(obs_List)))
+ sys.stdout.flush()
+
+ if len(mu) == 1:
+ mu_t = mu[0]
+ Lambda_t = Lambda[0]
+ else:
+ mu_t = mu[t]
+ Lambda_t = Lambda[t]
+ if lightcurve is not None:
+ tot = np.sum(mu_t.imvec)
+ mu_t.imvec = lightcurve[t]/tot * mu_t.imvec
+ Lambda_t = (lightcurve[t]/tot)**2 * Lambda_t
+
+ # predict
+ if t==lastidx:
+ z_star_t_tp1[t].imvec = copy.deepcopy(mu_t.imvec)
+ P_star_t_tp1[t] = copy.deepcopy(Lambda_t)
+ else:
+ if PROPERROR:
+ z_star_t_tp1[t].imvec[mask], P_star_t_tp1[t] = prodGaussiansLem2( A, Q + P_t_t[t+1], z_t_t[t+1].imvec[mask], mu_t.imvec[mask], Lambda_t)
+ else:
+ print('no prop error')
+ z_star_t_tp1[t].imvec[mask], P_star_t_tp1[t] = prodGaussiansLem2( A, Q, z_t_t[t+1].imvec[mask], mu_t.imvec[mask], Lambda_t)
+
+ # update
+ if lightcurve:
+ meas, idealmeas, F, measCov, valid = getMeasurementTerms(obs_List[t], apxImgs[t], measurement=measurement, tot_flux=lightcurve[t], mask=mask, normalize=normalize)
+ else:
+ meas, idealmeas, F, measCov, valid = getMeasurementTerms(obs_List[t], apxImgs[t], measurement=measurement, tot_flux=None, mask=mask, normalize=normalize)
+
+ if valid:
+ z_t_t[t].imvec[mask], P_t_t[t] = prodGaussiansLem2(F, measCov, meas, z_star_t_tp1[t].imvec[mask], P_star_t_tp1[t])
+
+ else:
+ z_t_t[t] = z_star_t_tp1[t].copy()
+ P_t_t[t] = copy.deepcopy(P_star_t_tp1[t])
+
+ return (z_t_t, P_t_t)
+
+
+def smoothingUpdates(z_t_t, P_t_t, z_t_tm1, P_t_tm1, A_orig, mask=[]):
+ '''
+ Smoothing
+ '''
+ z = copy.deepcopy(z_t_t)
+ P = copy.deepcopy(P_t_t)
+ backwardsA = copy.deepcopy(P_t_t)
+
+ if len(mask):
+ A = A_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1])
+
+ lastidx = len(z)-1
+ for t in range(lastidx,-1,-1):
+
+ if t < lastidx:
+ backwardsA[t] = np.dot( np.dot(P_t_t[t], A.T ), np.linalg.inv(P_t_tm1[t+1]) )
+ z[t].imvec[mask] = z_t_t[t].imvec[mask] + np.dot( backwardsA[t], z[t+1].imvec[mask] - z_t_tm1[t+1].imvec[mask] )
+ P[t] = np.dot( np.dot( backwardsA[t] , P[t+1] - P_t_tm1[t+1]), backwardsA[t].T ) + P_t_t[t]
+
+ return (z, P, backwardsA)
+
+
+
+def computeSuffStatistics(mu, Lambda, obs_List, Upsilon, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, init_images=None, method='phase', measurement={'vis':1}, lightcurve=None, interiorPriors=False, numLinIters=1, compute_expVal_tm1_t=True, mask=[], normalize=False):
+ '''
+ Gaussian image prior:
+ :param mu: (list - len(mu)=num_time_steps or 1): every element is an image object which contains the mean image
+ at given timestep. If list length is one mean image is duplicated for all time steps
+ :param Lambda_orig: (list - len(Lambda_orig)=num_time_steps or 1): original unmasked covariance matrix.
+ Every element is a 2D numpy array which contains the covariance at a given timestep.
+ If list length is one, the cov image is duplicated for all time steps.
+
+ Observations:
+ :param obs_List: list of observations, for each time step
+
+ Dynamical Evolution Model
+ :param A_orig: original unmasked A matrix - time-invariant mean of warp field for dynamical evolution
+ :param Q_orig: original unmasked Q matrix - time-invariant covariance matrix of dynamical evolution model,
+ describing the amount of allowed intensity deviation
+
+ Other Parameters:
+ :param init_images: option to provide initialization for the forward updates.
+ If none provided, then use the initialization from StarWarps paper
+ :param measurement: data products used
+ :param lightcurve: light curve time seres, needed if imposing a flux constraint
+ :param numLinIters: number of linearized iterations. We have non-linear measurement function f_{t}(x_{t}),
+ and we linearize the solution around \tilde{x_{t}} by taking the first order Taylor series expansion of f.
+ To improve the solution of the forward and backward terms, each step in the forward pass can be
+ iteratively re-solved and \tilde{x}_{t} can be updated at each iteration.
+ The values of \tilde{x}_{t} are fixed for the backward pass.
+ Note that if f is linear, only a single iteration will be enough to converge to the optimal solution.
+ Thus, if the only measurement is visibility, numLinIters = 1 should be set.
+ :param interiorPriors: Flag for whether to use interior priors.
+ :param compute_expVal_tm1_t: flag for whether to compute the second sufficient statistic, E[x_{t-1}x_{t}^{T}].
+ :param mask: to select parts of the image to utilize. default is True for all pixels.
+ This is because getMeasurementTerms doesn't work with a mask yet.
+ :param normalize: flag for whether to normalize sigma in getMeasurementTerms
+ '''
+
+ # if mask not provided, create default mask
+ if not len(mask):
+ mask = np.ones(mu[0].imvec.shape)>0
+
+ # check if first mean image is square
+ if mu[0].xdim != mu[0].ydim:
+ error('Error: This has only been checked thus far on square images!')
+
+ # lightcurve and flux constraint go together
+ if lightcurve == None and 'flux' in measurement.keys(): #KATIE ADDED FEB 1 2021
+ error('Error: if you are using a flux constraint you must specify a lightcurve')
+
+ # if the visibility is the only measurement,
+ if list(measurement.keys())==1 and measurement.keys()[0]=='vis':
+ numLinIters = 1
+
+ # calculate matrix to represent warp field, from first mean image
+ warpMtx = calcWarpMtx(mu[0], theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method)
+
+ # Parameterize dynamical evolution model with Guassian
+ # the time-invariant mean warp field
+ A = warpMtx
+ # time-invariant covariance matrix describing amount of allowed intensity variation in the warp field
+ Q = Upsilon
+
+ # Do forward passes
+ loglikelihood, z_t_tm1, P_t_tm1, z_t_t, P_t_t, apxImgs = forwardUpdates_apxImgs(mu, Lambda, obs_List, A, Q, init_images=init_images, measurement=measurement, lightcurve=lightcurve, interiorPriors=interiorPriors, numLinIters=numLinIters, mask=mask, normalize=normalize)
+
+ # Extended message passing with backward passes, using interior priors
+ if interiorPriors:
+ # Do backward passes
+ z_backward_t_t, P_backward_t_t = backwardUpdates(mu, Lambda, obs_List, A, Q, measurement=measurement, lightcurve=lightcurve, apxImgs=apxImgs, mask=mask, normalize=normalize)
+
+ z = copy.deepcopy(z_backward_t_t)
+ P = copy.deepcopy(P_backward_t_t)
+ for t in range(0,len(obs_List)):
+ if t==0:
+ z[t] = z_backward_t_t[t].copy()
+ P[t] = copy.deepcopy(P_backward_t_t[t])
+ else:
+ z[t].imvec[mask], P[t] = prodGaussiansLem1(z_t_tm1[t].imvec[mask], P_t_tm1[t], z_backward_t_t[t].imvec[mask], P_backward_t_t[t])
+
+ # Use smoothing updates instead
+ else:
+ z, P, backwardsA = smoothingUpdates(z_t_t, P_t_t, z_t_tm1, P_t_tm1, A, mask=mask)
+
+
+
+ expVal_t = copy.deepcopy(z)
+ #initilize the lists
+ expVal_t_t = copy.deepcopy(P)
+ expVal_tm1_t = copy.deepcopy(P)
+ for t in range(0,len(obs_List)):
+ # expected value of xx^T for each x
+ z_t_hvec = np.array([z[t].imvec[mask]])
+ expVal_t_t[t] = np.dot(z_t_hvec.T, z_t_hvec) + P[t]
+
+ # expected value of x_t x_t-1^T for each x except for the first one
+ if t>0 and interiorPriors==False and compute_expVal_tm1_t:
+ z_tm1_hvec = np.array([z[t-1].imvec[mask]])
+ expVal_tm1_t[t] = np.dot(z_tm1_hvec.T, z_t_hvec) + np.dot(backwardsA[t-1], P[t])
+
+ # expected value of x_t x_t-1^T, using interior priors
+ if interiorPriors and compute_expVal_tm1_t:
+ expVal_tm1_t = JointDist(z, z_t_t, P_t_t, z_backward_t_t, P_backward_t_t, A, Q)
+
+ return (expVal_t, expVal_t_t, expVal_tm1_t, loglikelihood, apxImgs)
+
+
+
+
+def JointDist(z, z_List_t_t_forward, P_List_t_t_forward, z_List_t_t_backward, P_List_t_t_backward, A, Q):
+ '''
+ Calculate the joint distribution p(x_{t},x_{t-1} | y_{1:N})
+ See section 2.2 in StarWarps supplementary doc, starting from eq 60
+ '''
+
+ expVal_tm1_t = []
+ expVal_tm1_t.append(0.0)
+
+ # section 2.2
+ for t in range(1, len(z_List_t_t_forward) ):
+
+ Sigma = Q + P_List_t_t_backward[t]
+ Sigma_inv = np.linalg.inv(Sigma)
+
+ # eq 76
+ M = np.dot(P_List_t_t_backward[t], np.dot(Sigma_inv, A) )
+
+ # eq 77, 78
+ (m, C) = prodGaussiansLem2(A, Sigma, z_List_t_t_backward[t].imvec, z_List_t_t_forward[t-1].imvec, P_List_t_t_forward[t-1])
+
+ # eq 79
+ D_tmp1 = np.dot(M, np.dot(C, M.T))
+ D_tmp2 = np.dot(Q, np.dot( Sigma_inv, P_List_t_t_backward[t] ) )
+ D = np.dot(C, np.dot(M.T, np.linalg.inv(D_tmp1 + D_tmp2) ) )
+
+ # eq 81
+ F = C - np.dot(D, np.dot(M, C))
+
+ # E[x_{t}]
+ z_t_hvec = np.array([z[t].imvec])
+ # E[x_{t-1}]
+ z_tm1_hvec = np.array([z[t-1].imvec])
+
+ # eq 88
+ expVal_tm1_t.append( np.dot(F, np.linalg.inv(D.T)) + np.dot(z_tm1_hvec.T, z_t_hvec) )
+
+ return expVal_tm1_t
+
+
+
+#################################
+
+
+def maximizeWarpMtx(expVal_t_t, expVal_tm1_t, expVal_t=0, B=0):
+
+ M1 = np.zeros(expVal_tm1_t[1].shape)
+ M2 = np.zeros(expVal_t_t[1].shape)
+
+ for t in range(1,len(expVal_t_t)):
+ #M1 = M1 + 0.5*expVal_tm1_t[t] + 0.5*expVal_tm1_t[t].T
+ M1 = M1 + expVal_tm1_t[t].T
+ if B !=0:
+ M1 = M1 + np.dot(B, expVal_t[t])
+ M2 = M2 + expVal_t_t[t-1]
+
+ warpMtx = np.dot( M1, np.linalg.inv(M2) )
+ return warpMtx
+
+def maximizeTheta_multiIter(expVal_t_t, expVal_tm1_t, dummy_im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase', nIter=10):
+
+ newTheta = centerTheta
+ for i in range(0, nIter):
+ newTheta = maximizeTheta(expVal_t_t, expVal_tm1_t, dummy_im, newTheta, newTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method)
+
+ return newTheta
+
+
+def maximizeTheta(expVal_t_t, expVal_tm1_t, dummy_im, Q, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase'):
+
+ if method == 'phase' or method == 'approx_phase':
+ dWarp_dTheta = calc_dWarp_dTheta(dummy_im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase')
+ else:
+ error('ERROR: WE ONLY HANDLE PHASE WARP MINIMIZATION RIGHT NOW')
+
+ warpMtx = calcWarpMtx(dummy_im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method)
+ invQ = np.linalg.inv(Q)
+
+ nbasis = len(initTheta)
+ thetaNew = np.zeros( initTheta.shape )
+
+ G1 = np.zeros(expVal_tm1_t[1].shape)
+ for t in range(1,len(expVal_t_t)):
+ #G1 = G1 + 0.5*expVal_tm1_t[t] + 0.5*expVal_tm1_t[t].T - np.dot( warpMtx, expVal_t_t[t-1] )
+ G1 = G1 + expVal_tm1_t[t].T - np.dot( warpMtx, expVal_t_t[t-1] )
+ #G1 = G1 + expVal_tm1_t[t] - np.dot( warpMtx, expVal_t_t[t-1] )
+ for b in range(0, nbasis):
+ G1 = G1 + np.dot( dWarp_dTheta[b], expVal_t_t[t-1] )*centerTheta[b]
+ G1 = np.dot(invQ, G1)
+
+ G2 = []
+ for b in range (0,nbasis):
+ G2.append(np.zeros(expVal_t_t[1].shape))
+ for t in range(1,len(expVal_t_t)):
+ G2[b] = G2[b] + np.dot( dWarp_dTheta[b], expVal_t_t[t-1] )
+ G2[b] = np.dot(invQ, G2[b])
+
+ D1 = np.zeros(initTheta.shape)
+ for b1 in range(0, nbasis):
+ for p in range(0,dWarp_dTheta[b1].shape[0]):
+ for q in range(0,dWarp_dTheta[b1].shape[1]):
+ D1[b1] = D1[b1] + G1[p,q]*dWarp_dTheta[b1][p,q]
+
+ D2 = np.zeros([nbasis, nbasis])
+ for b1 in range(0, nbasis):
+ for b2 in range(0, nbasis):
+ for p in range(0,dWarp_dTheta[b1].shape[0]):
+ for q in range(0,dWarp_dTheta[b1].shape[1]):
+ D2[b1,b2] = D2[b1,b2] + G2[b2][p,q]*dWarp_dTheta[b1][p,q]
+
+
+ thetaNew = np.dot(np.linalg.inv(D2), D1)
+
+
+
+ secondDeriv = np.zeros((nbasis, nbasis))
+ for b in range(0, nbasis):
+ thetaNew_tmp = copy.deepcopy(thetaNew)
+ thetaNew_tmp[b] = 1.0
+ secondDeriv[:,b] = D1 - np.dot(D2, thetaNew_tmp)
+ eigvals,_ = np.linalg.eig(secondDeriv)
+ if all(eigvals>0):
+ print('local min')
+ elif all(eigvals<0):
+ print('local max')
+ elif any(eigvals==0.0):
+ print('inconclusive')
+ else:
+ print('saddle point: ' + str(np.sum(eigvals<0)) + ' negative eigs of ' + str(len(eigvals)))
+
+
+ return (thetaNew, secondDeriv, D1, D2)
+
+
+def negloglikelihood(theta, mu, Lambda, obs_List, Upsilon, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method, measurement, interiorPriors, mask=[]):
+
+ warpMtx = calcWarpMtx(mu, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method)
+
+ A = warpMtx
+ B = np.zeros(mu.imvec.shape)
+ Q = Upsilon
+
+ loglike, z_t_tm1, P_t_tm1, z_t_t, P_t_t = forwardUpdates(mu, Lambda, obs_List, A, B, Q, measurement=measurement, interiorPriors=interiorPriors, mask=mask)
+
+ return -loglike[2]
+
+def expnegloglikelihood_full(theta, expectation_theta, mu, Lambda, obs_List, Q, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method, measurement, interiorPriors, numLinIters, apxImgs):
+
+ expVal_t, expVal_t_t, expVal_tm1_t, _ = computeSuffStatistics(mu, Lambda, obs_List, Q, expectation_theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method, measurement=measurement, interiorPriors=interiorPriors, numLinIters=numLinIters, apxImgs=apxImgs)
+ neg_expll = expnegloglikelihood(theta, expVal_t, expVal_t_t, expVal_tm1_t, mu, Lambda, obs_List, Q, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method)
+ print(neg_expll)
+
+
+def expnegloglikelihood(theta, expVal_t, expVal_t_t, expVal_tm1_t, mu, Lambda, obs_List, Upsilon, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method):
+
+ #if interiorPriors:
+ #TODO: print 'WARNING: not sure if this works with interior priors because of the derivation of the E[xMx] terms may be different'
+
+ warpMtx = calcWarpMtx(mu[0], theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method)
+ A = warpMtx
+ B = np.zeros(mu[0].imvec.shape)
+ Q = Upsilon
+ invQ = np.linalg.inv(Q)
+
+ value = 0.0
+ for t in range(1, len(expVal_t)):
+ x_t = np.array([expVal_t[t].imvec]).T
+ x_tm1 = np.array([expVal_t[t-1].imvec]).T
+
+ P_tm1_t = expVal_tm1_t[t] - np.dot(x_tm1, x_t.T)
+ P_tm1_tm1 = expVal_t_t[t-1] - np.dot(x_tm1, x_tm1.T)
+
+ term1 = exp_xtm1_M_xt(P_tm1_t.T, x_t, x_tm1, np.dot(invQ, A) )
+ term2 = exp_xtm1_M_xt(P_tm1_t, x_tm1, x_t, np.dot(A.T, invQ) )
+ term3 = exp_xtm1_M_xt(P_tm1_tm1, x_tm1, x_tm1, np.dot(A.T, np.dot(invQ, A) ) )
+ term4 = np.dot(B.T, np.dot(invQ, np.dot(A, x_tm1)))
+
+ value = value - 0.5*( -term1 - term2 + term3 + term4 + term4.T )
+
+ value = -value
+ return value
+
+def exp_xtm1_M_xt(P, z1, z2, M):
+ value = np.trace( np.dot(P, M.T) ) + np.dot(z1.T, np.dot(M, z2 ) )
+ return value
+
+
+
+def deriv_expnegloglikelihood_full(theta, expectation_theta, mu, Lambda, obs_List, Q, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method, measurement, interiorPriors, numLinIters, apxImgs):
+
+ expVal_t, expVal_t_t, expVal_tm1_t, _ = computeSuffStatistics(mu, Lambda, obs_List, Q, expectation_theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method, measurement=measurement, interiorPriors=interiorPriors, numLinIters=numLinIters, apxImgs=apxImgs)
+ return deriv_expnegloglikelihood(theta, expVal_t, expVal_t_t, expVal_tm1_t, mu, Lambda, obs_List, Q, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method)
+
+
+def deriv_expnegloglikelihood(theta, expVal_t, expVal_t_t, expVal_tm1_t, mu, Lambda, obs_List, Q, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method):
+
+ if method == 'phase':
+ dWarp_dTheta = calc_dWarp_dTheta(mu[0], theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method)
+ else:
+ print('WARNING: WE ONLY HANDLE PHASE WARP MINIMIZATION RIGHT NOW')
+
+ warpMtx = calcWarpMtx(mu[0], theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method)
+
+ invQ = np.linalg.inv(Q)
+ M1 = np.zeros(expVal_tm1_t[1].shape)
+ for t in range(1,len(expVal_t_t)):
+ #M1 = M1 + 0.5*expVal_tm1_t[t] + 0.5*expVal_tm1_t[t].T - np.dot( warpMtx, expVal_t_t[t-1] )
+ M1 = M1 + expVal_tm1_t[t].T - np.dot( warpMtx, expVal_t_t[t-1] )
+ M1 = np.dot( invQ , M1)
+
+ deriv = np.zeros( initTheta.shape )
+ for b in range(0,len(initTheta)):
+ for p in range(0,dWarp_dTheta[b].shape[0]):
+ for q in range(0,dWarp_dTheta[b].shape[1]):
+ deriv[b] = deriv[b] + M1[p,q]*dWarp_dTheta[b][p,q]
+
+ # the derivative computed is for the ll but we want the derivative of the neg ll
+ deriv = -deriv
+ return deriv
+
+
+def maximizeBrightness(expVal_t_t, expVal_tm1_t, dummy_im, Q):
+
+ dWarp_dTheta = np.eye(dummy_im.xdim*dummy_im.ydim)
+ invQ = np.linalg.inv(Q)
+
+ M1 = np.zeros(expVal_tm1_t[1].shape)
+ M2 = np.zeros(expVal_t_t[1].shape)
+
+ top = 0.0
+ bottom = 0.0
+
+ for t in range(1,len(expVal_t_t)):
+ #M1 = M1 + 0.5*expVal_tm1_t[t] + 0.5*expVal_tm1_t[t].T
+ M1 = M1 + expVal_tm1_t[t].T
+ M2 = M2 + expVal_t_t[t-1]
+ M1 = np.dot(invQ, M1)
+ M2 = np.dot(invQ, M2)
+
+ for p in range(0,dWarp_dTheta.shape[0]):
+ for q in range(0,dWarp_dTheta.shape[1]):
+ top = top + M1[p,q]*dWarp_dTheta[p,q]
+ bottom = bottom + M2[p,q]*dWarp_dTheta[p,q]
+
+ thetaNew = top/bottom
+
+ return thetaNew
+
+
+
+def evaluateGaussianDist_log(y, x, Sigma):
+
+ n = len(x)
+ if len(x) != np.prod(x.shape):
+ raise AssertionError()
+
+ diff = x - y
+ (sign, logdet) = np.linalg.slogdet(Sigma)
+ expval_log = - (n/2.0)*np.log( 2.0*np.pi ) - 0.5*(sign*logdet) - 0.5*np.dot( diff.T, np.dot( np.linalg.inv(Sigma), diff ) )
+
+ return expval_log
+
+def evaluateGaussianDist(y, x, Sigma):
+
+ expval_log = evaluateGaussianDist_log(y, x, Sigma)
+ expval = np.exp( expval_log )
+ return expval
+
+def prodGaussiansLem1(m1, S1, m2, S2):
+
+ K = np.linalg.inv(S1 + S2)
+
+ covariance = np.dot( S1, np.dot( K, S2 ) )
+ mean = np.dot(S1, np.dot(K, m2) ) + np.dot(S2, np.dot(K, m1) )
+
+ return (mean, covariance)
+
+def prodGaussiansLem2(A, Sigma, y, mu, Q):
+
+ K1 = np.linalg.inv( Sigma + np.dot(A, np.dot(Q, np.transpose(A))) )
+ K2 = np.dot( Q, np.dot( A.T, K1 ) )
+
+ covariance = Q - np.dot( K2, np.dot( A, Q ) )
+ mean = mu + np.dot( K2, y - np.dot(A, mu) )
+
+ return (mean, covariance)
+
+
+def getMeasurementTerms(obs, im, measurement={'vis': 1}, tot_flux=None, mask=[], normalize=False):
+ if not np.sum(mask)==len(mask):
+ raise ValueError('The code doenst currently work with a mask!')
+
+ #initilize the concatenated data terms
+ measdiff_all = []
+ ideal_all = []
+ F_all = []
+ Cov_all = []
+ data_all = []
+
+ count = 0
+ # loop through data products we want to constrain
+ for dname in list(measurement.keys()):
+
+ # ignore data terms that 0 weight
+ if np.allclose(measurement[dname],0.0):
+ continue
+
+ # check to see if you have data in the current obs
+ try:
+ if dname=='flux':
+ if tot_flux == None:
+ error('Error: if you are using a flux constraint you must specify a total flux (via the lightcurve)')
+ data = np.array([tot_flux])
+ sigma = np.array([1])
+ else:
+ data, sigma, A = chisqdata(obs, im, mask, dtype=dname, ttype='direct')
+ count = count + 1
+ except:
+ continue
+
+ #compute the derivative matrix and the ideal measurements if im was the true image
+ if dname == 'vis':
+ F = A
+ ideal = np.dot(A,im.imvec)
+ elif dname == 'bs':
+ F = grad_bs(im.imvec, A)
+ ideal = bs(im.imvec,A)
+ elif dname == 'cphase':
+ F = grad_cphase(im.imvec, A)
+ ideal = cphase(im.imvec,A)
+ elif dname == 'amp':
+ F = grad_amp(im.imvec, A)
+ ideal = amp(im.imvec,A)
+ elif dname == 'logcamp':
+ F = grad_logcamp(im.imvec, A)
+ ideal = logcamp(im.imvec,A)
+ elif dname == 'flux':
+ F = grad_flux(im.imvec)
+ ideal = flux(im.imvec)
+
+ #turn complex matrices to real
+ if not np.allclose(data.imag,0):
+ F = realimagStack(F)
+ data = realimagStack(data)
+ ideal = realimagStack(ideal)
+ sigma = np.concatenate((sigma,sigma), axis=0)
+
+ # change the error bars based upon which elements we want to constrain more
+ weight = measurement[dname]
+ if normalize:
+ sigma = sigma / np.sqrt(np.sum(sigma ** 2))
+ sigma = sigma / np.sqrt(weight)
+ Cov = np.diag(sigma ** 2)
+
+ data_all = np.concatenate((data_all, data.reshape(-1)), axis=0).reshape(-1)
+ ideal_all = np.concatenate((ideal_all, ideal.reshape(-1)), axis=0).reshape(-1)
+ measdiff_all = np.concatenate(
+ (measdiff_all, data.reshape(-1) + np.dot(F, im.imvec[mask]) - ideal.reshape(-1)), axis=0)
+ F_all = np.concatenate((F_all, F), axis=0) if len(F_all) else F
+ Cov_all = scipy.linalg.block_diag(Cov_all, Cov) if len(Cov_all) else Cov
+
+ if len(data_all):
+ return (measdiff_all, ideal_all, F_all, Cov_all, True)
+ else:
+ return (-1, -1, -1, -1, False)
+
+def bs(imvec, Amatrices):
+ """the bispectrum"""
+ out = np.dot(Amatrices[0],imvec) * np.dot(Amatrices[1],imvec) * np.dot(Amatrices[2],imvec)
+ return out
+
+def grad_bs(imvec, Amatrices):
+ """The gradient of the bispectrum"""
+ pt1 = np.dot(Amatrices[1],imvec) * np.dot(Amatrices[2],imvec)
+ pt2 = np.dot(Amatrices[0],imvec) * np.dot(Amatrices[2],imvec)
+ pt3 = np.dot(Amatrices[0],imvec) * np.dot(Amatrices[1],imvec)
+ out = pt1[:,None] * Amatrices[0] + pt2[:,None] * Amatrices[1] + pt3[:,None] * Amatrices[2]
+ return out
+
+def flux(imvec):
+ return np.sum(imvec)
+
+def grad_flux(imvec):
+ return np.ones((1, len(imvec)))
+
+def cphase(imvec, Amatrices):
+ """the closure phase"""
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ clphase_samples = np.angle(i1 * i2 * i3)
+ out = np.exp(1j * clphase_samples)
+ return out
+
+def grad_cphase(imvec, Amatrices):
+ """The gradient of the closure phase"""
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ clphase_samples = np.angle(i1 * i2 * i3)
+ pt1 = 1.0/i1
+ pt2 = 1.0/i2
+ pt3 = 1.0/i3
+ dphi = (pt1[:,None]*Amatrices[0]) + (pt2[:,None] * Amatrices[1]) + \
+ (pt3[:,None]* Amatrices[2])
+ out = 1j * np.imag(dphi) * np.exp(1j * clphase_samples[:,None])
+ return out
+
+def amp(imvec, A):
+ """the amplitude"""
+ i1 = np.dot(A, imvec)
+ out = np.abs(i1)
+ return out
+
+def grad_amp(imvec, A):
+ """The gradient of the amplitude """
+ i1 = np.dot(A, imvec)
+ pp = np.abs(i1) / i1
+ out = np.real(pp[:,None] * A)
+ return out
+
+def camp(imvec, Amatrices):
+ """the closure amplitude"""
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ i4 = np.dot(Amatrices[3], imvec)
+ out = np.abs((i1 * i2)/(i3 * i4))
+ return out
+
+
+def grad_camp(imvec, Amatrices):
+ """The gradient of the closure amplitude """
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ i4 = np.dot(Amatrices[3], imvec)
+ clamp_samples = np.abs((i1 * i2)/(i3 * i4))
+
+ pt1 = clamp_samples/i1
+ pt2 = clamp_samples/i2
+ pt3 = -clamp_samples/i3
+ pt4 = -clamp_samples/i4
+ out = (pt1[:,None]*Amatrices[0]) + (pt2[:,None]*Amatrices[1]) + (pt3[:,None]*Amatrices[2]) + (pt4[:,None]*Amatrices[3])
+ return np.real(out)
+
+
+def logcamp(imvec, Amatrices):
+ """The Log closure amplitude """
+
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ i4 = np.dot(Amatrices[3], imvec)
+ out = np.log(np.abs(i1)) + np.log(np.abs(i2)) - np.log(np.abs(i3)) - np.log(np.abs(i4))
+ return out
+
+def grad_logcamp(imvec, Amatrices):
+ """The gradient of the Log closure amplitude """
+
+ i1 = np.dot(Amatrices[0], imvec)
+ i2 = np.dot(Amatrices[1], imvec)
+ i3 = np.dot(Amatrices[2], imvec)
+ i4 = np.dot(Amatrices[3], imvec)
+
+ pt1 = 1/i1
+ pt2 = 1/i2
+ pt3 = -1/i3
+ pt4 = -1/i4
+ out = np.real(pt1[:,None] * Amatrices[0] + pt2[:,None] * Amatrices[1] + \
+ pt3[:,None] * Amatrices[2] + pt4[:,None] * Amatrices[3])
+ return out
+
+
+def mergeObs(obs_List):
+
+ obs = obs_List[0].copy()
+ data = obs.data
+ for t in range(1,len(obs_List)):
+ data = np.concatenate((data, obs_List[t].data))
+ obs.data = data
+ return obs
+
+
+def splitObs(obs):
+ """Split single observation into multiple observation files, one per scan
+ """
+
+ print("Splitting Observation File into " + str(len(obs.tlist())) + " scans")
+
+ #Note that the tarr of the output includes all sites, even those that don't participate in the scan
+ obs_List = [ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, tdata, obs.tarr, source=obs.source,
+ mjd=obs.mjd, ampcal=obs.ampcal, phasecal=obs.phasecal)
+ for tdata in obs.tlist()
+ ]
+ return obs_List
+
+
+def movie(im_List, out='movie.mp4', fps=10, dpi=120):
+ import matplotlib
+ matplotlib.use('agg')
+ import matplotlib.pyplot as plt
+ import matplotlib.animation as animation
+
+ fig = plt.figure()
+ frame = im_List[0].imvec #read_auto(filelist[len(filelist)/2])
+ fov = im_List[0].psize*im_List[0].xdim
+ extent = fov * np.array((1,-1,-1,1)) / 2.
+ maxi = np.max(frame)
+ im = plt.imshow( np.reshape(frame,[im_List[0].xdim, im_List[0].xdim]) , cmap='hot', extent=extent) #inferno
+ plt.colorbar()
+ im.set_clim([0,maxi])
+ fig.set_size_inches([5,5])
+ plt.tight_layout()
+
+ def update_img(n):
+ sys.stdout.write('\rprocessing image %i of %i ...' % (n,len(im_List)) )
+ sys.stdout.flush()
+ im.set_data(np.reshape(im_List[n].imvec, [im_List[n].xdim, im_List[n].xdim]) )
+ return im
+
+ ani = animation.FuncAnimation(fig,update_img,len(im_List),interval=1e3/fps)
+ writer = animation.writers['ffmpeg'](fps=max(20, fps), bitrate=1e6)
+ ani.save(out,writer=writer,dpi=dpi)
+
+
+def dirtyImage(im, obs_List, init_x=[], init_y=[], flowbasis_x=[], flowbasis_y=[], initTheta=[]):
+
+ if len(initTheta)==0:
+ init_x, init_y, flowbasis_x, flowbasis_y, initTheta = affineMotionBasis(im)
+
+ im_List = [];
+ for t in range(0,len(obs_List)):
+ im_List.append(im.copy())
+
+ for t in range(0,len(obs_List)):
+ A = genPhaseShiftMtx_obs(obs_List[t],init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, pulse=ehtim.observing.pulses.deltaPulse2D)
+ #im_List[t].imvec = np.real( np.dot( np.linalg.inv( np.dot(np.transpose(A),A) ), np.dot( np.transpose(A), obs_List[t].data['vis'] ) ) )
+ im_List[t].imvec = np.real( np.dot( np.transpose(np.conj(A) ), obs_List[t].data['vis']) )
+
+ return im_List
+
+def weinerFiltering(meanImg, covImg, obs_List, mask=[]):
+
+ if type(obs_List) != list:
+ obs_List = [obs_List]
+
+ im_List = []
+ cov_List = []
+ exp_tm1_t = []
+
+ for t in range(0,len(obs_List)):
+ im_List.append(meanImg.copy())
+ cov_List.append(np.zeros(covImg.shape))
+ exp_tm1_t.append(np.zeros(covImg.shape))
+
+ for t in range(0,len(obs_List)):
+
+ meas, idealmeas, A, measCov, valid = getMeasurementTerms(obs_List[t], meanImg, measurement={'vis':1}, mask=mask)
+
+ if valid==False:
+ im_List[t] = meanImg.copy()
+ cov_List[t] = copy.deepcopy(covImg)
+ else:
+ im_List[t].imvec, cov_List[t] = newDensity(meanImg.imvec, covImg, A, meas, idealmeas, measCov)
+
+ if t>0:
+ exp_tm1_t[t] = np.dot( np.array([im_List[t-1].imvec]).T , np.array([im_List[t].imvec]) )
+
+ return (im_List, cov_List, exp_tm1_t)
+
+
+
+def newDensity(X, covX, A, Y, idealY, covY):
+
+ measresidual = Y - idealY
+ residualcov = np.dot(np.dot(A , covX ), np.transpose(A)) + covY
+ G = np.dot( covX, np.dot( np.transpose(A) , np.linalg.inv ( residualcov ) ) )
+ Xnew = X + np.dot( G, measresidual )
+ covXnew = covX - np.dot(G, np.dot( A, covX ) )
+
+ return (Xnew, covXnew)
+
+
+def newDensity_linearize(X, covX, A, Y, idealY, covY, Xlin):
+
+ measresidual = Y - idealY + np.dot(A,Xlin) - np.dot(A,X)
+ residualcov = np.dot(np.dot(A , covX ), np.transpose(A)) + covY
+ G = np.dot( covX, np.dot( np.transpose(A) , np.linalg.inv ( residualcov ) ) )
+ Xnew = X + np.dot( G, measresidual )
+ covXnew = covX - np.dot(G, np.dot( A, covX ) )
+
+ return (Xnew, covXnew)
+
+
+def newDensity3(X0, covX0, X1, covX1, X2, Y, idealY_X2, A_X2, covY):
+
+ invCovX0 = np.linalg.inv(covX0)
+ invCovX1 = np.linalg.inv(covX1)
+ invCovY = np.linalg.inv(covY)
+ At_CovY = np.dot( np.transpose(A_X2), invCovY )
+ At_CovY_A = np.dot( At_CovY , A_X2 )
+
+ covXnew = np.linalg.inv( At_CovY_A + invCovX0 + invCovX1 )
+ Xnew = np.dot( covXnew , ( np.dot(At_CovY, Y - idealY_X2) + np.dot(At_CovY_A, X2) + np.dot(invCovX0,X0) + np.dot(invCovX1, X1) ) )
+
+ return (Xnew, covXnew)
+
+
+def newDensity2(X, covX, A, Y, covY):
+
+ measresidual = Y - np.dot( A, X )
+ residualcov = np.dot(np.dot(A , covX ), np.transpose(A)) + covY
+ G = np.dot( covX, np.dot( np.transpose(A) , np.linalg.inv ( residualcov ) ) )
+ Xnew = X + np.dot( G, measresidual )
+ covXnew = covX - np.dot(G, np.dot( A, covX ) )
+
+ return (Xnew, covXnew)
+
+
+
+def gaussImgCovariance_2(im, powerDropoff=1.0, frac=1.):
+
+ eps = 0.001
+
+ init_x, init_y, flowbasis_x, flowbasis_y, initTheta = affineMotionBasis(im)
+ ufull, vfull = genFreqComp(im)
+ shiftMtx = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse)
+ uvdist = np.reshape( np.sqrt(ufull**2 + vfull**2), (ufull.shape[0]) ) + eps
+ uvdist = uvdist / np.min(uvdist)
+ uvdist[0] = 'Inf'
+
+ #shiftMtx = np.dot(shiftMtx, np.diag(im.imvec) )
+ shiftMtx_exp = realimagStack(shiftMtx)
+ uvdist_exp = np.concatenate( (uvdist, uvdist), axis=0)
+
+ imCov = np.dot( np.transpose(shiftMtx_exp) , np.dot( np.diag( 1/(uvdist_exp**powerDropoff) ), shiftMtx_exp ) )
+ imCov = frac**2 * np.dot( np.diag(im.imvec).T, np.dot(imCov/imCov[0,0], np.diag(im.imvec) ) );
+ return imCov
+
+def gaussImgCovariance(im, pixelStdev=1.0, powerDropoff=1.0, filter='none', kernsig=3.0):
+
+ eps = 0.001
+
+ init_x, init_y, flowbasis_x, flowbasis_y, initTheta = affineMotionBasis(im)
+ ufull, vfull = genFreqComp(im)
+ shiftMtx = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse)
+ uvdist = np.reshape( np.sqrt(ufull**2 + vfull**2), (ufull.shape[0]) ) + eps
+ uvdist = uvdist / np.min(uvdist)
+ uvdist[0] = 'Inf'
+
+ if filter == 'hamming':
+ hammingwindow = np.dot(np.reshape(np.hamming(im.xdim), (im.xdim,1)), np.reshape(np.hamming(im.ydim) , (1, im.ydim)) )
+ shiftMtx = np.dot(shiftMtx, np.diag(np.reshape(hammingwindow, (im.xdim*im.ydim))) )
+ if filter == 'gaussian':
+ gausswindow = gkern(kernlen=im.xdim, nsig=kernsig)
+ shiftMtx = np.dot(shiftMtx, np.diag(np.reshape(gausswindow, (im.xdim*im.ydim))) )
+
+ shiftMtx_exp = realimagStack(shiftMtx)
+ uvdist_exp = np.concatenate( (uvdist, uvdist), axis=0)
+
+ imCov = np.dot( np.transpose(shiftMtx_exp) , np.dot( np.diag( 1/(uvdist_exp**powerDropoff) ), shiftMtx_exp ) )
+ imCov = pixelStdev**2 * (imCov/imCov[0,0]);
+ return imCov
+
+###################################### BASIS ########################################
+
+def affineMotionBasis(im):
+
+ xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0
+ ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0
+
+ init_x = np.array([[ [0] for i in xlist] for j in ylist])
+ init_y = np.array([[ [0] for i in xlist] for j in ylist])
+
+ flowbasis_x = np.array([[ [i, j ,im.psize, 0, 0, 0] for i in xlist] for j in ylist])
+ flowbasis_y = np.array([[ [0, 0, 0, i, j, im.psize] for i in xlist] for j in ylist])
+ initTheta = np.array([1, 0, 0, 0, 1, 0])
+
+ return (init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+
+def affineMotionBasis_noTranslation(im):
+
+ xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0
+ ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0
+
+ init_x = np.array([[ [0] for i in xlist] for j in ylist])
+ init_y = np.array([[ [0] for i in xlist] for j in ylist])
+
+ flowbasis_x = np.array([[ [i, j, 0, 0] for i in xlist] for j in ylist])
+ flowbasis_y = np.array([[ [0, 0, i, j] for i in xlist] for j in ylist])
+ initTheta = np.array([1, 0, 0, 1])
+
+ return (init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+
+def translationBasis(im):
+
+ xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0
+ ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0
+
+ init_x = np.array([[ [i] for i in xlist] for j in ylist])
+ init_y = np.array([[ [j] for i in xlist] for j in ylist])
+
+ flowbasis_x = np.array([[ [im.psize, 0.0] for i in xlist] for j in ylist])
+ flowbasis_y = np.array([[ [0.0, im.psize] for i in xlist] for j in ylist])
+ initTheta = np.array([0.0, 0.0])
+
+ return (init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+
+def xTranslationBasis(im):
+
+ xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0
+ ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0
+
+ init_x = np.array([[ [i] for i in xlist] for j in ylist])
+ init_y = np.array([[ [j] for i in xlist] for j in ylist])
+
+ flowbasis_x = np.array([[ [im.psize] for i in xlist] for j in ylist])
+ flowbasis_y = np.array([[ [0.0] for i in xlist] for j in ylist])
+ initTheta = np.array([0.0])
+
+ return (init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+
+def dftMotionBasis(im):
+
+ DFTBASIS_THRESH = 0.1
+ print('WARNING SOMETHING ISNT RIGHT WITH THE DFT GEN')
+
+ init_x, init_y, flowbasis_x, flowbasis_y, initTheta = affineMotionBasis(im)
+ ufull, vfull = genFreqComp(im)
+ shiftMtx = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse)
+ uvdist = np.reshape( np.sqrt(ufull**2 + vfull**2), (ufull.shape[0]) )
+ uvdist_norm = uvdist / np.max(uvdist)
+
+ halfbasis = shiftMtx[uvdist_norm < DFTBASIS_THRESH,:]
+ fullbasis = np.concatenate( ( np.real( halfbasis[1:,:] ) , np.imag( halfbasis[1:,:] ) ), axis = 0)
+ fullbasis = np.reshape(fullbasis.T, [im.xdim, im.ydim, fullbasis.shape[0]]) * im.psize
+
+ flowbasis_x = np.concatenate( (fullbasis, np.zeros(fullbasis.shape)), axis=2 )
+ flowbasis_y = np.concatenate( (np.zeros(fullbasis.shape), fullbasis), axis=2 )
+
+ xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0
+ ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0
+
+ init_x = np.array([[ [i] for i in xlist] for j in ylist])
+ init_y = np.array([[ [j] for i in xlist] for j in ylist])
+
+ initTheta = np.zeros(flowbasis_x.shape[2])
+
+ return (init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+
+
+def applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, theta):
+
+ imsize = flowbasis_x.shape[0:2]
+ nbasis = theta.shape[0]
+ npixels = np.prod(imsize)
+
+ flow_x = init_x[:,:,0] + np.reshape( np.dot( np.reshape(flowbasis_x, (npixels, nbasis), order ='F'), theta ), imsize, order='F')
+ flow_y = init_y[:,:,0] + np.reshape( np.dot( np.reshape(flowbasis_y, (npixels, nbasis), order ='F'), theta ), imsize, order='F')
+
+ return (flow_x, flow_y)
+
+
+###################################### FULL WARPING ########################################
+
+def applyWarp(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase'):
+
+ if method=='phase':
+ outim = applyPhaseWarp(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+ else:
+ outim = applyImageWarp(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+ return outim
+
+
+def calcWarpMtx(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase', normalize=False):
+
+ npixels = im.xdim*im.ydim
+
+ if method=='phase':
+ ufull, vfull = genFreqComp(im)
+ shiftMtx0 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, theta, im.psize, im.pulse)
+ shiftMtx1 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse)
+
+ #outMtx = np.real( np.dot(np.transpose(np.conj(shiftMtx1)), shiftMtx0 ) / (npixels) )
+ outMtx = np.real( np.dot( np.linalg.inv(shiftMtx1), shiftMtx0 ) )
+
+ elif method=='img':
+ probeim = im.copy()
+ outMtx = np.zeros((npixels, npixels))
+ for i in range(0,npixels):
+ probeim.imvec = np.zeros(probeim.imvec.shape)
+ probeim.imvec[i] = 1.0
+ outprobeim = applyImageWarp(probeim, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, pad=True)
+ outMtx[:,i] = outprobeim.imvec
+
+ outMtx = np.nan_to_num(outMtx)
+
+ if normalize:
+ for i in range(0,npixels):
+ outMtx[i,:] = outMtx[i,:]/np.sum(outMtx[i,:])
+
+ return outMtx
+
+def applyImageWarp(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, pad=False):
+
+ flow_x_orig, flow_y_orig = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+ flow_x_new, flow_y_new = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, theta)
+
+ from_pts = np.concatenate( ( reshapeFlowbasis(flow_x_new), reshapeFlowbasis(flow_y_new) ), axis=1)
+ to_pts = np.concatenate( ( reshapeFlowbasis(flow_x_orig), reshapeFlowbasis(flow_y_orig) ), axis=1)
+ im_pts = im.imvec
+
+ # add padding of 0's around the warped image so that the griddata function works properly
+ if pad:
+ # add padding on the x axis
+ for i in range(0,im.xdim):
+
+ vec_x = flow_x_new[1,i] - flow_x_new[0,i]
+ vec_y = flow_y_new[1,i] - flow_y_new[0,i]
+ from_pts = np.row_stack( (from_pts, np.array( [ flow_x_new[0,i]-vec_x, flow_y_new[0,i]-vec_y ] ) ) )
+ im_pts = np.concatenate( ( im_pts, np.array([0.0]) ), axis = 0 )
+
+ vec_x = flow_x_new[im.ydim-2,i] - flow_x_new[im.ydim-1,i]
+ vec_y = flow_y_new[im.ydim-2,i] - flow_y_new[im.ydim-1,i]
+ from_pts = np.row_stack( (from_pts, np.array( [ flow_x_new[im.ydim-1,i]-vec_x, flow_y_new[im.ydim-1,i]-vec_y ] ) ) )
+ im_pts = np.concatenate( ( im_pts, np.array([0.0]) ), axis = 0 )
+ # add padding on the y axis
+ for i in range(0,im.ydim):
+
+ vec_x = flow_x_new[i,1] - flow_x_new[i,0]
+ vec_y = flow_y_new[i,1] - flow_y_new[i,0]
+ from_pts = np.row_stack( (from_pts, np.array( [ flow_x_new[i,0]-vec_x, flow_y_new[i,0]-vec_y ] ) ) )
+ im_pts = np.concatenate( ( im_pts, np.array([0.0]) ), axis = 0 )
+
+ vec_x = flow_x_new[i,im.xdim-2] - flow_x_new[i,im.xdim-1]
+ vec_y = flow_y_new[i,im.xdim-2] - flow_y_new[i,im.xdim-1]
+ from_pts = np.row_stack( (from_pts, np.array( [ flow_x_new[i,im.xdim-1]-vec_x, flow_y_new[i,im.xdim-1]-vec_y ] ) ) )
+ im_pts = np.concatenate( ( im_pts, np.array([0.0]) ), axis = 0 )
+
+ #npixels = flowbasis_x.shape[0]*flowbasis_x.shape[1]
+ out = scipy.interpolate.griddata( from_pts , im_pts, to_pts , method='linear', fill_value=0.0 )
+ #out = scipy.interpolate.griddata( np.concatenate( ( np.reshape(flow_x_new, (npixels, -1)), np.reshape(flow_y_new, (npixels, -1)) ), axis=1) , im.imvec, np.concatenate( ( np.reshape(flow_x_orig, (npixels, -1)), np.reshape(flow_y_orig, (npixels, -1)) ), axis=1), method='linear', fill_value=0.0 )
+ outim = image.Image(np.reshape(out, (im.ydim, im.xdim)), im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse)
+ return outim
+
+
+def applyPhaseWarp(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta):
+
+ ufull, vfull = genFreqComp(im)
+
+ shiftMtx0 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, theta, im.psize, im.pulse)
+ shiftMtx1 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse)
+
+ #out = np.real( np.dot(np.transpose(np.conj(shiftMtx1)), np.dot(shiftMtx0, im.imvec) ) ) / (im.xdim * im.ydim)
+ out = np.real( np.dot( np.linalg.inv(shiftMtx1) , np.dot(shiftMtx0, im.imvec) ) )
+ outim = image.Image(np.reshape(out, (im.ydim, im.xdim)), im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse)
+ return outim
+
+def genPhaseShiftMtx(ulist, vlist, init_x, init_y, flowbasis_x, flowbasis_y, theta, pdim, pulse=ehtim.observing.pulses.deltaPulse2D):
+
+ flow_x, flow_y = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, theta)
+
+ imsize = flow_x.shape
+ npixels = np.prod(imsize)
+
+ flow_x_vec = np.reshape(flow_x, (npixels))
+ flow_y_vec = np.reshape(flow_y, (npixels))
+
+ shiftMtx_y = np.exp( [-1j * 2.0 * np.pi * flow_y_vec * v for v in vlist] )
+ shiftMtx_x = np.exp( [-1j * 2.0 * np.pi * flow_x_vec * u for u in ulist] )
+
+ uvlist = np.transpose(np.squeeze(np.array([ulist, vlist])))
+ uvlist = np.reshape(uvlist, (vlist.shape[0], 2))
+ pulseVec = [pulse(2*np.pi*uv[0], 2*np.pi*uv[1], pdim, dom="F") for uv in uvlist ]
+
+ shiftMtx = np.dot( np.diag(pulseVec) , np.reshape( np.squeeze(shiftMtx_x * shiftMtx_y), (vlist.shape[0], npixels) ) )
+ return shiftMtx
+
+################################### APPROXIMATE SHIFTING ###########################################
+
+def calc_dWarp_dTheta(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase'):
+
+ if method == 'phase':
+
+ ufull, vfull = genFreqComp(im)
+
+ derivShiftMtx_x, derivShiftMtx_y = calcDerivShiftMtx_freq(ufull, vfull, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, includeImgFlow=False)
+
+ shiftMtx1 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse)
+ invShiftMtx1 = np.linalg.inv(shiftMtx1)
+
+ flowbasis = np.concatenate((reshapeFlowbasis(flowbasis_x), reshapeFlowbasis(flowbasis_y)), axis=0)
+
+ reshape_flowbasis_x = reshapeFlowbasis(flowbasis_x)
+ reshape_flowbasis_y = reshapeFlowbasis(flowbasis_y)
+
+
+ dWarp_dTheta = [];
+ for b in range(0, flowbasis.shape[1]):
+ K = np.dot( derivShiftMtx_x, np.diag( reshape_flowbasis_x[:,b] ) ) + np.dot( derivShiftMtx_y, np.diag( reshape_flowbasis_y[:,b] ) )
+ dWarp_dTheta.append( np.real( np.dot( invShiftMtx1 , K ) ) )
+
+
+ else:
+ print('WARNING: we do not handle this method yet')
+
+ return dWarp_dTheta
+
+
+
+def applyAppxWarp(im, theta, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method1='phase', method2='phase'):
+
+ centerIm, dImg_dTheta = calAppxWarpTerms(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method1=method1, method2=method2)
+ out = centerIm.imvec + np.dot(dImg_dTheta, theta - centerTheta)
+ outim = image.Image(np.reshape(out, (im.ydim, im.xdim)), im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse)
+ return outim
+
+def calAppxWarpTerms(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method1='phase', method2='phase'):
+
+ centerIm = applyWarp(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method1)
+ dImg_dTheta = calc_dImage_dTheta(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method2)
+ return (centerIm, dImg_dTheta)
+
+def calc_dImage_dTheta(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase'):
+
+ if method == 'phase':
+ ufull, vfull = genFreqComp(im)
+ thetaDerivShiftMtx = calcDerivShiftMtx_freq(ufull, vfull, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, includeImgFlow=True)
+ shiftMtx1 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse)
+ #dImg_dTheta = np.real( np.dot( np.transpose(np.conj(shiftMtx1)) , thetaDerivShiftMtx) / (im.xdim * im.ydim) )
+ dImg_dTheta = np.real( np.dot( np.linalg.inv(shiftMtx1) , thetaDerivShiftMtx) )
+ else:
+ dImg_dTheta = calcDerivShiftMtx_image(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+
+ return dImg_dTheta
+
+def calcDerivWarpMtx_noimg_noshift(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase'):
+
+ if method == 'phase':
+
+ ufull, vfull = genFreqComp(im)
+ #npixels = im.xdim*im.ydim
+
+ freqShiftMtx_x, freqShiftMtx_y = calcDerivShiftMtx_freq(ufull, vfull, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, includeImgFlow=False)
+ shiftMtx1 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse)
+
+ #derivShiftMtx_y = np.real( np.dot( np.transpose(np.conj(shiftMtx1)) , freqShiftMtx_y) / npixels )
+ #derivShiftMtx_x = np.real( np.dot( np.transpose(np.conj(shiftMtx1)) , freqShiftMtx_x) / npixels )
+ invShiftMtx1 = np.linalg.inv(shiftMtx1)
+ derivShiftMtx_y = np.real( np.dot( invShiftMtx1 , freqShiftMtx_y) )
+ derivShiftMtx_x = np.real( np.dot( invShiftMtx1 , freqShiftMtx_x) )
+ else:
+ if (centerTheta != initTheta).any():
+ raise ValueError('Can only take the optical flow derivative around no shift')
+
+ derivShiftMtx_x = -gradMtx(im.xdim, im.ydim, dir='x')/im.psize
+ derivShiftMtx_y = -gradMtx(im.xdim, im.ydim, dir='y')/im.psize
+
+ return (derivShiftMtx_x, derivShiftMtx_y)
+
+
+def calcDerivShiftMtx_freq(ulist, vlist, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, includeImgFlow=True):
+
+ shiftMtx = genPhaseShiftMtx(ulist, vlist, init_x, init_y, flowbasis_x, flowbasis_y, centerTheta, im.psize, im.pulse)
+
+ shiftVec_y = np.array( [-1j * 2.0 * np.pi * v * np.ones(im.xdim*im.ydim) for v in vlist] )
+ shiftVec_x = np.array( [-1j * 2.0 * np.pi * u * np.ones(im.xdim*im.ydim) for u in ulist] )
+ derivShiftMtx_x = shiftVec_x * shiftMtx
+ derivShiftMtx_y = shiftVec_y * shiftMtx
+
+
+ if includeImgFlow:
+ # TODO: WARNING DOES THIS HANDLE THE INIT_X INIT_Y
+ print('WARNING: is this handling the init x and y??')
+
+ flowbasis = np.concatenate((reshapeFlowbasis(flowbasis_x), reshapeFlowbasis(flowbasis_y)), axis=0)
+ derivShiftMtx = np.concatenate( ( np.dot(derivShiftMtx_x, np.diag(im.imvec)) , np.dot(derivShiftMtx_y, np.diag(im.imvec)) ) , axis=1)
+ thetaDerivShiftMtx = np.dot(derivShiftMtx, flowbasis)
+ return thetaDerivShiftMtx
+ else:
+ return (derivShiftMtx_x, derivShiftMtx_y)
+
+def calcDerivShiftMtx_image(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta):
+# out = im.imvec + np.dot(gradIm_x, theta - initTheta ) + np.dot(gradIm_y, theta - initTheta )
+
+ centerImg = applyImageWarp(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+
+ Gx = -gradMtx(im.xdim, im.ydim, dir='x')/im.psize
+ Gy = -gradMtx(im.xdim, im.ydim, dir='y')/im.psize
+
+ gradIm_x = np.diag( np.dot(Gx, centerImg.imvec) )
+ gradIm_y = np.diag( np.dot(Gy, centerImg.imvec) )
+
+ # TODO: WARNING DOES THIS HANDLE THE INIT_X INIT_Y
+ print('WARNING: is this handling the init x and y??')
+
+ flowbasis = np.concatenate((reshapeFlowbasis(flowbasis_x), reshapeFlowbasis(flowbasis_y)), axis=0)
+ derivShiftMtx = np.concatenate( (gradIm_x, gradIm_y), axis=1)
+
+ thetaDerivShiftMtx = np.dot(derivShiftMtx, flowbasis)
+ return thetaDerivShiftMtx
+
+def calcAppxWarpMtx_image(im, init_x, init_y, flowbasis_x, flowbasis_y, theta, initTheta):
+
+ flow_x_orig, flow_y_orig = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+ flow_x_new, flow_y_new = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, theta)
+
+ flow_x = flow_x_new - flow_x_orig
+ flow_y = flow_y_new - flow_y_orig
+
+ Gx = -gradMtx(im.xdim, im.ydim, dir='x')
+ Gy = -gradMtx(im.xdim, im.ydim, dir='y')
+
+ derivx = np.dot( np.diag( vec(flow_x) ), Gx/im.psize)
+ derivy = np.dot( np.diag( vec(flow_y) ), Gy/im.psize)
+
+ appxMtx = np.eye(im.xdim*im.ydim) + derivx + derivy
+ return appxMtx
+
+
+###################################### FREQUENCY SPACE ########################################
+
+def shiftVisibilities(obs, shiftX, shiftY):
+ obs.data['vis'] = obs.data['vis']*np.exp(-1j*2.0*np.pi*( obs.data['u']*shiftX + obs.data['v']*shiftY ))
+ return obs
+
+def genAppxShiftMtx(ulist, vlist, npixels, shiftMtx):
+
+ derivShiftMtx_y = np.array( [-1j * 2.0 * np.pi * np.ones((npixels)) * v for v in vlist] ) * shiftMtx
+ derivShiftMtx_x = np.array( [-1j * 2.0 * np.pi * np.ones((npixels)) * u for u in ulist] ) * shiftMtx
+
+ return (derivShiftMtx_x, derivShiftMtx_y)
+
+def genFreqComp(im):
+
+ fN2 = int(np.floor(im.xdim/2)) #TODO: !!! THIS DOESNT WORK FOR ODD IMAGE SIZES
+ fM2 = int(np.floor(im.ydim/2))
+
+ ulist = (np.array([np.concatenate((np.linspace(0, fN2 - 1, fN2), np.linspace(-fN2, -1, fN2)), axis=0)]) / im.xdim ) / im.psize
+ vlist = (np.array([np.concatenate((np.linspace(0, fM2 - 1, fM2), np.linspace(-fM2, -1, fM2)), axis=0)]) / im.ydim ) / im.psize
+
+ ufull, vfull = np.meshgrid(ulist, vlist)
+
+ ufull = np.reshape(ufull, (im.xdim*im.ydim, -1), order='F')
+ vfull = np.reshape(vfull, (im.xdim*im.ydim, -1), order='F')
+
+ return (ufull, vfull)
+
+def genPhaseShiftMtx_obs(obs, init_x, init_y, flowbasis_x, flowbasis_y, theta, pdim, pulse=ehtim.observing.pulses.deltaPulse2D):
+ ulist = obs.unpack('u')['u']
+ vlist = obs.unpack('v')['v']
+
+ shiftMtx = genPhaseShiftMtx(ulist, vlist, init_x, init_y, flowbasis_x, flowbasis_y, theta, pdim, pulse)
+ return shiftMtx
+
+
+def calcDerivShiftMtx_obs(obs, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y):
+
+ ulist = obs.unpack('u')['u']
+ vlist = obs.unpack('v')['v']
+
+ thetaDerivShiftMtx = calcDerivShiftMtx_freq(ulist, vlist, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y)
+ return thetaDerivShiftMtx
+
+def cmpFreqExtraction_phaseWarp(obs, im_true, im_canonical, theta, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta):
+
+ data = obs.unpack(['u','v','vis','sigma'])
+ uv = np.hstack((data['u'].reshape(-1,1), data['v'].reshape(-1,1)))
+ A = ftmatrix(im_true.psize, im_true.xdim, im_true.ydim, uv, pulse=im_true.pulse)
+
+ shiftMtx_true = genPhaseShiftMtx_obs(obs, init_x, init_y, flowbasis_x, flowbasis_y, theta, im_canonical.psize, im_canonical.pulse)
+
+ shiftMtx_center = genPhaseShiftMtx_obs(obs, init_x, init_y, flowbasis_x, flowbasis_y, centerTheta, im_canonical.psize, im_canonical.pulse)
+ thetaDerivShiftMtx = calcDerivShiftMtx_obs(obs, im_canonical, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y)
+
+ centerim, dImg_dTheta = calAppxWarpTerms(im_canonical, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method1='phase', method2='phase')
+
+ chiSq_shift = 0.5 * np.mean( np.abs( (obs.data['vis'] - np.dot(shiftMtx_true, im_canonical.imvec ) )/ obs.data['sigma'] )**2 )
+ chiSq_appxshift = 0.5 * np.mean( np.abs( (obs.data['vis'] - np.dot(shiftMtx_center, im_canonical.imvec) - np.dot(thetaDerivShiftMtx, theta-centerTheta) ) / obs.data['sigma'] )**2 )
+ chiSq_true = 0.5 * np.mean( np.abs( (obs.data['vis'] - np.dot(A, im_true.imvec) ) / obs.data['sigma'] )**2 )
+
+ return (chiSq_true, chiSq_shift, chiSq_appxshift)
+
+
+################################# HELPER FUNCTIONS #############################################
+
+
+def gkern(kernlen=21, nsig=3):
+ """Returns a 2D Gaussian kernel array."""
+
+ interval = (2*nsig+1.)/(kernlen)
+ x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
+ kern1d = np.diff(st.norm.cdf(x))
+ kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
+ kernel = kernel_raw/kernel_raw.sum()
+ return kernel
+
+def gradMtx(w, h, dir='x'):
+
+ G = np.eye(h*w)
+ if dir=='x':
+ G = G - np.diag(np.ones((h*w-1)), k=1)
+ delrows = (np.linspace(w,h*w,h)-1).astype(int)
+ G[delrows,:] = 0
+ else:
+ G = G - np.diag(np.ones((h*w-w)), k=h)
+ delrows = range(h*w-w,h*w)
+ G[delrows,:] = 0
+ return G
+
+
+def realimagStack(mtx):
+ stack = np.concatenate( ( np.real(mtx), np.imag(mtx) ), axis=0 )
+ return stack
+
+def reshapeFlowbasis(flowbasis):
+ npixels = flowbasis.shape[0] * flowbasis.shape[1]
+ #nbasis = flowbasis.shape[2]
+ flowbasis = np.reshape(flowbasis, (npixels, -1) ) #, order ='F')
+ return flowbasis
+
+def vec(x, order='F'):
+ x = np.reshape(x, (np.prod(x.shape)), order = order)
+ return x
+
+def listconcatenate(*lists):
+ new_list = []
+ for i in lists:
+ new_list.extend(i)
+ return new_list
+
+def padNewFOV(im, fov_arcseconds):
+
+ oldfov = im.psize * im.xdim
+ newfov = fov_arcseconds * ehtim.RADPERUAS
+ tnpixels = np.ceil(im.xdim * newfov/oldfov).astype('int')
+
+ origimg = np.reshape(im.imvec, [im.xdim, im.xdim])
+ padimg = np.pad(origimg, ((0,tnpixels-im.xdim), (0,tnpixels-im.xdim)), 'constant')
+ padimg = np.roll(padimg, np.floor((tnpixels-im.xdim)/2.).astype('int'), axis=0)
+ padimg = np.roll(padimg, np.floor((tnpixels-im.xdim)/2.).astype('int'), axis=1)
+
+ return image.Image(padimg.reshape((tnpixels, tnpixels)), im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse)
+
+
+def flipImg(im, flip_lr, flip_ud):
+
+ img = np.reshape(im.imvec, [im.xdim, im.xdim])
+ if flip_lr:
+ img = np.fliplr(img)
+ if flip_ud:
+ img = np.flipud(img)
+ im.imec = img.reshape((im.xdim*im.xdim))
+ return image.Image(img, im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse)
+
+
+def rotateImg(im, k):
+
+ img = np.reshape(im.imvec, [im.xdim, im.xdim])
+ img = np.rot90(img, k=k)
+ im.imec = img.reshape((im.xdim*im.xdim))
+ return image.Image(img, im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse)
+
+
+####################################### MICHAELS STUFF #######################################
+
+def plot_im_List(im_List, title_List=[], ipynb=False):
+
+ plt.title("Test", fontsize=20)
+ plt.ion()
+ plt.clf()
+
+ Prior = im_List[0]
+
+ for i in range(len(im_List)):
+ plt.subplot(1, len(im_List), i+1)
+ plt.imshow(im_List[i].imvec.reshape(Prior.ydim,Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+ plt.axis('off')
+ xticks = ticks(Prior.xdim, Prior.psize/ehtim.RADPERAS/1e-6)
+ yticks = ticks(Prior.ydim, Prior.psize/ehtim.RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ if i == 0:
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ else:
+ plt.xlabel('')
+ plt.ylabel('')
+ #plt.title('')
+ if len(title_List)==len(im_List):
+ plt.title(title_List[i], fontsize = 5)
+
+
+ plt.draw()
+
+
+def plot_Flow(Im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, step=4, ipynb=False):
+
+ # Get vectors and ratio from current image
+ x = np.array([[i for i in range(Im.xdim)] for j in range(Im.ydim)])
+ y = np.array([[j for i in range(Im.xdim)] for j in range(Im.ydim)])
+
+ flow_x_new, flow_y_new = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, theta)
+ flow_x_orig, flow_y_orig = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, initTheta)
+
+ vx = -(flow_x_new - flow_x_orig)
+ vy = -(flow_y_new - flow_y_orig)
+
+ # Create figure and title
+ plt.ion()
+ plt.clf()
+
+ # Stokes I plot
+ plt.subplot(111)
+ plt.imshow(Im.imvec.reshape(Im.ydim, Im.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+ plt.quiver(x[::step,::step], y[::step,::step], vx[::step,::step], vy[::step,::step],
+ headaxislength=3, headwidth=7, headlength=5, minlength=0, minshaft=1,
+ width=.005*Im.xdim/30., pivot='mid', color='w', angles='xy')
+
+ xticks = ticks(Im.xdim, Im.psize/ehtim.RADPERAS/1e-6)
+ yticks = ticks(Im.ydim, Im.psize/ehtim.RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ plt.title('Flow Map')
+ #plt.ylim(plt.ylim()[::-1])
+ # Display
+ plt.draw()
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/io/__init__.py b/io/__init__.py
new file mode 100644
index 00000000..b7ba4935
--- /dev/null
+++ b/io/__init__.py
@@ -0,0 +1,11 @@
+"""
+.. module:: ehtim.io
+ :platform: Unix
+ :synopsis: EHT Imaging Utilities: I/O functions
+
+.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu)
+
+"""
+from . import load
+
+from ..const_def import *
diff --git a/io/load.py b/io/load.py
new file mode 100644
index 00000000..8e38a4f5
--- /dev/null
+++ b/io/load.py
@@ -0,0 +1,1837 @@
+# load.py
+# functions to load observation & image data from files
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import astropy.io.fits as fits
+import datetime
+import os
+import copy
+import sys
+import time as ttime
+import h5py
+
+import ehtim.obsdata
+import ehtim.image
+import ehtim.array
+import ehtim.movie
+import ehtim.vex
+import ehtim.observing
+
+import ehtim.io.oifits
+import ehtim.const_def as ehc
+
+import warnings
+warnings.filterwarnings("ignore", message="Mean of empty slice")
+warnings.filterwarnings("ignore", message="invalid value encountered in true_divide")
+
+##################################################################################################
+# Vex IO
+##################################################################################################
+
+
+def load_vex(fname):
+ """Read in .vex files.
+ Assumes there is only 1 MODE in vex file
+ Hotaka Shiokawa - 2017
+
+ Args:
+ fname (str): path to input .vex file
+ Returns:
+ vex (Vex): Vex file object
+ """
+ print("Loading vexfile: ", fname)
+ return ehtim.vex.Vex(fname)
+
+
+##################################################################################################
+# Image IO
+##################################################################################################
+def load_im_txt(filename, pulse=ehc.PULSE_DEFAULT, polrep='stokes', pol_prim='I', zero_pol=True):
+ """Read in an image from a text file.
+
+ Args:
+ filename (str): path to input text file
+ pulse (function): The function convolved with the pixel values for continuous image.
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+ zero_pol (bool): If True, loads any missing polarizations as zeros
+
+ Returns:
+ (Image): loaded image object
+ """
+
+ print("Loading text image: ", filename)
+
+ # Read the header
+ file = open(filename)
+ src = ' '.join(file.readline().split()[2:])
+ ra = file.readline().split()
+ ra = float(ra[2]) + float(ra[4]) / 60.0 + float(ra[6]) / 3600.0
+ dec = file.readline().split()
+ dec = np.sign(float(dec[2])) * (abs(float(dec[2])) +
+ float(dec[4]) / 60.0 + float(dec[6]) / 3600.0)
+ mjd_float = float(file.readline().split()[2])
+ mjd = int(mjd_float)
+ time = (mjd_float - mjd) * 24
+ rf = float(file.readline().split()[2]) * 1e9
+ xdim = file.readline().split()
+ xdim_p = int(xdim[2])
+ psize_x = float(xdim[4]) * ehc.RADPERAS / xdim_p
+ ydim = file.readline().split()
+ ydim_p = int(ydim[2])
+ psize_y = float(ydim[4]) * ehc.RADPERAS / ydim_p
+ file.close()
+
+ if psize_x != psize_y:
+ raise Exception("Pixel dimensions in x and y are inconsistent!")
+
+ # Load the data, convert to list format, make object
+ datatable = np.loadtxt(filename, dtype=float)
+ image = datatable[:, 2].reshape(ydim_p, xdim_p)
+ outim = ehtim.image.Image(image, psize_x, ra, dec,
+ rf=rf, source=src, mjd=mjd, time=time, pulse=pulse,
+ polrep='stokes', pol_prim='I')
+
+ # Look for Stokes Q and U
+ qimage = uimage = vimage = np.zeros(image.shape)
+ if datatable.shape[1] == 6:
+ qimage = datatable[:, 3].reshape(ydim_p, xdim_p)
+ uimage = datatable[:, 4].reshape(ydim_p, xdim_p)
+ vimage = datatable[:, 5].reshape(ydim_p, xdim_p)
+ elif datatable.shape[1] == 5:
+ qimage = datatable[:, 3].reshape(ydim_p, xdim_p)
+ uimage = datatable[:, 4].reshape(ydim_p, xdim_p)
+
+ if np.any((qimage != 0) + (uimage != 0)) and np.any((vimage != 0)):
+ # print('Loaded Stokes I, Q, U, and V Images')
+ outim.add_qu(qimage, uimage)
+ outim.add_v(vimage)
+ elif np.any((vimage != 0)):
+ # print('Loaded Stokes I and V Images')
+ outim.add_v(vimage)
+ if zero_pol:
+ outim.add_qu(0 * vimage, 0 * vimage)
+ elif np.any((qimage != 0) + (uimage != 0)):
+ # print('Loaded Stokes I, Q, and U Images')
+ outim.add_qu(qimage, uimage)
+ if zero_pol:
+ outim.add_v(0 * qimage)
+ else:
+ if zero_pol:
+ outim.add_qu(0 * image, 0 * image)
+ outim.add_v(0 * image)
+ # print('Loaded Stokes I Image Only')
+
+ # Transform to desired pol rep
+ if not (polrep == 'stokes' and pol_prim == 'I'):
+ outim = outim.switch_polrep(polrep_out=polrep, pol_prim_out=pol_prim)
+
+ return outim
+
+
+def load_im_hdf5(filename):
+ """Read in an image from an hdf5 file.
+ Args:
+ filename (str): path to input hdf5 file
+ Returns:
+ (Image): loaded image object
+ """
+ print("Loading hdf5 image: ", filename)
+
+ # Load information from hdf5 file
+
+ hfp = h5py.File(filename,'r')
+ dsource = hfp['header']['dsource'][()] # distance to source in cm
+ jyscale = hfp['header']['scale'][()] # convert cgs intensity -> Jy flux density
+ rf = hfp['header']['freqcgs'][()] # in cgs
+ tunit = hfp['header']['units']['T_unit'][()] # in seconds
+ lunit = hfp['header']['units']['L_unit'][()] # in cm
+ DX = hfp['header']['camera']['dx'][()] # in GM/c^2
+ nx = hfp['header']['camera']['nx'][()] # width in pixels
+ time = hfp['header']['t'][()] * tunit / 3600. # time in hours
+ if 'pol' in hfp:
+ poldat = np.copy(hfp['pol'])[:, :, :4] # NX,NY,{I,Q,U,V}
+ else: # unpolarized data only
+ unpoldat = np.copy(hfp['unpol']) # NX,NY
+ poldat = np.zeros(list(unpoldat.shape)+[4])
+ poldat[:,:,0] = unpoldat
+ hfp.close()
+
+ # Correct image orientation
+ # unpoldat = np.flip(unpoldat.transpose((1, 0)), axis=0)
+ poldat = np.flip(poldat.transpose((1, 0, 2)), axis=0)
+
+ # Make a guess at the source based on distance and optionally fall back on mass
+ src = ehc.SOURCE_DEFAULT
+ if dsource > 4.e25 and dsource < 6.2e25:
+ src = "M87"
+ elif dsource > 2.45e22 and dsource < 2.6e22:
+ src = "SgrA"
+
+ # Fill in information according to the source
+ ra = ehc.RA_DEFAULT
+ dec = ehc.DEC_DEFAULT
+ if src == "SgrA":
+ ra = 17.76112247
+ dec = -28.992189444
+ elif src == "M87":
+ ra = 187.70593075
+ dec = 12.391123306
+
+ # Process image to set proper dimensions
+ fovmuas = DX / dsource * lunit * 2.06265e11
+ psize_x = ehc.RADPERUAS * fovmuas / nx
+
+ Iim = poldat[:, :, 0] * jyscale
+ Qim = poldat[:, :, 1] * jyscale
+ Uim = poldat[:, :, 2] * jyscale
+ Vim = poldat[:, :, 3] * jyscale
+
+ outim = ehtim.image.Image(Iim, psize_x, ra, dec, rf=rf, source=src,
+ polrep='stokes', pol_prim='I', time=time)
+ outim.add_qu(Qim, Uim)
+ outim.add_v(Vim)
+
+ return outim
+
+
+def load_im_fits(filename, aipscc=False, pulse=ehc.PULSE_DEFAULT,
+ punit="deg", polrep='stokes', pol_prim=None, zero_pol=True):
+ """Read in an image from a FITS file.
+
+ Args:
+ fname (str): path to input fits file
+ aipscc (bool): if True, then AIPS CC table will be loaded
+ pulse (function): The function convolved with the pixel values for continuous image.
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U,V for Stokes, RR,LL,LR,RL for circ
+ zero_pol (bool): If True, loads any missing polarizations as zeros
+
+ Returns:
+ (Image): loaded image object
+ """
+
+ print("Loading fits image: ", filename)
+
+ # Radian or Degree?
+ if punit == "deg":
+ pscl = ehc.DEGREE
+ elif punit == "rad":
+ pscl = 1.0
+ elif punit == "uas":
+ pscl = ehc.RADPERUAS
+ elif punit == "mas":
+ pscl = ehc.RADPERUAS * 1000.0
+
+ # Open the FITS file
+ hdulist = fits.open(filename)
+
+ # Assume stokes I is the primary hdu
+ header = hdulist[0].header
+
+ # Read some header values
+ xdim_p = header['NAXIS1']
+ psize_x = np.abs(header['CDELT1']) * pscl
+ ydim_p = header['NAXIS2']
+ psize_y = np.abs(header['CDELT2']) * pscl
+
+ if 'OBSRA' in list(header.keys()):
+ ra = header['OBSRA'] * 12 / 180.
+ elif 'CRVAL1' in list(header.keys()):
+ ra = header['CRVAL1'] * 12 / 180.
+ else:
+ ra = 0.
+
+ if 'OBSDEC' in list(header.keys()):
+ dec = header['OBSDEC']
+ elif 'CRVAL2' in list(header.keys()):
+ dec = header['CRVAL2']
+ else:
+ dec = 0.
+
+ if 'FREQ' in list(header.keys()):
+ rf = header['FREQ']
+ elif 'CRVAL3' in list(header.keys()):
+ rf = header['CRVAL3']
+ else:
+ rf = 0.
+
+ if 'MJD' in list(header.keys()):
+ mjd_float = header['MJD']
+ else:
+ mjd_float = 0.
+ mjd = int(mjd_float)
+ time = (mjd_float - mjd) * 24
+
+ if 'OBJECT' in list(header.keys()):
+ src = header['OBJECT']
+ else:
+ src = ''
+
+ # Get the image and create the object
+ data = hdulist[0].data
+
+ # Check for multiple stokes in top hdu
+ stokes_in_hdu0 = False
+ if len(data.shape) == 4:
+ print("reading all stokes images from top HDU -- assuming IQUV order")
+ stokesdata = data[:4,:,:,:] # ignore fields after the first 4
+ data = stokesdata[0, 0]
+ stokes_in_hdu0 = True
+
+ elif len(data.shape) == 3:
+ # ANDREW added this for BHAC models 3/22/23
+ print("reading all stokes images from top HDU -- assuming IQUV order")
+ stokesdata = data[:4,:,:] # ignore fields after the first 4
+ stokesdata = stokesdata.reshape(4,-1,stokesdata.shape[-2],stokesdata.shape[-1])
+ data = stokesdata[0, 0]
+ stokes_in_hdu0 = True
+
+ #data = data.reshape((data.shape[-2], data.shape[-1]))
+ data = data.reshape((data.shape[-2], data.shape[-1]))
+
+ # Update the image using the AIPS CC table
+ if aipscc:
+ try:
+ aipscctab = hdulist["AIPS CC"]
+ except BaseException:
+ print("Input FITS file does not have an AIPS CC table. Loading image instead.")
+ aipscc = False
+
+ if aipscc:
+
+ print("loading the AIPS CC table.")
+ print("force the pulse function to be the delta function.")
+ pulse = ehtim.observing.pulses.deltaPulse2D
+
+ # get the source brightness brightness ifromation
+ flux = aipscctab.data["FLUX"]
+ deltax = aipscctab.data["DELTAX"]
+ deltay = aipscctab.data["DELTAY"]
+
+ # check to make sure all the source types are point sources and gaussian components
+ try:
+ checkmtype = np.abs(np.unique(aipscctab.data["TYPE OBJ"])) < 2.0
+ if False in checkmtype.tolist():
+ errmsg = "The primary AIPS CC table in the input FITS file has non point-source"
+ errmsg += " or Gaussian Source CC components, which are not currently supported."
+ raise ValueError(errmsg)
+ point_src = aipscctab.data["TYPE OBJ"] == 0
+ gaussian_src = aipscctab.data["TYPE OBJ"] == 1
+ except(KeyError):
+ print("Cannot load AIPS CC Table OBJ data -- assuming all CC components are point sources!")
+ point_src = np.ones(aipscctab.data.shape).astype(bool)
+ gaussian_src = np.zeros(aipscctab.data.shape).astype(bool)
+ print("%d CC components are loaded." % (len(flux)))
+
+ # compile the point source aipscc info
+ flux_ps = flux[point_src]
+ deltax_ps = deltax[point_src]
+ deltay_ps = deltay[point_src]
+
+ # compile the gaussian aipscc info, if any
+ if np.any(gaussian_src):
+ flux_gs = flux[gaussian_src]
+ deltax_gs = deltax[gaussian_src]
+ deltay_gs = deltay[gaussian_src]
+ maj_gs = aipscctab.data["MAJOR AX"][gaussian_src]
+ min_gs = aipscctab.data["MINOR AX"][gaussian_src]
+ pa_gs = aipscctab.data["POSANGLE"][gaussian_src]
+ else:
+ flux_gs = []
+
+ # the map_coordinates delta x / delta y of each delta CC component are
+ # relative to the reference pixel which is defined by CRPIX1 and CRPIX2.
+ try:
+ Nxref = header.get("CRPIX1")
+ except BaseException:
+ Nxref = header.get("NAXIS1") // 2 + 1
+ try:
+ Nyref = header.get("CRPIX2")
+ except BaseException:
+ Nyref = header.get("NAXIS2") // 2 + 1
+
+ # compute the corresponding index of pixel for each deltax_ps / deltay_ps
+ ix = np.array(np.int64(np.round(deltax_ps / header.get("CDELT1") + Nxref - 1)))
+ iy = np.array(np.int64(np.round(deltay_ps / header.get("CDELT2") + Nyref - 1)))
+
+ # reset the image and input flux information
+ data[:, :] = 0.
+ Noutcomp = 0
+ for i in range(len(flux_ps)):
+ try:
+ data[iy[i], ix[i]] += flux_ps[i]
+ except BaseException:
+ Noutcomp += 1
+ print("added %d CC delta components." % (len(flux_ps)))
+ if Noutcomp > 0:
+ print("%d CC delta components are outside of the FoV and ignored." % (Noutcomp))
+
+ # flip y-axis!
+ image = data[::-1, :]
+
+ # normalize the flux
+ normalizer = 1.0
+ if 'BUNIT' in list(header.keys()):
+ if header['BUNIT'].lower() == 'JY/BEAM'.lower():
+
+ print("converting Jy/Beam --> Jy/pixel")
+ bmaj = bmin = 1.0 # default values
+
+ if 'BMAJ' in list(header.keys()):
+ bmaj = header['BMAJ']
+ bmin = header['BMIN']
+
+ elif 'HISTORY' in list(header.keys()): # Alternate option, to read AIPS fits images
+ print("No beam info in header; reading from AIPS HISTORY instead...")
+ for line in header['HISTORY']:
+ if 'BMAJ' in line and len(line.split())>6:
+ bmaj = float(line.split()[3])
+ bmin = float(line.split()[5])
+
+ if bmaj==1.0 and bmin==1.0:
+ print("No beam info found! Assuming nominal values for conversion.")
+ bmaj = bmin = 1.0
+
+ beamarea = (2.0 * np.pi * bmaj * bmin / (8.0 * np.log(2)))
+ normalizer = (header['CDELT2'])**2 / beamarea
+
+ if aipscc:
+ print("the computed normalizer will not be applied since we are loading the AIPS CC table")
+ else:
+ image *= normalizer
+
+ # make image object in Stokes I
+ outim = ehtim.image.Image(image, psize_x, ra, dec,
+ rf=rf, source=src, mjd=mjd, time=time, pulse=pulse,
+ polrep='stokes', pol_prim='I')
+
+ # add gaussian components to the image from the aipscc table
+ if aipscc and len(flux_gs):
+ Noutcomp = 0
+ for i in range(len(flux_gs)):
+ # make sure the aipscc table gaussian is within the FOV
+ if ((deltax_gs[i] * ehc.DEGREE < outim.fovx() / 2.0) and
+ (deltay_gs[i] * ehc.DEGREE < outim.fovy() / 2.0) and
+ (maj_gs[i] * ehc.DEGREE * 3 < (np.min([outim.fovx(), outim.fovy()]) / 2.0))):
+ # add a gaussian component with the specified flux and location
+ outim = outim.add_gauss(flux_gs[i], (maj_gs[i] * ehc.DEGREE, min_gs[i] * ehc.DEGREE,
+ pa_gs[i] * ehc.DEGREE,
+ deltax_gs[i] * ehc.DEGREE,
+ deltay_gs[i] * ehc.DEGREE))
+ else:
+ Noutcomp += 1
+ print("added %d CC gaussian components." % (len(flux_gs)))
+ if Noutcomp > 0:
+ print("%d CC gaussian components are outside of the FoV and ignored." % (Noutcomp))
+
+ # Look for Stokes Q and U and V
+ qimage = uimage = vimage = np.array([])
+
+ if stokes_in_hdu0: # stokes in top HDU
+ try:
+ qdata = stokesdata[1, 0].reshape((data.shape[-2], data.shape[-1]))
+ qimage = normalizer * qdata[::-1, :] # flip y-axis!
+ except IndexError:
+ pass
+ try:
+ udata = stokesdata[2, 0].reshape((data.shape[-2], data.shape[-1]))
+ uimage = normalizer * udata[::-1, :] # flip y-axis!
+ except IndexError:
+ pass
+ try:
+ vdata = stokesdata[3, 0].reshape((data.shape[-2], data.shape[-1]))
+ vimage = normalizer * vdata[::-1, :] # flip y-axis!
+ except IndexError:
+ pass
+
+ else: # stokes in different HDUS
+ for hdu in hdulist[1:]:
+ header = hdu.header
+ data = hdu.data
+ try:
+ data = data.reshape((data.shape[-2], data.shape[-1]))
+ except IndexError:
+ continue
+
+ if 'STOKES' in list(header.keys()) and header['STOKES'] == 'Q':
+ qimage = normalizer * data[::-1, :] # flip y-axis!
+ if 'STOKES' in list(header.keys()) and header['STOKES'] == 'U':
+ uimage = normalizer * data[::-1, :] # flip y-axis!
+ if 'STOKES' in list(header.keys()) and header['STOKES'] == 'V':
+ vimage = normalizer * data[::-1, :] # flip y-axis!
+
+ if qimage.shape == uimage.shape == vimage.shape == image.shape:
+ # print('Loaded Stokes I, Q, U, and V Images')
+ outim.add_qu(qimage, uimage)
+ outim.add_v(vimage)
+ elif vimage.shape == image.shape:
+ # print('Loaded Stokes I and V Images')
+ outim.add_v(vimage)
+ if zero_pol:
+ outim.add_qu(0 * vimage, 0 * vimage)
+ elif qimage.shape == uimage.shape == image.shape:
+ # print('Loaded Stokes I, Q, and U Images')
+ outim.add_qu(qimage, uimage)
+ if zero_pol:
+ outim.add_v(0 * qimage)
+ else:
+ if zero_pol:
+ outim.add_qu(0 * image, 0 * image)
+ outim.add_v(0 * image)
+ # print('Loaded Stokes I Image Only')
+
+ # Transform to desired pol rep
+ if not (polrep == 'stokes' and pol_prim == 'I'):
+ outim = outim.switch_polrep(polrep_out=polrep, pol_prim_out=pol_prim)
+
+ return outim
+
+##################################################################################################
+# Movie IO
+##################################################################################################
+
+# Old version for arizona datasets
+# def load_movie_hdf5(file_name, framedur_sec=1, psize=-1,
+# ra=17.761122472222223, dec=-28.992189444444445, rf=230e9, source='SgrA',
+# pulse=ehc.PULSE_DEFAULT, polrep='stokes', pol_prim=None, zero_pol=True):
+
+
+# """Read in a movie from an hdf5 file and create a Movie object.
+
+# Args:
+# file_name (str): The name of the hdf5 file.
+# framedur_sec (float): The frame duration in seconds
+# psize (float): Pixel size in radian
+# ra (float): The movie right ascension
+# dec (float): The movie declination
+# rf (float): The movie frequency
+# pulse (function): The function convolved with the pixel values for continuous image
+# polrep (str): polarization representation, either 'stokes' or 'circ'
+# pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+# zero_pol (bool): If True, loads any missing polarizations as zeros
+
+# Returns:
+# Movie: a Movie object
+# """
+
+# # TODO: Currently only supports one polarization!
+# file = h5py.File(file_name, 'r')
+# name = list(file.keys())[0]
+# d = file[str(name)]
+# frames = d[:]
+# file.close()
+
+# # TODO: currently no frame times stored in hdf5!
+# framedur_hr = framedur/3600.
+# mjd0 = ehc.MJD_DEFAULT
+# hour0 = 0
+# nframes = len(frames)
+# tstart = hour0
+# tstop = hour0 + framedur_hr*nframes
+# times = np.linspace(tstart, tstop, nframes)
+
+# movie = Movie(frames, times,
+# psize, ra, dec, rf=rf,
+# polrep=polrep, pol_prim=pol_prim,
+# source=source, mjd=ehc.MJD_DEFAULT, pulse=pulse)
+
+# if zero_pol:
+# for pol in list(movie._movdict.keys()):
+# if pol==movie.pol_prim: continue
+# polframes = np.zeros(frames.shape)
+# newmov.add_pol_movie(polframes, pol)
+
+# return movie
+
+
+def load_movie_hdf5(file_name, pulse=ehc.PULSE_DEFAULT, interp=ehc.INTERP_DEFAULT,
+ bounds_error=ehc.BOUNDS_ERROR):
+ """Read in a movie from an hdf5 file and create a Movie object.
+
+ Args:
+ file_name (str): The name of the hdf5 file.
+ framedur_sec (float): The frame duration in seconds (overwrites internal timestamps)
+ pulse (function): The function convolved with the pixel values for continuous image
+ interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword
+ bounds_error (bool): if False, return nearest frame outside [start_hr, stop_hr]
+
+ Returns:
+ Movie: a Movie object
+ """
+
+ # TODO: Currently only supports one polarization!
+ with h5py.File(file_name, 'r') as file:
+
+ head = file['header']
+
+ mjd = int(head.attrs['mjd'].astype(str))
+ psize = float(head.attrs['psize'].astype(str))
+ source = head.attrs['source'].astype(str)
+ ra = float(head.attrs['ra'].astype(str))
+ dec = float(head.attrs['dec'].astype(str))
+ rf = float(head.attrs['rf'].astype(str))
+ polrep = head.attrs['polrep'].astype(str)
+ pol_prim = head.attrs['pol_prim'].astype(str)
+
+ times = file['times'][:]
+ frames = file[pol_prim][:]
+
+ movie = ehtim.movie.Movie(frames, times,
+ psize, ra, dec, rf=rf,
+ interp=interp, bounds_error=bounds_error,
+ polrep=polrep, pol_prim=pol_prim,
+ source=source, mjd=mjd, pulse=pulse)
+
+ if polrep == 'stokes':
+ keys = ['I', 'Q', 'U', 'V']
+ elif polrep == 'circ':
+ keys = ['RR', 'LL', 'RL', 'LR']
+ else:
+ raise Exception("hdf5 polrep is not 'circ' or 'stokes'!")
+
+ for pol in keys:
+ if pol == movie.pol_prim:
+ continue
+ if pol in file.keys():
+ polframes = file[pol][:]
+ if len(polframes):
+ movie.add_pol_movie(polframes, pol)
+ file.close()
+
+ return movie
+
+
+def load_movie_txt(basename, nframes, framedur=-1, pulse=ehc.PULSE_DEFAULT,
+ polrep='stokes', pol_prim=None, zero_pol=True,
+ interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR):
+ """Read in a movie from text files and create a Movie object.
+
+ Args:
+ basename (str): The base name of individual movie frames.
+ Files should have names basename + 00001, etc.
+ nframes (int): The total number of frames
+ framedur (float): The frame duration in seconds
+ (default = -1, and framedur is taken from file headers)
+ pulse (function): The function convolved with the pixel values for continuous image
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+ zero_pol (bool): If True, loads any missing polarizations as zeros
+ interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword
+ bounds_error (bool): if False, return nearest frame outside [start_hr, stop_hr]
+
+ Returns:
+ Movie: a Movie object
+ """
+
+ imlist = []
+
+ for i in range(nframes):
+ filename = basename + "%05d" % i
+
+ sys.stdout.write('\rReading Movie Image %i/%i...' % (i, nframes))
+ sys.stdout.flush()
+
+ im = load_im_txt(filename, pulse=pulse, polrep=polrep, pol_prim=pol_prim, zero_pol=zero_pol)
+ imlist.append(im)
+
+ if i == 0:
+ hour0 = im.time
+ times = [hour0]
+ else:
+ times.append(hour0)
+
+ if framedur != -1:
+
+ framedur_hr = framedur / 3600.
+ nframes = len(imlist)
+ tstart = hour0
+ tstop = hour0 + framedur_hr * nframes
+ times = np.linspace(tstart, tstop, nframes)
+ for kk in range(len(imlist)):
+ imlist[kk].time = times[kk]
+
+ out_mov = ehtim.movie.merge_im_list(imlist, framedur=framedur,
+ interp=interp, bounds_error=bounds_error)
+
+ return out_mov
+
+
+def load_movie_fits(basename, nframes, framedur=-1,
+ interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR,
+ pulse=ehc.PULSE_DEFAULT, polrep='stokes', pol_prim=None, zero_pol=True):
+ """Read in a movie from fits files and create a Movie object.
+
+ Args:
+ basename (str): The base name of individual movie frames.
+ Files should have names basename + 00001, etc.
+ nframes (int): The total number of frames
+ framedur (float): The frame duration in seconds
+ (default = -1, and framedur is taken from file headers)
+ pulse (function): The function convolved with the pixel values for continuous image
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+ zero_pol (bool): If True, loads any missing polarizations as zeros
+ interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword
+ bounds_error (bool): if False, return nearest frame outside [start_hr, stop_hr]
+
+ Returns:
+ Movie: a Movie object
+ """
+
+ imlist = []
+
+ for i in range(nframes):
+ sys.stdout.write('\rReading Movie Image %i/%i...' % (i, nframes))
+ sys.stdout.flush()
+ for tag in ["%02d" % i, "%03d" % i, "%04d" % i, "%05d" % i]:
+
+ try:
+ filename = basename + tag + '.fits'
+
+ im = load_im_fits(filename, pulse=pulse, polrep=polrep,
+ pol_prim=pol_prim, zero_pol=zero_pol)
+ imlist.append(im)
+ break
+ except BaseException:
+ continue
+
+ if i == 0:
+ hour0 = im.time
+ else:
+ pass
+
+ if framedur != -1:
+
+ framedur_hr = framedur / 3600.
+ nframes = len(imlist)
+ tstart = hour0
+ tstop = hour0 + framedur_hr * nframes
+ times = np.linspace(tstart, tstop, nframes)
+ for kk in range(len(imlist)):
+ imlist[kk].time = times[kk]
+
+ out_mov = ehtim.movie.merge_im_list(imlist, framedur=framedur,
+ interp=interp, bounds_error=bounds_error)
+
+ return out_mov
+
+
+def load_movie_dat(basename, nframes, startframe=0, framedur_sec=1, psize=-1,
+ interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR,
+ ra=ehc.RA_DEFAULT, dec=ehc.DEC_DEFAULT, rf=ehc.RF_DEFAULT,
+ pulse=ehc.PULSE_DEFAULT):
+ """Read in a movie from dat files and create a Movie object.
+
+ Args:
+ basename (str): The base name of individual movie frames.
+ Files should have names basename + 000001, etc.
+ nframes (int): The total number of frames
+ startframe (int): The index of the first frame to load
+ framedur_sec (float): The frame duration in seconds (default = 1)
+ psize (float): The pixel size in radian
+ ra (float): the right ascension of the source (default for SgrA*)
+ dec (float): the declination of the source (default for SgrA*)
+ rf (float): The refrence frequency of the observation
+ pulse (function): The function convolved with the pixel values for continuous image
+ interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword
+ bounds_error (bool): if False, return nearest frame outside [start_hr, stop_hr]
+
+ Returns:
+ Movie: a Movie object
+ """
+
+ for i in range(startframe, startframe + nframes):
+
+ filename = basename + "%04d" % i + '.dat'
+
+ sys.stdout.write('\rReading Movie Image %i/%i...' % (i - startframe, nframes))
+ sys.stdout.flush()
+
+ datatable = np.loadtxt(filename, dtype=np.float64)
+
+ if i == startframe:
+ sim = np.zeros([nframes, datatable.shape[0]])
+
+ sim[i - startframe, :] = datatable[:, 2]
+
+ npix = np.sqrt(sim.shape[1]).astype('int')
+ sim = np.reshape(sim, [sim.shape[0], npix, npix])
+ sim = np.array([im.transpose()[::-1, :] for im in sim])
+
+ # TODO: read frame times from files?
+ hour0 = 0
+ framedur_hr = framedur_sec / 3600.
+ nframes = len(sim)
+ tstart = hour0
+ tstop = hour0 + framedur_hr * nframes
+ times = np.linspace(tstart, tstop, nframes)
+
+ return(ehtim.movie.Movie(sim, times, psize, ra, dec, rf,
+ interp=interp, bounds_error=bounds_error))
+
+
+###################################################################################################
+# Array IO
+###################################################################################################
+def load_array_txt(filename, ephemdir='ephemeris'):
+ """Read an array from a text file and return an Array object
+ Sites with x=y=z=0 are spacecraft - TLE ephemeris loaded from ephemdir
+
+ Args:
+ filename (str): path to input text file
+ ephemdir (str): directory with TLE files for spacecraft
+
+ Returns:
+ arr (Array): Array object loaded from file
+ """
+
+
+ tdata = np.loadtxt(filename, dtype=bytes, comments='#').astype(str)
+ if tdata[0][0].lower() == 'site':
+ tdata = tdata[1:]
+
+ path = os.path.dirname(filename)
+
+ tdataout = []
+ if (tdata.shape[1] != 5 and tdata.shape[1] != 13):
+ raise Exception("Array file should have format: " +
+ "(name, x, y, z, SEFDR, SEFDL " +
+ "FR_PAR_ANGLE FR_ELEV_ANGLE FR_OFFSET" +
+ "DR_RE DR_IM DL_RE DL_IM )")
+
+ elif tdata.shape[1] == 5:
+ tdataout = [np.array((x[0], float(x[1]), float(x[2]), float(x[3]), float(x[4]), float(x[4]),
+ 0.0, 0.0,
+ 0.0, 0.0, 0.0),
+ dtype=ehc.DTARR) for x in tdata]
+ elif tdata.shape[1] == 13:
+ tdataout = [np.array((x[0], float(x[1]), float(x[2]), float(x[3]), float(x[4]), float(x[5]),
+ float(x[9]) + 1j * float(x[10]), float(x[11]) + 1j * float(x[12]),
+ float(x[6]), float(x[7]), float(x[8])),
+ dtype=ehc.DTARR) for x in tdata]
+
+ # load spacecraft
+ tdataout = np.array(tdataout)
+ edata = {}
+ for line in tdataout:
+ if np.all(np.array([line['x'], line['y'], line['z']]) == (0., 0., 0.)):
+ sitename = str(line['site'])
+
+ # TODO ephempath shouldn't always start with array file path
+
+
+ try:
+ ephempath = path + '/' + ephemdir + '/' + sitename + '.tle'
+ edata[sitename] = np.loadtxt(ephempath, dtype=bytes,
+ comments='#', delimiter='/').astype(str)
+ print('loaded spacecraft ephemeris %s' % ephempath)
+ except IOError:
+ pass
+ try:
+ ephempath = path + '/' + ephemdir + '/' + sitename
+ edata[sitename] = np.loadtxt(ephempath, dtype=bytes,
+ comments='#', delimiter='/').astype(str)
+ print('loaded spacecraft ephemeris %s' % ephempath)
+ except IOError:
+ raise Exception('no ephemeris file %s !' % ephempath)
+
+ return ehtim.array.Array(tdataout, ephem=edata)
+
+##################################################################################################
+# Observation IO
+##################################################################################################
+
+
+def load_obs_txt(filename, polrep='stokes'):
+ """Read an observation from a text file.
+ Args:
+ fname (str): path to input text file
+ polrep (str): load data as either 'stokes' or 'circ'
+ Returns:
+ obs (Obsdata): Obsdata object loaded from file
+ """
+
+ if not(polrep in ['stokes', 'circ']):
+ raise Exception("polrep should be 'stokes' or 'circ' in load_uvfits")
+ print("Loading text observation: ", filename)
+
+ # Read the header parameters
+ file = open(filename)
+ src = ' '.join(file.readline().split()[2:])
+ ra = file.readline().split()
+ ra = float(ra[2]) + float(ra[4]) / 60.0 + float(ra[6]) / 3600.0
+ dec = file.readline().split()
+ dec = np.sign(float(dec[2])) * (abs(float(dec[2])) +
+ float(dec[4]) / 60.0 + float(dec[6]) / 3600.0)
+ mjd = float(file.readline().split()[2])
+ rf = float(file.readline().split()[2]) * 1e9
+ bw = float(file.readline().split()[2]) * 1e9
+ phasecal = bool(file.readline().split()[2])
+ ampcal = bool(file.readline().split()[2])
+
+ # New Header Parameters
+ x = file.readline().split()
+ if x[1] == 'OPACITYCAL:':
+ opacitycal = bool(x[2])
+ dcal = bool(file.readline().split()[2])
+ frcal = bool(file.readline().split()[2])
+ file.readline()
+ else:
+ opacitycal = True
+ dcal = True
+ frcal = True
+ file.readline()
+
+ # read the tarr
+ line = file.readline().split()
+ tarr = []
+ while line[1][0] != "-":
+ if len(line) == 6:
+ tarr.append(np.array((line[1], line[2], line[3], line[4], line[5], line[5],
+ 0, 0, 0, 0, 0), dtype=ehc.DTARR))
+ elif len(line) == 14:
+ tarr.append(np.array((line[1], line[2], line[3], line[4], line[5], line[6],
+ float(line[10]) + 1j * float(line[11]),
+ float(line[12]) + 1j * float(line[13]),
+ line[7], line[8], line[9]), dtype=ehc.DTARR))
+ else:
+ raise Exception("Telescope header doesn't have the right number of fields!")
+ line = file.readline().split()
+ tarr = np.array(tarr, dtype=ehc.DTARR)
+
+ # read the polrep
+ line = file.readline().split()
+ if line[12] == 'RRamp':
+ polrep_orig = 'circ'
+ elif line[12] == 'Iamp':
+ polrep_orig = 'stokes'
+ else:
+ raise Exception("cannot determine original polrep from observation text file!")
+ file.close()
+
+ # Load the data, convert to list format, return object
+ datatable = np.loadtxt(filename, dtype=bytes).astype(str)
+ datatable2 = []
+ for row in datatable:
+ time = float(row[0])
+ tint = float(row[1])
+ t1 = row[2]
+ t2 = row[3]
+
+ # Old datatable formats
+ if datatable.shape[1] < 20:
+ tau1 = float(row[6])
+ tau2 = float(row[7])
+ u = float(row[8])
+ v = float(row[9])
+ vis1 = float(row[10]) * np.exp(1j * float(row[11]) * ehc.DEGREE)
+ if datatable.shape[1] == 19:
+ vis2 = float(row[12]) * np.exp(1j * float(row[13]) * ehc.DEGREE)
+ vis3 = float(row[14]) * np.exp(1j * float(row[15]) * ehc.DEGREE)
+ vis4 = float(row[16]) * np.exp(1j * float(row[17]) * ehc.DEGREE)
+ sigma1 = sigma2 = sigma3 = sigma4 = float(row[18])
+ elif datatable.shape[1] == 17:
+ vis2 = float(row[12]) * np.exp(1j * float(row[13]) * ehc.DEGREE)
+ vis3 = float(row[14]) * np.exp(1j * float(row[15]) * ehc.DEGREE)
+ vis4 = 0 + 0j
+ sigma1 = sigma2 = sigma3 = sigma4 = float(row[16])
+ elif datatable.shape[1] == 15:
+ vis2 = 0 + 0j
+ vis3 = 0 + 0j
+ vis4 = 0 + 0j
+ sigma1 = sigma2 = sigma3 = sigma4 = float(row[12])
+ else:
+ raise Exception('Text file does not have the right number of fields!')
+
+ # Current datatable format
+ elif datatable.shape[1] == 20:
+ tau1 = float(row[4])
+ tau2 = float(row[5])
+ u = float(row[6])
+ v = float(row[7])
+ vis1 = float(row[8]) * np.exp(1j * float(row[9]) * ehc.DEGREE)
+ vis2 = float(row[10]) * np.exp(1j * float(row[11]) * ehc.DEGREE)
+ vis3 = float(row[12]) * np.exp(1j * float(row[13]) * ehc.DEGREE)
+ vis4 = float(row[14]) * np.exp(1j * float(row[15]) * ehc.DEGREE)
+ sigma1 = float(row[16])
+ sigma2 = float(row[17])
+ sigma3 = float(row[18])
+ sigma4 = float(row[19])
+
+ else:
+ raise Exception('Text file does not have the right number of fields!')
+
+ if polrep_orig == 'stokes':
+ datatable2.append(np.array((time, tint, t1, t2, tau1, tau2,
+ u, v, vis1, vis2, vis3, vis4,
+ sigma1, sigma2, sigma3, sigma4), dtype=ehc.DTPOL_STOKES))
+ elif polrep_orig == 'circ':
+ datatable2.append(np.array((time, tint, t1, t2, tau1, tau2,
+ u, v, vis1, vis2, vis3, vis4,
+ sigma1, sigma2, sigma3, sigma4), dtype=ehc.DTPOL_CIRC))
+
+ # Return the data object
+ datatable2 = np.array(datatable2)
+ out = ehtim.obsdata.Obsdata(ra, dec, rf, bw, datatable2, tarr, polrep=polrep_orig,
+ source=src, mjd=mjd,
+ ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal,
+ dcal=dcal, frcal=frcal)
+ out = out.switch_polrep(polrep_out=polrep)
+ return out
+
+# TODO can we save new telescope array terms and flags to uvfits and load them?
+# TODO uv coordinates, multiply by IF freqs and not header FREQ?
+def load_obs_uvfits(filename, polrep='stokes', flipbl=False,
+ allow_singlepol=True, force_singlepol=None,
+ channel=all, IF=all, remove_nan=False,
+ ignore_pzero_date=True,
+ trial_speedups=False):
+ """Load observation data from a uvfits file.
+
+ Args:
+ filename (str or HDUList): path to either an input text file or an HDUList object
+ polrep (str): load data as either 'stokes' or 'circ'
+ flipbl (bool): flip baseline phases if True.
+ allow_singlepol (bool): If True and polrep='stokes',
+ treat single-polarization data as Stokes I
+ force_singlepol (str): 'R' or 'L' to load only 1 polarization and treat as Stokes I
+ channel (list): list of channels to average in the import. channel=all averages all
+ IF (list): list of IFs to average in the import. IF=all averages all
+ remove_nan (bool): whether or not to remove entries with nan data
+
+ ignore_pzero_date (bool): if True, ignore the offset parameters in DATE field
+ TODO: what is the correct behavior per AIPS memo 117?
+ Returns:
+ obs (Obsdata): Obsdata object loaded from file
+ """
+
+ if not(polrep in ['stokes', 'circ']):
+ raise Exception("polrep should be 'stokes' or 'circ' in load_uvfits")
+ if not(force_singlepol is None or force_singlepol is False) and polrep != 'stokes':
+ raise Exception(
+ "force_singlepol is incompatible with polrep!='stokes' in load_uvfits")
+
+ # Load the uvfits file
+ if isinstance(filename, fits.HDUList):
+ hdulist = filename.copy()
+ else:
+ print("Loading uvfits: ", filename)
+ hdulist = fits.open(filename)
+ header = hdulist[0].header
+ data = hdulist[0].data
+
+ # Load the array data
+ tnames = hdulist['AIPS AN'].data['ANNAME']
+ tnums = hdulist['AIPS AN'].data['NOSTA'] - 1
+ xyz = np.real(hdulist['AIPS AN'].data['STABXYZ'])
+ try:
+ sefdr = np.real(hdulist['AIPS AN'].data['SEFD'])
+ sefdl = np.real(hdulist['AIPS AN'].data['SEFD']) # TODO add sefdl to uvfits?
+ except KeyError:
+ sefdr = np.zeros(len(tnames))
+ sefdl = np.zeros(len(tnames))
+
+ # TODO - get the *actual* values of these telescope parameters from the uvfits file?
+ fr_par = np.zeros(len(tnames))
+ fr_el = np.zeros(len(tnames))
+ fr_off = np.zeros(len(tnames))
+ dr = np.zeros(len(tnames)) + 1j * np.zeros(len(tnames))
+ dl = np.zeros(len(tnames)) + 1j * np.zeros(len(tnames))
+
+ tarr = [np.array((
+ str(tnames[i]), xyz[i][0], xyz[i][1], xyz[i][2],
+ sefdr[i], sefdl[i], dr[i], dl[i],
+ fr_par[i], fr_el[i], fr_off[i]),
+ dtype=ehc.DTARR) for i in range(len(tnames))]
+
+ tarr = np.array(tarr)
+
+ # Various header parameters
+ try:
+ ra = header['OBSRA'] * 12. / 180.
+ dec = header['OBSDEC']
+ except KeyError:
+ if header['CTYPE6'] == 'RA':
+ ra = header['CRVAL6'] * 12. / 180.
+ else:
+ raise Exception('Cannot find RA!')
+ if header['CTYPE7'] == 'DEC':
+ dec = header['CRVAL7']
+ else:
+ raise Exception('Cannot find DEC!')
+
+ src = header['OBJECT']
+ rf = hdulist['AIPS AN'].header['FREQ']
+
+ if header['CTYPE4'] == 'FREQ':
+ ch1_freq = header['CRVAL4']
+ ch_bw = header['CDELT4']
+ nchan = header['NAXIS4']
+
+ else:
+ raise Exception('Cannot find observing frequencies!')
+
+ nif = 1
+ try:
+ if header['CTYPE5'] == 'IF':
+ nif = header['NAXIS5']
+ except KeyError:
+ print("no IF in uvfits header!")
+
+ try:
+ if header['CTYPE3'] == 'STOKES':
+ if header['CRVAL3'] == 1:
+ polrep_uvfits = 'stokes'
+ elif header['CRVAL3'] == -1:
+ polrep_uvfits = 'circ'
+ else:
+ raise Exception("header[CRVAL3] not a recognized polarization basis!")
+ except BaseException:
+ raise Exception("STOKES field not in expected header position 'CTYPE3'!")
+ print('POLREP_UVFITS:', polrep_uvfits)
+
+ if polrep_uvfits == 'stokes' and not(force_singlepol is None):
+ raise Exception("force_singlepole not implemented on native Stokes uvfits files!")
+
+ # determine the bandwidth
+ bw = ch_bw * nchan * nif
+
+ # Determine the number of correlation products in the data
+ num_corr = data['DATA'].shape[5]
+ print("Number of uvfits Correlation Products:", num_corr)
+ if num_corr == 1 and force_singlepol is not None:
+ print("Cannot force single polarization when file is not full polarization.")
+ force_singlepol = None
+
+ # If the user selects force_singlepol, then we must allow_singlepol for stokes conversion
+ if force_singlepol is not None and polrep == 'stokes':
+ allow_singlepol = True
+
+ # Mask to screen bad data
+ # Reducing to single frequency
+
+ # prepare the arrays of if and channels that will be extracted from the data.
+ nvis = data['DATA'].shape[0]
+ full_nchannels = data['DATA'].shape[4]
+ full_nifs = data['DATA'].shape[3]
+ if channel == all:
+ channel = np.arange(0, full_nchannels, 1)
+ nchannels = full_nchannels
+ else:
+ try:
+ nchannels = len(np.array(channel))
+ channel = np.array(channel).reshape(-1)
+ except TypeError:
+ channel = np.array([channel]).reshape(-1)
+ nchannels = len(np.array(channel))
+
+ if IF == all:
+ IF = np.arange(0, full_nifs, 1)
+ nifs = full_nifs
+ else:
+ try:
+ nifs = len(IF)
+ IF = np.array(IF).reshape(-1)
+ except TypeError:
+ IF = np.array([IF]).reshape(-1)
+ nifs = len(np.array(IF))
+
+ if (np.max(channel) >= full_nchannels) or (np.min(channel) < 0):
+ raise Exception('The specified channel does not exist')
+ if (np.max(IF) >= full_nifs) or (np.min(IF) < 0):
+ raise Exception('The specified IF does not exist')
+
+ # NOTE: here we are assuming data is in RR, LL, RL, LR basis with the variable names
+ # BUT: polrep_uvfits will correctly interpret these data as IQUV if necessary
+ # TODO: change the variable names!
+ rrweight = data['DATA'][:, 0, 0, IF, channel, 0, 2].reshape(nvis, nifs, nchannels)
+ if num_corr >= 2:
+ llweight = data['DATA'][:, 0, 0, IF, channel, 1, 2].reshape(nvis, nifs, nchannels)
+ else:
+ llweight = rrweight * 0.0
+ if num_corr >= 3:
+ rlweight = data['DATA'][:, 0, 0, IF, channel, 2, 2].reshape(nvis, nifs, nchannels)
+ else:
+ rlweight = rrweight * 0.0
+ if num_corr >= 4:
+ lrweight = data['DATA'][:, 0, 0, IF, channel, 3, 2].reshape(nvis, nifs, nchannels)
+ else:
+ lrweight = rrweight * 0.0
+
+ # If necessary, enforce single polarization
+ if polrep_uvfits == 'circ':
+ if force_singlepol in ['L' or 'LL']:
+ rrweight = rrweight * 0.0
+ rlweight = rlweight * 0.0
+ lrweight = lrweight * 0.0
+ elif force_singlepol in ['R' or 'RR']:
+ llweight = llweight * 0.0
+ rlweight = rlweight * 0.0
+ lrweight = lrweight * 0.0
+ elif force_singlepol == 'LR':
+ print('WARNING: Putting LR data in Stokes I')
+ rrweight = copy.deepcopy(lrweight)
+ llweight = llweight * 0.0
+ rlweight = rlweight * 0.0
+ lrweight = lrweight * 0.0
+ elif force_singlepol == 'RL':
+ print('WARNING: Putting RL data in Stokes I')
+ rrweight = copy.deepcopy(rlweight)
+ llweight = llweight * 0.0
+ rlweight = rlweight * 0.0
+ lrweight = lrweight * 0.0
+
+ # first, catch nans
+ rrnanmask_2d = (np.isnan(rrweight))
+ llnanmask_2d = (np.isnan(llweight))
+ rlnanmask_2d = (np.isnan(rlweight))
+ lrnanmask_2d = (np.isnan(lrweight))
+
+ rrweight[rrnanmask_2d] = 0.
+ llweight[llnanmask_2d] = 0.
+ rlweight[rlnanmask_2d] = 0.
+ lrweight[lrnanmask_2d] = 0.
+
+ # look for weights < 0
+ rrmask_2d = (rrweight > 0.)
+ llmask_2d = (llweight > 0.)
+ rlmask_2d = (rlweight > 0.)
+ lrmask_2d = (lrweight > 0.)
+
+ # if there is any unmasked data in the frequency column, use it
+ rrmask = np.any(np.any(rrmask_2d, axis=2), axis=1)
+ llmask = np.any(np.any(llmask_2d, axis=2), axis=1)
+ rlmask = np.any(np.any(rlmask_2d, axis=2), axis=1)
+ lrmask = np.any(np.any(lrmask_2d, axis=2), axis=1)
+
+ # Total intensity mask
+ if polrep_uvfits == 'circ':
+ mask = rrmask + llmask
+ elif polrep_uvfits == 'stokes':
+ mask = rrmask # remember rr is really I when polrep_uvfits=='stokes'!
+
+ if not np.any(mask):
+ raise Exception("No unflagged RR or LL data in uvfits file!")
+ if np.any(~(rrmask * llmask)):
+ print("Warning: removing flagged data present!")
+
+ # Obs Times
+ paridx = data.parnames.index("DATE")+1
+ if "PSCAL%d"%(paridx) in header.keys():
+ jd1scal = header["PSCAL%d"%(paridx)]
+ else:
+ jd1scal = 1.0
+ if "PZERO%d"%(paridx) in header.keys():
+ jd1zero = header["PZERO%d"%(paridx)]
+ else:
+ jd1zero = 0.0
+ if "PSCAL%d"%(paridx+1) in header.keys():
+ jd2scal = header["PSCAL%d"%(paridx+1)]
+ else:
+ jd2scal = 1.0
+ if "PZERO%d"%(paridx+1) in header.keys():
+ jd2zero = header["PZERO%d"%(paridx+1)]
+ else:
+ jd2zero = 0.0
+
+ if ignore_pzero_date:
+ if jd1zero!=0. or jd2zero!=0.:
+ print("Warning! ignoring nonzero header PZERO values for DATE. Check your observation mjd/times!")
+ jd1zero = 0.
+ jd2zero = 0.
+
+ jds = jd1scal * data['DATE'][mask].astype('d') + jd1zero
+ jds += jd2scal * data['_DATE'][mask].astype('d') + jd2zero
+
+ mjd = int(np.min(jds) - 2400000.5)
+ times = (jds - 2400000.5 - mjd) * 24.0
+
+ try:
+ scantable = []
+ nxtable = hdulist['AIPS NX']
+ for scan in nxtable.data:
+ scan_start = scan['TIME'] # in days since reference date
+ scan_dur = scan['TIME INTERVAL']
+ startvis = scan['START VIS'] - 1
+ endvis = scan['END VIS'] - 1
+ scantable.append([scan_start - 0.5 * scan_dur,
+ scan_start + 0.5 * scan_dur])
+ scantable = np.array(scantable) * 24
+
+ except BaseException:
+ print("No NX table in uvfits!")
+ scantable = None
+
+ # Integration times
+ try:
+ tints = data['INTTIM'][mask]
+ except KeyError:
+ tints = np.zeros(len(mask))
+
+ # Sites - add names
+ t1c = data['BASELINE'][mask].astype(int) // 256
+ t2c = data['BASELINE'][mask].astype(int) - t1c * 256
+ t1c = t1c - 1
+ t2c = t2c - 1
+
+ # TODO make site identificantion faster
+ if trial_speedups and (not np.any(tnums!=np.arange(len(tnums)))):
+ sites = tarr['site']
+ t1 = sites[t1c]
+ t2 = sites[t2c]
+ else: # original, slow code
+ t1 = np.array([tarr[np.where(tnums==i)[0][0]]['site'] for i in t1c])
+ t2 = np.array([tarr[np.where(tnums==i)[0][0]]['site'] for i in t2c])
+
+ # Opacities (not in standard files)
+ try:
+ tau1 = data['TAU1'][mask]
+ tau2 = data['TAU2'][mask]
+ except KeyError:
+ tau1 = tau2 = np.zeros(len(t1))
+
+ # Convert uv in lightsec to lambda by multiplying by rf
+ try:
+ u = data['UU---SIN'][mask] * rf
+ v = data['VV---SIN'][mask] * rf
+ except KeyError:
+ try:
+ u = data['UU'][mask] * rf
+ v = data['VV'][mask] * rf
+ except KeyError:
+ try:
+ u = data['UU--'][mask] * rf
+ v = data['VV--'][mask] * rf
+ except KeyError:
+ raise Exception("Cant figure out column label for UV coords")
+
+ # Get and coherently average visibility data in frequency
+ # replace masked vis with nans so they don't mess up the average
+ rr_2d = data['DATA'][:, 0, 0, IF, channel, 0, 0] + \
+ 1j * data['DATA'][:, 0, 0, IF, channel, 0, 1]
+ rr_2d = rr_2d.reshape(nvis, nifs, nchannels)
+ if num_corr >= 2:
+ ll_2d = data['DATA'][:, 0, 0, IF, channel, 1, 0] + \
+ 1j * data['DATA'][:, 0, 0, IF, channel, 1, 1]
+ ll_2d = ll_2d.reshape(nvis, nifs, nchannels)
+ else:
+ ll_2d = rr_2d * 0.0
+ if num_corr >= 3:
+ rl_2d = data['DATA'][:, 0, 0, IF, channel, 2, 0] + \
+ 1j * data['DATA'][:, 0, 0, IF, channel, 2, 1]
+ rl_2d = rl_2d.reshape(nvis, nifs, nchannels)
+ else:
+ rl_2d = rr_2d * 0.0
+ if num_corr >= 4:
+ lr_2d = data['DATA'][:, 0, 0, IF, channel, 3, 0] + \
+ 1j * data['DATA'][:, 0, 0, IF, channel, 3, 1]
+ lr_2d = lr_2d.reshape(nvis, nifs, nchannels)
+ else:
+ lr_2d = rr_2d * 0.0
+
+ if polrep_uvfits == 'circ':
+ if force_singlepol == 'LR':
+ rr_2d = copy.deepcopy(lr_2d)
+ elif force_singlepol == 'RL':
+ rr_2d = copy.deepcopy(rl_2d)
+
+ rr_2d[~rrmask_2d] = np.nan
+ ll_2d[~llmask_2d] = np.nan
+ rl_2d[~rlmask_2d] = np.nan
+ lr_2d[~lrmask_2d] = np.nan
+
+ rr = np.nanmean(np.nanmean(rr_2d, axis=2), axis=1)[mask]
+ ll = np.nanmean(np.nanmean(ll_2d, axis=2), axis=1)[mask]
+ rl = np.nanmean(np.nanmean(rl_2d, axis=2), axis=1)[mask]
+ lr = np.nanmean(np.nanmean(lr_2d, axis=2), axis=1)[mask]
+
+ # average the weights
+ # variances are mean / N , or sum / N^2
+ # then replace masked weights with nans so they don't mess up the average
+ rrweight[~rrmask_2d] = np.nan
+ llweight[~llmask_2d] = np.nan
+ rlweight[~rlmask_2d] = np.nan
+ lrweight[~lrmask_2d] = np.nan
+
+ nsig_rr = np.sum(np.sum(rrmask_2d, axis=2), axis=1).astype(float)
+ nsig_rr[~rrmask] = np.nan
+ rrsig = np.sqrt(np.nansum(np.nansum(1. / rrweight, axis=2), axis=1)) / nsig_rr
+ rrsig = rrsig[mask]
+
+ nsig_ll = np.sum(np.sum(llmask_2d, axis=2), axis=1).astype(float)
+ nsig_ll[~llmask] = np.nan
+ llsig = np.sqrt(np.nansum(np.nansum(1. / llweight, axis=2), axis=1)) / nsig_ll
+ llsig = llsig[mask]
+
+ nsig_rl = np.sum(np.sum(rlmask_2d, axis=2), axis=1).astype(float)
+ nsig_rl[~rlmask] = np.nan
+ rlsig = np.sqrt(np.nansum(np.nansum(1. / rlweight, axis=2), axis=1)) / nsig_rl
+ rlsig = rlsig[mask]
+
+ nsig_lr = np.sum(np.sum(lrmask_2d, axis=2), axis=1).astype(float)
+ nsig_lr[~lrmask] = np.nan
+ lrsig = np.sqrt(np.nansum(np.nansum(1. / lrweight, axis=2), axis=1)) / nsig_lr
+ lrsig = lrsig[mask]
+
+ # Reverse sign of baselines for correct imaging if asked
+ if flipbl:
+ u = -u
+ v = -v
+
+ # determine correct data type:
+ # TODO add linear!
+ if polrep_uvfits == 'circ':
+ dtpol_out = ehc.DTPOL_CIRC
+ poldict_out = ehc.POLDICT_CIRC
+ elif polrep_uvfits == 'stokes':
+ dtpol_out = ehc.DTPOL_STOKES
+ poldict_out = ehc.POLDICT_STOKES
+
+ #TODO new, faster,
+ if trial_speedups:
+ datatable = np.empty((len(times)),dtype=dtpol_out)
+ datatable['time'] = times
+ datatable['tint'] = tints
+ datatable['t1'] = t1
+ datatable['t2'] = t2
+ datatable['tau1'] = tau1
+ datatable['tau2'] = tau2
+ datatable['u'] = u
+ datatable['v'] = v
+ datatable[poldict_out['vis1']] = rr
+ datatable[poldict_out['vis2']] = ll
+ datatable[poldict_out['vis3']] = rl
+ datatable[poldict_out['vis4']] = lr
+ datatable[poldict_out['sigma1']] = rrsig
+ datatable[poldict_out['sigma2']] = llsig
+ datatable[poldict_out['sigma3']] = rlsig
+ datatable[poldict_out['sigma4']] = lrsig
+ else: # original, slower code
+ datatable = []
+ for i in range(len(times)):
+ datatable.append(np.array
+ ((
+ times[i], tints[i],
+ t1[i], t2[i], tau1[i], tau2[i],
+ u[i], v[i],
+ rr[i], ll[i], rl[i], lr[i],
+ rrsig[i], llsig[i], rlsig[i], lrsig[i]
+ ), dtype=dtpol_out
+ ))
+ datatable = np.array(datatable)
+
+ obs = ehtim.obsdata.Obsdata(ra, dec, rf, bw, datatable, tarr, polrep=polrep_uvfits,
+ source=src, mjd=mjd, scantable=scantable,
+ trial_speedups=trial_speedups)
+
+ # TODO -- this is bad and slow, use masks!
+ if remove_nan:
+ if polrep_uvfits == 'circ':
+ for j in range(len(obs.data)):
+ if np.isnan(obs.data[j]['rrsigma']):
+ obs.data[j]['rrsigma'] = obs.data[j]['llsigma']
+ if np.isnan(obs.data[j]['llsigma']):
+ obs.data[j]['llsigma'] = obs.data[j]['rrsigma']
+ if np.isnan(obs.data[j]['rlsigma']):
+ obs.data[j]['rlsigma'] = obs.data[j]['rrsigma']
+ if np.isnan(obs.data[j]['lrsigma']):
+ obs.data[j]['lrsigma'] = obs.data[j]['rrsigma']
+ else:
+ print("WARNING: remove_nan not implemented with stokes uvfits files!")
+
+ obs = obs.switch_polrep(polrep, allow_singlepol=allow_singlepol)
+
+ # TODO get calibration flags from uvfits?
+ return obs
+
+
+def load_obs_oifits(filename, flux=1.0):
+ """Load data from an oifits file. Does NOT currently support polarization.
+ Args:
+ fname (str): path to input text file
+ flux (float): normalization total flux
+ Returns:
+ obs (Obsdata): Obsdata object loaded from file
+ """
+
+ print('Warning: load_obs_oifits does NOT currently support polarimetric data!')
+
+ # open oifits file and get visibilities
+ oidata = ehtim.io.oifits.open(filename)
+ vis_data = oidata.vis
+
+ # get source info
+ src = oidata.target[0].target
+ ra = oidata.target[0].raep0.angle
+ dec = oidata.target[0].decep0.angle
+
+ # get annena info
+ nAntennas = len(oidata.array[list(oidata.array.keys())[0]].station)
+ sites = np.array([oidata.array[list(oidata.array.keys())[0]
+ ].station[i].sta_name for i in range(nAntennas)])
+ arrayX = oidata.array[list(oidata.array.keys())[0]].arrxyz[0]
+ arrayY = oidata.array[list(oidata.array.keys())[0]].arrxyz[1]
+ arrayZ = oidata.array[list(oidata.array.keys())[0]].arrxyz[2]
+ x = np.array([arrayX + oidata.array[list(oidata.array.keys())[0]].station[i].staxyz[0]
+ for i in range(nAntennas)])
+ y = np.array([arrayY + oidata.array[list(oidata.array.keys())[0]].station[i].staxyz[1]
+ for i in range(nAntennas)])
+ z = np.array([arrayZ + oidata.array[list(oidata.array.keys())[0]].station[i].staxyz[2]
+ for i in range(nAntennas)])
+
+ # get wavelength and corresponding frequencies
+ wavelength = oidata.wavelength[list(oidata.wavelength.keys())[0]].eff_wave
+ nWavelengths = wavelength.shape[0]
+ bandpass = oidata.wavelength[list(oidata.wavelength.keys())[0]].eff_band
+ frequency = ehc.C / wavelength
+
+ # TODO: this result seems wrong...
+ bw = np.mean(2 * (np.sqrt(bandpass**2 * frequency**2 + ehc.C**2) - ehc.C) / bandpass)
+ rf = np.mean(frequency)
+
+ # get the u-v point for each visibility
+ u = np.array([vis_data[i].ucoord / wavelength for i in range(len(vis_data))])
+ v = np.array([vis_data[i].vcoord / wavelength for i in range(len(vis_data))])
+
+ # get visibility info - currently the phase error is not being used properly
+ amp = np.array([vis_data[i]._visamp for i in range(len(vis_data))])
+ phase = np.array([vis_data[i]._visphi for i in range(len(vis_data))])
+ amperr = np.array([vis_data[i]._visamperr for i in range(len(vis_data))])
+ visphierr = np.array([vis_data[i]._visphierr for i in range(len(vis_data))])
+ timeobs = np.array([vis_data[i].timeobs for i in range(len(vis_data))]
+ ) # convert to single number
+
+ # return timeobs
+ time = np.transpose(np.tile(np.array([(ttime.mktime((timeobs[i] +
+ datetime.timedelta(days=1)).timetuple())
+ ) / (60.0 * 60.0)
+ for i in range(len(timeobs))]), [nWavelengths, 1]))
+
+ # integration time
+ tint = np.array([vis_data[i].int_time for i in range(len(vis_data))])
+ # if not all(tint[0] == item for item in np.reshape(tint, (-1)) ):
+ # raise TypeError("The time integrations for each visibility are different")
+ tint = tint[0]
+ tint = tint * np.ones(amp.shape)
+
+ # get telescope names for each visibility
+ t1 = np.transpose(np.tile(np.array([vis_data[i].station[0].sta_name
+ for i in range(len(vis_data))]), [nWavelengths, 1]))
+ t2 = np.transpose(np.tile(np.array([vis_data[i].station[1].sta_name
+ for i in range(len(vis_data))]), [nWavelengths, 1]))
+
+ # dummy variables
+ tau1 = np.zeros(amp.shape)
+ tau2 = np.zeros(amp.shape)
+ qvis = np.zeros(amp.shape)
+ uvis = np.zeros(amp.shape)
+ vvis = np.zeros(amp.shape)
+ sefdr = np.zeros(x.shape)
+ sefdl = np.zeros(x.shape)
+ fr_par = np.zeros(x.shape)
+ fr_el = np.zeros(x.shape)
+ fr_off = np.zeros(x.shape)
+ dr = np.zeros(x.shape) + 1j * np.zeros(x.shape)
+ dl = np.zeros(x.shape) + 1j * np.zeros(x.shape)
+
+ # vectorize
+ time = time.ravel()
+ tint = tint.ravel()
+ t1 = t1.ravel()
+ t2 = t2.ravel()
+
+ tau1 = tau1.ravel()
+ tau2 = tau2.ravel()
+ u = u.ravel()
+ v = v.ravel()
+ vis = amp.ravel() * np.exp(-1j * phase.ravel() * np.pi / 180.0)
+ qvis = qvis.ravel()
+ uvis = uvis.ravel()
+ vvis = vvis.ravel()
+ amperr = amperr.ravel()
+
+ # TODO - check that we are properly using the error from the amplitude and phase
+ # create data tables
+ datatable = np.array([(time[i], tint[i], t1[i], t2[i], tau1[i], tau2[i], u[i], v[i],
+ flux * vis[i], qvis[i], uvis[i], vvis[i],
+ flux * amperr[i], flux * amperr[i], flux * amperr[i], flux * amperr[i]
+ ) for i in range(len(vis))
+ ], dtype=ehc.DTPOL_STOKES)
+
+ tarr = np.array([(sites[i], x[i], y[i], z[i],
+ sefdr[i], sefdl[i], dr[i], dl[i],
+ fr_par[i], fr_el[i], fr_off[i],
+ ) for i in range(nAntennas)
+ ], dtype=ehc.DTARR)
+
+ # return object
+ return ehtim.obsdata.Obsdata(ra, dec, rf, bw, datatable, tarr,
+ polrep='stokes', source=src, mjd=time[0])
+
+def load_obs_maps(arrfile, obsspec, ifile, qfile=0, ufile=0, vfile=0,
+ src=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, ampcal=False, phasecal=False):
+ """Read an observation from a maps text file and return an Obsdata object.
+
+ Args:
+ arrfile (str): path to input array file
+ obsspec (str): path to input obs spec file
+ ifile (str): path to input Stokes I data file
+ qfile (str): path to input Stokes Q data file
+ ufile (str): path to input Stokes U data file
+ vfile (str): path to input Stokes V data file
+ src (str): source name
+ mjd (int): integer observation MJD
+ ampcal (bool): True if amplitude calibrated
+ phasecal (bool): True if phase calibrated
+
+ Returns:
+ obs (Obsdata): Obsdata object loaded from file
+ """
+ # Read telescope parameters from the array file
+ tdata = np.loadtxt(arrfile, dtype=bytes).astype(str)
+ tdata = [np.array((x[0], float(x[1]), float(x[2]), float(x[3]),
+ float(x[-1]), float(x[-1]), 0., 0., 0., 0., 0.),
+ dtype=ehc.DTARR) for x in tdata]
+ tdata = np.array(tdata)
+
+ # Read parameters from the obs_spec
+ f = open(obsspec)
+ stop = False
+ while not stop:
+ line = f.readline().split()
+ if line == [] or line[0] == '\\':
+ continue
+ elif line[0] == 'FOV_center_RA':
+ x = line[2].split(':')
+ ra = float(x[0]) + float(x[1]) / 60.0 + float(x[2]) / 3600.0
+ elif line[0] == 'FOV_center_Dec':
+ x = line[2].split(':')
+ dec = np.sign(float(x[0])) * (abs(float(x[0])) +
+ float(x[1]) / 60.0 + float(x[2]) / 3600.0)
+ elif line[0] == 'Corr_int_time':
+ tint = float(line[2])
+ elif line[0] == 'Corr_chan_bw': # TODO what if multiple channels?
+ bw = float(line[2]) * 1e6 # in MHz
+ elif line[0] == 'Channel': # TODO what if multiple scans with different params?
+ rf = float(line[2].split(':')[0]) * 1e6
+ elif line[0] == 'Scan_start':
+ x = line[2].split(':') # TODO properly compute MJD!
+ elif line[0] == 'Endscan':
+ stop = True
+ f.close()
+
+ # Load the data, convert to list format, return object
+ datatable = []
+ f = open(ifile)
+
+ for line in f:
+ line = line.split()
+ if not (line[0] in ['UV', 'Scan', '\n']):
+ time = line[0].split(':')
+ time = float(time[2]) + float(time[3]) / 60.0 + float(time[4]) / 3600.0
+ u = float(line[1]) * 1000
+ v = float(line[2]) * 1000
+ bl = line[4].split('-')
+ t1 = tdata[int(bl[0]) - 1]['site']
+ t2 = tdata[int(bl[1]) - 1]['site']
+ tau1 = 0.
+ tau2 = 0.
+ vis = float(line[7][:-1]) * np.exp(1j * float(line[8][:-1]) * ehc.DEGREE)
+ sigma = float(line[10])
+ datatable.append(np.array((time, tint, t1, t2, tau1, tau2,
+ u, v, vis, 0.0, 0.0, 0.0,
+ sigma, 0.0, 0.0, 0.0), dtype=ehc.DTPOL_STOKES))
+
+ datatable = np.array(datatable)
+
+ # TODO qfile ufile and vfile must have exactly the same format as ifile!
+ # add some consistency check
+ if not qfile == 0:
+ f = open(qfile)
+ i = 0
+ for line in f:
+ line = line.split()
+ if not (line[0] in ['UV', 'Scan', '\n']):
+ datatable[i]['qvis'] = float(line[7][:-1]) * \
+ np.exp(1j * float(line[8][:-1]) * ehc.DEGREE)
+ datatable[i]['qsigma'] = float(line[10])
+ i += 1
+
+ if not ufile == 0:
+ f = open(ufile)
+ i = 0
+ for line in f:
+ line = line.split()
+ if not (line[0] in ['UV', 'Scan', '\n']):
+ datatable[i]['uvis'] = float(line[7][:-1]) * \
+ np.exp(1j * float(line[8][:-1]) * ehc.DEGREE)
+ datatable[i]['usigma'] = float(line[10])
+ i += 1
+
+ if not vfile == 0:
+ f = open(vfile)
+ i = 0
+ for line in f:
+ line = line.split()
+ if not (line[0] in ['UV', 'Scan', '\n']):
+ datatable[i]['vvis'] = float(line[7][:-1]) * \
+ np.exp(1j * float(line[8][:-1]) * ehc.DEGREE)
+ datatable[i]['vsigma'] = float(line[10])
+ i += 1
+
+ # Return the data object
+ return ehtim.obsdata.Obsdata(ra, dec, rf, bw, datatable, tdata,
+ source=src, mjd=mjd, polrep='stokes')
+
+def load_dtype_txt(obs, filename, dtype='cphase'):
+
+ """Load the dtype data in a text file and put it in the already-created obs object
+ Args:
+ obs (Obsdata): obsdata object
+ filename (str): path to output text file
+ dtype (str): desired data type
+ Returns:
+ """
+
+ print("Loading text observation: ", filename)
+
+ # Read the header parameters
+ file = open(filename)
+ src = ' '.join(file.readline().split()[2:])
+ ra = file.readline().split()
+ ra = float(ra[2]) + float(ra[4]) / 60.0 + float(ra[6]) / 3600.0
+ dec = file.readline().split()
+ dec = np.sign(float(dec[2])) * (abs(float(dec[2])) +
+ float(dec[4]) / 60.0 + float(dec[6]) / 3600.0)
+ mjd = float(file.readline().split()[2])
+ rf = float(file.readline().split()[2]) * 1e9
+ bw = float(file.readline().split()[2]) * 1e9
+ phasecal = bool(file.readline().split()[2])
+ ampcal = bool(file.readline().split()[2])
+
+ # Load the data, convert to list format, return object
+ datatable = np.loadtxt(filename, dtype=bytes).astype(str)
+
+ if dtype == 'cphase':
+ datatable2 = []
+ for row in datatable:
+ time = float(row[0])
+ t1 = row[1]
+ t2 = row[2]
+ t3 = row[3]
+ u1 = float(row[4])
+ v1 = float(row[5])
+ u2 = float(row[6])
+ v2 = float(row[7])
+ u3 = float(row[8])
+ v3 = float(row[9])
+ cphase = float(row[10])
+ sigmacp = float(row[11])
+ datatable2.append(np.array((time, t1, t2, t3, u1, v1, u2, v2,
+ u3, v3, cphase, sigmacp), dtype=ehc.DTCPHASE))
+ obs.cphase = np.array(datatable2)
+
+ elif dtype == 'logcamp':
+ datatable2 = []
+ for row in datatable:
+ time = float(row[0])
+ t1 = row[1]
+ t2 = row[2]
+ t3 = row[3]
+ t4 = row[4]
+ u1 = float(row[5])
+ v1 = float(row[6])
+ u2 = float(row[7])
+ v2 = float(row[8])
+ u3 = float(row[9])
+ v3 = float(row[10])
+ u4 = float(row[11])
+ v4 = float(row[12])
+ logcamp = float(row[13])
+ sigmalogcamp = float(row[14])
+ datatable2.append(np.array((time, t1, t2, t3, t4, u1, v1, u2, v2, u3,
+ v3, u4, v4, logcamp, sigmalogcamp), dtype=ehc.DTCAMP))
+ obs.logcamp = np.array(datatable2)
+
+ elif dtype == 'camp':
+ datatable2 = []
+ for row in datatable:
+ time = float(row[0])
+ t1 = row[1]
+ t2 = row[2]
+ t3 = row[3]
+ t4 = row[4]
+ u1 = float(row[5])
+ v1 = float(row[6])
+ u2 = float(row[7])
+ v2 = float(row[8])
+ u3 = float(row[9])
+ v3 = float(row[10])
+ u4 = float(row[11])
+ v4 = float(row[12])
+ camp = float(row[13])
+ sigmacamp = float(row[14])
+ datatable2.append(np.array((time, t1, t2, t3, t4, u1, v1, u2, v2,
+ u3, v3, u4, v4, camp, sigmacamp), dtype=ehc.DTCAMP))
+ obs.camp = np.array(datatable2)
+
+ elif dtype == 'bs':
+ datatable2 = []
+ for row in datatable:
+ time = float(row[0])
+ t1 = row[1]
+ t2 = row[2]
+ t3 = row[3]
+ u1 = float(row[4])
+ v1 = float(row[5])
+ u2 = float(row[6])
+ v2 = float(row[7])
+ u3 = float(row[8])
+ v3 = float(row[9])
+ bispec = float(row[10])
+ sigmab = float(row[11])
+ datatable2.append(np.array((time, t1, t2, t3, u1, v1, u2,
+ v2, u3, v3, bispec, sigmab), dtype=ehc.DTBIS))
+ obs.bispec = np.array(datatable2)
+
+ elif dtype == 'amp':
+ datatable2 = []
+ for row in datatable:
+ time = float(row[0])
+ tint = float(row[1])
+ t1 = row[2]
+ t2 = row[3]
+ u = float(row[4])
+ v = float(row[5])
+ amp = float(row[6])
+ sigmaamp = float(row[7])
+ datatable2.append(np.array((time, tint, t1, t2, u, v, amp, sigmaamp), dtype=ehc.DTAMP))
+ obs.amp = np.array(datatable2)
+
+ else:
+ raise Exception(dtype + ' is not a possible data type!')
+
+ return
diff --git a/io/oifits.py b/io/oifits.py
new file mode 100644
index 00000000..4583dc63
--- /dev/null
+++ b/io/oifits.py
@@ -0,0 +1,1358 @@
+"""
+A module for reading/writing OIFITS files
+
+This module is NOT related to the OIFITS Python module provided at
+http://www.mrao.cam.ac.uk/research/OAS/oi_data/oifits.html
+It is a (better) alternative.
+
+To open an existing OIFITS file, use the oifits.open(filename)
+function. This will return an oifits object with the following
+members (any of which can be empty dictionaries or numpy arrays):
+
+ array: a dictionary of interferometric arrays, as defined by the
+ OI_ARRAY tables. The dictionary key is the name of the array
+ (ARRNAME).
+
+ target: a numpy array of targets, as defined by the rows of the
+ OI_TARGET table.
+
+ wavelength: a dictionary of wavelength tables (OI_WAVELENGTH). The
+ dictionary key is the name of the instrument/settings (INSNAME).
+
+ vis, vis2 and t3: numpy arrays of objects containing all the
+ measurement information. Each list member corresponds to a row in
+ an OI_VIS/OI_VIS2/OI_T3 table.
+
+This module makes an ad-hoc, backwards-compatible change to the OIFITS
+revision 1 standard originally described by Pauls et al., 2005, PASP,
+117, 1255. The OI_VIS and OI_VIS2 tables in OIFITS files produced by
+this file contain two additional columns for the correlated flux,
+CFLUX and CFLUXERR , which are arrays with a length corresponding to
+the number of wavelength elements (just as VISAMP/VIS2DATA).
+
+The main purpose of this module is to allow easy access to your OIFITS
+data within Python, where you can then analyze it in any way you want.
+As of version 0.3, the module can now be used to create OIFITS files
+from scratch without serious pain. Be warned, creating an array table
+from scratch is probably like nailing jelly to a tree. In a future
+verison this will become easier.
+
+The module also provides a simple mechanism for combining multiple
+oifits objects, achieved by using the '+' operator on two oifits
+objects: result = a + b. The result can then be written to a file
+using result.save(filename).
+
+Many of the parameters and their meanings are not specifically
+documented here. However, the nomenclature mirrors that of the OIFITS
+standard, so it is recommended to use this module with the PASP
+reference above in hand.
+
+Beginning with version 0.3, the OI_VIS/OI_VIS2/OI_T3 classes now use
+masked arrays for convenience, where the mask is defined via the
+'flag' member of these classes. Beware of the following subtlety: as
+before, the array data are accessed via (for example) OI_VIS.visamp;
+however, OI_VIS.visamp is just a method which constructs (on the fly)
+a masked array from OI_VIS._visamp, which is where the data are
+actually stored. This is done transparently, and the data can be
+accessed and modified transparently via the "visamp" hidden attribute.
+The same goes for correlated fluxes, differential/closure phases,
+triple products, etc. See the notes on the individual classes for a
+list of all the "hidden" attributes.
+
+For further information, contact Paul Boley (boley@mpia-hd.mpg.de).
+
+"""
+from __future__ import division
+from __future__ import print_function
+from builtins import str
+from builtins import object
+
+import numpy as np
+from numpy import double, ma
+from astropy.io import fits as pyfits
+import datetime
+import copy
+
+__author__ = "Paul Boley"
+__email__ = "boley@mpia-hd.mpg.de"
+__date__ ='1 October 2012'
+__version__ = '0.3.1'
+_mjdzero = datetime.datetime(1858, 11, 17)
+
+matchtargetbyname = False
+matchstationbyname = False
+refdate = datetime.datetime(2000, 1, 1)
+
+def _plurals(count):
+ if count != 1: return 's'
+ return ''
+
+def _array_eq(a, b):
+ "Test whether all the elements of two arrays are equal."
+
+ try:
+ return not (a != b).any()
+ except:
+ return not (a != b)
+
+class _angpoint(float):
+ "Convenience object for representing angles."
+
+ def __init__(self, angle):
+ self.angle = angle
+
+ def __repr__(self):
+ return '_angpoint(%s)'%self.angle.__repr__()
+
+ def __str__(self):
+ return "%g degrees"%(self.angle)
+
+ def __eq__(self, other):
+ return self.angle == other.angle
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def asdms(self):
+ """Return the value as a string in dms format,
+ e.g. +25:30:22.55. Useful for declination."""
+ angle = self.angle
+ if angle < 0:
+ negative = True
+ angle *= -1.0
+ else:
+ negative = False
+ degrees = np.floor(angle)
+ minutes = np.floor((angle - degrees)*60.0)
+ seconds = (angle - degrees - minutes/60.0)*3600.0
+ try:
+ if negative:
+ return "-%02d:%02d:%05.2f"%(degrees,minutes,seconds)
+ else:
+ return "+%02d:%02d:%05.2f"%(degrees,minutes,seconds)
+ except TypeError:
+ return self.__repr__()
+
+ def ashms(self):
+ """Return the value as a string in hms format,
+ e.g. 5:12:17.21. Useful for right ascension."""
+ angle = self.angle*24.0/360.0
+
+ hours = np.floor(angle)
+ minutes = np.floor((angle - hours)*60.0)
+ seconds = (angle - hours - minutes/60.0)*3600.0
+ try:
+ return "%02d:%02d:%05.2f"%(hours,minutes,seconds)
+ except TypeError:
+ return self.__repr__()
+
+class OI_TARGET(object):
+
+ def __init__(self, target, raep0, decep0, equinox=2000.0, ra_err=0.0, dec_err=0.0,
+ sysvel=0.0, veltyp='TOPCENT', veldef='OPTICAL', pmra=0.0, pmdec=0.0,
+ pmra_err=0.0, pmdec_err=0.0, parallax=0.0, para_err=0.0, spectyp='UNKNOWN'):
+ self.target = target
+ self.raep0 = _angpoint(raep0)
+ self.decep0 = _angpoint(decep0)
+ self.equinox = equinox
+ self.ra_err = ra_err
+ self.dec_err = dec_err
+ self.sysvel = sysvel
+ self.veltyp = veltyp
+ self.veldef = veldef
+ self.pmra = pmra
+ self.pmdec = pmdec
+ self.pmra_err = pmra_err
+ self.pmdec_err = pmdec_err
+ self.parallax = parallax
+ self.para_err = para_err
+ self.spectyp = spectyp
+
+ def __eq__(self, other):
+
+ if type(self) != type(other): return False
+
+ return not (
+ (self.target != other.target) or
+ (self.raep0 != other.raep0) or
+ (self.decep0 != other.decep0) or
+ (self.equinox != other.equinox) or
+ (self.ra_err != other.ra_err) or
+ (self.dec_err != other.dec_err) or
+ (self.sysvel != other.sysvel) or
+ (self.veltyp != other.veltyp) or
+ (self.veldef != other.veldef) or
+ (self.pmra != other.pmra) or
+ (self.pmdec != other.pmdec) or
+ (self.pmra_err != other.pmra_err) or
+ (self.pmdec_err != other.pmdec_err) or
+ (self.parallax != other.parallax) or
+ (self.para_err != other.para_err) or
+ (self.spectyp != other.spectyp))
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __str__(self):
+ return "%s: %s %s (%g)"%(self.target, self.raep0.ashms(), self.decep0.asdms(), self.equinox)
+
+ def info(self):
+ print(str(self))
+
+class OI_WAVELENGTH(object):
+
+ def __init__(self, eff_wave, eff_band=None):
+ self.eff_wave = np.array(eff_wave, dtype=double).reshape(-1)
+ if eff_band == None:
+ eff_band = np.zeros_like(eff_wave)
+ self.eff_band = np.array(eff_band, dtype=double).reshape(-1)
+
+ def __eq__(self, other):
+
+ if type(self) != type(other): return False
+
+ return not (
+ (not _array_eq(self.eff_wave, other.eff_wave)) or
+ (not _array_eq(self.eff_band, other.eff_band)))
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __repr__(self):
+ return "%d wavelength%s (%.3g-%.3g um)"%(len(self.eff_wave), _plurals(len(self.eff_wave)), 1e6*np.min(self.eff_wave),1e6*np.max(self.eff_wave))
+
+ def info(self):
+ print(str(self))
+
+
+class OI_VIS(object):
+ """
+ Class for storing visibility amplitude and differential phase data.
+ To access the data, use the following hidden attributes:
+
+ visamp, visamperr, visphi, visphierr, flag;
+ and possibly cflux, cfluxerr.
+
+ """
+
+ def __init__(self, timeobs, int_time, visamp, visamperr, visphi, visphierr, flag, ucoord,
+ vcoord, wavelength, target, array=None, station=(None,None), cflux=None, cfluxerr=None):
+ self.timeobs = timeobs
+ self.array = array
+ self.wavelength = wavelength
+ self.target = target
+ self.int_time = int_time
+ self._visamp = np.array(visamp, dtype=double).reshape(-1)
+ self._visamperr = np.array(visamperr, dtype=double).reshape(-1)
+ self._visphi = np.array(visphi, dtype=double).reshape(-1)
+ self._visphierr = np.array(visphierr, dtype=double).reshape(-1)
+ if cflux != None: self._cflux = np.array(cflux, dtype=double).reshape(-1)
+ else: self._cflux = None
+ if cfluxerr != None: self._cfluxerr = np.array(cfluxerr, dtype=double).reshape(-1)
+ else: self._cfluxerr = None
+ self.flag = np.array(flag, dtype=bool).reshape(-1)
+ self.ucoord = ucoord
+ self.vcoord = vcoord
+ self.station = station
+
+ def __eq__(self, other):
+
+ if type(self) != type(other): return False
+
+ return not (
+ (self.timeobs != other.timeobs) or
+ (self.array != other.array) or
+ (self.wavelength != other.wavelength) or
+ (self.target != other.target) or
+ (self.int_time != other.int_time) or
+ (self.ucoord != other.ucoord) or
+ (self.vcoord != other.vcoord) or
+ (self.array != other.array) or
+ (self.station != other.station) or
+ (not _array_eq(self.visamp, other.visamp)) or
+ (not _array_eq(self.visamperr, other.visamperr)) or
+ (not _array_eq(self.visphi, other.visphi)) or
+ (not _array_eq(self.visphierr, other.visphierr)) or
+ (not _array_eq(self.flag, other.flag)))
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __getattr__(self, attrname):
+ if attrname in ('visamp', 'visamperr', 'visphi', 'visphierr'):
+ return ma.masked_array(self.__dict__['_' + attrname], mask=self.flag)
+ elif attrname in ('cflux', 'cfluxerr'):
+ if (self.__dict__['_' + attrname] != None):
+ return ma.masked_array(self.__dict__['_' + attrname], mask=self.flag)
+ else:
+ return None
+ else:
+ raise AttributeError(attrname)
+
+ def __setattr__(self, attrname, value):
+ if attrname in ('visamp', 'visamperr', 'visphi', 'visphierr', 'cflux', 'cfluxerr'):
+ self.__dict__['_' + attrname] = value
+ else:
+ self.__dict__[attrname] = value
+
+ def __repr__(self):
+ meanvis = ma.mean(self.visamp)
+ if self.station[0] and self.station[1]:
+ baselinename = ' (' + self.station[0].sta_name + self.station[1].sta_name + ')'
+ else:
+ baselinename = ''
+ return '%s %s%s: %d point%s (%d masked), B = %5.1f m, PA = %5.1f deg, = %4.2g'%(self.target.target, self.timeobs.strftime('%F %T'), baselinename, len(self.visamp), _plurals(len(self.visamp)), np.sum(self.flag), np.sqrt(self.ucoord**2 + self.vcoord**2), np.arctan(self.ucoord/self.vcoord) * 180.0 / np.pi % 180.0, meanvis)
+
+ def info(self):
+ print(str(self))
+
+class OI_VIS2(object):
+ """
+ Class for storing squared visibility amplitude data.
+ To access the data, use the following hidden attributes:
+
+ vis2data, vis2err
+
+ """
+ def __init__(self, timeobs, int_time, vis2data, vis2err, flag, ucoord, vcoord, wavelength,
+ target, array=None, station=(None, None)):
+ self.timeobs = timeobs
+ self.array = array
+ self.wavelength = wavelength
+ self.target = target
+ self.int_time = int_time
+ self._vis2data = np.array(vis2data, dtype=double).reshape(-1)
+ self._vis2err = np.array(vis2err, dtype=double).reshape(-1)
+ self.flag = np.array(flag, dtype=bool).reshape(-1)
+ self.ucoord = ucoord
+ self.vcoord = vcoord
+ self.station = station
+
+ def __eq__(self, other):
+
+ if type(self) != type(other): return False
+
+ return not (
+ (self.timeobs != other.timeobs) or
+ (self.array != other.array) or
+ (self.wavelength != other.wavelength) or
+ (self.target != other.target) or
+ (self.int_time != other.int_time) or
+ (self.ucoord != other.ucoord) or
+ (self.vcoord != other.vcoord) or
+ (self.array != other.array) or
+ (self.station != other.station) or
+ (not _array_eq(self.vis2data, other.vis2data)) or
+ (not _array_eq(self.vis2err, other.vis2err)) or
+ (not _array_eq(self.flag, other.flag)))
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __getattr__(self, attrname):
+ if attrname in ('vis2data', 'vis2err'):
+ return ma.masked_array(self.__dict__['_' + attrname], mask=self.flag)
+ else:
+ raise AttributeError(attrname)
+
+ def __setattr__(self, attrname, value):
+ if attrname in ('vis2data', 'vis2err'):
+ self.__dict__['_' + attrname] = value
+ else:
+ self.__dict__[attrname] = value
+
+ def __repr__(self):
+ meanvis = ma.mean(self.vis2data)
+ if self.station[0] and self.station[1]:
+ baselinename = ' (' + self.station[0].sta_name + self.station[1].sta_name + ')'
+ else:
+ baselinename = ''
+ return "%s %s%s: %d point%s (%d masked), B = %5.1f m, PA = %5.1f deg, = %4.2g"%(self.target.target, self.timeobs.strftime('%F %T'), baselinename, len(self.vis2data), _plurals(len(self.vis2data)), np.sum(self.flag), np.sqrt(self.ucoord**2 + self.vcoord**2), np.arctan(self.ucoord/self.vcoord) * 180.0 / np.pi % 180.0, meanvis)
+
+ def info(self):
+ print(str(self))
+
+
+class OI_T3(object):
+ """
+ Class for storing triple product and closure phase data.
+ To access the data, use the following hidden attributes:
+
+ t3amp, t3amperr, t3phi, t3phierr
+
+ """
+
+ def __init__(self, timeobs, int_time, t3amp, t3amperr, t3phi, t3phierr, flag, u1coord,
+ v1coord, u2coord, v2coord, wavelength, target, array=None, station=(None,None,None)):
+ self.timeobs = timeobs
+ self.array = array
+ self.wavelength = wavelength
+ self.target = target
+ self.int_time = int_time
+ self._t3amp = np.array(t3amp, dtype=double).reshape(-1)
+ self._t3amperr = np.array(t3amperr, dtype=double).reshape(-1)
+ self._t3phi = np.array(t3phi, dtype=double).reshape(-1)
+ self._t3phierr = np.array(t3phierr, dtype=double).reshape(-1)
+ self.flag = np.array(flag, dtype=bool).reshape(-1)
+ self.u1coord = u1coord
+ self.v1coord = v1coord
+ self.u2coord = u2coord
+ self.v2coord = v2coord
+ self.station = station
+
+ def __eq__(self, other):
+
+ if type(self) != type(other): return False
+
+ return not (
+ (self.timeobs != other.timeobs) or
+ (self.array != other.array) or
+ (self.wavelength != other.wavelength) or
+ (self.target != other.target) or
+ (self.int_time != other.int_time) or
+ (self.u1coord != other.u1coord) or
+ (self.v1coord != other.v1coord) or
+ (self.u2coord != other.u2coord) or
+ (self.v2coord != other.v2coord) or
+ (self.array != other.array) or
+ (self.station != other.station) or
+ (not _array_eq(self.t3amp, other.t3amp)) or
+ (not _array_eq(self.t3amperr, other.t3amperr)) or
+ (not _array_eq(self.t3phi, other.t3phi)) or
+ (not _array_eq(self.t3phierr, other.t3phierr)) or
+ (not _array_eq(self.flag, other.flag)))
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __getattr__(self, attrname):
+ if attrname in ('t3amp', 't3amperr', 't3phi', 't3phierr'):
+ return ma.masked_array(self.__dict__['_' + attrname], mask=self.flag)
+ else:
+ raise AttributeError(attrname)
+
+ def __setattr__(self, attrname, value):
+ if attrname in ('vis2data', 'vis2err'):
+ self.__dict__['_' + attrname] = value
+ else:
+ self.__dict__[attrname] = value
+
+ def __repr__(self):
+ meant3 = np.mean(self.t3amp[np.where(self.flag == False)])
+ if self.station[0] and self.station[1] and self.station[2]:
+ baselinename = ' (' + self.station[0].sta_name + self.station[1].sta_name + self.station[2].sta_name + ')'
+ else:
+ baselinename = ''
+ return "%s %s%s: %d point%s (%d masked), B = %5.1fm, %5.1fm, = %4.2g"%(self.target.target, self.timeobs.strftime('%F %T'), baselinename, len(self.t3amp), _plurals(len(self.t3amp)), np.sum(self.flag), np.sqrt(self.u1coord**2 + self.v1coord**2), np.sqrt(self.u2coord**2 + self.v2coord**2), meant3)
+
+ def info(self):
+ print(str(self))
+
+class OI_STATION(object):
+ """ This class corresponds to a single row (i.e. single
+ station/telescope) of an OI_ARRAY table."""
+
+ def __init__(self, tel_name=None, sta_name=None, diameter=None, staxyz=[None, None, None]):
+ self.tel_name = tel_name
+ self.sta_name = sta_name
+ self.diameter = diameter
+ self.staxyz = staxyz
+
+ def __eq__(self, other):
+
+ if type(self) != type(other): return False
+
+ return not (
+ (self.tel_name != other.tel_name) or
+ (self.sta_name != other.sta_name) or
+ (self.diameter != other.diameter) or
+ (not _array_eq(self.staxyz, other.staxyz)))
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __repr__(self):
+ return '%s/%s (%g m)'%(self.sta_name, self.tel_name, self.diameter)
+
+class OI_ARRAY(object):
+ """Contains all the data for a single OI_ARRAY table. Note the
+ hidden convenience attributes latitude, longitude, and altitude."""
+
+ def __init__(self, frame, arrxyz, stations=()):
+ self.frame = frame
+ self.arrxyz = arrxyz
+ #self.station = stations;
+ self.station = np.empty(0)
+ for station in stations:
+ tel_name, sta_name, sta_index, diameter, staxyz = station
+ self.station = np.append(self.station, OI_STATION(tel_name=tel_name, sta_name=sta_name, diameter=diameter, staxyz=staxyz))
+
+ def __eq__(self, other):
+
+ if type(self) != type(other): return False
+
+ equal = not (
+ (self.frame != other.frame) or
+ (not _array_eq(self.arrxyz, other.arrxyz)))
+
+ if not equal: return False
+
+ # If position appears to be the same, check that the stations
+ # (and ordering) are also the same
+ if (self.station != other.station).any():
+ return False
+
+ return True
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def __getattr__(self, attrname):
+ if attrname == 'latitude':
+ radius = np.sqrt((self.arrxyz**2).sum())
+ return _angpoint(np.arcsin(self.arrxyz[2]/radius)*180.0/np.pi)
+ elif attrname == 'longitude':
+ radius = np.sqrt((self.arrxyz**2).sum())
+ xylen = np.sqrt(self.arrxyz[0]**2+self.arrxyz[1]**2)
+ return _angpoint(np.arcsin(self.arrxyz[1]/xylen)*180.0/np.pi)
+ elif attrname == 'altitude':
+ radius = np.sqrt((self.arrxyz**2).sum())
+ return radius - 6378100.0
+ else:
+ raise AttributeError(attrname)
+
+ def __repr__(self):
+ return '%s %s %g m, %d station%s'%(self.latitude.asdms(), self.longitude.asdms(), self.altitude, len(self.station), _plurals(len(self.station)))
+
+ def info(self, verbose=0):
+ """Print the array's center coordinates. If verbosity >= 1,
+ print information about each station."""
+ print(str(self))
+ if verbose >= 1:
+ for station in self.station:
+ print(" %s"%str(station))
+
+ def get_station_by_name(self, name):
+
+ for station in self.station:
+ if station.sta_name == name:
+ return station
+
+ raise LookupError('No such station %s'%name)
+
+class oifits(object):
+
+ def __init__(self):
+
+ self.wavelength = {}
+ self.target = np.empty(0)
+ self.array = {}
+ self.vis = np.empty(0)
+ self.vis2 = np.empty(0)
+ self.t3 = np.empty(0)
+
+ def __add__(self, other):
+ """Consistently combine two separate oifits objects. Note
+ that targets can be matched by name only (e.g. if coordinates
+ differ) by setting oifits.matchtargetbyname to True. The same
+ goes for stations of the array (controlled by
+ oifits.matchstationbyname)"""
+ # Don't do anything if the two oifits objects are not CONSISTENT!
+ if self.isconsistent() == False or other.isconsistent() == False:
+ print('oifits objects are not consistent, bailing.')
+ return
+
+ new = copy.deepcopy(self)
+ if len(other.wavelength):
+ wavelengthmap = {}
+ for key in list(other.wavelength.keys()):
+ if key not in list(new.wavelength.keys()):
+ new.wavelength[key] = copy.deepcopy(other.wavelength[key])
+ elif new.wavelength[key] != other.wavelength[key]:
+ raise ValueError('Wavelength tables have the same key but differing contents.')
+ wavelengthmap[id(other.wavelength[key])] = new.wavelength[key]
+
+ if len(other.target):
+ targetmap = {}
+ for otarget in other.target:
+ for ntarget in new.target:
+ if matchtargetbyname and ntarget.target == otarget.target:
+ targetmap[id(otarget)] = ntarget
+ break
+ elif ntarget == otarget:
+ targetmap[id(otarget)] = ntarget
+ break
+ elif ntarget.target == otarget.target:
+ print('Found a target with a matching name, but some differences in the target specification. Creating a new target. Set oifits.matchtargetbyname to True to override this behavior.')
+ # If 'id(otarget)' is not in targetmap, then this is a new
+ # target and should be added to the array of targets
+ if id(otarget) not in list(targetmap.keys()):
+ try:
+ newkey = list(new.target.keys())[-1]+1
+ except:
+ newkey = 1
+ target = copy.deepcopy(otarget)
+ new.target = np.append(new.target, target)
+ targetmap[id(otarget)] = target
+
+ if len(other.array):
+ stationmap = {}
+ arraymap = {}
+ for key, otharray in other.array.items():
+ arraymap[id(otharray)] = key
+ if key not in list(new.array.keys()):
+ new.array[key] = copy.deepcopy(other.array[key])
+ # If arrays have the same name but seem to differ, try
+ # to combine the two (by including the union of both
+ # sets of stations)
+ for othsta in other.array[key].station:
+ for newsta in new.array[key].station:
+ if newsta == othsta:
+ stationmap[id(othsta)] = newsta
+ break
+ elif matchstationbyname and newsta.sta_name == othsta.sta_name:
+ stationmap[id(othsta)] = newsta
+ break
+ elif newsta.sta_name == othsta.sta_name and matchstationbyname == False:
+ raise ValueError('Stations have matching names but conflicting data.')
+ # If 'id(othsta)' is not in the stationmap
+ # dictionary, then this is a new station and
+ # should be added to the current array
+ if id(othsta) not in list(stationmap.keys()):
+ newsta = copy.deepcopy(othsta)
+ new.array[key].station = np.append(new.array[key].station, newsta)
+ stationmap[id(othsta)] = newsta
+ # Make sure that staxyz of the new station is relative to the new array center
+ newsta.staxyz = othsta.staxyz - other.array[key].arrxyz + new.array[key].arrxyz
+
+ for vis in other.vis:
+ if vis not in new.vis:
+ newvis = copy.copy(vis)
+ # The wavelength, target, array and station objects
+ # should point to the appropriate objects inside the
+ # 'new' structure
+ newvis.wavelength = wavelengthmap[id(vis.wavelength)]
+ newvis.target = targetmap[id(vis.target)]
+ if (vis.array):
+ newvis.array = new.array[arraymap[id(vis.array)]]
+ newvis.station = [None, None]
+ newvis.station[0] = stationmap[id(vis.station[0])]
+ newvis.station[1] = stationmap[id(vis.station[1])]
+ new.vis = np.append(new.vis, newvis)
+
+ for vis2 in other.vis2:
+ if vis2 not in new.vis2:
+ newvis2 = copy.copy(vis2)
+ # The wavelength, target, array and station objects
+ # should point to the appropriate objects inside the
+ # 'new' structure
+ newvis2.wavelength = wavelengthmap[id(vis2.wavelength)]
+ newvis2.target = targetmap[id(vis2.target)]
+ if (vis2.array):
+ newvis2.array = new.array[arraymap[id(vis2.array)]]
+ newvis2.station = [None, None]
+ newvis2.station[0] = stationmap[id(vis2.station[0])]
+ newvis2.station[1] = stationmap[id(vis2.station[1])]
+ new.vis2 = np.append(new.vis2, newvis2)
+
+ for t3 in other.t3:
+ if t3 not in new.t3:
+ newt3 = copy.copy(t3)
+ # The wavelength, target, array and station objects
+ # should point to the appropriate objects inside the
+ # 'new' structure
+ newt3.wavelength = wavelengthmap[id(t3.wavelength)]
+ newt3.target = targetmap[id(t3.target)]
+ if (t3.array):
+ newt3.array = new.array[arraymap[id(t3.array)]]
+ newt3.station = [None, None, None]
+ newt3.station[0] = stationmap[id(t3.station[0])]
+ newt3.station[1] = stationmap[id(t3.station[1])]
+ newt3.station[2] = stationmap[id(t3.station[2])]
+ new.t3 = np.append(new.t3, newt3)
+
+ return(new)
+
+
+ def __eq__(self, other):
+
+ if type(self) != type(other): return False
+
+ return not (
+ (self.wavelength != other.wavelength) or
+ (self.target != other.target).any() or
+ (self.array != other.array) or
+ (self.vis != other.vis).any() or
+ (self.vis2 != other.vis2).any() or
+ (self.t3 != other.t3).any())
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def isvalid(self):
+ """Returns True of the oifits object is both consistent (as
+ determined by isconsistent()) and conforms to the OIFITS
+ standard (according to Pauls et al., 2005, PASP, 117, 1255)."""
+
+ warnings = []
+ errors = []
+ if not self.isconsistent():
+ errors.append('oifits object is not consistent')
+ if not self.target.size:
+ errors.append('No OI_TARGET data')
+ if not self.wavelength:
+ errors.append('No OI_WAVELENGTH data')
+ else:
+ for wavelength in list(self.wavelength.values()):
+ if len(wavelength.eff_wave) != len(wavelength.eff_band):
+ errors.append("eff_wave and eff_band are of different lengths for wavelength table '%s'"%key)
+ if (self.vis.size + self.vis2.size + self.t3.size == 0):
+ errors.append('Need to have atleast one measurement table (vis, vis2 or t3)')
+ for vis in self.vis:
+ nwave = len(vis.wavelength.eff_band)
+ if (len(vis.visamp) != nwave) or (len(vis.visamperr) != nwave) or (len(vis.visphi) != nwave) or (len(vis.visphierr) != nwave) or (len(vis.flag) != nwave):
+ errors.append("Data size mismatch for visibility measurement 0x%x (wavelength table has a length of %d)"%(id(vis), nwave))
+ for vis2 in self.vis2:
+ nwave = len(vis2.wavelength.eff_band)
+ if (len(vis2.vis2data) != nwave) or (len(vis2.vis2err) != nwave) or (len(vis2.flag) != nwave):
+ errors.append("Data size mismatch for visibility^2 measurement 0x%x (wavelength table has a length of %d)"%(id(vis), nwave))
+ for t3 in self.t3:
+ nwave = len(t3.wavelength.eff_band)
+ if (len(t3.t3amp) != nwave) or (len(t3.t3amperr) != nwave) or (len(t3.t3phi) != nwave) or (len(t3.t3phierr) != nwave) or (len(t3.flag) != nwave):
+ errors.append("Data size mismatch for visibility measurement 0x%x (wavelength table has a length of %d)"%(id(vis), nwave))
+
+ if warnings:
+ print("*** %d warning%s:"%(len(warnings), _plurals(len(warnings))))
+ for warning in warnings:
+ print(' ' + warning)
+ if errors:
+ print("*** %d ERROR%s:"%(len(errors), _plurals(len(errors)).upper()))
+ for error in errors:
+ print(' ' + error)
+
+ return not (len(warnings) or len(errors))
+
+ def isconsistent(self):
+ """Returns True if the object is entirely self-contained,
+ i.e. all cross-references to wavelength tables, arrays,
+ stations etc. in the measurements refer to elements which are
+ stored in the oifits object. Note that an oifits object can
+ be 'consistent' in this sense without being 'valid' as checked
+ by isvalid()."""
+
+ for vis in self.vis:
+ if vis.array and (vis.array not in list(self.array.values())):
+ print('A visibility measurement (0x%x) refers to an array which is not inside the main oifits object.'%id(vis))
+ return False
+ if ((vis.station[0] and (vis.station[0] not in vis.array.station)) or
+ (vis.station[1] and (vis.station[1] not in vis.array.station))):
+ print('A visibility measurement (0x%x) refers to a station which is not inside the main oifits object.'%id(vis))
+ return False
+ if vis.wavelength not in list(self.wavelength.values()):
+ print('A visibility measurement (0x%x) refers to a wavelength table which is not inside the main oifits object.'%id(vis))
+ return False
+ if vis.target not in self.target:
+ print('A visibility measurement (0x%x) refers to a target which is not inside the main oifits object.'%id(vis))
+ return False
+
+ for vis2 in self.vis2:
+ if vis2.array and (vis2.array not in list(self.array.values())):
+ print('A visibility^2 measurement (0x%x) refers to an array which is not inside the main oifits object.'%id(vis2))
+ return False
+ if ((vis2.station[0] and (vis2.station[0] not in vis2.array.station)) or
+ (vis2.station[1] and (vis2.station[1] not in vis2.array.station))):
+ print('A visibility^2 measurement (0x%x) refers to a station which is not inside the main oifits object.'%id(vis))
+ return False
+ if vis2.wavelength not in list(self.wavelength.values()):
+ print('A visibility^2 measurement (0x%x) refers to a wavelength table which is not inside the main oifits object.'%id(vis2))
+ return False
+ if vis2.target not in self.target:
+ print('A visibility^2 measurement (0x%x) refers to a target which is not inside the main oifits object.'%id(vis2))
+ return False
+
+ for t3 in self.t3:
+ if t3.array and (t3.array not in list(self.array.values())):
+ print('A closure phase measurement (0x%x) refers to an array which is not inside the main oifits object.'%id(t3))
+ return False
+ if ((t3.station[0] and (t3.station[0] not in t3.array.station)) or
+ (t3.station[1] and (t3.station[1] not in t3.array.station)) or
+ (t3.station[2] and (t3.station[2] not in t3.array.station))):
+ print('A closure phase measurement (0x%x) refers to a station which is not inside the main oifits object.'%id(t3))
+ return False
+ if t3.wavelength not in list(self.wavelength.values()):
+ print('A closure phase measurement (0x%x) refers to a wavelength table which is not inside the main oifits object.'%id(t3))
+ return False
+ if t3.target not in self.target:
+ print('A closure phase measurement (0x%x) refers to a target which is not inside the main oifits object.'%id(t3))
+ return False
+
+ return True
+
+ def info(self, recursive=True, verbose=0):
+ """Print out a summary of the contents of the oifits object.
+ Set recursive=True to obtain more specific information about
+ each of the individual components, and verbose to an integer
+ to increase the verbosity level."""
+
+ if self.wavelength:
+ wavelengths = 0
+ if recursive:
+ print("====================================================================")
+ print("SUMMARY OF WAVELENGTH TABLES")
+ print("====================================================================")
+ for key in list(self.wavelength.keys()):
+ wavelengths += len(self.wavelength[key].eff_wave)
+ if recursive: print("'%s': %s"%(key, str(self.wavelength[key])))
+ print("%d wavelength table%s with %d wavelength%s in total"%(len(self.wavelength), _plurals(len(self.wavelength)), wavelengths, _plurals(wavelengths)))
+ if self.target.size:
+ if recursive:
+ print("====================================================================")
+ print("SUMMARY OF TARGET TABLES")
+ print("====================================================================")
+ for target in self.target:
+ target.info()
+ print("%d target%s"%(len(self.target), _plurals(len(self.target))))
+ if self.array:
+ stations = 0
+ if recursive:
+ print("====================================================================")
+ print("SUMMARY OF ARRAY TABLES")
+ print("====================================================================")
+ for key in list(self.array.keys()):
+ if recursive:
+ print(key + ':')
+ self.array[key].info(verbose=verbose)
+ stations += len(self.array[key].station)
+ print("%d array%s with %d station%s"%(len(self.array), _plurals(len(self.array)), stations, _plurals(stations)))
+ if self.vis.size:
+ if recursive:
+ print("====================================================================")
+ print("SUMMARY OF VISIBILITY MEASUREMENTS")
+ print("====================================================================")
+ for vis in self.vis:
+ vis.info()
+ print("%d visibility measurement%s"%(len(self.vis), _plurals(len(self.vis))))
+ if self.vis2.size:
+ if recursive:
+ print("====================================================================")
+ print("SUMMARY OF VISIBILITY^2 MEASUREMENTS")
+ print("====================================================================")
+ for vis2 in self.vis2:
+ vis2.info()
+ print("%d visibility^2 measurement%s"%(len(self.vis2), _plurals(len(self.vis2))))
+ if self.t3.size:
+ if recursive:
+ print("====================================================================")
+ print("SUMMARY OF T3 MEASUREMENTS")
+ print("====================================================================")
+ for t3 in self.t3:
+ t3.info()
+ print("%d closure phase measurement%s"%(len(self.t3), _plurals(len(self.t3))))
+
+ def save(self, filename):
+ """Write the contents of the oifits object to a file in OIFITS
+ format."""
+
+ if not self.isconsistent():
+ print('oifits object is not consistent, refusing to go further')
+ return
+
+ hdulist = pyfits.HDUList()
+ hdu = pyfits.PrimaryHDU()
+ hdu.header.update('DATE', datetime.datetime.now().strftime(format='%F'), comment='Creation date')
+ hdu.header.add_comment('Written by OIFITS Python module version %s'%__version__)
+ hdu.header.add_comment('http://www.mpia-hd.mpg.de/homes/boley/oifits/')
+
+ wavelengthmap = {}
+ hdulist.append(hdu)
+ for insname, wavelength in self.wavelength.items():
+ wavelengthmap[id(wavelength)] = insname
+ hdu = pyfits.new_table(pyfits.ColDefs((
+ pyfits.Column(name='EFF_WAVE', format='1E', unit='METERS', array=wavelength.eff_wave),
+ pyfits.Column(name='EFF_BAND', format='1E', unit='METERS', array=wavelength.eff_band)
+ )))
+ hdu.header.update('EXTNAME', 'OI_WAVELENGTH')
+ hdu.header.update('OI_REVN', 1, 'Revision number of the table definition')
+ hdu.header.update('INSNAME', insname, 'Name of detector, for cross-referencing')
+ hdulist.append(hdu)
+
+ targetmap = {}
+ if self.target.size:
+ target_id = []
+ target = []
+ raep0 = []
+ decep0 = []
+ equinox = []
+ ra_err = []
+ dec_err = []
+ sysvel = []
+ veltyp = []
+ veldef = []
+ pmra = []
+ pmdec = []
+ pmra_err = []
+ pmdec_err = []
+ parallax = []
+ para_err = []
+ spectyp = []
+ for i, targ in enumerate(self.target):
+ key = i+1
+ targetmap[id(targ)] = key
+ target_id.append(key)
+ target.append(targ.target)
+ raep0.append(targ.raep0)
+ decep0.append(targ.decep0)
+ equinox.append(targ.equinox)
+ ra_err.append(targ.ra_err)
+ dec_err.append(targ.dec_err)
+ sysvel.append(targ.sysvel)
+ veltyp.append(targ.veltyp)
+ veldef.append(targ.veldef)
+ pmra.append(targ.pmra)
+ pmdec.append(targ.pmdec)
+ pmra_err.append(targ.pmra_err)
+ pmdec_err.append(targ.pmdec_err)
+ parallax.append(targ.parallax)
+ para_err.append(targ.para_err)
+ spectyp.append(targ.spectyp)
+
+ hdu = pyfits.new_table(pyfits.ColDefs((
+ pyfits.Column(name='TARGET_ID', format='1I', array=target_id),
+ pyfits.Column(name='TARGET', format='16A', array=target),
+ pyfits.Column(name='RAEP0', format='D1', unit='DEGREES', array=raep0),
+ pyfits.Column(name='DECEP0', format='D1', unit='DEGREES', array=decep0),
+ pyfits.Column(name='EQUINOX', format='E1', unit='YEARS', array=equinox),
+ pyfits.Column(name='RA_ERR', format='D1', unit='DEGREES', array=ra_err),
+ pyfits.Column(name='DEC_ERR', format='D1', unit='DEGREES', array=dec_err),
+ pyfits.Column(name='SYSVEL', format='D1', unit='M/S', array=sysvel),
+ pyfits.Column(name='VELTYP', format='A8', array=veltyp),
+ pyfits.Column(name='VELDEF', format='A8', array=veldef),
+ pyfits.Column(name='PMRA', format='D1', unit='DEG/YR', array=pmra),
+ pyfits.Column(name='PMDEC', format='D1', unit='DEG/YR', array=pmdec),
+ pyfits.Column(name='PMRA_ERR', format='D1', unit='DEG/YR', array=pmra_err),
+ pyfits.Column(name='PMDEC_ERR', format='D1', unit='DEG/YR', array=pmdec_err),
+ pyfits.Column(name='PARALLAX', format='E1', unit='DEGREES', array=parallax),
+ pyfits.Column(name='PARA_ERR', format='E1', unit='DEGREES', array=para_err),
+ pyfits.Column(name='SPECTYP', format='A16', array=spectyp)
+ )))
+ hdu.header.update('EXTNAME', 'OI_TARGET')
+ hdu.header.update('OI_REVN', 1, 'Revision number of the table definition')
+ hdulist.append(hdu)
+
+ arraymap = {}
+ stationmap = {}
+ for arrname, array in self.array.items():
+ arraymap[id(array)] = arrname
+ tel_name = []
+ sta_name = []
+ sta_index = []
+ diameter = []
+ staxyz = []
+ if array.station.size:
+ for i, station in enumerate(array.station, 1):
+ stationmap[id(station)] = i
+ tel_name.append(station.tel_name)
+ sta_name.append(station.sta_name)
+ sta_index.append(i)
+ diameter.append(station.diameter)
+ staxyz.append(station.staxyz)
+ hdu = pyfits.new_table(pyfits.ColDefs((
+ pyfits.Column(name='TEL_NAME', format='16A', array=tel_name),
+ pyfits.Column(name='STA_NAME', format='16A', array=sta_name),
+ pyfits.Column(name='STA_INDEX', format='1I', array=sta_index),
+ pyfits.Column(name='DIAMETER', unit='METERS', format='1E', array=diameter),
+ pyfits.Column(name='STAXYZ', unit='METERS', format='3D', array=staxyz)
+ )))
+ hdu.header.update('EXTNAME', 'OI_ARRAY')
+ hdu.header.update('OI_REVN', 1, 'Revision number of the table definition')
+ hdu.header.update('ARRNAME', arrname, comment='Array name, for cross-referencing')
+ hdu.header.update('FRAME', array.frame, comment='Coordinate frame')
+ hdu.header.update('ARRAYX', array.arrxyz[0], comment='Array center x coordinate (m)')
+ hdu.header.update('ARRAYY', array.arrxyz[1], comment='Array center y coordinate (m)')
+ hdu.header.update('ARRAYZ', array.arrxyz[2], comment='Array center z coordinate (m)')
+ hdulist.append(hdu)
+
+ if self.vis.size:
+ # The tables are grouped by ARRNAME and INSNAME -- all
+ # observations which have the same ARRNAME and INSNAME are
+ # put into a single FITS binary table.
+ tables = {}
+ for vis in self.vis:
+ nwave = vis.wavelength.eff_wave.size
+ if vis.array:
+ key = (arraymap[id(vis.array)], wavelengthmap[id(vis.wavelength)])
+ else:
+ key = (None, wavelengthmap[id(vis.wavelength)])
+ if key in list(tables.keys()):
+ data = tables[key]
+ else:
+ data = tables[key] = {'target_id':[], 'time':[], 'mjd':[], 'int_time':[],
+ 'visamp':[], 'visamperr':[], 'visphi':[], 'visphierr':[],
+ 'cflux':[], 'cfluxerr':[], 'ucoord':[], 'vcoord':[],
+ 'sta_index':[], 'flag':[]}
+ data['target_id'].append(targetmap[id(vis.target)])
+ if vis.timeobs:
+ time = vis.timeobs - refdate
+ data['time'].append(time.days * 24.0 * 3600.0 + time.seconds)
+ mjd = (vis.timeobs - _mjdzero).days + (vis.timeobs - _mjdzero).seconds / 3600.0 / 24.0
+ data['mjd'].append(mjd)
+ else:
+ data['time'].append(None)
+ data['mjd'].append(None)
+ data['int_time'].append(vis.int_time)
+ if nwave == 1:
+ data['visamp'].append(vis.visamp[0])
+ data['visamperr'].append(vis.visamperr[0])
+ data['visphi'].append(vis.visphi[0])
+ data['visphierr'].append(vis.visphierr[0])
+ data['flag'].append(vis.flag[0])
+ if vis.cflux != None:
+ data['cflux'].append(vis.cflux[0])
+ else:
+ data['cflux'].append(None)
+ if vis.cfluxerr != None:
+ data['cfluxerr'].append(vis.cfluxerr[0])
+ else:
+ data['cfluxerr'].append(None)
+ else:
+ data['visamp'].append(vis.visamp)
+ data['visamperr'].append(vis.visamperr)
+ data['visphi'].append(vis.visphi)
+ data['visphierr'].append(vis.visphierr)
+ data['flag'].append(vis.flag)
+ if vis.cflux != None:
+ data['cflux'].append(vis.cflux)
+ else:
+ cflux=np.empty(nwave)
+ cflux[:]=None
+ data['cflux'].append(cflux)
+ if vis.cfluxerr != None:
+ data['cfluxerr'].append(vis.cfluxerr)
+ else:
+ cfluxerr=np.empty(nwave)
+ cfluxerr[:]=None
+ data['cfluxerr'].append(cfluxerr)
+ data['ucoord'].append(vis.ucoord)
+ data['vcoord'].append(vis.vcoord)
+ if vis.station[0] and vis.station[1]:
+ data['sta_index'].append([stationmap[id(vis.station[0])], stationmap[id(vis.station[1])]])
+ else:
+ data['sta_index'].append([-1, -1])
+ for key in list(tables.keys()):
+ data = tables[key]
+ nwave = self.wavelength[key[1]].eff_wave.size
+
+ hdu = pyfits.new_table(pyfits.ColDefs([
+ pyfits.Column(name='TARGET_ID', format='1I', array=data['target_id']),
+ pyfits.Column(name='TIME', format='1D', unit='SECONDS', array=data['time']),
+ pyfits.Column(name='MJD', unit='DAY', format='1D', array=data['mjd']),
+ pyfits.Column(name='INT_TIME', format='1D', unit='SECONDS', array=data['int_time']),
+ pyfits.Column(name='VISAMP', format='%dD'%nwave, array=data['visamp']),
+ pyfits.Column(name='VISAMPERR', format='%dD'%nwave, array=data['visamperr']),
+ pyfits.Column(name='VISPHI', unit='DEGREES', format='%dD'%nwave, array=data['visphi']),
+ pyfits.Column(name='VISPHIERR', unit='DEGREES', format='%dD'%nwave, array=data['visphierr']),
+ pyfits.Column(name='CFLUX', format='%dD'%nwave, array=data['cflux']),
+ pyfits.Column(name='CFLUXERR', format='%dD'%nwave, array=data['cfluxerr']),
+ pyfits.Column(name='UCOORD', format='1D', unit='METERS', array=data['ucoord']),
+ pyfits.Column(name='VCOORD', format='1D', unit='METERS', array=data['vcoord']),
+ pyfits.Column(name='STA_INDEX', format='2I', array=data['sta_index'], null=-1),
+ pyfits.Column(name='FLAG', format='%dL'%nwave)
+ ]))
+
+ # Setting the data of logical field via the
+ # pyfits.Column call above with length > 1 (eg
+ # format='171L' above) seems to be broken, atleast as
+ # of PyFITS 2.2.2
+ hdu.data.field('FLAG').setfield(data['flag'], bool)
+ hdu.header.update('EXTNAME', 'OI_VIS')
+ hdu.header.update('OI_REVN', 1, 'Revision number of the table definition')
+ hdu.header.update('DATE-OBS', refdate.strftime('%F'), comment='Zero-point for table (UTC)')
+ if key[0]: hdu.header.update('ARRNAME', key[0], 'Identifies corresponding OI_ARRAY')
+ hdu.header.update('INSNAME', key[1], 'Identifies corresponding OI_WAVELENGTH table')
+ hdulist.append(hdu)
+
+ if self.vis2.size:
+ tables = {}
+ for vis in self.vis2:
+ nwave = vis.wavelength.eff_wave.size
+ if vis.array:
+ key = (arraymap[id(vis.array)], wavelengthmap[id(vis.wavelength)])
+ else:
+ key = (None, wavelengthmap[id(vis.wavelength)])
+ if key in list(tables.keys()):
+ data = tables[key]
+ else:
+ data = tables[key] = {'target_id':[], 'time':[], 'mjd':[], 'int_time':[],
+ 'vis2data':[], 'vis2err':[], 'ucoord':[], 'vcoord':[],
+ 'sta_index':[], 'flag':[]}
+ data['target_id'].append(targetmap[id(vis.target)])
+ if vis.timeobs:
+ time = vis.timeobs - refdate
+ data['time'].append(time.days * 24.0 * 3600.0 + time.seconds)
+ mjd = (vis.timeobs - _mjdzero).days + (vis.timeobs - _mjdzero).seconds / 3600.0 / 24.0
+ data['mjd'].append(mjd)
+ else:
+ data['time'].append(None)
+ data['mjd'].append(None)
+ data['int_time'].append(vis.int_time)
+ if nwave == 1:
+ data['vis2data'].append(vis.vis2data[0])
+ data['vis2err'].append(vis.vis2err[0])
+ data['flag'].append(vis.flag[0])
+ else:
+ data['vis2data'].append(vis.vis2data)
+ data['vis2err'].append(vis.vis2err)
+ data['flag'].append(vis.flag)
+ data['ucoord'].append(vis.ucoord)
+ data['vcoord'].append(vis.vcoord)
+ if vis.station[0] and vis.station[1]:
+ data['sta_index'].append([stationmap[id(vis.station[0])], stationmap[id(vis.station[1])]])
+ else:
+ data['sta_index'].append([-1, -1])
+ for key in list(tables.keys()):
+ data = tables[key]
+ nwave = self.wavelength[key[1]].eff_wave.size
+
+ hdu = pyfits.new_table(pyfits.ColDefs([
+ pyfits.Column(name='TARGET_ID', format='1I', array=data['target_id']),
+ pyfits.Column(name='TIME', format='1D', unit='SECONDS', array=data['time']),
+ pyfits.Column(name='MJD', format='1D', unit='DAY', array=data['mjd']),
+ pyfits.Column(name='INT_TIME', format='1D', unit='SECONDS', array=data['int_time']),
+ pyfits.Column(name='VIS2DATA', format='%dD'%nwave, array=data['vis2data']),
+ pyfits.Column(name='VIS2ERR', format='%dD'%nwave, array=data['vis2err']),
+ pyfits.Column(name='UCOORD', format='1D', unit='METERS', array=data['ucoord']),
+ pyfits.Column(name='VCOORD', format='1D', unit='METERS', array=data['vcoord']),
+ pyfits.Column(name='STA_INDEX', format='2I', array=data['sta_index'], null=-1),
+ pyfits.Column(name='FLAG', format='%dL'%nwave, array=data['flag'])
+ ]))
+ # Setting the data of logical field via the
+ # pyfits.Column call above with length > 1 (eg
+ # format='171L' above) seems to be broken, atleast as
+ # of PyFITS 2.2.2
+ hdu.data.field('FLAG').setfield(data['flag'], bool)
+ hdu.header.update('EXTNAME', 'OI_VIS2')
+ hdu.header.update('OI_REVN', 1, 'Revision number of the table definition')
+ hdu.header.update('DATE-OBS', refdate.strftime('%F'), comment='Zero-point for table (UTC)')
+ if key[0]: hdu.header.update('ARRNAME', key[0], 'Identifies corresponding OI_ARRAY')
+ hdu.header.update('INSNAME', key[1], 'Identifies corresponding OI_WAVELENGTH table')
+ hdulist.append(hdu)
+
+ if self.t3.size:
+ tables = {}
+ for t3 in self.t3:
+ nwave = t3.wavelength.eff_wave.size
+ if t3.array:
+ key = (arraymap[id(t3.array)], wavelengthmap[id(t3.wavelength)])
+ else:
+ key = (None, wavelengthmap[id(t3.wavelength)])
+ if key in list(tables.keys()):
+ data = tables[key]
+ else:
+ data = tables[key] = {'target_id':[], 'time':[], 'mjd':[], 'int_time':[],
+ 't3amp':[], 't3amperr':[], 't3phi':[], 't3phierr':[],
+ 'u1coord':[], 'v1coord':[], 'u2coord':[], 'v2coord':[],
+ 'sta_index':[], 'flag':[]}
+ data['target_id'].append(targetmap[id(t3.target)])
+ if t3.timeobs:
+ time = t3.timeobs - refdate
+ data['time'].append(time.days * 24.0 * 3600.0 + time.seconds)
+ mjd = (t3.timeobs - _mjdzero).days + (t3.timeobs - _mjdzero).seconds / 3600.0 / 24.0
+ data['mjd'].append(mjd)
+ else:
+ data['time'].append(None)
+ data['mjd'].append(None)
+ data['int_time'].append(t3.int_time)
+ if nwave == 1:
+ data['t3amp'].append(t3.t3amp[0])
+ data['t3amperr'].append(t3.t3amperr[0])
+ data['t3phi'].append(t3.t3phi[0])
+ data['t3phierr'].append(t3.t3phierr[0])
+ data['flag'].append(t3.flag[0])
+ else:
+ data['t3amp'].append(t3.t3amp)
+ data['t3amperr'].append(t3.t3amperr)
+ data['t3phi'].append(t3.t3phi)
+ data['t3phierr'].append(t3.t3phierr)
+ data['flag'].append(t3.flag)
+ data['u1coord'].append(t3.u1coord)
+ data['v1coord'].append(t3.v1coord)
+ data['u2coord'].append(t3.u2coord)
+ data['v2coord'].append(t3.v2coord)
+ if t3.station[0] and t3.station[1] and t3.station[2]:
+ data['sta_index'].append([stationmap[id(t3.station[0])], stationmap[id(t3.station[1])], stationmap[id(t3.station[2])]])
+ else:
+ data['sta_index'].append([-1, -1, -1])
+ for key in list(tables.keys()):
+ data = tables[key]
+ nwave = self.wavelength[key[1]].eff_wave.size
+
+ hdu = pyfits.new_table(pyfits.ColDefs((
+ pyfits.Column(name='TARGET_ID', format='1I', array=data['target_id']),
+ pyfits.Column(name='TIME', format='1D', unit='SECONDS', array=data['time']),
+ pyfits.Column(name='MJD', format='1D', unit='DAY', array=data['mjd']),
+ pyfits.Column(name='INT_TIME', format='1D', unit='SECONDS', array=data['int_time']),
+ pyfits.Column(name='T3AMP', format='%dD'%nwave, array=data['t3amp']),
+ pyfits.Column(name='T3AMPERR', format='%dD'%nwave, array=data['t3amperr']),
+ pyfits.Column(name='T3PHI', format='%dD'%nwave, unit='DEGREES', array=data['t3phi']),
+ pyfits.Column(name='T3PHIERR', format='%dD'%nwave, unit='DEGREES', array=data['t3phierr']),
+ pyfits.Column(name='U1COORD', format='1D', unit='METERS', array=data['u1coord']),
+ pyfits.Column(name='V1COORD', format='1D', unit='METERS', array=data['v1coord']),
+ pyfits.Column(name='U2COORD', format='1D', unit='METERS', array=data['u2coord']),
+ pyfits.Column(name='V2COORD', format='1D', unit='METERS', array=data['v2coord']),
+ pyfits.Column(name='STA_INDEX', format='3I', array=data['sta_index'], null=-1),
+ pyfits.Column(name='FLAG', format='%dL'%nwave, array=data['flag'])
+ )))
+ # Setting the data of logical field via the
+ # pyfits.Column call above with length > 1 (eg
+ # format='171L' above) seems to be broken, atleast as
+ # of PyFITS 2.2.2
+ hdu.data.field('FLAG').setfield(data['flag'], bool)
+ hdu.header.update('EXTNAME', 'OI_T3')
+ hdu.header.update('OI_REVN', 1, 'Revision number of the table definition')
+ hdu.header.update('DATE-OBS', refdate.strftime('%F'), 'Zero-point for table (UTC)')
+ if key[0]: hdu.header.update('ARRNAME', key[0], 'Identifies corresponding OI_ARRAY')
+ hdu.header.update('INSNAME', key[1], 'Identifies corresponding OI_WAVELENGTH table')
+ hdulist.append(hdu)
+
+ hdulist.writeto(filename, clobber=True)
+
+
+
+def open(filename, quiet=False):
+ """Open an OIFITS file."""
+
+ newobj = oifits()
+ targetmap = {}
+ sta_indices = {}
+
+ if not quiet:
+ print("Opening %s"%filename)
+ hdulist = pyfits.open(filename)
+ # First get all the OI_TARGET, OI_WAVELENGTH and OI_ARRAY tables
+ for hdu in hdulist:
+ header = hdu.header
+ data = hdu.data
+ if hdu.name == 'OI_WAVELENGTH':
+ if newobj.wavelength == None: newobj.wavelength = {}
+ insname = header['INSNAME']
+ newobj.wavelength[insname] = OI_WAVELENGTH(data.field('EFF_WAVE'), data.field('EFF_BAND'))
+ elif hdu.name == 'OI_TARGET':
+ for row in data:
+ target_id = row['TARGET_ID']
+ target = OI_TARGET(target=row['TARGET'], raep0=row['RAEP0'], decep0=row['DECEP0'],
+ equinox=row['EQUINOX'], ra_err=row['RA_ERR'], dec_err=row['DEC_ERR'],
+ sysvel=row['SYSVEL'], veltyp=row['VELTYP'], veldef=row['VELDEF'],
+ pmra=row['PMRA'], pmdec=row['PMDEC'], pmra_err=row['PMRA_ERR'],
+ pmdec_err=row['PMDEC_ERR'], parallax=row['PARALLAX'],
+ para_err=row['PARA_ERR'], spectyp=row['SPECTYP'])
+ newobj.target = np.append(newobj.target, target)
+ targetmap[target_id] = target
+ elif hdu.name == 'OI_ARRAY':
+ if newobj.array == None: newobj.array = {}
+ arrname = header['ARRNAME']
+ frame = header['FRAME']
+ arrxyz = np.array([header['ARRAYX'], header['ARRAYY'], header['ARRAYZ']])
+ newobj.array[arrname] = OI_ARRAY(frame, arrxyz, stations=data)
+ # Save the sta_index for each array, as we will need it
+ # later to match measurements to stations
+ sta_indices[arrname] = data.field('sta_index')
+
+ # Then get any science measurements
+ for hdu in hdulist:
+ header = hdu.header
+ data = hdu.data
+ if hdu.name in ('OI_VIS', 'OI_VIS2', 'OI_T3'):
+ if 'ARRNAME' in list(header.keys()):
+ arrname = header['ARRNAME']
+ else:
+ arrname = None
+ if arrname and newobj.array:
+ array = newobj.array[arrname]
+ else:
+ array = None
+ wavelength = newobj.wavelength[header['INSNAME']]
+ if hdu.name == 'OI_VIS':
+ for row in data:
+ date = header['DATE-OBS'].split('-')
+ timeobs = datetime.datetime(int(date[0]), int(date[1]), int(date[2])) + datetime.timedelta(seconds=np.around(row.field('TIME'), 2))
+ int_time = row.field('INT_TIME')
+ visamp = np.reshape(row.field('VISAMP'), -1)
+ visamperr = np.reshape(row.field('VISAMPERR'), -1)
+ visphi = np.reshape(row.field('VISPHI'), -1)
+ visphierr = np.reshape(row.field('VISPHIERR'), -1)
+ if 'CFLUX' in row.array.names: cflux = np.reshape(row.field('CFLUX'), -1)
+ else: cflux = None
+ if 'CFLUXERR' in row.array.names: cfluxerr = np.reshape(row.field('CFLUXERR'), -1)
+ else: cfluxerr = None
+ flag = np.reshape(row.field('FLAG'), -1)
+ ucoord = row.field('UCOORD')
+ vcoord = row.field('VCOORD')
+ target = targetmap[row.field('TARGET_ID')]
+ if array:
+ sta_index = row.field('STA_INDEX')
+ s1 = array.station[sta_indices[arrname] == sta_index[0]][0]
+ s2 = array.station[sta_indices[arrname] == sta_index[1]][0]
+ station = [s1, s2]
+ else:
+ station = [None, None]
+ newobj.vis = np.append(newobj.vis, OI_VIS(timeobs=timeobs, int_time=int_time, visamp=visamp,
+ visamperr=visamperr, visphi=visphi, visphierr=visphierr,
+ flag=flag, ucoord=ucoord, vcoord=vcoord, wavelength=wavelength,
+ target=target, array=array, station=station, cflux=cflux,
+ cfluxerr=cfluxerr))
+ elif hdu.name == 'OI_VIS2':
+ for row in data:
+ date = header['DATE-OBS'].split('-')
+ timeobs = datetime.datetime(int(date[0]), int(date[1]), int(date[2])) + datetime.timedelta(seconds=np.around(row.field('TIME'), 2))
+ int_time = row.field('INT_TIME')
+ vis2data = np.reshape(row.field('VIS2DATA'), -1)
+ vis2err = np.reshape(row.field('VIS2ERR'), -1)
+ flag = np.reshape(row.field('FLAG'), -1)
+ ucoord = row.field('UCOORD')
+ vcoord = row.field('VCOORD')
+ target = targetmap[row.field('TARGET_ID')]
+ if array:
+ sta_index = row.field('STA_INDEX')
+ s1 = array.station[sta_indices[arrname] == sta_index[0]][0]
+ s2 = array.station[sta_indices[arrname] == sta_index[1]][0]
+ station = [s1, s2]
+ else:
+ station = [None, None]
+ newobj.vis2 = np.append(newobj.vis2, OI_VIS2(timeobs=timeobs, int_time=int_time, vis2data=vis2data,
+ vis2err=vis2err, flag=flag, ucoord=ucoord, vcoord=vcoord,
+ wavelength=wavelength, target=target, array=array,
+ station=station))
+ elif hdu.name == 'OI_T3':
+ for row in data:
+ date = header['DATE-OBS'].split('-')
+ timeobs = datetime.datetime(int(date[0]), int(date[1]), int(date[2])) + datetime.timedelta(seconds=np.around(row.field('TIME'), 2))
+ int_time = row.field('INT_TIME')
+ t3amp = np.reshape(row.field('T3AMP'), -1)
+ t3amperr = np.reshape(row.field('T3AMPERR'), -1)
+ t3phi = np.reshape(row.field('T3PHI'), -1)
+ t3phierr = np.reshape(row.field('T3PHIERR'), -1)
+ flag = np.reshape(row.field('FLAG'), -1)
+ u1coord = row.field('U1COORD')
+ v1coord = row.field('V1COORD')
+ u2coord = row.field('U2COORD')
+ v2coord = row.field('V2COORD')
+ target = targetmap[row.field('TARGET_ID')]
+ if array:
+ sta_index = row.field('STA_INDEX')
+ s1 = array.station[sta_indices[arrname] == sta_index[0]][0]
+ s2 = array.station[sta_indices[arrname] == sta_index[1]][0]
+ s3 = array.station[sta_indices[arrname] == sta_index[2]][0]
+ station = [s1, s2, s3]
+ else:
+ station = [None, None, None]
+ newobj.t3 = np.append(newobj.t3, OI_T3(timeobs=timeobs, int_time=int_time, t3amp=t3amp,
+ t3amperr=t3amperr, t3phi=t3phi, t3phierr=t3phierr,
+ flag=flag, u1coord=u1coord, v1coord=v1coord, u2coord=u2coord,
+ v2coord=v2coord, wavelength=wavelength, target=target,
+ array=array, station=station))
+
+ hdulist.close()
+ if not quiet:
+ newobj.info(recursive=False)
+
+ return newobj
diff --git a/io/save.py b/io/save.py
new file mode 100644
index 00000000..da7d11e7
--- /dev/null
+++ b/io/save.py
@@ -0,0 +1,1017 @@
+# save.py
+# functions to save observation & image data from files
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import astropy.io.fits as fits
+import datetime
+import h5py
+
+import ehtim.io.writeoifits
+import ehtim.io.oifits
+from astropy.time import Time
+
+import ehtim.const_def as ehc
+import ehtim.observing.obs_helpers as obsh
+
+##################################################################################################
+# Image IO
+##################################################################################################
+
+
+def save_im_txt(im, fname, mjd=False, time=False):
+ """Save image data to text file.
+
+ Args:
+ im (Image): image object
+ fname (str): path to output text file
+ mjd (int): MJD of saved image
+ time (float): UTC time of saved image
+
+ Returns:
+ """
+
+ # Transform to Stokes parameters:
+ if im.polrep != 'stokes' or im.pol_prim != 'I':
+ im = im.switch_polrep(polrep_out='stokes', pol_prim_out=None)
+
+ # Coordinate values
+ pdimas = im.psize/ehc.RADPERAS
+ xs = np.array([[j for j in range(im.xdim)] for i in range(im.ydim)]).reshape(im.xdim*im.ydim, 1)
+ xs = pdimas * (xs[::-1] - im.xdim/2.0)
+ ys = np.array([[i for j in range(im.xdim)] for i in range(im.ydim)]).reshape(im.xdim*im.ydim, 1)
+ ys = pdimas * (ys[::-1] - im.xdim/2.0)
+
+ # If V values but no Q/U values, make Q/U zero
+ if len(im.vvec) and not len(im.qvec):
+ im.qvec = 0*im.vvec
+ im.uvec = 0*im.vvec
+
+ # Format Data
+ if len(im.qvec) and len(im.vvec):
+ outdata = np.hstack((xs, ys, (im.imvec).reshape(im.xdim*im.ydim, 1),
+ (im.qvec).reshape(im.xdim*im.ydim, 1),
+ (im.uvec).reshape(im.xdim*im.ydim, 1),
+ (im.vvec).reshape(im.xdim*im.ydim, 1)))
+ hf = "x (as) y (as) I (Jy/pixel) Q (Jy/pixel) U (Jy/pixel) V (Jy/pixel)"
+
+ fmts = "%10.10f %10.10f %10.10f %10.10f %10.10f %10.10f"
+
+ elif len(im.qvec):
+ outdata = np.hstack((xs, ys, (im.imvec).reshape(im.xdim*im.ydim, 1),
+ (im.qvec).reshape(im.xdim*im.ydim, 1),
+ (im.uvec).reshape(im.xdim*im.ydim, 1)))
+ hf = "x (as) y (as) I (Jy/pixel) Q (Jy/pixel) U (Jy/pixel)"
+
+ fmts = "%10.10f %10.10f %10.10f %10.10f %10.10f"
+
+ else:
+ outdata = np.hstack((xs, ys, (im.imvec).reshape(im.xdim*im.ydim, 1)))
+ hf = "x (as) y (as) I (Jy/pixel)"
+ fmts = "%10.10f %10.10f %10.10f"
+
+ # Header
+ if not mjd:
+ mjd = float(im.mjd)
+ if not time:
+ time = im.time
+ mjd += (time/24.)
+
+ head = ("SRC: %s \n" % im.source +
+ "RA: " + obsh.rastring(im.ra) + "\n" + "DEC: " + obsh.decstring(im.dec) + "\n" +
+ "MJD: %.6f \n" % (float(mjd)) +
+ "RF: %.4f GHz \n" % (im.rf/1e9) +
+ "FOVX: %i pix %f as \n" % (im.xdim, pdimas * im.xdim) +
+ "FOVY: %i pix %f as \n" % (im.ydim, pdimas * im.ydim) +
+ "------------------------------------\n" + hf)
+
+ # Save
+ np.savetxt(fname, outdata, header=head, fmt=fmts)
+ return
+
+# TODO save image in circular basis?
+def save_im_fits(im, fname, mjd=False, time=False):
+ """Save image data to a fits file.
+
+ Args:
+ im (Image): image object
+ fname (str): path to output fits file
+ mjd (int): MJD of saved image
+ time (float): UTC time of saved image
+
+ Returns:
+ """
+
+ # Transform to Stokes parameters:
+ if (im.polrep != 'stokes') or (im.pol_prim != 'I'):
+ im = im.switch_polrep(polrep_out='stokes', pol_prim_out=None)
+
+ # Create header and fill in some values
+ header = fits.Header()
+ header['OBJECT'] = im.source
+ header['CTYPE1'] = 'RA---SIN'
+ header['CTYPE2'] = 'DEC--SIN'
+ header['CDELT1'] = -im.psize/ehc.DEGREE
+ header['CDELT2'] = im.psize/ehc.DEGREE
+ header['OBSRA'] = im.ra * 180/12.
+ header['OBSDEC'] = im.dec
+ header['FREQ'] = im.rf
+
+ # TODO these are the default values for centered images
+ # TODO support for arbitrary CRPIX?
+ header['CRPIX1'] = im.xdim/2. + .5
+ header['CRPIX2'] = im.ydim/2. + .5
+
+ if not mjd:
+ mjd = float(im.mjd)
+ if not time:
+ time = im.time
+ mjd += (time/24.)
+
+ header['MJD'] = float(mjd)
+ header['TELESCOP'] = 'VLBI'
+ header['BUNIT'] = 'JY/PIXEL'
+ header['STOKES'] = 'I'
+
+ # Create the fits image
+ image = np.reshape(im.imvec, (im.ydim, im.xdim))[::-1, :] # flip y axis!
+ hdu = fits.PrimaryHDU(image, header=header)
+ hdulist = [hdu]
+ if len(im.qvec):
+ qimage = np.reshape(im.qvec, (im.ydim, im.xdim))[::-1, :]
+ uimage = np.reshape(im.uvec, (im.ydim, im.xdim))[::-1, :]
+ header['STOKES'] = 'Q'
+ hduq = fits.ImageHDU(qimage, name='Q', header=header)
+ header['STOKES'] = 'U'
+ hduu = fits.ImageHDU(uimage, name='U', header=header)
+ hdulist = [hdu, hduq, hduu]
+ if len(im.vvec):
+ vimage = np.reshape(im.vvec, (im.ydim, im.xdim))[::-1, :]
+ header['STOKES'] = 'V'
+ hduv = fits.ImageHDU(vimage, name='V', header=header)
+ hdulist.append(hduv)
+
+ hdulist = fits.HDUList(hdulist)
+
+ # Save fits
+ hdulist.writeto(fname, overwrite=True)
+
+ return
+
+##################################################################################################
+# Movie IO
+##################################################################################################
+
+
+def save_mov_hdf5(mov, fname, mjd=False):
+ """Save movie data to an hdf5 file.
+
+ Args:
+ mov (Movie): movie object
+ fname (str): basename of output fits file
+ mjd (int): MJD of saved movie
+
+ Returns:
+ """
+
+ # TODO: Currently only supports one polarization!
+ with h5py.File(fname, 'w') as file:
+
+ head = file.create_dataset('header', (0,), dtype="S10")
+
+ if mjd is False:
+ mjd = mov.mjd
+
+ head.attrs['mjd'] = np.string_(str(mjd))
+ head.attrs['psize'] = np.string_(str(mov.psize))
+ head.attrs['source'] = np.string_(str(mov.source))
+ head.attrs['ra'] = np.string_(str(mov.ra))
+ head.attrs['dec'] = np.string_(str(mov.dec))
+ head.attrs['rf'] = np.string_(str(mov.rf))
+ head.attrs['polrep'] = np.string_(str(mov.polrep))
+ head.attrs['pol_prim'] = np.string_(str(mov.pol_prim))
+
+ name = 'times'
+ times = mov.times
+ dset = file.create_dataset(name, data=times, dtype='f8')
+
+ name = mov.pol_prim
+ frames = mov.frames.reshape((mov.nframes, mov.ydim, mov.xdim))
+ dset = file.create_dataset(name, data=frames, dtype='f8')
+
+ for pol in list(mov._movdict.keys()):
+ if pol == mov.pol_prim:
+ continue
+ polframes = mov._movdict[pol]
+ if len(polframes):
+ polframes = polframes.reshape((mov.nframes, mov.ydim, mov.xdim))
+ dset = file.create_dataset(pol, data=polframes, dtype='f8')
+ return
+
+
+def save_mov_fits(mov, fname, mjd=False):
+ """Save movie data to series of fits files.
+
+ Args:
+ mov (Movie): movie object
+ fname (str): basename of output fits file
+ mjd (int): MJD of saved movie
+
+ Returns:
+ """
+
+ if mjd is False:
+ mjd = mov.mjd
+
+ for i in range(mov.nframes):
+ time_frame = mov.times[i]
+ fname_frame = fname + "%05d" % i
+ print('saving file '+fname_frame)
+ frame_im = mov.get_frame(i)
+ save_im_fits(frame_im, fname_frame, mjd=mjd, time=time_frame)
+
+ return
+
+
+def save_mov_txt(mov, fname, mjd=False):
+ """Save movie data to series of text files.
+
+ Args:
+ mov (Movie): movie object
+ fname (str): basename of output text file
+ mjd (int): MJD of saved movie
+
+ Returns:
+ """
+
+ if mjd is False:
+ mjd = mov.mjd
+
+ for i in range(mov.nframes):
+ time_frame = mov.times[i]
+ fname_frame = fname + "%05d" % i
+ print('saving file '+fname_frame)
+ frame_im = mov.get_frame(i)
+ save_im_txt(frame_im, fname_frame, mjd=mjd, time=time_frame)
+
+ return
+
+
+##################################################################################################
+# Array IO
+##################################################################################################
+
+def save_array_txt(arr, fname):
+ """Save the array data in a text file.
+
+ Args:
+ arr (Array): array object
+ fname (str): name of output text file
+
+ Returns:
+ """
+
+ if type(arr) == np.ndarray:
+ tarr = arr
+ else:
+ try:
+ tarr = arr.tarr
+ except:
+ print("Array format not recognized!")
+
+ out = ("#Site X(m) Y(m) Z(m) " +
+ "SEFDR SEFDL FR_PAR FR_EL FR_OFF " +
+ "DR_RE DR_IM DL_RE DL_IM \n")
+ for scope in range(len(tarr)):
+ dat = (tarr[scope]['site'],
+ tarr[scope]['x'], tarr[scope]['y'], tarr[scope]['z'],
+ tarr[scope]['sefdr'], tarr[scope]['sefdl'],
+ tarr[scope]['fr_par'], tarr[scope]['fr_elev'], tarr[scope]['fr_off'],
+ tarr[scope]['dr'].real, tarr[scope]['dr'].imag,
+ tarr[scope]['dl'].real, tarr[scope]['dl'].imag
+ )
+ out += "%-8s %15.5f %15.5f %15.5f %8.2f %8.2f %5.2f %5.2f %5.2f %8.4f %8.4f %8.4f %8.4f \n" % dat
+ f = open(fname, 'w')
+ f.write(out)
+ f.close()
+ return
+
+
+##################################################################################################
+# Observation IO
+##################################################################################################
+def save_obs_txt(obs, fname):
+ """Save the observation data in a text file.
+
+ Args:
+ obs (Obsdata): obsdata object
+ fname (str): name of output text file
+
+ Returns:
+ """
+
+ # output times must be in utc
+ obs = obs.switch_timetype(timetype_out='UTC')
+
+ # Get the necessary data and the header
+ if obs.polrep == 'stokes':
+ outdata = obs.unpack(['time', 'tint', 't1', 't2', 'tau1', 'tau2',
+ 'u', 'v', 'amp', 'phase', 'qamp', 'qphase', 'uamp', 'uphase',
+ 'vamp', 'vphase',
+ 'sigma', 'qsigma', 'usigma', 'vsigma'])
+ elif obs.polrep == 'circ':
+ outdata = obs.unpack(['time', 'tint', 't1', 't2', 'tau1', 'tau2',
+ 'u', 'v', 'rramp', 'rrphase', 'llamp', 'llphase', 'rlamp', 'rlphase',
+ 'lramp', 'lrphase',
+ 'rrsigma', 'llsigma', 'rlsigma', 'lrsigma'])
+
+ else:
+ raise Exception("obs.polrep not 'stokes' or 'circ'!")
+
+ head = ("SRC: %s \n" % obs.source +
+ "RA: " + obsh.rastring(obs.ra) + "\n" + "DEC: " + obsh.decstring(obs.dec) + "\n" +
+ "MJD: %i \n" % obs.mjd +
+ "RF: %.4f GHz \n" % (obs.rf/1e9) +
+ "BW: %.4f GHz \n" % (obs.bw/1e9) +
+ "PHASECAL: %i \n" % obs.phasecal +
+ "AMPCAL: %i \n" % obs.ampcal +
+ "OPACITYCAL: %i \n" % obs.opacitycal +
+ "DCAL: %i \n" % obs.dcal +
+ "FRCAL: %i \n" % obs.frcal +
+ "----------------------------------------------------------------------" +
+ "------------------------------------------------------------------\n" +
+ "Site X(m) Y(m) Z(m) " +
+ "SEFDR SEFDL FR_PAR FR_EL FR_OFF " +
+ "DR_RE DR_IM DL_RE DL_IM \n"
+ )
+
+ for i in range(len(obs.tarr)):
+ head += ("%-8s %15.5f %15.5f %15.5f %8.2f %8.2f %5.2f %5.2f %5.2f %8.4f %8.4f %8.4f %8.4f \n" %
+ (obs.tarr[i]['site'],
+ obs.tarr[i]['x'], obs.tarr[i]['y'], obs.tarr[i]['z'],
+ obs.tarr[i]['sefdr'], obs.tarr[i]['sefdl'],
+ obs.tarr[i]['fr_par'], obs.tarr[i]['fr_elev'], obs.tarr[i]['fr_off'],
+ (obs.tarr[i]['dr']).real, (obs.tarr[i]['dr']).imag,
+ (obs.tarr[i]['dl']).real, (obs.tarr[i]['dl']).imag
+ ))
+
+ if obs.polrep == 'stokes':
+ head += (
+ "----------------------------------------------------------------------" +
+ "------------------------------------------------------------------\n" +
+ "time (hr) tint T1 T2 Tau1 Tau2 U (lambda) V (lambda) " +
+ "Iamp (Jy) Iphase(d) Qamp (Jy) Qphase(d) Uamp (Jy) Uphase(d) Vamp (Jy) Vphase(d) " +
+ "Isigma (Jy) Qsigma (Jy) Usigma (Jy) Vsigma (Jy)"
+ )
+ elif obs.polrep == 'circ':
+ head += (
+ "----------------------------------------------------------------------" +
+ "------------------------------------------------------------------\n" +
+ "time (hr) tint T1 T2 Tau1 Tau2 U (lambda) V (lambda) " +
+ "RRamp (Jy) RRphase(d) LLamp (Jy) LLphase(d) RLamp (Jy) RLphase(d) LRamp (Jy) LRphase(d) " +
+ "RRsigma (Jy) LLsigma (Jy) RLsigma (Jy) LRsigma (Jy)"
+ )
+
+ # Format and save the data
+ fmts = ("%011.8f %4.2f %6s %6s %4.2f %4.2f %16.4f %16.4f " +
+ "%10.8f %10.4f %10.8f %10.4f %10.8f %10.4f %10.8f %10.4f " +
+ "%10.8f %10.8f %10.8f %10.8f")
+ np.savetxt(fname, outdata, header=head, fmt=fmts)
+ return
+
+
+def save_obs_uvfits(obs, fname=None, force_singlepol=None, polrep_out='circ'):
+ """Save observation data to uvfits.
+
+ Args:
+ obs (Obsdata): obsdata object
+ fname (str): path to output fits file, or None to return HDUList only
+ force_singlepol (str): if 'R' or 'L', will interpret stokes I field as 'RR' or 'LL'
+ polrep_out (str): 'circ' or 'stokes': how data should be stored in the uvfits file
+ Returns:
+ hdulist (astropy.io.fits.HDUList)
+
+ """
+
+ # output times must be in utc
+ obs = obs.switch_timetype(timetype_out='UTC')
+
+ if polrep_out == 'circ':
+ obs = obs.switch_polrep('circ')
+ elif polrep_out == 'stokes':
+ obs = obs.switch_polrep('stokes')
+ else:
+ raise Exception("'polrep_out' in 'save_obs_uvfits' must be 'circ' or 'stokes'!")
+
+ hdulist_new = fits.HDUList()
+ hdulist_new.append(fits.GroupsHDU())
+
+ #####################
+ # AIPS Data TABLE
+ #####################
+
+ # Data header (based on the BU format)
+ MJD_0 = 2400000.5
+ header = hdulist_new['PRIMARY'].header
+ header['OBSRA'] = obs.ra * 180./12.
+ header['OBSDEC'] = obs.dec
+ header['OBJECT'] = obs.source
+ header['MJD'] = float(obs.mjd)
+ header['DATE-OBS'] = Time(obs.mjd + MJD_0, format='jd', scale='utc').iso[0:10]
+ header['BSCALE'] = 1.0
+ header['BZERO'] = 0.0
+ header['BUNIT'] = 'JY'
+ header['VELREF'] = 3 # TODO ??
+ header['EQUINOX'] = 'J2000'
+ header['ALTRPIX'] = 1.e0
+ header['ALTRVAL'] = 0.e0
+ header['TELESCOP'] = 'VLBA' # TODO Can we change this field?
+ header['INSTRUME'] = 'VLBA'
+ header['OBSERVER'] = 'EHT'
+
+ header['CTYPE2'] = 'COMPLEX'
+ header['CRVAL2'] = 1.e0
+ header['CDELT2'] = 1.e0
+ header['CRPIX2'] = 1.e0
+ header['CROTA2'] = 0.e0
+ header['CTYPE3'] = 'STOKES'
+ if polrep_out == 'circ':
+ header['CRVAL3'] = -1.e0
+ header['CDELT3'] = -1.e0
+ elif polrep_out == 'stokes':
+ header['CRVAL3'] = 1.e0
+ header['CDELT3'] = 1.e0
+ header['CRPIX3'] = 1.e0
+ header['CROTA3'] = 0.e0
+ header['CTYPE4'] = 'FREQ'
+ header['CRVAL4'] = obs.rf
+ header['CDELT4'] = obs.bw
+ header['CRPIX4'] = 1.e0
+ header['CROTA4'] = 0.e0
+ header['CTYPE6'] = 'RA'
+ header['CRVAL6'] = header['OBSRA']
+ header['CDELT6'] = 1.e0
+ header['CRPIX6'] = 1.e0
+ header['CROTA6'] = 0.e0
+ header['CTYPE7'] = 'DEC'
+ header['CRVAL7'] = header['OBSDEC']
+ header['CDELT7'] = 1.e0
+ header['CRPIX7'] = 1.e0
+ header['CROTA7'] = 0.e0
+ header['PTYPE1'] = 'UU---SIN'
+ header['PSCAL1'] = 1.0
+ header['PZERO1'] = 0.e0
+ header['PTYPE2'] = 'VV---SIN'
+ header['PSCAL2'] = 1.0
+ header['PZERO2'] = 0.e0
+ header['PTYPE3'] = 'WW---SIN'
+ header['PSCAL3'] = 1.0
+ header['PZERO3'] = 0.e0
+ header['PTYPE4'] = 'BASELINE'
+ header['PSCAL4'] = 1.e0
+ header['PZERO4'] = 0.e0
+ header['PTYPE5'] = 'DATE'
+ header['PSCAL5'] = 1.e0
+ header['PZERO5'] = 0.e0
+ header['PTYPE6'] = 'DATE'
+ header['PSCAL6'] = 1.e0
+ header['PZERO6'] = 0.0
+ header['PTYPE7'] = 'INTTIM'
+ header['PSCAL7'] = 1.e0
+ header['PZERO7'] = 0.e0
+ header['PTYPE8'] = 'TAU1'
+ header['PSCAL8'] = 1.e0
+ header['PZERO8'] = 0.e0
+ header['PTYPE9'] = 'TAU2'
+ header['PSCAL9'] = 1.e0
+ header['PZERO9'] = 0.e0
+ header['history'] = "AIPS SORT ORDER='TB'"
+
+ # Get data
+
+ if polrep_out == 'circ':
+ obsdata = obs.unpack(['time', 'tint', 'u', 'v',
+ 'rrvis', 'llvis', 'rlvis', 'lrvis',
+ 'rrsigma', 'llsigma', 'rlsigma', 'lrsigma',
+ 't1', 't2', 'tau1', 'tau2'])
+ elif polrep_out == 'stokes':
+ obsdata = obs.unpack(['time', 'tint', 'u', 'v', 'vis', 'qvis', 'uvis', 'vvis',
+ 'sigma', 'qsigma', 'usigma', 'vsigma', 't1', 't2', 'tau1', 'tau2'])
+
+ ndat = len(obsdata['time'])
+
+ # times and tints
+ jds = (2400000.5 + obs.mjd) * np.ones(len(obsdata))
+ fractimes = (obsdata['time']/24.0)
+ tints = obsdata['tint']
+
+ # Baselines
+ t1 = [obs.tkey[scope] + 1 for scope in obsdata['t1']]
+ t2 = [obs.tkey[scope] + 1 for scope in obsdata['t2']]
+ bl = 256*np.array(t1) + np.array(t2)
+
+ # opacities
+ tau1 = obsdata['tau1']
+ tau2 = obsdata['tau2']
+
+ # uv are in lightseconds
+ u = obsdata['u']/obs.rf
+ v = obsdata['v']/obs.rf
+
+ # rr, ll, lr, rl, weights
+
+ if polrep_out == 'circ':
+ rr = obsdata['rrvis']
+ ll = obsdata['llvis']
+ rl = obsdata['rlvis']
+ lr = obsdata['lrvis']
+ weightrr = 1.0/(obsdata['rrsigma']**2)
+ weightll = 1.0/(obsdata['llsigma']**2)
+ weightrl = 1.0/(obsdata['rlsigma']**2)
+ weightlr = 1.0/(obsdata['lrsigma']**2)
+
+ # If necessary, enforce single polarization
+ if force_singlepol == 'L':
+ if obs.polrep == 'stokes':
+ raise Exception("force_singlepol only works with obs.polrep=='stokes'!")
+ print("force_singlepol='L': treating Stokes 'I' as LL and ignoring Q,U,V!!")
+ ll = obsdata['vis']
+ rr = rr * 0.0
+ rl = rl * 0.0
+ lr = lr * 0.0
+ weightrr = weightrr * 0.0
+ weightrl = weightrl * 0.0
+ weightlr = weightlr * 0.0
+ elif force_singlepol == 'R':
+ if obs.polrep == 'stokes':
+ raise Exception("force_singlepol only works with obs.polrep=='stokes'!")
+ print("force_singlepol='R': treating Stokes 'I' as RR and ignoring Q,U,V!!")
+ rr = obsdata['vis']
+ ll = rr * 0.0
+ rl = rl * 0.0
+ lr = lr * 0.0
+ weightll = weightll * 0.0
+ weightrl = weightrl * 0.0
+ weightlr = weightlr * 0.0
+
+ dat1 = rr
+ dat2 = ll
+ dat3 = rl
+ dat4 = lr
+ weight1 = weightrr
+ weight2 = weightll
+ weight3 = weightrl
+ weight4 = weightlr
+
+ elif polrep_out == 'stokes':
+ dat1 = obsdata['vis']
+ dat2 = obsdata['qvis']
+ dat3 = obsdata['uvis']
+ dat4 = obsdata['vvis']
+ weight1 = 1.0/(obsdata['sigma']**2)
+ weight2 = 1.0/(obsdata['qsigma']**2)
+ weight3 = 1.0/(obsdata['usigma']**2)
+ weight4 = 1.0/(obsdata['vsigma']**2)
+
+ # Replace nans by zeros (including zero weights)
+ dat1 = np.nan_to_num(dat1)
+ dat2 = np.nan_to_num(dat2)
+ dat3 = np.nan_to_num(dat3)
+ dat4 = np.nan_to_num(dat4)
+ weight1 = np.nan_to_num(weight1)
+ weight2 = np.nan_to_num(weight2)
+ weight3 = np.nan_to_num(weight3)
+ weight4 = np.nan_to_num(weight4)
+
+ # Data array
+ outdat = np.zeros((ndat, 1, 1, 1, 1, 4, 3))
+ outdat[:, 0, 0, 0, 0, 0, 0] = np.real(dat1)
+ outdat[:, 0, 0, 0, 0, 0, 1] = np.imag(dat1)
+ outdat[:, 0, 0, 0, 0, 0, 2] = weight1
+ outdat[:, 0, 0, 0, 0, 1, 0] = np.real(dat2)
+ outdat[:, 0, 0, 0, 0, 1, 1] = np.imag(dat2)
+ outdat[:, 0, 0, 0, 0, 1, 2] = weight2
+ outdat[:, 0, 0, 0, 0, 2, 0] = np.real(dat3)
+ outdat[:, 0, 0, 0, 0, 2, 1] = np.imag(dat3)
+ outdat[:, 0, 0, 0, 0, 2, 2] = weight3
+ outdat[:, 0, 0, 0, 0, 3, 0] = np.real(dat4)
+ outdat[:, 0, 0, 0, 0, 3, 1] = np.imag(dat4)
+ outdat[:, 0, 0, 0, 0, 3, 2] = weight4
+
+ # Save data
+ pars = ['UU---SIN', 'VV---SIN', 'WW---SIN', 'BASELINE', 'DATE', 'DATE',
+ 'INTTIM', 'TAU1', 'TAU2']
+ x = fits.GroupData(outdat, parnames=pars,
+ pardata=[u, v, np.zeros(ndat), bl, jds, fractimes, tints, tau1, tau2],
+ bitpix=-32)
+
+ hdulist_new['PRIMARY'].data = x
+ hdulist_new['PRIMARY'].header = header # TODO necessary ??
+
+ #####################
+ # AIPS AN TABLE
+ #####################
+
+ # Load the array data
+ tarr = obs.tarr
+ tnames = tarr['site']
+ tnums = np.arange(1, len(tarr)+1)
+ xyz = np.array([[tarr[i]['x'], tarr[i]['y'], tarr[i]['z']] for i in np.arange(len(tarr))])
+ sefd = tarr['sefdr']
+
+ nsta = len(tnames)
+ col1 = fits.Column(name='ANNAME', format='8A', array=tnames)
+ col2 = fits.Column(name='STABXYZ', format='3D', unit='METERS', array=xyz)
+ col3 = fits.Column(name='NOSTA', format='1J', array=tnums)
+ colfin = fits.Column(name='SEFD', format='1D', array=sefd)
+
+ # TODO these antenna fields+header are questionable - look into them
+ col4 = fits.Column(name='MNTSTA', format='1J',
+ array=np.zeros(nsta))
+ col5 = fits.Column(name='STAXOF', format='1E', unit='METERS',
+ array=np.zeros(nsta))
+ col6 = fits.Column(name='POLTYA', format='1A',
+ array=np.array(['R' for i in range(nsta)], dtype='|S1'))
+ col7 = fits.Column(name='POLAA', format='1E', unit='DEGREES',
+ array=np.zeros(nsta))
+ col8 = fits.Column(name='POLCALA', format='3E',
+ array=np.zeros((nsta, 3)))
+ col9 = fits.Column(name='POLTYB', format='1A',
+ array=np.array(['L' for i in range(nsta)], dtype='|S1'))
+ col10 = fits.Column(name='POLAB', format='1E', unit='DEGREES',
+ array=(90.*np.ones(nsta)))
+ col11 = fits.Column(name='POLCALB', format='3E',
+ array=np.zeros((nsta, 3)))
+ col25 = fits.Column(name='ORBPARM', format='1E',
+ array=np.zeros(0))
+
+ # Antenna Header params
+ # TODO do we need to change more of these??
+ collist = [col1, col2, col25, col3, col4, col5, col6, col7, col8, col9, col10, col11, colfin]
+ tbhdu = fits.BinTableHDU.from_columns(fits.ColDefs(collist), name='AIPS AN')
+ hdulist_new.append(tbhdu)
+
+ head = hdulist_new['AIPS AN'].header
+
+ head['EXTVER'] = 1
+ head['ARRAYX'] = 0.e0
+ head['ARRAYY'] = 0.e0
+ head['ARRAYZ'] = 0.e0
+
+ # TODO change the reference date
+ #rdate_tt_new = Time(obs.mjd + MJD_0, format='jd', scale='utc', out_subfmt='date')
+ #rdate_out = rdate_tt_new.iso
+
+ rdate_tt_new = Time(obs.mjd + MJD_0, format='jd', scale='utc')
+ rdate_out = rdate_tt_new.iso[0:10]
+
+ rdate_tt_new.out_subfmt = 'float' # TODO -- needed to fix subformat issue in astropy 4.0
+ rdate_jd_out = rdate_tt_new.jd
+ rdate_gstiao_out = rdate_tt_new.sidereal_time('apparent', 'greenwich').degree
+ rdate_offset_out = (rdate_tt_new.ut1.datetime.second - rdate_tt_new.utc.datetime.second)
+ rdate_offset_out += 1.e-6*(rdate_tt_new.ut1.datetime.microsecond -
+ rdate_tt_new.utc.datetime.microsecond)
+
+ head['RDATE'] = rdate_out
+ head['GSTIA0'] = rdate_gstiao_out
+ head['DEGPDY'] = 360.9856
+ head['UT1UTC'] = rdate_offset_out # difference between UT1 and UTC ?
+ head['DATUTC'] = 0.e0
+ head['TIMESYS'] = 'UTC'
+
+ head['FREQ'] = obs.rf
+ head['POLARX'] = 0.e0
+ head['POLARY'] = 0.e0
+
+ head['ARRNAM'] = 'VLBA' # TODO must be recognized by aips/casa
+ head['XYZHAND'] = 'RIGHT'
+ head['FRAME'] = '????'
+ head['NUMORB'] = 0
+ head['NO_IF'] = 1 # TODO nchan
+ head['NOPCAL'] = 0 # TODO add pol cal information
+ head['POLTYPE'] = 'VLBI'
+ head['FREQID'] = 1
+
+ hdulist_new['AIPS AN'].header = head # TODO necessary, or is it a pointer?
+
+ #####################
+ # AIPS FQ TABLE
+ #####################
+ # Convert types & columns
+
+ nif = 1
+ col1 = np.array(1, dtype=np.int32).reshape([nif]) # frqsel
+ col2 = np.array(0.0, dtype=np.float64).reshape([nif]) # iffreq
+ col3 = np.array([obs.bw], dtype=np.float32).reshape([nif]) # chwidth
+ col4 = np.array([obs.bw], dtype=np.float32).reshape([nif]) # bw
+ col5 = np.array([1], dtype=np.int32).reshape([nif]) # sideband
+
+ col1 = fits.Column(name="FRQSEL", format="1J", array=col1)
+ col2 = fits.Column(name="IF FREQ", format="%dD" % (nif), array=col2)
+ col3 = fits.Column(name="CH WIDTH", format="%dE" % (nif), array=col3)
+ col4 = fits.Column(name="TOTAL BANDWIDTH", format="%dE" % (nif), array=col4)
+ col5 = fits.Column(name="SIDEBAND", format="%dJ" % (nif), array=col5)
+ cols = fits.ColDefs([col1, col2, col3, col4, col5])
+
+ # create table
+ tbhdu = fits.BinTableHDU.from_columns(cols)
+
+ # add header information
+ tbhdu.header.append(("NO_IF", nif, "Number IFs"))
+ tbhdu.header.append(("EXTNAME", "AIPS FQ"))
+ tbhdu.header.append(("EXTVER", 1))
+ hdulist_new.append(tbhdu)
+
+ #####################
+ # AIPS NX TABLE
+ #####################
+
+ scan_times = []
+ scan_time_ints = []
+ start_vis = []
+ stop_vis = []
+
+ # TODO make sure jds AND scan_info MUST be time sorted!!
+ jj = 0
+
+ ROUND_SCAN_INT = 5
+ comp_fac = 3600*24*100 # compare to 100th of a second
+ scan_arr = obs.scans
+ print('Building NX table')
+ if (scan_arr is None or len(scan_arr) == 0):
+ print("No NX table in saved uvfits")
+ else:
+ try:
+ scan_arr = scan_arr/24.
+ for scan in scan_arr:
+ scan_start = round(scan[0], ROUND_SCAN_INT)
+ scan_stop = round(scan[1], ROUND_SCAN_INT)
+ scan_dur = (scan_stop - scan_start)
+
+ if jj >= len(fractimes):
+ # print start_vis, stop_vis
+ break
+
+ # print ("%.12f %.12f %.12f" % (fractimes[jj], scan_start, scan_stop))
+ jd = round(fractimes[jj], ROUND_SCAN_INT)*comp_fac # TODO precision??
+
+ if ((np.floor(jd) >= np.floor(scan_start*comp_fac)) and
+ (np.ceil(jd) <= np.ceil(comp_fac*scan_stop))):
+ start_vis.append(jj)
+
+ # TODO AIPS MEMO 117 says scan_times should be midpoint!
+ # but AIPS data looks likes it's at the start?
+ scan_times.append(scan_start + 0.5*scan_dur) # - rdate_jd_out)
+ scan_time_ints.append(scan_dur)
+ ceilcut = np.ceil(comp_fac*scan_stop)
+ while ((jj < len(fractimes) and
+ np.floor(round(fractimes[jj], ROUND_SCAN_INT)*comp_fac) <= ceilcut)):
+ jj += 1
+ stop_vis.append(jj-1)
+ else:
+ continue
+
+ if jj < len(fractimes):
+ print(scan_arr[-1])
+ print(round(scan_arr[-1][0], ROUND_SCAN_INT),
+ round(scan_arr[-1][1], ROUND_SCAN_INT))
+ print(jj, len(jds), round(jds[jj], ROUND_SCAN_INT))
+ print("WARNING!!!: in save_uvfits NX table, " +
+ "didn't get to all entries when computing scan start/stop!")
+ print(scan_times)
+ time_nx = fits.Column(name="TIME", format="1D", unit='DAYS', array=np.array(scan_times))
+ timeint_nx = fits.Column(name="TIME INTERVAL", format="1E",
+ unit='DAYS', array=np.array(scan_time_ints))
+ sourceid_nx = fits.Column(name="SOURCE ID", format="1J",
+ unit='', array=np.ones(len(scan_times)))
+ subarr_nx = fits.Column(name="SUBARRAY", format="1J", unit='',
+ array=np.ones(len(scan_times)))
+ freqid_nx = fits.Column(name="FREQ ID", format="1J", unit='',
+ array=np.ones(len(scan_times)))
+ startvis_nx = fits.Column(name="START VIS", format="1J",
+ unit='', array=np.array(start_vis)+1)
+ endvis_nx = fits.Column(name="END VIS", format="1J",
+ unit='', array=np.array(stop_vis)+1)
+ cols = fits.ColDefs([time_nx, timeint_nx, sourceid_nx, subarr_nx,
+ freqid_nx, startvis_nx, endvis_nx])
+
+ tbhdu = fits.BinTableHDU.from_columns(cols)
+
+ # header information
+ tbhdu.header.append(("EXTNAME", "AIPS NX"))
+ tbhdu.header.append(("EXTVER", 1))
+
+ hdulist_new.append(tbhdu)
+ except TypeError:
+ print("No NX table in saved uvfits")
+
+ # Write final HDUList to file
+ if fname is not None:
+ hdulist_new.writeto(fname, overwrite=True)
+
+ return hdulist_new.copy()
+
+
+
+def save_obs_oifits(obs, fname, flux=1.0):
+ """Save visibility data to oifits file.
+ Polarization data is NOT saved
+ NOTE: as of 2021, this function is very out-of-date and should be updated
+ Args:
+ obs (Obsdata): obsdata object
+ fname (str): path to output uvfits file.
+ flux (float): Flux density normalization
+ Returns:
+ """
+
+ # TODO: Add polarization to oifits??
+ print('Warning: save_oifits does NOT save polarimetric visibility data!')
+
+ # output times must be in utc
+ obs = obs.switch_timetype(timetype_out='UTC')
+
+ if (obs.polrep != 'stokes'):
+ raise Exception("save_obs_oifits only works with polrep 'stokes'!")
+
+ # Normalizing by the total flux passed in - note this is changing the data inside the obs structure
+ obs.data['vis'] /= flux
+ obs.data['sigma'] /= flux
+
+ data = obs.unpack(['u', 'v', 'amp', 'phase', 'sigma', 'time', 't1', 't2', 'tint'])
+ biarr = obs.bispectra(mode="all", count="min")
+
+ # extract the telescope names and parameters
+ antennaNames = obs.tarr['site']
+ sefd = obs.tarr['sefdr']
+ antennaX = obs.tarr['x']
+ antennaY = obs.tarr['y']
+ antennaZ = obs.tarr['z']
+
+ # TODO: this is incorrect and there is just a dummy variable here
+ # antennaDiam = -np.ones(antennaX.shape)
+ antennaDiam = sefd # replace antennaDiam with SEFD for radio observtions
+
+ # create dictionary
+ union = {}
+ union = ehtim.io.writeoifits.arrayUnion(antennaNames, union)
+
+ # extract the integration time
+ intTime = data['tint'][0]
+ if not all(data['tint'][0] == item for item in np.reshape(data['tint'], (-1))):
+ raise TypeError("The time integrations for each visibility are different")
+
+ # get visibility information
+ amp = data['amp']
+ phase = data['phase']
+ viserror = data['sigma']
+ u = data['u']
+ v = data['v']
+
+ # convert antenna name strings to number identifiers
+ ant1 = ehtim.io.writeoifits.convertStrings(data['t1'], union)
+ ant2 = ehtim.io.writeoifits.convertStrings(data['t2'], union)
+
+ # convert times to datetime objects
+ # TODO: these do not correspond to the acutal times
+ time = data['time']
+ dttime = np.array([datetime.datetime.utcfromtimestamp(x*60.0*60.0)
+ for x in time])
+
+ # get the bispectrum information
+ bi = biarr['bispec']
+ t3amp = np.abs(bi)
+ t3phi = np.angle(bi, deg=1)
+ t3amperr = biarr['sigmab']
+ t3phierr = 180.0/np.pi * (1.0/t3amp) * t3amperr
+ uClosure = np.transpose(np.array([np.array(biarr['u1']), np.array(biarr['u2'])]))
+ vClosure = np.transpose(np.array([np.array(biarr['v1']), np.array(biarr['v2'])]))
+
+ # convert times to datetime objects
+ # TODO: these do not correspond to the acutal times
+ timeClosure = biarr['time']
+ dttimeClosure = np.array([datetime.datetime.utcfromtimestamp(x*60.0*60.0)
+ for x in timeClosure])
+
+ # convert antenna name strings to number identifiers
+ biarr_ant1 = ehtim.io.writeoifits.convertStrings(biarr['t1'], union)
+ biarr_ant2 = ehtim.io.writeoifits.convertStrings(biarr['t2'], union)
+ biarr_ant3 = ehtim.io.writeoifits.convertStrings(biarr['t3'], union)
+ antOrder = np.transpose(np.array([biarr_ant1, biarr_ant2, biarr_ant3]))
+
+ # todo: check that putting the negatives on the phase and t3phi is correct
+ ehtim.io.writeoifits.writeOIFITS(fname, obs.ra, obs.dec, obs.rf, obs.bw, intTime,
+ amp, viserror, phase, viserror, u, v, ant1, ant2, dttime,
+ t3amp, t3amperr, t3phi, t3phierr, uClosure, vClosure, antOrder,
+ dttimeClosure, antennaNames, antennaDiam,
+ antennaX, antennaY, antennaZ)
+
+ # Un-Normalizing by the total flux passed in
+ # NOTE this is changing the data inside the obs structure back to what it originally was
+ obs.data['vis'] *= flux
+ obs.data['sigma'] *= flux
+
+ return
+
+
+def save_dtype_txt(obs, fname, dtype='cphase'):
+ """Save the data product of type 'dtype' in a text file.
+ Args:
+ obs (Obsdata): obsdata object
+ fname (str): path to output text file
+ dtype (str): desired data type
+ Returns:
+ """
+
+ head = ("SRC: %s \n" % obs.source +
+ "RA: " + obsh.rastring(obs.ra) + "\n" + "DEC: " + obsh.decstring(obs.dec) + "\n" +
+ "MJD: %i \n" % obs.mjd +
+ "RF: %.4f GHz \n" % (obs.rf/1e9) +
+ "BW: %.4f GHz \n" % (obs.bw/1e9) +
+ "PHASECAL: %i \n" % obs.phasecal +
+ "AMPCAL: %i \n" % obs.ampcal +
+ "OPACITYCAL: %i \n" % obs.opacitycal +
+ "DCAL: %i \n" % obs.dcal +
+ "FRCAL: %i \n" % obs.frcal +
+ "----------------------------------------------------------------------" +
+ "------------------------------------------------------------------\n" +
+ "Site X(m) Y(m) Z(m) " +
+ "SEFDR SEFDL FR_PAR FR_EL FR_OFF " +
+ "DR_RE DR_IM DL_RE DL_IM \n"
+ )
+
+ for i in range(len(obs.tarr)):
+ head += ("%-8s %15.5f %15.5f %15.5f %8.2f %8.2f %5.2f %5.2f %5.2f %8.4f %8.4f %8.4f %8.4f \n" %
+ (obs.tarr[i]['site'],
+ obs.tarr[i]['x'], obs.tarr[i]['y'], obs.tarr[i]['z'],
+ obs.tarr[i]['sefdr'], obs.tarr[i]['sefdl'],
+ obs.tarr[i]['fr_par'], obs.tarr[i]['fr_elev'], obs.tarr[i]['fr_off'],
+ (obs.tarr[i]['dr']).real, (obs.tarr[i]['dr']).imag,
+ (obs.tarr[i]['dl']).real, (obs.tarr[i]['dl']).imag
+ ))
+
+ if dtype == 'cphase':
+ outdata = obs.cphase
+ head += (
+ "----------------------------------------------------------------------" +
+ "------------------------------------------------------------------\n" +
+ "time (hr) T1 T2 T3 U1 (lambda) V1 (lambda) U2 (lambda) V2 (lambda) U3 (lambda) V3 (lambda) Cphase (d) Sigmacp")
+ fmts = ("%011.8f %6s %6s %6s %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %10.4f %10.8f")
+
+ elif dtype == 'logcamp':
+ outdata = obs.logcamp
+ head += (
+ "----------------------------------------------------------------------" +
+ "------------------------------------------------------------------\n" +
+ "time (hr) T1 T2 T3 T4 U1 (lambda) V1 (lambda) U2 (lambda) V2 (lambda) U3 (lambda) V3 (lambda) U4 (lambda) V4 (lambda) Logcamp Sigmalogca")
+ fmts = ("%011.8f %6s %6s %6s %6s %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %10.4f %10.8f")
+
+ elif dtype == 'camp':
+ outdata = obs.camp
+ head += (
+ "----------------------------------------------------------------------" +
+ "------------------------------------------------------------------\n" +
+ "time (hr) T1 T2 T3 T4 U1 (lambda) V1 (lambda) U2 (lambda) V2 (lambda) U3 (lambda) V3 (lambda) U4 (lambda) V4 (lambda) Camp Sigmaca")
+ fmts = ("%011.8f %6s %6s %6s %6s %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %10.4f %10.8f")
+
+ elif dtype == 'bs':
+ outdata = obs.bispec
+ head += (
+ "----------------------------------------------------------------------" +
+ "------------------------------------------------------------------\n" +
+ "time (hr) T1 T2 T3 U1 (lambda) V1 (lambda) U2 (lambda) V2 (lambda) U3 (lambda) V3 (lambda) Bispec Sigmab")
+ fmts = ("%011.8f %6s %6s %6s %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %10.4f %10.8f")
+
+ elif dtype == 'amp':
+ outdata = obs.amp
+ head += (
+ "----------------------------------------------------------------------" +
+ "------------------------------------------------------------------\n" +
+ "time (hr) tint T1 T2 U (lambda) V (lambda) Amp (Jy) Ampsigma")
+ fmts = ("%011.8f %4.2f %6s %6s %16.4f %16.4f %10.8f %10.8f")
+
+ else:
+ raise Exception(dtype + ' is not a possible data type!')
+
+ np.savetxt(fname, outdata, header=head, fmt=fmts)
+ return
diff --git a/io/writeoifits.py b/io/writeoifits.py
new file mode 100644
index 00000000..0d5eb250
--- /dev/null
+++ b/io/writeoifits.py
@@ -0,0 +1,102 @@
+# writeData.py
+# functionto save observation to OIFITS
+# author: Katie Bouman
+
+from __future__ import division
+from __future__ import print_function
+from builtins import range
+
+import numpy as np
+import ehtim.io.oifits
+import ehtim.const_def as ehc
+
+
+def writeOIFITS(filename, RA, DEC, frequency, bandWidth, intTime,
+ visamp, visamperr, visphi, visphierr, u, v, ant1, ant2, timeobs,
+ t3amp, t3amperr, t3phi, t3phierr, uClosure, vClosure, antOrder, timeClosure,
+ antennaNames, antennaDiam, antennaX, antennaY, antennaZ):
+
+ speedoflight = ehc.C
+ flagVis = False # do not flag any data
+
+ # open a new oifits file
+ data = ehtim.io.oifits.oifits()
+
+ # put in the target information - RA and DEC should be in degrees
+ name = 'TARGET_NAME'
+ data.target = np.append(data.target, ehtim.io.oifits.OI_TARGET(name, RA, DEC, veltyp='LSR'))
+
+ # calulate wavelength and bandpass
+ wavelength = speedoflight/frequency
+ bandlow = speedoflight/(frequency+(0.5*bandWidth))
+ bandhigh = speedoflight/(frequency-(0.5*bandWidth))
+ bandpass = bandhigh-bandlow
+
+# put in the wavelength information - only using a single frequency
+ data.wavelength['WAVELENGTH_NAME'] = ehtim.io.oifits.OI_WAVELENGTH(
+ wavelength, eff_band=bandpass)
+
+ # put in information about the telescope stations in the array
+ stations = []
+ for i in range(0, len(antennaNames)):
+ stations.append((antennaNames[i], antennaNames[i], i+1,
+ antennaDiam[i], [antennaX[i], antennaY[i], antennaZ[i]]))
+ data.array['ARRAY_NAME'] = ehtim.io.oifits.OI_ARRAY('GEOCENTRIC', [0, 0, 0], stations)
+
+ print('Warning: set cflux and cfluxerr = False ' +
+ 'because otherwise problems were being generated ' +
+ '...are they the total flux density?')
+ print('Warning: are there any true flags?')
+
+ # put in the visibility information - note this does not include phase errors!
+ for i in range(0, len(u)):
+ station_curr = (data.array['ARRAY_NAME'].station[int(ant1[i] - 1)],
+ data.array['ARRAY_NAME'].station[int(ant2[i] - 1)])
+ currVis = ehtim.io.oifits.OI_VIS(timeobs[i], intTime, visamp[i], visamperr[i],
+ visphi[i], visphierr[i], flagVis,
+ u[i]*wavelength, v[i]*wavelength,
+ data.wavelength['WAVELENGTH_NAME'], data.target[0],
+ array=data.array['ARRAY_NAME'], station=station_curr,
+ cflux=False, cfluxerr=False)
+ data.vis = np.append(data.vis, currVis)
+
+ # put in bispectrum information
+ for j in range(0, len(uClosure)):
+ station_curr = (data.array['ARRAY_NAME'].station[int(antOrder[j][0] - 1)],
+ data.array['ARRAY_NAME'].station[int(antOrder[j][1] - 1)],
+ data.array['ARRAY_NAME'].station[int(antOrder[j][2] - 1)])
+ currT3 = ehtim.io.oifits.OI_T3(timeClosure[j], intTime, t3amp[j], t3amperr[j],
+ t3phi[j], t3phierr[j], flagVis,
+ uClosure[j][0]*wavelength, vClosure[j][0]*wavelength,
+ uClosure[j][1]*wavelength, vClosure[j][1]*wavelength,
+ data.wavelength['WAVELENGTH_NAME'], data.target[0],
+ array=data.array['ARRAY_NAME'], station=station_curr)
+ data.t3 = np.append(data.t3, currT3)
+
+ # put in visibility squared information
+ for k in range(0, len(u)):
+ station_curr = (data.array['ARRAY_NAME'].station[int(ant1[k] - 1)],
+ data.array['ARRAY_NAME'].station[int(ant2[k] - 1)])
+ currVis2 = ehtim.io.oifits.OI_VIS2(timeobs[k], intTime, visamp[k]**2,
+ 2.0*visamp[k]*visamperr[k], flagVis,
+ u[k]*wavelength, v[k]*wavelength,
+ data.wavelength['WAVELENGTH_NAME'], data.target[0],
+ array=data.array['ARRAY_NAME'], station=station_curr)
+ data.vis2 = np.append(data.vis2, currVis2)
+
+# save oifits file
+ data.save(filename)
+
+
+def arrayUnion(array, union):
+ for item in array:
+ if not (item in list(union.keys())):
+ union[item] = len(union)+1
+ return union
+
+
+def convertStrings(array, union):
+ returnarray = np.zeros(array.shape)
+ for i in range(len(array)):
+ returnarray[i] = union[array[i]]
+ return returnarray
diff --git a/model.py b/model.py
new file mode 100644
index 00000000..839dcda4
--- /dev/null
+++ b/model.py
@@ -0,0 +1,2412 @@
+# model.py
+# an interferometric model class
+
+from __future__ import division
+from __future__ import print_function
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import scipy.special as sps
+import scipy.integrate as integrate
+import scipy.interpolate as interpolate
+import copy
+
+import ehtim.observing.obs_simulate as simobs
+import ehtim.observing.pulses
+
+from ehtim.const_def import *
+from ehtim.observing.obs_helpers import *
+#from ehtim.modeling.modeling_utils import *
+
+import ehtim.image as image
+
+from ehtim.const_def import *
+
+LINE_THICKNESS = 2 # Thickness of 1D models on the image, in pixels
+FOV_DEFAULT = 100.*RADPERUAS
+NPIX_DEFAULT = 256
+COMPLEX_BASIS = 'abs-arg' # Basis for representing (most) complex quantities: 'abs-arg' or 're-im'
+
+###########################################################################################################################################
+#Model object
+###########################################################################################################################################
+
+def model_params(model_type, model_params=None, fit_pol=False, fit_cpol=False):
+ """Return the ordered list of model parameters for a specified model type. This order must match that of the gradient function, sample_1model_grad_uv.
+ """
+
+ if COMPLEX_BASIS == 're-im':
+ complex_labels = ['_re','_im']
+ elif COMPLEX_BASIS == 'abs-arg':
+ complex_labels = ['_abs','_arg']
+ else:
+ raise Exception('COMPLEX_BASIS ' + COMPLEX_BASIS + ' not recognized!')
+
+ params = []
+
+ # Function to add polarimetric parameters; these must be added before stretch parameters
+ def add_pol():
+ if fit_pol:
+ if model_type.find('mring') == -1:
+ params.append('pol_frac')
+ params.append('pol_evpa')
+ else:
+ for j in range(-(len(model_params['beta_list_pol'])-1)//2,(len(model_params['beta_list_pol'])+1)//2):
+ params.append('betapol' + str(j) + complex_labels[0])
+ params.append('betapol' + str(j) + complex_labels[1])
+ if fit_cpol:
+ if model_type.find('mring') == -1:
+ params.append('cpol_frac')
+ else:
+ for j in range(len(model_params['beta_list_cpol'])):
+ if j==0:
+ params.append('betacpol0')
+ else:
+ params.append('betacpol' + str(j) + complex_labels[0])
+ params.append('betacpol' + str(j) + complex_labels[1])
+
+ if model_type == 'point':
+ params = ['F0','x0','y0']
+ add_pol()
+ elif model_type == 'circ_gauss':
+ params = ['F0','FWHM','x0','y0']
+ add_pol()
+ elif model_type == 'gauss':
+ params = ['F0','FWHM_maj','FWHM_min','PA','x0','y0']
+ add_pol()
+ elif model_type == 'disk':
+ params = ['F0','d','x0','y0']
+ add_pol()
+ elif model_type == 'blurred_disk':
+ params = ['F0','d','alpha','x0','y0']
+ add_pol()
+ elif model_type == 'crescent':
+ params = ['F0','d', 'fr', 'fo', 'ff', 'phi','x0','y0']
+ add_pol()
+ elif model_type == 'blurred_crescent':
+ params = ['F0','d','alpha','fr', 'fo', 'ff', 'phi','x0','y0']
+ add_pol()
+ elif model_type == 'ring':
+ params = ['F0','d','x0','y0']
+ add_pol()
+ elif model_type == 'stretched_ring':
+ params = ['F0','d','x0','y0','stretch','stretch_PA']
+ add_pol()
+ elif model_type == 'thick_ring':
+ params = ['F0','d','alpha','x0','y0']
+ add_pol()
+ elif model_type == 'stretched_thick_ring':
+ params = ['F0','d','alpha','x0','y0','stretch','stretch_PA']
+ add_pol()
+ elif model_type == 'mring':
+ params = ['F0','d','x0','y0']
+ for j in range(len(model_params['beta_list'])):
+ params.append('beta' + str(j+1) + complex_labels[0])
+ params.append('beta' + str(j+1) + complex_labels[1])
+ add_pol()
+ elif model_type == 'stretched_mring':
+ params = ['F0','d','x0','y0']
+ for j in range(len(model_params['beta_list'])):
+ params.append('beta' + str(j+1) + complex_labels[0])
+ params.append('beta' + str(j+1) + complex_labels[1])
+ add_pol()
+ params.append('stretch')
+ params.append('stretch_PA')
+ elif model_type == 'thick_mring':
+ params = ['F0','d','alpha','x0','y0']
+ for j in range(len(model_params['beta_list'])):
+ params.append('beta' + str(j+1) + complex_labels[0])
+ params.append('beta' + str(j+1) + complex_labels[1])
+ add_pol()
+ elif model_type == 'thick_mring_floor':
+ params = ['F0','d','alpha','ff','x0','y0']
+ for j in range(len(model_params['beta_list'])):
+ params.append('beta' + str(j+1) + complex_labels[0])
+ params.append('beta' + str(j+1) + complex_labels[1])
+ add_pol()
+ elif model_type == 'thick_mring_Gfloor':
+ params = ['F0','d','alpha','ff','FWHM','x0','y0']
+ for j in range(len(model_params['beta_list'])):
+ params.append('beta' + str(j+1) + complex_labels[0])
+ params.append('beta' + str(j+1) + complex_labels[1])
+ add_pol()
+
+ elif model_type == 'stretched_thick_mring':
+ params = ['F0','d','alpha', 'x0','y0']
+ for j in range(len(model_params['beta_list'])):
+ params.append('beta' + str(j+1) + complex_labels[0])
+ params.append('beta' + str(j+1) + complex_labels[1])
+ add_pol()
+ params.append('stretch')
+ params.append('stretch_PA')
+ elif model_type == 'stretched_thick_mring_floor':
+ params = ['F0','d','alpha','ff', 'x0','y0']
+ for j in range(len(model_params['beta_list'])):
+ params.append('beta' + str(j+1) + complex_labels[0])
+ params.append('beta' + str(j+1) + complex_labels[1])
+ add_pol()
+ params.append('stretch')
+ params.append('stretch_PA')
+ else:
+ print('Model ' + model_init.models[j] + ' not recognized.')
+ params = []
+
+ return params
+
+def default_prior(model_type,model_params=None,fit_pol=False,fit_cpol=False):
+ """Return the default model prior and transformation for a specified model type
+ """
+
+ if COMPLEX_BASIS == 're-im':
+ complex_labels = ['_re','_im']
+ complex_priors = [{'prior_type':'flat','min':-0.5,'max':0.5}, {'prior_type':'flat','min':-0.5,'max':0.5}]
+ complex_priors2 = [{'prior_type':'flat','min':-1,'max':1}, {'prior_type':'flat','min':-1,'max':1}]
+ elif COMPLEX_BASIS == 'abs-arg':
+ complex_labels = ['_abs','_arg']
+ # Note: angle range here must match np.angle(). Need to properly define wrapped distributions
+ complex_priors = [{'prior_type':'flat','min':0.0,'max':0.5}, {'prior_type':'flat','min':-np.pi, 'max':np.pi}]
+ complex_priors2 = [{'prior_type':'flat','min':0.0,'max':1.0}, {'prior_type':'flat','min':-np.pi, 'max':np.pi}]
+ else:
+ raise Exception('COMPLEX_BASIS ' + COMPLEX_BASIS + ' not recognized!')
+
+ prior = {'F0':{'prior_type':'none','transform':'log'},
+ 'x0':{'prior_type':'none'},
+ 'y0':{'prior_type':'none'}}
+ if model_type == 'point':
+ pass
+ elif model_type == 'circ_gauss':
+ prior['FWHM'] = {'prior_type':'none','transform':'log'}
+ elif model_type == 'gauss':
+ prior['FWHM_maj'] = {'prior_type':'positive','transform':'log'}
+ prior['FWHM_min'] = {'prior_type':'positive','transform':'log'}
+ prior['PA'] = {'prior_type':'none'}
+ elif model_type == 'disk':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ elif model_type == 'blurred_disk':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['alpha'] = {'prior_type':'positive','transform':'log'}
+ elif model_type == 'crescent':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['fr'] = {'prior_type':'flat','min':0,'max':1}
+ prior['fo'] = {'prior_type':'flat','min':0,'max':1}
+ prior['ff'] = {'prior_type':'flat','min':0,'max':1}
+ prior['phi'] = {'prior_type':'flat','min':0,'max':2.*np.pi}
+ elif model_type == 'blurred_crescent':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['alpha'] = {'prior_type':'positive','transform':'log'}
+ prior['fr'] = {'prior_type':'flat','min':0,'max':1}
+ prior['fo'] = {'prior_type':'flat','min':0,'max':1}
+ prior['ff'] = {'prior_type':'flat','min':0,'max':1}
+ prior['phi'] = {'prior_type':'flat','min':0,'max':2.*np.pi}
+ elif model_type == 'ring':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ elif model_type == 'stretched_ring':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['stretch'] = {'prior_type':'positive','transform':'log'}
+ prior['stretch_PA'] = {'prior_type':'none'}
+ elif model_type == 'thick_ring':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['alpha'] = {'prior_type':'positive','transform':'log'}
+ elif model_type == 'stretched_thick_ring':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['alpha'] = {'prior_type':'positive','transform':'log'}
+ prior['stretch'] = {'prior_type':'positive','transform':'log'}
+ prior['stretch_PA'] = {'prior_type':'none'}
+ elif model_type == 'mring':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ for j in range(len(model_params['beta_list'])):
+ prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0]
+ prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1]
+ elif model_type == 'stretched_mring':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ for j in range(len(model_params['beta_list'])):
+ prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0]
+ prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1]
+ prior['stretch'] = {'prior_type':'positive','transform':'log'}
+ prior['stretch_PA'] = {'prior_type':'none'}
+ elif model_type == 'thick_mring':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['alpha'] = {'prior_type':'positive','transform':'log'}
+ for j in range(len(model_params['beta_list'])):
+ prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0]
+ prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1]
+ elif model_type == 'thick_mring_floor':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['alpha'] = {'prior_type':'positive','transform':'log'}
+ prior['ff'] = {'prior_type':'flat','min':0,'max':1}
+ for j in range(len(model_params['beta_list'])):
+ prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0]
+ prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1]
+ elif model_type == 'thick_mring_Gfloor':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['alpha'] = {'prior_type':'positive','transform':'log'}
+ prior['ff'] = {'prior_type':'flat','min':0,'max':1}
+ prior['FWHM'] = {'prior_type':'positive','transform':'log'}
+ for j in range(len(model_params['beta_list'])):
+ prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0]
+ prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1]
+ elif model_type == 'stretched_thick_mring':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['alpha'] = {'prior_type':'positive','transform':'log'}
+ for j in range(len(model_params['beta_list'])):
+ prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0]
+ prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1]
+ prior['stretch'] = {'prior_type':'positive','transform':'log'}
+ prior['stretch_PA'] = {'prior_type':'none'}
+ elif model_type == 'stretched_thick_mring_floor':
+ prior['d'] = {'prior_type':'positive','transform':'log'}
+ prior['alpha'] = {'prior_type':'positive','transform':'log'}
+ prior['ff'] = {'prior_type':'flat','min':0,'max':1}
+ for j in range(len(model_params['beta_list'])):
+ prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0]
+ prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1]
+ prior['stretch'] = {'prior_type':'positive','transform':'log'}
+ prior['stretch_PA'] = {'prior_type':'none'}
+ else:
+ print('Model not recognized!')
+
+ if fit_pol:
+ if model_type.find('mring') == -1:
+ prior['pol_frac'] = {'prior_type':'flat','min':0.0,'max':1.0}
+ prior['pol_evpa'] = {'prior_type':'flat','min':0.0,'max':np.pi}
+ else:
+ for j in range(-(len(model_params['beta_list_pol'])-1)//2,(len(model_params['beta_list_pol'])+1)//2):
+ prior['betapol' + str(j) + complex_labels[0]] = complex_priors2[0]
+ prior['betapol' + str(j) + complex_labels[1]] = complex_priors2[1]
+
+ if fit_cpol:
+ if model_type.find('mring') == -1:
+ prior['cpol_frac'] = {'prior_type':'flat','min':-1.0,'max':1.0}
+ else:
+ for j in range(len(model_params['beta_list_cpol'])):
+ if j > 0:
+ prior['betacpol' + str(j) + complex_labels[0]] = complex_priors2[0]
+ prior['betacpol' + str(j) + complex_labels[1]] = complex_priors2[1]
+ else:
+ prior['betacpol0'] = {'prior_type':'flat','min':-1.0,'max':1.0}
+
+ return prior
+
+def stretch_xy(x, y, params):
+ x_stretch = ((x - params['x0']) * (np.cos(params['stretch_PA'])**2 + np.sin(params['stretch_PA'])**2 / params['stretch'])
+ + (y - params['y0']) * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA']) * (1.0/params['stretch'] - 1.0))
+ y_stretch = ((y - params['y0']) * (np.cos(params['stretch_PA'])**2 / params['stretch'] + np.sin(params['stretch_PA'])**2)
+ + (x - params['x0']) * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA']) * (1.0/params['stretch'] - 1.0))
+ return (params['x0'] + x_stretch,params['y0'] + y_stretch)
+
+def stretch_uv(u, v, params):
+ u_stretch = (u * (np.cos(params['stretch_PA'])**2 + np.sin(params['stretch_PA'])**2 * params['stretch'])
+ + v * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA']) * (params['stretch'] - 1.0))
+ v_stretch = (v * (np.cos(params['stretch_PA'])**2 * params['stretch'] + np.sin(params['stretch_PA'])**2)
+ + u * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA']) * (params['stretch'] - 1.0))
+ return (u_stretch,v_stretch)
+
+def get_const_polfac(model_type, params, pol):
+ # Return the scaling factor for models with constant fractional polarization
+
+ if model_type.find('mring') != -1:
+ # mring models have polarization information specified differently than a constant scaling factor
+ return 1.0
+
+ try:
+ if pol == 'I':
+ return 1.0
+ elif pol == 'Q':
+ return params['pol_frac'] * np.cos(2.0 * params['pol_evpa'])
+ elif pol == 'U':
+ return params['pol_frac'] * np.sin(2.0 * params['pol_evpa'])
+ elif pol == 'V':
+ return params['cpol_frac']
+ elif pol == 'P':
+ return params['pol_frac'] * np.exp(1j * 2.0 * params['pol_evpa'])
+ elif pol == 'RR':
+ return get_const_polfac(model_type, params, 'I') + get_const_polfac(model_type, params, 'V')
+ elif pol == 'RL':
+ return get_const_polfac(model_type, params, 'Q') + 1j*get_const_polfac(model_type, params, 'U')
+ elif pol == 'LR':
+ return get_const_polfac(model_type, params, 'Q') - 1j*get_const_polfac(model_type, params, 'U')
+ elif pol == 'LL':
+ return get_const_polfac(model_type, params, 'I') - get_const_polfac(model_type, params, 'V')
+ except Exception:
+ pass
+
+ return 0.0
+
+def sample_1model_xy(x, y, model_type, params, psize=1.*RADPERUAS, pol='I'):
+ if pol == 'Q':
+ return np.real(sample_1model_xy(x, y, model_type, params, psize=psize, pol='P'))
+ elif pol == 'U':
+ return np.imag(sample_1model_xy(x, y, model_type, params, psize=psize, pol='P'))
+ elif pol in ['I','V','P']:
+ pass
+ else:
+ raise Exception('Polarization ' + pol + ' not implemented!')
+
+ if model_type == 'point':
+ val = params['F0'] * (np.abs( x - params['x0']) < psize/2.0) * (np.abs( y - params['y0']) < psize/2.0)
+ elif model_type == 'circ_gauss':
+ sigma = params['FWHM'] / (2. * np.sqrt(2. * np.log(2.)))
+ val = (params['F0']*psize**2 * 4.0 * np.log(2.)/(np.pi * params['FWHM']**2) *
+ np.exp(-((x - params['x0'])**2 + (y - params['y0'])**2)/(2*sigma**2)))
+ elif model_type == 'gauss':
+ sigma_maj = params['FWHM_maj'] / (2. * np.sqrt(2. * np.log(2.)))
+ sigma_min = params['FWHM_min'] / (2. * np.sqrt(2. * np.log(2.)))
+ cth = np.cos(params['PA'])
+ sth = np.sin(params['PA'])
+ val = (params['F0']*psize**2 * 4.0 * np.log(2.)/(np.pi * params['FWHM_maj'] * params['FWHM_min']) *
+ np.exp(-((y - params['y0'])*np.cos(params['PA']) + (x - params['x0'])*np.sin(params['PA']))**2/(2*sigma_maj**2) +
+ -((x - params['x0'])*np.cos(params['PA']) - (y - params['y0'])*np.sin(params['PA']))**2/(2*sigma_min**2)))
+ elif model_type == 'disk':
+ val = params['F0']*psize**2/(np.pi*params['d']**2/4.) * (np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2) < params['d']/2.0)
+ elif model_type == 'blurred_disk':
+ # Note: the exact form of a blurred disk requires numeric integration
+
+ # This is the peak brightness of the blurred disk
+ I_peak = 4.0/(np.pi*params['d']**2) * (1.0 - 2.0**(-params['d']**2/params['alpha']**2))
+
+ # Constant prefactor
+ prefac = 32.0 * np.log(2.0)/(np.pi * params['alpha']**2 * params['d']**2)
+
+ def f(r):
+ return integrate.quad(lambda rp:
+ prefac * rp * np.exp( -4.0 * np.log(2.0)/params['alpha']**2 * (r**2 + rp**2 - 2.0*r * rp) )
+ * sps.ive(0, 8.0*np.log(2.0) * r * rp/params['alpha']**2),
+ 0, params['d']/2.0, limit=1000, epsabs=I_peak/1e9, epsrel=1.0e-6)[0]
+ f=np.vectorize(f)
+ r = np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2)
+
+ # For images, it's much quicker to do the 1-D problem and interpolate
+ if np.ndim(r) > 0:
+ r_min = np.min(r)
+ r_max = np.max(r)
+ r_list = np.linspace(r_min, r_max, int((r_max-r_min)/(params['alpha']) * 20))
+ if len(r_list) < len(np.ravel(r))/2 and len(r) > 100:
+ f = interpolate.interp1d(r_list, f(r_list), kind='cubic')
+ val = params['F0'] * psize**2 * f(r)
+ elif model_type == 'crescent':
+ phi = params['phi']
+ fr = params['fr']
+ fo = params['fo']
+ ff = params['ff']
+ r = params['d'] / 2.
+ params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'x0': params['x0'], 'y0': params['y0']}
+ params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)}
+ val = sample_1model_xy(x, y, 'disk', params0, psize=psize, pol=pol)
+ val -= sample_1model_xy(x, y, 'disk', params1, psize=psize, pol=pol)
+ elif model_type == 'blurred_crescent':
+ phi = params['phi']
+ fr = params['fr']
+ fo = params['fo']
+ ff = params['ff']
+ r = params['d'] / 2.
+ params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'alpha':params['alpha'], 'x0': params['x0'], 'y0': params['y0']}
+ params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'alpha':params['alpha'], 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)}
+ val = sample_1model_xy(x, y, 'blurred_disk', params0, psize=psize, pol=pol)
+ val -= sample_1model_xy(x, y, 'blurred_disk', params1, psize=psize, pol=pol)
+ elif model_type == 'ring':
+ val = (params['F0']*psize**2/(np.pi*params['d']*psize*LINE_THICKNESS)
+ * (params['d']/2.0 - psize*LINE_THICKNESS/2 < np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2))
+ * (params['d']/2.0 + psize*LINE_THICKNESS/2 > np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2)))
+ elif model_type == 'thick_ring':
+ r = np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2)
+ z = 4.*np.log(2.) * r * params['d']/params['alpha']**2
+ val = (params['F0']*psize**2 * 4.0 * np.log(2.)/(np.pi * params['alpha']**2)
+ * np.exp(-4.*np.log(2.)/params['alpha']**2*(r**2 + params['d']**2/4.) + z)
+ * sps.ive(0, z))
+ elif model_type == 'mring':
+ phi = np.angle((y - params['y0']) + 1j*(x - params['x0']))
+ if pol == 'I':
+ beta_factor = (1.0 + np.sum([2.*np.real(params['beta_list'][m-1] * np.exp(1j * m * phi)) for m in range(1,len(params['beta_list'])+1)],axis=0))
+ elif pol == 'V' and len(params['beta_list_cpol']) > 0:
+ beta_factor = np.real(params['beta_list_cpol'][0]) + np.sum([2.*np.real(params['beta_list_cpol'][m] * np.exp(1j * m * phi)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ elif pol == 'P' and len(params['beta_list_pol']) > 0:
+ num_coeff = len(params['beta_list_pol'])
+ beta_factor = np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * np.exp(1j * m * phi) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)
+ else:
+ beta_factor = 0.0
+
+ val = (params['F0']*psize**2/(np.pi*params['d']*psize*LINE_THICKNESS)
+ * beta_factor
+ * (params['d']/2.0 - psize*LINE_THICKNESS/2 < np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2))
+ * (params['d']/2.0 + psize*LINE_THICKNESS/2 > np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2)))
+ elif model_type == 'thick_mring':
+ phi = np.angle((y - params['y0']) + 1j*(x - params['x0']))
+ r = np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2)
+ z = 4.*np.log(2.) * r * params['d']/params['alpha']**2
+ if pol == 'I':
+ beta_factor = (sps.ive(0, z) + np.sum([2.*np.real(sps.ive(m, z) * params['beta_list'][m-1] * np.exp(1j * m * phi)) for m in range(1,len(params['beta_list'])+1)],axis=0))
+ elif pol == 'V' and len(params['beta_list_cpol']) > 0:
+ beta_factor = (sps.ive(0, z) * np.real(params['beta_list_cpol'][0]) + np.sum([2.*np.real(sps.ive(m, z) * params['beta_list_cpol'][m] * np.exp(1j * m * phi)) for m in range(1,len(params['beta_list_cpol']))],axis=0))
+ elif pol == 'P' and len(params['beta_list_pol']) > 0:
+ num_coeff = len(params['beta_list_pol'])
+ beta_factor = np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.ive(m, z) * np.exp(1j * m * phi) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)
+ else:
+ # Note: not all polarizations accounted for yet (need RR, RL, LR, LL; do these by calling for linear combinations of I, Q, U, V)!
+ beta_factor = 0.0
+
+ val = (params['F0']*psize**2 * 4.0 * np.log(2.)/(np.pi * params['alpha']**2)
+ * np.exp(-4.*np.log(2.)/params['alpha']**2*(r**2 + params['d']**2/4.) + z)
+ * beta_factor)
+ elif model_type == 'thick_mring_floor':
+ val = (1.0 - params['ff']) * sample_1model_xy(x, y, 'thick_mring', params, psize=psize, pol=pol)
+ val += params['ff'] * sample_1model_xy(x, y, 'blurred_disk', params, psize=psize, pol=pol)
+ elif model_type == 'thick_mring_Gfloor':
+ val = (1.0 - params['ff']) * sample_1model_xy(x, y, 'thick_mring', params, psize=psize, pol=pol)
+ val += params['ff'] * sample_1model_xy(x, y, 'circ_gauss', params, psize=psize, pol=pol)
+ elif model_type[:9] == 'stretched':
+ params_stretch = params.copy()
+ params_stretch['F0'] /= params['stretch']
+ val = sample_1model_xy(*stretch_xy(x, y, params), model_type[10:], params_stretch, psize, pol=pol)
+ else:
+ print('Model ' + model_type + ' not recognized!')
+ val = 0.0
+ return val * get_const_polfac(model_type, params, pol)
+
+def sample_1model_uv(u, v, model_type, params, pol='I', jonesdict=None):
+ if jonesdict is not None:
+ # Define the various lists
+ fr1 = jonesdict['fr1'] # Field rotation of site 1
+ fr2 = jonesdict['fr2'] # Field rotation of site 2
+ DR1 = jonesdict['DR1'] # Right leakage term of site 1
+ DL1 = jonesdict['DL1'] # Left leakage term of site 1
+ DR2 = np.conj(jonesdict['DR2']) # Right leakage term of site 2
+ DL2 = np.conj(jonesdict['DL2']) # Left leakage term of site 2
+ # Sample the model without leakage
+ RR = sample_1model_uv(u, v, model_type, params, pol='RR')
+ RL = sample_1model_uv(u, v, model_type, params, pol='RL')
+ LR = sample_1model_uv(u, v, model_type, params, pol='LR')
+ LL = sample_1model_uv(u, v, model_type, params, pol='LL')
+ # Apply the Jones matrices
+ RRp = RR + LR * DR1 * np.exp( 2j*fr1) + RL * DR2 * np.exp(-2j*fr2) + LL * DR1 * DR2 * np.exp( 2j*(fr1-fr2))
+ RLp = RL + LL * DR1 * np.exp( 2j*fr1) + RR * DL2 * np.exp( 2j*fr2) + LR * DR1 * DL2 * np.exp( 2j*(fr1+fr2))
+ LRp = LR + RR * DL1 * np.exp(-2j*fr1) + LL * DR2 * np.exp(-2j*fr2) + RL * DL1 * DR2 * np.exp(-2j*(fr1+fr2))
+ LLp = LL + LR * DL2 * np.exp( 2j*fr2) + RL * DL1 * np.exp(-2j*fr1) + RR * DL1 * DL2 * np.exp(-2j*(fr1-fr2))
+ # Return the specified polarization
+ if pol == 'RR': return RRp
+ elif pol == 'RL': return RLp
+ elif pol == 'LR': return LRp
+ elif pol == 'LL': return LLp
+ elif pol == 'I': return 0.5 * (RRp + LLp)
+ elif pol == 'Q': return 0.5 * (LRp + RLp)
+ elif pol == 'U': return 0.5j* (LRp - RLp)
+ elif pol == 'V': return 0.5 * (RRp - LLp)
+ elif pol == 'P': return RLp
+ else:
+ raise Exception('Polarization ' + pol + ' not recognized!')
+
+ if pol == 'Q':
+ return 0.5 * (sample_1model_uv(u, v, model_type, params, pol='P') + np.conj(sample_1model_uv(-u, -v, model_type, params, pol='P')))
+ elif pol == 'U':
+ return -0.5j * (sample_1model_uv(u, v, model_type, params, pol='P') - np.conj(sample_1model_uv(-u, -v, model_type, params, pol='P')))
+ elif pol in ['I','V','P']:
+ pass
+ elif pol == 'RR':
+ return sample_1model_uv(u, v, model_type, params, pol='I') + sample_1model_uv(u, v, model_type, params, pol='V')
+ elif pol == 'LL':
+ return sample_1model_uv(u, v, model_type, params, pol='I') - sample_1model_uv(u, v, model_type, params, pol='V')
+ elif pol == 'RL':
+ return sample_1model_uv(u, v, model_type, params, pol='Q') + 1j*sample_1model_uv(u, v, model_type, params, pol='U')
+ elif pol == 'LR':
+ return sample_1model_uv(u, v, model_type, params, pol='Q') - 1j*sample_1model_uv(u, v, model_type, params, pol='U')
+ else:
+ raise Exception('Polarization ' + pol + ' not implemented!')
+
+ if model_type == 'point':
+ val = params['F0'] * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+ elif model_type == 'circ_gauss':
+ val = (params['F0']
+ * np.exp(-np.pi**2/(4.*np.log(2.)) * (u**2 + v**2) * params['FWHM']**2)
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ elif model_type == 'gauss':
+ u_maj = u*np.sin(params['PA']) + v*np.cos(params['PA'])
+ u_min = u*np.cos(params['PA']) - v*np.sin(params['PA'])
+ val = (params['F0']
+ * np.exp(-np.pi**2/(4.*np.log(2.)) * ((u_maj * params['FWHM_maj'])**2 + (u_min * params['FWHM_min'])**2))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ elif model_type == 'disk':
+ z = np.pi * params['d'] * (u**2 + v**2)**0.5
+ #Add a small offset to avoid issues with division by zero
+ z += (z == 0.0) * 1e-10
+ val = (params['F0'] * 2.0/z * sps.jv(1, z)
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ elif model_type == 'blurred_disk':
+ z = np.pi * params['d'] * (u**2 + v**2)**0.5
+ #Add a small offset to avoid issues with division by zero
+ z += (z == 0.0) * 1e-10
+ val = (params['F0'] * 2.0/z * sps.jv(1, z)
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+ * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))))
+ elif model_type == 'crescent':
+ phi = params['phi']
+ fr = params['fr']
+ fo = params['fo']
+ ff = params['ff']
+ r = params['d'] / 2.
+ params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'x0': params['x0'], 'y0': params['y0']}
+ params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)}
+ val = sample_1model_uv(u, v, 'disk', params0, pol=pol, jonesdict=jonesdict)
+ val -= sample_1model_uv(u, v, 'disk', params1, pol=pol, jonesdict=jonesdict)
+ elif model_type == 'blurred_crescent':
+ phi = params['phi']
+ fr = params['fr']
+ fo = params['fo']
+ ff = params['ff']
+ r = params['d'] / 2.
+ params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'alpha':params['alpha'], 'x0': params['x0'], 'y0': params['y0']}
+ params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'alpha':params['alpha'], 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)}
+ val = sample_1model_uv(u, v, 'blurred_disk', params0, pol=pol, jonesdict=jonesdict)
+ val -= sample_1model_uv(u, v, 'blurred_disk', params1, pol=pol, jonesdict=jonesdict)
+ elif model_type == 'ring':
+ z = np.pi * params['d'] * (u**2 + v**2)**0.5
+ val = (params['F0'] * sps.jv(0, z)
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ elif model_type == 'thick_ring':
+ z = np.pi * params['d'] * (u**2 + v**2)**0.5
+ val = (params['F0'] * sps.jv(0, z)
+ * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.)))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ elif model_type == 'mring':
+ phi = np.angle(v + 1j*u)
+ # Flip the baseline sign to match eht-imaging conventions
+ phi += np.pi
+ z = np.pi * params['d'] * (u**2 + v**2)**0.5
+
+ if pol == 'I':
+ beta_factor = (sps.jv(0, z)
+ + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0))
+ elif pol == 'V' and len(params['beta_list_cpol']) > 0:
+ beta_factor = (np.real(params['beta_list_cpol'][0]) * sps.jv(0, z)
+ + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0))
+ elif pol == 'P' and len(params['beta_list_pol']) > 0:
+ num_coeff = len(params['beta_list_pol'])
+ beta_factor = np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)
+ else:
+ beta_factor = 0.0
+
+ val = params['F0'] * beta_factor * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+ elif model_type == 'thick_mring':
+ phi = np.angle(v + 1j*u)
+ # Flip the baseline sign to match eht-imaging conventions
+ phi += np.pi
+ z = np.pi * params['d'] * (u**2 + v**2)**0.5
+
+ if pol == 'I':
+ beta_factor = (sps.jv(0, z)
+ + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0))
+ elif pol == 'V' and len(params['beta_list_cpol']) > 0:
+ beta_factor = (np.real(params['beta_list_cpol'][0]) * sps.jv(0, z)
+ + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0))
+ elif pol == 'P' and len(params['beta_list_pol']) > 0:
+ num_coeff = len(params['beta_list_pol'])
+ beta_factor = np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)
+ else:
+ beta_factor = 0.0
+
+ val = (params['F0'] * beta_factor
+ * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.)))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ elif model_type == 'thick_mring_floor':
+ val = (1.0 - params['ff']) * sample_1model_uv(u, v, 'thick_mring', params, pol=pol, jonesdict=jonesdict)
+ val += params['ff'] * sample_1model_uv(u, v, 'blurred_disk', params, pol=pol, jonesdict=jonesdict)
+ elif model_type == 'thick_mring_Gfloor':
+ val = (1.0 - params['ff']) * sample_1model_uv(u, v, 'thick_mring', params, pol=pol, jonesdict=jonesdict)
+ val += params['ff'] * sample_1model_uv(u, v, 'circ_gauss', params, pol=pol, jonesdict=jonesdict)
+ elif model_type[:9] == 'stretched':
+ params_stretch = params.copy()
+ params_stretch['x0'] = 0.0
+ params_stretch['y0'] = 0.0
+ val = sample_1model_uv(*stretch_uv(u,v,params), model_type[10:], params_stretch, pol=pol) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+ else:
+ print('Model ' + model_type + ' not recognized!')
+ val = 0.0
+ return val * get_const_polfac(model_type, params, pol)
+
+def sample_1model_graduv_uv(u, v, model_type, params, pol='I', jonesdict=None):
+ # Gradient of the visibility function, (dV/du, dV/dv)
+ # This function makes it convenient to, e.g., compute gradients of stretched images and to compute the model centroid
+
+ if jonesdict is not None:
+ # Define the various lists
+ fr1 = jonesdict['fr1'] # Field rotation of site 1
+ fr2 = jonesdict['fr2'] # Field rotation of site 2
+ DR1 = jonesdict['DR1'] # Right leakage term of site 1
+ DL1 = jonesdict['DL1'] # Left leakage term of site 1
+ DR2 = np.conj(jonesdict['DR2']) # Right leakage term of site 2
+ DL2 = np.conj(jonesdict['DL2']) # Left leakage term of site 2
+ # Sample the model without leakage
+ RR = sample_1model_graduv_uv(u, v, model_type, params, pol='RR').reshape(2,len(u))
+ RL = sample_1model_graduv_uv(u, v, model_type, params, pol='RL').reshape(2,len(u))
+ LR = sample_1model_graduv_uv(u, v, model_type, params, pol='LR').reshape(2,len(u))
+ LL = sample_1model_graduv_uv(u, v, model_type, params, pol='LL').reshape(2,len(u))
+ # Apply the Jones matrices
+ RRp = (RR + LR * DR1 * np.exp( 2j*fr1) + RL * DR2 * np.exp(-2j*fr2) + LL * DR1 * DR2 * np.exp( 2j*(fr1-fr2)))
+ RLp = (RL + LL * DR1 * np.exp( 2j*fr1) + RR * DL2 * np.exp( 2j*fr2) + LR * DR1 * DL2 * np.exp( 2j*(fr1+fr2)))
+ LRp = (LR + RR * DL1 * np.exp(-2j*fr1) + LL * DR2 * np.exp(-2j*fr2) + RL * DL1 * DR2 * np.exp(-2j*(fr1+fr2)))
+ LLp = (LL + LR * DL2 * np.exp( 2j*fr2) + RL * DL1 * np.exp(-2j*fr1) + RR * DL1 * DL2 * np.exp(-2j*(fr1-fr2)))
+ # Return the specified polarization
+ if pol == 'RR': return RRp
+ elif pol == 'RL': return RLp
+ elif pol == 'LR': return LRp
+ elif pol == 'LL': return LLp
+ elif pol == 'I': return 0.5 * (RRp + LLp)
+ elif pol == 'Q': return 0.5 * (LRp + RLp)
+ elif pol == 'U': return 0.5j* (LRp - RLp)
+ elif pol == 'V': return 0.5 * (RRp - LLp)
+ elif pol == 'P': return RLp
+ else:
+ raise Exception('Polarization ' + pol + ' not recognized!')
+
+ if pol == 'Q':
+ return 0.5 * (sample_1model_graduv_uv(u, v, model_type, params, pol='P') + np.conj(sample_1model_graduv_uv(-u, -v, model_type, params, pol='P')))
+ elif pol == 'U':
+ return -0.5j * (sample_1model_graduv_uv(u, v, model_type, params, pol='P') - np.conj(sample_1model_graduv_uv(-u, -v, model_type, params, pol='P')))
+ elif pol in ['I','V','P']:
+ pass
+ elif pol == 'RR':
+ return sample_1model_graduv_uv(u, v, model_type, params, pol='I') + sample_1model_graduv_uv(u, v, model_type, params, pol='V')
+ elif pol == 'LL':
+ return sample_1model_graduv_uv(u, v, model_type, params, pol='I') - sample_1model_graduv_uv(u, v, model_type, params, pol='V')
+ elif pol == 'RL':
+ return sample_1model_graduv_uv(u, v, model_type, params, pol='Q') + 1j*sample_1model_graduv_uv(u, v, model_type, params, pol='U')
+ elif pol == 'LR':
+ return sample_1model_graduv_uv(u, v, model_type, params, pol='Q') - 1j*sample_1model_graduv_uv(u, v, model_type, params, pol='U')
+ else:
+ raise Exception('Polarization ' + pol + ' not implemented!')
+
+ vis = sample_1model_uv(u, v, model_type, params, jonesdict=jonesdict)
+ if model_type == 'point':
+ val = np.array([ 1j * 2.0 * np.pi * params['x0'] * vis,
+ 1j * 2.0 * np.pi * params['y0'] * vis])
+ elif model_type == 'circ_gauss':
+ val = np.array([ (1j * 2.0 * np.pi * params['x0'] - params['FWHM']**2 * np.pi**2 * u/(2. * np.log(2.))) * vis,
+ (1j * 2.0 * np.pi * params['y0'] - params['FWHM']**2 * np.pi**2 * v/(2. * np.log(2.))) * vis])
+ elif model_type == 'gauss':
+ u_maj = u*np.sin(params['PA']) + v*np.cos(params['PA'])
+ u_min = u*np.cos(params['PA']) - v*np.sin(params['PA'])
+ val = np.array([ (1j * 2.0 * np.pi * params['x0'] - params['FWHM_maj']**2 * np.pi**2 * u_maj/(2. * np.log(2.)) * np.sin(params['PA']) - params['FWHM_min']**2 * np.pi**2 * u_min/(2. * np.log(2.)) * np.cos(params['PA'])) * vis,
+ (1j * 2.0 * np.pi * params['y0'] - params['FWHM_maj']**2 * np.pi**2 * u_maj/(2. * np.log(2.)) * np.cos(params['PA']) + params['FWHM_min']**2 * np.pi**2 * u_min/(2. * np.log(2.)) * np.sin(params['PA'])) * vis])
+ elif model_type == 'disk':
+ # Take care of the degenerate origin point by a small offset
+ #v += (u==0.)*(v==0.)*1e-10
+ uvdist = (u**2 + v**2 + (u==0.)*(v==0.)*1e-10)**0.5
+ z = np.pi * params['d'] * uvdist
+ bessel_deriv = 0.5 * (sps.jv( 0, z) - sps.jv( 2, z))
+ val = np.array([ (1j * 2.0 * np.pi * params['x0'] - u/uvdist**2) * vis
+ + params['F0'] * 2./z * np.pi * params['d'] * u/uvdist * bessel_deriv * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ (1j * 2.0 * np.pi * params['y0'] - v/uvdist**2) * vis
+ + params['F0'] * 2./z * np.pi * params['d'] * v/uvdist * bessel_deriv * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ])
+ elif model_type == 'blurred_disk':
+ # Take care of the degenerate origin point by a small offset
+ #u += (u==0.)*(v==0.)*1e-10
+ uvdist = (u**2 + v**2 + (u==0.)*(v==0.)*1e-10)**0.5
+ z = np.pi * params['d'] * uvdist
+ blur = np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.)))
+ bessel_deriv = 0.5 * (sps.jv( 0, z) - sps.jv( 2, z))
+ val = np.array([ (1j * 2.0 * np.pi * params['x0'] - u/uvdist**2 - params['alpha']**2 * np.pi**2 * u/(2. * np.log(2.))) * vis
+ + params['F0'] * 2./z * np.pi * params['d'] * u/uvdist * bessel_deriv * blur * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ (1j * 2.0 * np.pi * params['y0'] - v/uvdist**2 - params['alpha']**2 * np.pi**2 * v/(2. * np.log(2.))) * vis
+ + params['F0'] * 2./z * np.pi * params['d'] * v/uvdist * bessel_deriv * blur * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ])
+ elif model_type == 'crescent':
+ phi = params['phi']
+ fr = params['fr']
+ fo = params['fo']
+ ff = params['ff']
+ r = params['d'] / 2.
+ params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'x0': params['x0'], 'y0': params['y0']}
+ params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)}
+ val = sample_1model_graduv_uv(u, v, 'disk', params0, pol=pol, jonesdict=jonesdict)
+ val -= sample_1model_graduv_uv(u, v, 'disk', params1, pol=pol, jonesdict=jonesdict)
+ elif model_type == 'blurred_crescent':
+ phi = params['phi']
+ fr = params['fr']
+ fo = params['fo']
+ ff = params['ff']
+ r = params['d'] / 2.
+ params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'alpha':params['alpha'], 'x0': params['x0'], 'y0': params['y0']}
+ params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'alpha':params['alpha'], 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)}
+ val = sample_1model_graduv_uv(u, v, 'blurred_disk', params0, pol=pol, jonesdict=jonesdict)
+ val -= sample_1model_graduv_uv(u, v, 'blurred_disk', params1, pol=pol, jonesdict=jonesdict)
+ elif model_type == 'ring':
+ # Take care of the degenerate origin point by a small offset
+ u += (u==0.)*(v==0.)*1e-10
+ uvdist = (u**2 + v**2 + (u==0.)*(v==0.)*1e-10)**0.5
+ z = np.pi * params['d'] * uvdist
+ val = np.array([ 1j * 2.0 * np.pi * params['x0'] * vis
+ - params['F0'] * np.pi*params['d']*u/uvdist * sps.jv(1, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ 1j * 2.0 * np.pi * params['y0'] * vis
+ - params['F0'] * np.pi*params['d']*v/uvdist * sps.jv(1, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ])
+ elif model_type == 'thick_ring':
+ uvdist = (u**2 + v**2)**0.5
+ #Add a small offset to avoid issues with division by zero
+ uvdist += (uvdist == 0.0) * 1e-10
+ z = np.pi * params['d'] * uvdist
+ val = np.array([ (1j * 2.0 * np.pi * params['x0'] - params['alpha']**2 * np.pi**2 * u/(2. * np.log(2.))) * vis
+ - params['F0'] * np.pi*params['d']*u/uvdist * sps.jv(1, z)
+ * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ (1j * 2.0 * np.pi * params['y0'] - params['alpha']**2 * np.pi**2 * v/(2. * np.log(2.))) * vis
+ - params['F0'] * np.pi*params['d']*v/uvdist * sps.jv(1, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+ * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ])
+ elif model_type == 'mring':
+ # Take care of the degenerate origin point by a small offset
+ u += (u==0.)*(v==0.)*1e-10
+ phi = np.angle(v + 1j*u)
+ # Flip the baseline sign to match eht-imaging conventions
+ phi += np.pi
+ uvdist = (u**2 + v**2 + (u==0.)*(v==0.)*1e-10)**0.5
+ dphidu = v/uvdist**2
+ dphidv = -u/uvdist**2
+ z = np.pi * params['d'] * uvdist
+
+ if pol == 'I':
+ beta_factor_u = (-np.pi * params['d'] * u/uvdist * sps.jv(1, z)
+ + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * (-1j * m * dphidu) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([params['beta_list'][m-1] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0))
+ beta_factor_v = (-np.pi * params['d'] * v/uvdist * sps.jv(1, z)
+ + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * (-1j * m * dphidv) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([params['beta_list'][m-1] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0))
+
+ elif pol == 'V' and len(params['beta_list_cpol']) > 0:
+ beta_factor_u = (-np.pi * params['d'] * u/uvdist * sps.jv(1, z) * np.real(params['beta_list_cpol'][0])
+ + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * (-1j * m * dphidu) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([params['beta_list_cpol'][m] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0))
+ beta_factor_v = (-np.pi * params['d'] * v/uvdist * sps.jv(1, z)
+ + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * (-1j * m * dphidv) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([params['beta_list_cpol'][m] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0))
+ elif pol == 'P' and len(params['beta_list_pol']) > 0:
+ num_coeff = len(params['beta_list_pol'])
+ beta_factor_u = (
+ np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)
+ + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0))
+ beta_factor_v = (
+ np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)
+ + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0))
+ else:
+ beta_factor_u = beta_factor_v = 0.0
+
+ val = np.array([
+ 1j * 2.0 * np.pi * params['x0'] * vis
+ + params['F0'] * beta_factor_u
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ 1j * 2.0 * np.pi * params['y0'] * vis
+ + params['F0'] * beta_factor_v
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ])
+ elif model_type == 'thick_mring':
+ # Take care of the degenerate origin point by a small offset
+ u += (u==0.)*(v==0.)*1e-10
+ phi = np.angle(v + 1j*u)
+ # Flip the baseline sign to match eht-imaging conventions
+ phi += np.pi
+ uvdist = (u**2 + v**2 + (u==0.)*(v==0.)*1e-10)**0.5
+ dphidu = v/uvdist**2
+ dphidv = -u/uvdist**2
+ z = np.pi * params['d'] * uvdist
+
+ if pol == 'I':
+ beta_factor_u = (-np.pi * params['d'] * u/uvdist * sps.jv(1, z)
+ + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * (-1j * m * dphidu) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([params['beta_list'][m-1] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0))
+ beta_factor_v = (-np.pi * params['d'] * v/uvdist * sps.jv(1, z)
+ + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * (-1j * m * dphidv) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([params['beta_list'][m-1] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0))
+
+ elif pol == 'V' and len(params['beta_list_cpol']) > 0:
+ beta_factor_u = (-np.pi * params['d'] * u/uvdist * sps.jv(1, z) * np.real(params['beta_list_cpol'][0])
+ + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * (-1j * m * dphidu) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([params['beta_list_cpol'][m] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0))
+ beta_factor_v = (-np.pi * params['d'] * v/uvdist * sps.jv(1, z)
+ + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * (-1j * m * dphidv) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([params['beta_list_cpol'][m] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0))
+ elif pol == 'P' and len(params['beta_list_pol']) > 0:
+ num_coeff = len(params['beta_list_pol'])
+ beta_factor_u = (0.0
+ + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)
+ + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0))
+ beta_factor_v = (0.0
+ + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)
+ + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0))
+ else:
+ beta_factor_u = beta_factor_v = 0.0
+
+ val = np.array([
+ (1j * 2.0 * np.pi * params['x0'] - params['alpha']**2 * np.pi**2 * u/(2. * np.log(2.))) * vis
+ + params['F0'] * beta_factor_u
+ * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.)))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ (1j * 2.0 * np.pi * params['y0'] - params['alpha']**2 * np.pi**2 * v/(2. * np.log(2.))) * vis
+ + params['F0'] * beta_factor_v
+ * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.)))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ])
+ elif model_type == 'thick_mring_floor':
+ val = (1.0 - params['ff']) * sample_1model_graduv_uv(u, v, 'thick_mring', params, pol=pol, jonesdict=jonesdict)
+ val += params['ff'] * sample_1model_graduv_uv(u, v, 'blurred_disk', params, pol=pol, jonesdict=jonesdict)
+ elif model_type == 'thick_mring_Gfloor':
+ val = (1.0 - params['ff']) * sample_1model_graduv_uv(u, v, 'thick_mring', params, pol=pol, jonesdict=jonesdict)
+ val += params['ff'] * sample_1model_graduv_uv(u, v, 'circ_gauss', params, pol=pol, jonesdict=jonesdict)
+ elif model_type[:9] == 'stretched':
+ # Take care of the degenerate origin point by a small offset
+ u += (u==0.)*(v==0.)*1e-10
+ params_stretch = params.copy()
+ params_stretch['x0'] = 0.0
+ params_stretch['y0'] = 0.0
+ (u_stretch, v_stretch) = stretch_uv(u,v,params)
+
+ # First calculate the gradient of the unshifted but stretched image
+ grad0 = sample_1model_graduv_uv(u_stretch, v_stretch, model_type[10:], params_stretch, pol=pol) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+ grad = grad0.copy() * 0.0
+ grad[0] = ( grad0[0] * (np.cos(params['stretch_PA'])**2 + np.sin(params['stretch_PA'])**2*params['stretch'])
+ + grad0[1] * ((params['stretch'] - 1.0) * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA'])))
+ grad[1] = ( grad0[1] * (np.cos(params['stretch_PA'])**2*params['stretch'] + np.sin(params['stretch_PA'])**2)
+ + grad0[0] * ((params['stretch'] - 1.0) * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA'])))
+
+ # Add the gradient term from the shift
+ vis = sample_1model_uv(u_stretch, v_stretch, model_type[10:], params_stretch, jonesdict=jonesdict)
+ grad[0] += vis * 1j * 2.0 * np.pi * params['x0'] * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+ grad[1] += vis * 1j * 2.0 * np.pi * params['y0'] * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+
+ val = grad
+ else:
+ print('Model ' + model_type + ' not recognized!')
+ val = 0.0
+ return val * get_const_polfac(model_type, params, pol)
+
+def sample_1model_grad_leakage_uv_re(u, v, model_type, params, pol, site, hand, jonesdict):
+ # Convenience function to calculate the gradient with respect to the real part of a specified site/hand leakage
+
+ # Define the various lists
+ fr1 = jonesdict['fr1'] # Field rotation of site 1
+ fr2 = jonesdict['fr2'] # Field rotation of site 2
+ DR1 = jonesdict['DR1'] # Right leakage term of site 1
+ DL1 = jonesdict['DL1'] # Left leakage term of site 1
+ DR2 = np.conj(jonesdict['DR2']) # Right leakage term of site 2
+ DL2 = np.conj(jonesdict['DL2']) # Left leakage term of site 2
+ # Sample the model without leakage
+ RR = sample_1model_uv(u, v, model_type, params, pol='RR')
+ RL = sample_1model_uv(u, v, model_type, params, pol='RL')
+ LR = sample_1model_uv(u, v, model_type, params, pol='LR')
+ LL = sample_1model_uv(u, v, model_type, params, pol='LL')
+
+ # Figure out which terms to include in the gradient
+ DR1mask = 0.0 + (hand == 'R') * (jonesdict['t1'] == site)
+ DR2mask = 0.0 + (hand == 'R') * (jonesdict['t2'] == site)
+ DL1mask = 0.0 + (hand == 'L') * (jonesdict['t1'] == site)
+ DL2mask = 0.0 + (hand == 'L') * (jonesdict['t2'] == site)
+
+ # These are the leakage gradient terms
+ RRp = LR * DR1mask * np.exp( 2j*fr1) + RL * DR2mask * np.exp(-2j*fr2) + LL * DR1mask * DR2 * np.exp( 2j*(fr1-fr2)) + LL * DR1 * DR2mask * np.exp( 2j*(fr1-fr2))
+ RLp = LL * DR1mask * np.exp( 2j*fr1) + RR * DL2mask * np.exp( 2j*fr2) + LR * DR1mask * DL2 * np.exp( 2j*(fr1+fr2)) + LR * DR1 * DL2mask * np.exp( 2j*(fr1+fr2))
+ LRp = RR * DL1mask * np.exp(-2j*fr1) + LL * DR2mask * np.exp(-2j*fr2) + RL * DL1mask * DR2 * np.exp(-2j*(fr1+fr2)) + RL * DL1 * DR2mask * np.exp(-2j*(fr1+fr2))
+ LLp = LR * DL2mask * np.exp( 2j*fr2) + RL * DL1mask * np.exp(-2j*fr1) + RR * DL1mask * DL2 * np.exp(-2j*(fr1-fr2)) + RR * DL1 * DL2mask * np.exp(-2j*(fr1-fr2))
+
+ # Return the specified polarization
+ if pol == 'RR': return RRp
+ elif pol == 'RL': return RLp
+ elif pol == 'LR': return LRp
+ elif pol == 'LL': return LLp
+ elif pol == 'I': return 0.5 * (RRp + LLp)
+ elif pol == 'Q': return 0.5 * (LRp + RLp)
+ elif pol == 'U': return 0.5j* (LRp - RLp)
+ elif pol == 'V': return 0.5 * (RRp - LLp)
+ elif pol == 'P': return RLp
+ else:
+ raise Exception('Polarization ' + pol + ' not recognized!')
+
+def sample_1model_grad_leakage_uv_im(u, v, model_type, params, pol, site, hand, jonesdict):
+ # Convenience function to calculate the gradient with respect to the imaginary part of a specified site/hand leakage
+ # The tricky thing here is the conjugation of the second leakage site, flipping the sign of the gradient
+
+ # Define the various lists
+ fr1 = jonesdict['fr1'] # Field rotation of site 1
+ fr2 = jonesdict['fr2'] # Field rotation of site 2
+ DR1 = jonesdict['DR1'] # Right leakage term of site 1
+ DL1 = jonesdict['DL1'] # Left leakage term of site 1
+ DR2 = np.conj(jonesdict['DR2']) # Right leakage term of site 2
+ DL2 = np.conj(jonesdict['DL2']) # Left leakage term of site 2
+ # Sample the model without leakage
+ RR = sample_1model_uv(u, v, model_type, params, pol='RR')
+ RL = sample_1model_uv(u, v, model_type, params, pol='RL')
+ LR = sample_1model_uv(u, v, model_type, params, pol='LR')
+ LL = sample_1model_uv(u, v, model_type, params, pol='LL')
+
+ # Figure out which terms to include in the gradient
+ DR1mask = 0.0 + (hand == 'R') * (jonesdict['t1'] == site)
+ DR2mask = 0.0 + (hand == 'R') * (jonesdict['t2'] == site)
+ DL1mask = 0.0 + (hand == 'L') * (jonesdict['t1'] == site)
+ DL2mask = 0.0 + (hand == 'L') * (jonesdict['t2'] == site)
+
+ # These are the leakage gradient terms
+ RRp = 1j*( LR * DR1mask * np.exp( 2j*fr1) - RL * DR2mask * np.exp(-2j*fr2) + LL * DR1mask * DR2 * np.exp( 2j*(fr1-fr2)) - LL * DR1 * DR2mask * np.exp( 2j*(fr1-fr2)))
+ RLp = 1j*( LL * DR1mask * np.exp( 2j*fr1) - RR * DL2mask * np.exp( 2j*fr2) + LR * DR1mask * DL2 * np.exp( 2j*(fr1+fr2)) - LR * DR1 * DL2mask * np.exp( 2j*(fr1+fr2)))
+ LRp = 1j*( RR * DL1mask * np.exp(-2j*fr1) - LL * DR2mask * np.exp(-2j*fr2) + RL * DL1mask * DR2 * np.exp(-2j*(fr1+fr2)) - RL * DL1 * DR2mask * np.exp(-2j*(fr1+fr2)))
+ LLp = 1j*(-LR * DL2mask * np.exp( 2j*fr2) + RL * DL1mask * np.exp(-2j*fr1) + RR * DL1mask * DL2 * np.exp(-2j*(fr1-fr2)) - RR * DL1 * DL2mask * np.exp(-2j*(fr1-fr2)))
+
+ # Return the specified polarization
+ if pol == 'RR': return RRp
+ elif pol == 'RL': return RLp
+ elif pol == 'LR': return LRp
+ elif pol == 'LL': return LLp
+ elif pol == 'I': return 0.5 * (RRp + LLp)
+ elif pol == 'Q': return 0.5 * (LRp + RLp)
+ elif pol == 'U': return 0.5j* (LRp - RLp)
+ elif pol == 'V': return 0.5 * (RRp - LLp)
+ elif pol == 'P': return RLp
+ else:
+ raise Exception('Polarization ' + pol + ' not recognized!')
+
+def sample_1model_grad_uv(u, v, model_type, params, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ # Gradient of the model for each model parameter
+
+ if jonesdict is not None:
+ # Define the various lists
+ fr1 = jonesdict['fr1'] # Field rotation of site 1
+ fr2 = jonesdict['fr2'] # Field rotation of site 2
+ DR1 = jonesdict['DR1'] # Right leakage term of site 1
+ DL1 = jonesdict['DL1'] # Left leakage term of site 1
+ DR2 = np.conj(jonesdict['DR2']) # Right leakage term of site 2
+ DL2 = np.conj(jonesdict['DL2']) # Left leakage term of site 2
+ # Sample the gradients without leakage
+ RR = sample_1model_grad_uv(u, v, model_type, params, pol='RR', fit_pol=fit_pol, fit_cpol=fit_cpol)
+ RL = sample_1model_grad_uv(u, v, model_type, params, pol='RL', fit_pol=fit_pol, fit_cpol=fit_cpol)
+ LR = sample_1model_grad_uv(u, v, model_type, params, pol='LR', fit_pol=fit_pol, fit_cpol=fit_cpol)
+ LL = sample_1model_grad_uv(u, v, model_type, params, pol='LL', fit_pol=fit_pol, fit_cpol=fit_cpol)
+ # Apply the Jones matrices
+ RRp = (RR + LR * DR1 * np.exp( 2j*fr1) + RL * DR2 * np.exp(-2j*fr2) + LL * DR1 * DR2 * np.exp( 2j*(fr1-fr2)))
+ RLp = (RL + LL * DR1 * np.exp( 2j*fr1) + RR * DL2 * np.exp( 2j*fr2) + LR * DR1 * DL2 * np.exp( 2j*(fr1+fr2)))
+ LRp = (LR + RR * DL1 * np.exp(-2j*fr1) + LL * DR2 * np.exp(-2j*fr2) + RL * DL1 * DR2 * np.exp(-2j*(fr1+fr2)))
+ LLp = (LL + LR * DL2 * np.exp( 2j*fr2) + RL * DL1 * np.exp(-2j*fr1) + RR * DL1 * DL2 * np.exp(-2j*(fr1-fr2)))
+ # Return the specified polarization
+ if pol == 'RR': grad = RRp
+ elif pol == 'RL': grad = RLp
+ elif pol == 'LR': grad = LRp
+ elif pol == 'LL': grad = LLp
+ elif pol == 'I': grad = 0.5 * (RRp + LLp)
+ elif pol == 'Q': grad = 0.5 * (LRp + RLp)
+ elif pol == 'U': grad = 0.5j* (LRp - RLp)
+ elif pol == 'V': grad = 0.5 * (RRp - LLp)
+ elif pol == 'P': grad = RLp
+ else:
+ raise Exception('Polarization ' + pol + ' not recognized!')
+ # If necessary, add the gradient components from the leakage terms
+ # Each leakage term has two corresponding gradient terms: d/dRe and d/dIm.
+ if fit_leakage:
+ # 'leakage_fit' is a list of tuples [site, 'R' or 'L'] denoting the fitted leakage terms
+ for (site, hand) in jonesdict['leakage_fit']:
+ grad = np.vstack([grad, sample_1model_grad_leakage_uv_re(u, v, model_type, params, pol, site, hand, jonesdict), sample_1model_grad_leakage_uv_im(u, v, model_type, params, pol, site, hand, jonesdict)])
+
+ return grad
+
+ if pol == 'Q':
+ return 0.5 * (sample_1model_grad_uv(u, v, model_type, params, pol='P', fit_pol=fit_pol, fit_cpol=fit_cpol) + np.conj(sample_1model_grad_uv(-u, -v, model_type, params, pol='P', fit_pol=fit_pol, fit_cpol=fit_cpol)))
+ elif pol == 'U':
+ return -0.5j * (sample_1model_grad_uv(u, v, model_type, params, pol='P', fit_pol=fit_pol, fit_cpol=fit_cpol) - np.conj(sample_1model_grad_uv(-u, -v, model_type, params, pol='P', fit_pol=fit_pol, fit_cpol=fit_cpol)))
+ elif pol in ['I','V','P']:
+ pass
+ elif pol == 'RR':
+ return sample_1model_grad_uv(u, v, model_type, params, pol='I', fit_pol=fit_pol, fit_cpol=fit_cpol) + sample_1model_grad_uv(u, v, model_type, params, pol='V', fit_pol=fit_pol, fit_cpol=fit_cpol)
+ elif pol == 'LL':
+ return sample_1model_grad_uv(u, v, model_type, params, pol='I', fit_pol=fit_pol, fit_cpol=fit_cpol) - sample_1model_grad_uv(u, v, model_type, params, pol='V', fit_pol=fit_pol, fit_cpol=fit_cpol)
+ elif pol == 'RL':
+ return sample_1model_grad_uv(u, v, model_type, params, pol='Q', fit_pol=fit_pol, fit_cpol=fit_cpol) + 1j*sample_1model_grad_uv(u, v, model_type, params, pol='U', fit_pol=fit_pol, fit_cpol=fit_cpol)
+ elif pol == 'LR':
+ return sample_1model_grad_uv(u, v, model_type, params, pol='Q', fit_pol=fit_pol, fit_cpol=fit_cpol) - 1j*sample_1model_grad_uv(u, v, model_type, params, pol='U', fit_pol=fit_pol, fit_cpol=fit_cpol)
+ else:
+ raise Exception('Polarization ' + pol + ' not implemented!')
+
+ if model_type == 'point': # F0, x0, y0
+ val = np.array([ np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ 1j * 2.0 * np.pi * u * params['F0'] * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ 1j * 2.0 * np.pi * v * params['F0'] * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))])
+ elif model_type == 'circ_gauss': # F0, FWHM, x0, y0
+ gauss = (params['F0'] * np.exp(-np.pi**2/(4.*np.log(2.)) * (u**2 + v**2) * params['FWHM']**2)
+ *np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ val = np.array([ 1.0/params['F0'] * gauss,
+ -np.pi**2/(2.*np.log(2.)) * (u**2 + v**2) * params['FWHM'] * gauss,
+ 1j * 2.0 * np.pi * u * gauss,
+ 1j * 2.0 * np.pi * v * gauss])
+ elif model_type == 'gauss': # F0, FWHM_maj, FWHM_min, PA, x0, y0
+ u_maj = u*np.sin(params['PA']) + v*np.cos(params['PA'])
+ u_min = u*np.cos(params['PA']) - v*np.sin(params['PA'])
+ vis = (params['F0']
+ * np.exp(-np.pi**2/(4.*np.log(2.)) * ((u_maj * params['FWHM_maj'])**2 + (u_min * params['FWHM_min'])**2))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ val = np.array([ 1.0/params['F0'] * vis,
+ -np.pi**2/(2.*np.log(2.)) * params['FWHM_maj'] * u_maj**2 * vis,
+ -np.pi**2/(2.*np.log(2.)) * params['FWHM_min'] * u_min**2 * vis,
+ -np.pi**2/(2.*np.log(2.)) * (params['FWHM_maj']**2 - params['FWHM_min']**2) * u_maj * u_min * vis,
+ 1j * 2.0 * np.pi * u * vis,
+ 1j * 2.0 * np.pi * v * vis])
+ elif model_type == 'disk': # F0, d, x0, y0
+ z = np.pi * params['d'] * (u**2 + v**2)**0.5
+ #Add a small offset to avoid issues with division by zero
+ z += (z == 0.0) * 1e-10
+ vis = (params['F0'] * 2.0/z * sps.jv(1, z)
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ val = np.array([ 1.0/params['F0'] * vis,
+ -(params['F0'] * 2.0/z * sps.jv(2, z) * np.pi * (u**2 + v**2)**0.5 * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) ,
+ 1j * 2.0 * np.pi * u * vis,
+ 1j * 2.0 * np.pi * v * vis])
+ elif model_type == 'blurred_disk': # F0, d, alpha, x0, y0
+ z = np.pi * params['d'] * (u**2 + v**2)**0.5
+ #Add a small offset to avoid issues with division by zero
+ z += (z == 0.0) * 1e-10
+ blur = np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.)))
+ vis = (params['F0'] * 2.0/z * sps.jv(1, z)
+ * blur
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ val = np.array([ 1.0/params['F0'] * vis,
+ -params['F0'] * 2.0/z * sps.jv(2, z) * np.pi * (u**2 + v**2)**0.5 * blur * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ -np.pi**2 * (u**2 + v**2) * params['alpha']/(2.*np.log(2.)) * vis,
+ 1j * 2.0 * np.pi * u * vis,
+ 1j * 2.0 * np.pi * v * vis])
+ elif model_type == 'crescent': #['F0','d', 'fr', 'fo', 'phi','x0','y0']
+ phi = params['phi']
+ fr = params['fr']
+ fo = params['fo']
+ ff = params['ff']
+ r = params['d'] / 2.
+ params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'x0': params['x0'], 'y0': params['y0']}
+ params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)}
+
+ grad0 = sample_1model_grad_uv(u, v, 'disk', params0, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ grad1 = sample_1model_grad_uv(u, v, 'disk', params1, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+
+ # Add the derivatives one by one
+ grad = []
+
+ # F0
+ grad.append( 1.0/(1.0-(1-ff)*fr**2)*grad0[0] - (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*grad1[0] )
+
+ # d
+ grad.append( grad0[1] - fr*grad1[1] - 0.5 * (1.0 - fr) * fo * (np.sin(phi) * grad1[2] + np.cos(phi) * grad1[3]) )
+
+ # fr
+ grad.append( 2.0*params['F0']*(1-ff)*fr/(1.0 - (1-ff)*fr**2)**2 * (grad0[0] - grad1[0]) - params['d']*grad1[1] + r * fo * (np.sin(phi) * grad1[2] + np.cos(phi) * grad1[3]) )
+
+ # fo
+ grad.append( -r * (1-fr) * (np.sin(phi) * grad1[2] + np.cos(phi) * grad1[3]) )
+
+ # ff
+ grad.append( -params['F0']*fr**2/(1.0 - (1-ff)*fr**2)**2 * (grad0[0] - grad1[0]) )
+
+ # phi
+ grad.append( -r*(1-fr)*fo* (np.cos(phi) * grad1[2] - np.sin(phi) * grad1[3]) )
+
+ # x0, y0
+ grad.append( grad0[2] - grad1[2] )
+ grad.append( grad0[3] - grad1[3] )
+
+ val = np.array(grad)
+ elif model_type == 'blurred_crescent': #['F0','d','alpha','fr', 'fo', 'phi','x0','y0']
+ phi = params['phi']
+ fr = params['fr']
+ fo = params['fo']
+ ff = params['ff']
+ r = params['d'] / 2.
+ params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'alpha':params['alpha'], 'x0': params['x0'], 'y0': params['y0']}
+ params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'alpha':params['alpha'], 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)}
+
+ grad0 = sample_1model_grad_uv(u, v, 'blurred_disk', params0, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ grad1 = sample_1model_grad_uv(u, v, 'blurred_disk', params1, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+
+ # Add the derivatives one by one
+ grad = []
+
+ # F0
+ grad.append( 1.0/(1.0-(1-ff)*fr**2)*grad0[0] - (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*grad1[0] )
+
+ # d
+ grad.append( grad0[1] - fr*grad1[1] - 0.5 * (1.0 - fr) * fo * (np.sin(phi) * grad1[3] + np.cos(phi) * grad1[4]) )
+
+ # alpha
+ grad.append( grad0[2] - grad1[2] )
+
+ # fr
+ grad.append( 2.0*params['F0']*(1-ff)*fr/(1.0 - (1-ff)*fr**2)**2 * (grad0[0] - grad1[0]) - params['d']*grad1[1] + r * fo * (np.sin(phi) * grad1[3] + np.cos(phi) * grad1[4]) )
+
+ # fo
+ grad.append( -r * (1-fr) * (np.sin(phi) * grad1[3] + np.cos(phi) * grad1[4]) )
+
+ # ff
+ grad.append( -params['F0']*fr**2/(1.0 - (1-ff)*fr**2)**2 * (grad0[0] - grad1[0]) )
+
+ # phi
+ grad.append( -r*(1-fr)*fo* (np.cos(phi) * grad1[3] - np.sin(phi) * grad1[4]) )
+
+ # x0, y0
+ grad.append( grad0[3] - grad1[3] )
+ grad.append( grad0[4] - grad1[4] )
+
+ val = np.array(grad)
+
+ elif model_type == 'ring': # F0, d, x0, y0
+ z = np.pi * params['d'] * (u**2 + v**2)**0.5
+ val = np.array([ sps.jv(0, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ -np.pi * (u**2 + v**2)**0.5 * params['F0'] * sps.jv(1, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ 2.0 * np.pi * 1j * u * params['F0'] * sps.jv(0, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])),
+ 2.0 * np.pi * 1j * v * params['F0'] * sps.jv(0, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))])
+ elif model_type == 'thick_ring': # F0, d, alpha, x0, y0
+ z = np.pi * params['d'] * (u**2 + v**2)**0.5
+ vis = (params['F0'] * sps.jv(0, z)
+ * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.)))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ val = np.array([ 1.0/params['F0'] * vis,
+ -(params['F0'] * np.pi * (u**2 + v**2)**0.5 * sps.jv(1, z)
+ * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.)))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))),
+ -np.pi**2 * (u**2 + v**2) * params['alpha']/(2.*np.log(2.)) * vis,
+ 1j * 2.0 * np.pi * u * vis,
+ 1j * 2.0 * np.pi * v * vis])
+ elif model_type in ['mring','thick_mring']: # F0, d, [alpha], x0, y0, beta1_re/abs, beta1_im/arg, beta2_re/abs, beta2_im/arg, ...
+ phi = np.angle(v + 1j*u)
+ # Flip the baseline sign to match eht-imaging conventions
+ phi += np.pi
+ uvdist = (u**2 + v**2)**0.5
+ z = np.pi * params['d'] * uvdist
+ if model_type == 'thick_mring':
+ alpha_factor = np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.)))
+ else:
+ alpha_factor = 1
+
+ # Only one of the beta_lists will affect the measurement and have non-zero gradients. Figure out which:
+ # These are for the derivatives wrt diameter
+ if pol == 'I':
+ beta_factor = (-np.pi * uvdist * sps.jv(1, z)
+ + np.sum([params['beta_list'][m-1] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)
+ + np.sum([np.conj(params['beta_list'][m-1]) * 0.5 * (sps.jv( -m-1, z) - sps.jv( -m+1, z)) * np.pi * uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0))
+ elif pol == 'P' and len(params['beta_list_pol']) > 0:
+ num_coeff = len(params['beta_list_pol'])
+ beta_factor = np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)
+ elif pol == 'V' and len(params['beta_list_cpol']) > 0:
+ beta_factor = (-np.pi * uvdist * sps.jv(1, z) * np.real(params['beta_list_cpol'][0])
+ + np.sum([params['beta_list_cpol'][m] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)
+ + np.sum([np.conj(params['beta_list_cpol'][m]) * 0.5 * (sps.jv( -m-1, z) - sps.jv( -m+1, z)) * np.pi * uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0))
+ else:
+ beta_factor = 0.0
+
+ vis = sample_1model_uv(u, v, model_type, params, pol=pol, jonesdict=jonesdict)
+ grad = [ 1.0/params['F0'] * vis,
+ (params['F0'] * alpha_factor * beta_factor * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))]
+ if model_type == 'thick_mring':
+ grad.append(-np.pi**2/(2.*np.log(2)) * uvdist**2 * params['alpha'] * vis)
+ grad.append(1j * 2.0 * np.pi * u * vis)
+ grad.append(1j * 2.0 * np.pi * v * vis)
+
+ if pol=='I':
+ # Add derivatives of the beta terms
+ for m in range(1,len(params['beta_list'])+1):
+ beta_grad_re = params['F0'] * alpha_factor * (
+ sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) + sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ beta_grad_im = 1j * params['F0'] * alpha_factor * (
+ sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) - sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ if COMPLEX_BASIS == 're-im':
+ grad.append(beta_grad_re)
+ grad.append(beta_grad_im)
+ elif COMPLEX_BASIS == 'abs-arg':
+ beta_abs = np.abs(params['beta_list'][m-1])
+ beta_arg = np.angle(params['beta_list'][m-1])
+ grad.append(beta_grad_re * np.cos(beta_arg) + beta_grad_im * np.sin(beta_arg))
+ grad.append(-beta_abs * np.sin(beta_arg) * beta_grad_re + beta_abs * np.cos(beta_arg) * beta_grad_im)
+ else:
+ raise Exception('COMPLEX_BASIS ' + COMPLEX_BASIS + ' not recognized!')
+ else:
+ [grad.append(np.zeros_like(grad[0])) for _ in range(2*len(params['beta_list']))]
+
+ if pol=='P' and fit_pol:
+ # Add derivatives of the beta_pol terms
+ num_coeff = len(params['beta_list_pol'])
+ for m in range(-(num_coeff-1)//2,(num_coeff+1)//2):
+ beta_grad_re = params['F0'] * alpha_factor * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+ beta_grad_im = 1j * beta_grad_re
+ if COMPLEX_BASIS == 're-im':
+ grad.append(beta_grad_re)
+ grad.append(beta_grad_im)
+ elif COMPLEX_BASIS == 'abs-arg':
+ beta_abs = np.abs(params['beta_list_pol'][m+(num_coeff-1)//2])
+ beta_arg = np.angle(params['beta_list_pol'][m+(num_coeff-1)//2])
+ grad.append(beta_grad_re * np.cos(beta_arg) + beta_grad_im * np.sin(beta_arg))
+ grad.append(-beta_abs * np.sin(beta_arg) * beta_grad_re + beta_abs * np.cos(beta_arg) * beta_grad_im)
+ else:
+ raise Exception('COMPLEX_BASIS ' + COMPLEX_BASIS + ' not recognized!')
+ elif pol!='P' and fit_pol:
+ [grad.append(np.zeros_like(grad[0])) for _ in range(2*len(params['beta_list_pol']))]
+
+ if pol=='V' and fit_cpol:
+ # Add derivatives of the beta_cpol terms
+ num_coeff = len(params['beta_list_cpol']) - 1
+
+ # First do the beta0 mode (real)
+ beta_grad_re = params['F0'] * alpha_factor * sps.jv( 0, z)
+ grad.append(beta_grad_re)
+
+ # Now do the remaining modes (complex)
+ for m in range(1,num_coeff+1):
+ beta_grad_re = params['F0'] * alpha_factor * (
+ sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) + sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ beta_grad_im = 1j * params['F0'] * alpha_factor * (
+ sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) - sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.))
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+ if COMPLEX_BASIS == 're-im':
+ grad.append(beta_grad_re)
+ grad.append(beta_grad_im)
+ elif COMPLEX_BASIS == 'abs-arg':
+ beta_abs = np.abs(params['beta_list_cpol'][m])
+ beta_arg = np.angle(params['beta_list_cpol'][m])
+ grad.append(beta_grad_re * np.cos(beta_arg) + beta_grad_im * np.sin(beta_arg))
+ grad.append(-beta_abs * np.sin(beta_arg) * beta_grad_re + beta_abs * np.cos(beta_arg) * beta_grad_im)
+ else:
+ raise Exception('COMPLEX_BASIS ' + COMPLEX_BASIS + ' not recognized!')
+ elif pol!='V' and fit_cpol:
+ [grad.append(np.zeros_like(grad[0])) for _ in range(2*len(params['beta_list_cpol'])-1)]
+
+ val = np.array(grad)
+ elif model_type == 'thick_mring_floor': # F0, d, [alpha], ff, x0, y0, beta1_re/abs, beta1_im/arg, beta2_re/abs, beta2_im/arg, ...
+ # We need to stich together the two gradients for the mring and the disk; we also need to add the gradient for the floor fraction ff
+ grad_mring = (1.0 - params['ff']) * sample_1model_grad_uv(u, v, 'thick_mring', params, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol)
+ grad_disk = params['ff'] * sample_1model_grad_uv(u, v, 'blurred_disk', params, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol)
+
+ # mring: F0, d, alpha, x0, y0, beta1_re/abs, beta1_im/arg, beta2_re/abs, beta2_im/arg, ...
+ # disk: F0, d, alpha, x0, y0
+
+ # Here are derivatives for F0, d, and alpha
+ grad = []
+ for j in range(3):
+ grad.append(grad_mring[j] + grad_disk[j])
+
+ # Here is the derivative for ff
+ grad.append( params['F0'] * (grad_disk[0]/params['ff'] - grad_mring[0]/(1.0 - params['ff'])) )
+
+ # Now the derivatives for x0 and y0
+ grad.append(grad_mring[3] + grad_disk[3])
+ grad.append(grad_mring[4] + grad_disk[4])
+
+ # Add remaining gradients for the mring
+ for j in range(5,len(grad_mring)):
+ grad.append(grad_mring[j])
+
+ val = np.array(grad)
+ elif model_type == 'thick_mring_Gfloor': # F0, d, [alpha], ff, FWHM, x0, y0, beta1_re/abs, beta1_im/arg, beta2_re/abs, beta2_im/arg, ...
+ # We need to stich together the two gradients for the mring and the gaussian; we also need to add the gradient for the floor fraction ff
+ grad_mring = (1.0 - params['ff']) * sample_1model_grad_uv(u, v, 'thick_mring', params, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol)
+ grad_gauss = params['ff'] * sample_1model_grad_uv(u, v, 'circ_gauss', params, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol)
+
+ # mring: F0, d, alpha, x0, y0, beta1_re/abs, beta1_im/arg, beta2_re/abs, beta2_im/arg, ...
+ # gauss: F0, [d, alpha] FWHM, x0, y0
+
+ grad = []
+ grad.append(grad_mring[0] + grad_gauss[0]) # Here are derivatives for F0
+
+ # Here are derivatives for d, and alpha
+ grad.append(grad_mring[1])
+ grad.append(grad_mring[2])
+
+ # Here is the derivative for ff
+ grad.append( params['F0'] * (grad_gauss[0]/params['ff'] - grad_mring[0]/(1.0 - params['ff'])) )
+
+ # Now the derivatives for FWHM
+ grad.append(grad_gauss[1])
+
+ # Now the derivatives for x0 and y0
+ grad.append(grad_mring[3] + grad_gauss[2])
+ grad.append(grad_mring[4] + grad_gauss[3])
+
+ # Add remaining gradients for the mring
+ for j in range(5,len(grad_mring)):
+ grad.append(grad_mring[j])
+
+ val = np.array(grad)
+ elif model_type[:9] == 'stretched':
+ # Start with the model visibility
+ vis = sample_1model_uv(u, v, model_type, params, pol=pol, jonesdict=jonesdict)
+
+ # Next, calculate the gradient wrt model parameters other than stretch and stretch_PA
+ # These are the same as the gradient of the unstretched model on stretched baselines
+ params_stretch = params.copy()
+ params_stretch['x0'] = 0.0
+ params_stretch['y0'] = 0.0
+ (u_stretch, v_stretch) = stretch_uv(u,v,params)
+ grad = (sample_1model_grad_uv(u_stretch, v_stretch, model_type[10:], params_stretch, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol)
+ * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))
+
+ # Add the gradient terms for the centroid
+ grad[model_params(model_type, params).index('x0')] = 1j * 2.0 * np.pi * u * vis
+ grad[model_params(model_type, params).index('y0')] = 1j * 2.0 * np.pi * v * vis
+
+ # Now calculate the gradient with respect to stretch and stretch PA
+ grad_uv = sample_1model_graduv_uv(u_stretch, v_stretch, model_type[10:], params_stretch, pol=pol)
+ grad_stretch = grad_uv.copy() * 0.0
+ grad_stretch[0] = ( grad_uv[0] * (u * np.sin(params['stretch_PA'])**2 + v * np.sin(params['stretch_PA']) * np.cos(params['stretch_PA']))
+ + grad_uv[1] * (v * np.cos(params['stretch_PA'])**2 + u * np.sin(params['stretch_PA']) * np.cos(params['stretch_PA'])))
+ grad_stretch[0] *= np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+
+ grad_stretch[1] = ( grad_uv[0] * (params['stretch'] - 1.0) * ( u * np.sin(2.0 * params['stretch_PA']) + v * np.cos(2.0 * params['stretch_PA']))
+ + grad_uv[1] * (params['stretch'] - 1.0) * (-v * np.sin(2.0 * params['stretch_PA']) + u * np.cos(2.0 * params['stretch_PA'])))
+ grad_stretch[1] *= np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))
+
+ val = np.concatenate([grad, grad_stretch])
+ else:
+ print('Model ' + model_type + ' not recognized!')
+ val = 0.0
+
+ grad = val * get_const_polfac(model_type, params, pol)
+
+ if (fit_pol or fit_cpol) and model_type.find('mring') == -1:
+ # Add gradient contributions for models that have constant polarization
+ if fit_pol:
+ # Add gradient wrt pol_frac if the polarization is P, otherwise ignore
+ grad_params = copy.deepcopy(params)
+ grad_params['pol_frac'] = 1.0
+ grad = np.vstack([grad, (pol == 'P') * sample_1model_uv(u, v, model_type, grad_params, pol=pol, jonesdict=jonesdict)])
+
+ # Add gradient wrt pol_evpa if the polarization is P, otherwise ignore
+ grad_params = copy.deepcopy(params)
+ grad_params['pol_frac'] *= 2j
+ grad = np.vstack([grad, (pol == 'P') * sample_1model_uv(u, v, model_type, grad_params, pol=pol, jonesdict=jonesdict)])
+ if fit_cpol:
+ # Add gradient wrt cpol_frac
+ grad_params = copy.deepcopy(params)
+ grad_params['cpol_frac'] = 1.0
+ grad = np.vstack([grad, (pol == 'V') * sample_1model_uv(u, v, model_type, grad_params, pol=pol, jonesdict=jonesdict)])
+
+ return grad
+
+def sample_model_xy(models, params, x, y, psize=1.*RADPERUAS, pol='I'):
+ return np.sum(sample_1model_xy(x, y, models[j], params[j], psize=psize,pol=pol) for j in range(len(models)))
+
+def sample_model_uv(models, params, u, v, pol='I', jonesdict=None):
+ return np.sum(sample_1model_uv(u, v, models[j], params[j], pol=pol, jonesdict=jonesdict) for j in range(len(models)))
+
+def sample_model_graduv_uv(models, params, u, v, pol='I', jonesdict=None):
+ # Gradient of a sum of models wrt (u,v)
+ return np.sum([sample_1model_graduv_uv(u, v, models[j], params[j], pol=pol, jonesdict=jonesdict) for j in range(len(models))],axis=0)
+
+def sample_model_grad_uv(models, params, u, v, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ # Gradient of a sum of models for each parameter
+ if fit_leakage == False:
+ return np.concatenate([sample_1model_grad_uv(u, v, models[j], params[j], pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) for j in range(len(models))])
+ else:
+ # Need to sum the leakage contributions
+ allgrad = [sample_1model_grad_uv(u, v, models[j], params[j], pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) for j in range(len(models))]
+ n_leakage = len(jonesdict['leakage_fit'])*2
+ grad = np.concatenate([allgrad[j][:-n_leakage] for j in range(len(models))])
+ grad_leakage = np.sum([allgrad[j][-n_leakage:] for j in range(len(models))],axis=0)
+ return np.concatenate([grad, grad_leakage])
+
+def blur_circ_1model(model_type, params, fwhm):
+ """Blur a single model, returning new model type and associated parameters
+
+ Args:
+ fwhm (float) : Full width at half maximum of the kernel (radians)
+
+ Returns:
+ (dict) : Dictionary with new 'model_type' and new 'params'
+ """
+
+ model_type_blur = model_type
+ params_blur = params.copy()
+
+ if model_type == 'point':
+ model_type_blur = 'circ_gauss'
+ params_blur['FWHM'] = fwhm
+ elif model_type == 'circ_gauss':
+ params_blur['FWHM'] = (params_blur['FWHM']**2 + fwhm**2)**0.5
+ elif model_type == 'gauss':
+ params_blur['FWHM_maj'] = (params_blur['FWHM_maj']**2 + fwhm**2)**0.5
+ params_blur['FWHM_min'] = (params_blur['FWHM_min']**2 + fwhm**2)**0.5
+ elif 'thick' in model_type or 'blurred' in model_type:
+ params_blur['alpha'] = (params_blur['alpha']**2 + fwhm**2)**0.5
+ elif model_type == 'disk':
+ model_type_blur = 'blurred_' + model_type
+ params_blur['alpha'] = fwhm
+ elif model_type == 'crescent':
+ model_type_blur = 'blurred_' + model_type
+ params_blur['alpha'] = fwhm
+ elif model_type == 'ring' or model_type == 'mring':
+ model_type_blur = 'thick_' + model_type
+ params_blur['alpha'] = fwhm
+ elif model_type == 'stretched_ring' or model_type == 'stretched_mring':
+ model_type_blur = 'stretched_thick_' + model_type[10:]
+ params_blur['alpha'] = fwhm
+ else:
+ raise Exception("A blurred " + model_type + " is not yet a supported model!")
+
+ return {'model_type':model_type_blur, 'params':params_blur}
+
+class Model(object):
+ """A model with analytic representations in the image and visibility domains.
+
+ Attributes:
+ """
+
+ def __init__(self, ra=RA_DEFAULT, dec=DEC_DEFAULT, pa=0.0,
+ polrep='stokes', pol_prim=None,
+ rf=RF_DEFAULT, source=SOURCE_DEFAULT,
+ mjd=MJD_DEFAULT, time=0.):
+
+ """A model with analytic representations in the image and visibility domains.
+
+ Args:
+
+ Returns:
+ """
+
+ # The model is a sum of component models, each defined by a tag and associated parameters
+ self.pol_prim = 'I'
+ self.polrep = 'stokes'
+ self._imdict = {'I':{'models':[],'params':[]},'Q':{'models':[],'params':[]},'U':{'models':[],'params':[]},'V':{'models':[],'params':[]}}
+
+ # Save the image metadata
+ self.ra = float(ra)
+ self.dec = float(dec)
+ self.pa = float(pa)
+ self.rf = float(rf)
+ self.source = str(source)
+ self.mjd = int(mjd)
+ if time > 24:
+ self.mjd += int((time - time % 24)/24)
+ self.time = float(time % 24)
+ else:
+ self.time = time
+
+ @property
+ def models(self):
+ return self._imdict[self.pol_prim]['models']
+
+ @models.setter
+ def models(self, model_list):
+ self._imdict[self.pol_prim]['models'] = model_list
+
+ @property
+ def params(self):
+ return self._imdict[self.pol_prim]['params']
+
+ @params.setter
+ def params(self, param_list):
+ self._imdict[self.pol_prim]['params'] = param_list
+
+ def copy(self):
+ """Return a copy of the Model object.
+
+ Args:
+
+ Returns:
+ (Model): copy of the Model.
+ """
+ out = Model(ra=self.ra, dec=self.dec, pa=self.pa, polrep=self.polrep, pol_prim=self.pol_prim,rf=self.rf,source=self.source,mjd=self.mjd,time=self.time)
+ out.models = copy.deepcopy(self.models)
+ out.params = copy.deepcopy(self.params.copy())
+ return out
+
+ def switch_polrep(self, polrep_out='stokes', pol_prim_out=None):
+
+ """Return a new model with the polarization representation changed
+ Args:
+ polrep_out (str): the polrep of the output data
+ pol_prim_out (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for circ
+
+ Returns:
+ (Model): new Model object with potentially different polrep
+ """
+
+ # Note: this currently does nothing, but it is put here for compatibility with functions such as selfcal
+ if polrep_out not in ['stokes','circ']:
+ raise Exception("polrep_out must be either 'stokes' or 'circ'")
+ if pol_prim_out is None:
+ if polrep_out=='stokes': pol_prim_out = 'I'
+ elif polrep_out=='circ': pol_prim_out = 'RR'
+
+ return self.copy()
+
+ def N_models(self):
+ """Return the number of model components
+
+ Args:
+
+ Returns:
+ (int): number of model components
+ """
+ return len(self.models)
+
+ def total_flux(self):
+ """Return the total flux of the model in Jy.
+
+ Args:
+
+ Returns:
+ (float) : model total flux (Jy)
+ """
+ return np.real(self.sample_uv(0,0))
+
+ def blur_circ(self, fwhm):
+ """Return a new model, equal to the current one convolved with a circular Gaussian kernel
+
+ Args:
+ fwhm (float) : Full width at half maximum of the kernel (radians)
+
+ Returns:
+ (Model) : Blurred model
+ """
+
+ out = self.copy()
+
+ for j in range(len(out.models)):
+ blur_model = blur_circ_1model(out.models[j], out.params[j], fwhm)
+ out.models[j] = blur_model['model_type']
+ out.params[j] = blur_model['params']
+
+ return out
+
+ def add_point(self, F0 = 1.0, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+ """Add a point source model.
+
+ Args:
+ F0 (float): The total flux of the point source (Jy)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+
+ out = self.copy()
+ out.models.append('point')
+ out.params.append({'F0':F0,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac})
+ return out
+
+ def add_circ_gauss(self, F0 = 1.0, FWHM = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+ """Add a circular Gaussian model.
+
+ Args:
+ F0 (float): The total flux of the Gaussian (Jy)
+ FWHM (float): The FWHM of the Gaussian (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+ out = self.copy()
+ out.models.append('circ_gauss')
+ out.params.append({'F0':F0,'FWHM':FWHM,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac})
+ return out
+
+ def add_gauss(self, F0 = 1.0, FWHM_maj = 50.*RADPERUAS, FWHM_min = 50.*RADPERUAS, PA = 0.0, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+ """Add an anisotropic Gaussian model.
+
+ Args:
+ F0 (float): The total flux of the Gaussian (Jy)
+ FWHM_maj (float): The FWHM of the Gaussian major axis (radians)
+ FWHM_min (float): The FWHM of the Gaussian minor axis (radians)
+ PA (float): Position angle of the major axis, east of north (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+ out = self.copy()
+ out.models.append('gauss')
+ out.params.append({'F0':F0,'FWHM_maj':FWHM_maj,'FWHM_min':FWHM_min,'PA':PA,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac})
+ return out
+
+ def add_disk(self, F0 = 1.0, d = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+ """Add a circular disk model.
+
+ Args:
+ F0 (float): The total flux of the disk (Jy)
+ d (float): The diameter (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+ out = self.copy()
+ out.models.append('disk')
+ out.params.append({'F0':F0,'d':d,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac})
+ return out
+
+ def add_blurred_disk(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+ """Add a circular disk model that is blurred with a circular Gaussian kernel.
+
+ Args:
+ F0 (float): The total flux of the disk (Jy)
+ d (float): The diameter (radians)
+ alpha (float): The blurring (FWHM of Gaussian convolution) (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+ out = self.copy()
+ out.models.append('blurred_disk')
+ out.params.append({'F0':F0,'d':d,'alpha':alpha,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac})
+ return out
+
+ def add_crescent(self, F0 = 1.0, d = 50.*RADPERUAS, fr = 0.0, fo = 0.0, ff = 0.0, phi = 0.0, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+ """Add a crescent model.
+
+ Args:
+ F0 (float): The total flux of the disk (Jy)
+ d (float): The diameter (radians)
+ fr (float): Fractional radius of the inner subtracted disk with respect to the radius of the outer disk
+ fo (float): Fractional offset of the inner disk from the center of the outer disk
+ ff (float): Fractional brightness of the inner disk
+ phi (float): angle of offset of the inner disk
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+ out = self.copy()
+ out.models.append('crescent')
+ out.params.append({'F0':F0,'d':d,'fr':fr, 'fo':fo, 'ff':ff, 'phi':phi, 'x0':x0, 'y0':y0, 'pol_frac':pol_frac, 'pol_evpa':pol_evpa, 'cpol_frac':cpol_frac})
+ return out
+
+ def add_blurred_crescent(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, fr = 0.0, fo = 0.0, ff = 0.0, phi = 0.0, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+
+ """Add a circular disk model that is blurred with a circular Gaussian kernel.
+
+ Args:
+ F0 (float): The total flux of the disk (Jy)
+ d (float): The diameter (radians)
+ alpha (float) :The blurring (FWHM of Gaussian convolution) (radians)
+ fr (float): Fractional radius of the inner subtracted disk with respect to the radius of the outer disk
+ fo (float): Fractional offset of the inner disk from the center of the outer disk
+ ff (float): Fractional brightness of the inner disk
+ phi (float): angle of offset of the inner disk
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+ out = self.copy()
+ out.models.append('blurred_crescent')
+ out.params.append({'F0':F0,'d':d,'alpha':alpha, 'fr':fr, 'fo':fo, 'ff':ff, 'phi':phi, 'x0':x0, 'y0':y0, 'pol_frac':pol_frac, 'pol_evpa':pol_evpa, 'cpol_frac':cpol_frac})
+ return out
+
+ def add_ring(self, F0 = 1.0, d = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+ """Add a ring model with infinitesimal thickness.
+
+ Args:
+ F0 (float): The total flux of the ring (Jy)
+ d (float): The diameter (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+ out = self.copy()
+ out.models.append('ring')
+ out.params.append({'F0':F0,'d':d,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac})
+ return out
+
+ def add_stretched_ring(self, F0 = 1.0, d = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, stretch = 1.0, stretch_PA = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+ """Add a stretched ring model with infinitesimal thickness.
+
+ Args:
+ F0 (float): The total flux of the ring (Jy)
+ d (float): The diameter (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+ stretch (float): The stretch to apply (1.0 = no stretch)
+ stretch_PA (float): Position angle of the stretch, east of north (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+ out = self.copy()
+ out.models.append('stretched_ring')
+ out.params.append({'F0':F0,'d':d,'x0':x0,'y0':y0,'stretch':stretch,'stretch_PA':stretch_PA,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac})
+ return out
+
+ def add_thick_ring(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+ """Add a ring model with finite thickness, determined by circular Gaussian convolution of a thin ring.
+ For details, see Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf
+
+ Args:
+ F0 (float): The total flux of the ring (Jy)
+ d (float): The ring diameter (radians)
+ alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+ out = self.copy()
+ out.models.append('thick_ring')
+ out.params.append({'F0':F0,'d':d,'alpha':alpha,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac})
+ return out
+
+ def add_stretched_thick_ring(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, x0 = 0.0, y0 = 0.0, stretch = 1.0, stretch_PA = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0):
+ """Add a ring model with finite thickness, determined by circular Gaussian convolution of a thin ring.
+ For details, see Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf
+
+ Args:
+ F0 (float): The total flux of the ring (Jy)
+ d (float): The ring diameter (radians)
+ alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+ stretch (float): The stretch to apply (1.0 = no stretch)
+ stretch_PA (float): Position angle of the stretch, east of north (radians)
+
+ Returns:
+ (Model): Updated Model
+ """
+ out = self.copy()
+ out.models.append('stretched_thick_ring')
+ out.params.append({'F0':F0,'d':d,'alpha':alpha,'x0':x0,'y0':y0,'stretch':stretch,'stretch_PA':stretch_PA,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac})
+ return out
+
+ def add_mring(self, F0 = 1.0, d = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None):
+ """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion.
+ For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329
+
+ Args:
+ F0 (float): The total flux of the ring (Jy), which is also beta_0.
+ d (float): The diameter (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+ beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...].
+ Negative indices are determined by the condition beta_{-m} = beta_m*.
+ Indices are all scaled by F0 = beta_0, so they are dimensionless.
+ Returns:
+ (Model): Updated Model
+ """
+ if beta_list is None: beta_list = []
+ if beta_list_pol is None: beta_list_pol = []
+ if beta_list_cpol is None: beta_list_cpol = []
+
+ out = self.copy()
+ if beta_list is None:
+ beta_list = [0.0]
+ out.models.append('mring')
+ out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'x0':x0,'y0':y0})
+ return out
+
+ def add_stretched_mring(self, F0 = 1.0, d = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None, stretch = 1.0, stretch_PA = 0.0):
+ """Add a stretched ring model with azimuthal brightness variations determined by a Fourier mode expansion.
+ For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329
+
+ Args:
+ F0 (float): The total flux of the ring (Jy), which is also beta_0.
+ d (float): The diameter (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+ beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...].
+ Negative indices are determined by the condition beta_{-m} = beta_m*.
+ Indices are all scaled by F0 = beta_0, so they are dimensionless.
+ stretch (float): The stretch to apply (1.0 = no stretch)
+ stretch_PA (float): Position angle of the stretch, east of north (radians)
+ Returns:
+ (Model): Updated Model
+ """
+ if beta_list is None: beta_list = []
+ if beta_list_pol is None: beta_list_pol = []
+ if beta_list_cpol is None: beta_list_cpol = []
+
+ out = self.copy()
+ if beta_list is None:
+ beta_list = [0.0]
+ out.models.append('stretched_mring')
+ out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'x0':x0,'y0':y0,'stretch':stretch,'stretch_PA':stretch_PA})
+ return out
+
+ def add_thick_mring(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None):
+ """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion and thickness determined by circular Gaussian convolution.
+ For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329
+ The Gaussian convolution calculation is a trivial generalization of Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf
+
+ Args:
+ F0 (float): The total flux of the ring (Jy), which is also beta_0.
+ d (float): The ring diameter (radians)
+ alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+ beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...].
+ Negative indices are determined by the condition beta_{-m} = beta_m*.
+ Indices are all scaled by F0 = beta_0, so they are dimensionless.
+ Returns:
+ (Model): Updated Model
+ """
+ if beta_list is None: beta_list = []
+ if beta_list_pol is None: beta_list_pol = []
+ if beta_list_cpol is None: beta_list_cpol = []
+
+ out = self.copy()
+ if beta_list is None:
+ beta_list = [0.0]
+ out.models.append('thick_mring')
+ out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'alpha':alpha,'x0':x0,'y0':y0})
+ return out
+
+ def add_thick_mring_floor(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, ff=0.0, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None):
+ """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion, thickness determined by circular Gaussian convolution, and a floor
+ The floor is a blurred disk, with diameter d and blurred FWHM alpha
+ For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329
+ The Gaussian convolution calculation is a trivial generalization of Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf
+
+ Args:
+ F0 (float): The total flux of the ring (Jy), which is also beta_0.
+ d (float): The ring diameter (radians)
+ alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+ ff (float): The fraction of the total flux in the floor
+ beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...].
+ Negative indices are determined by the condition beta_{-m} = beta_m*.
+ Indices are all scaled by F0 = beta_0, so they are dimensionless.
+ Returns:
+ (Model): Updated Model
+ """
+ if beta_list is None: beta_list = []
+ if beta_list_pol is None: beta_list_pol = []
+ if beta_list_cpol is None: beta_list_cpol = []
+
+ out = self.copy()
+ if beta_list is None:
+ beta_list = [0.0]
+ out.models.append('thick_mring_floor')
+ out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'alpha':alpha,'x0':x0,'y0':y0,'ff':ff})
+ return out
+
+ def add_thick_mring_Gfloor(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, ff=0.0, FWHM = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None):
+ """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion, thickness determined by circular Gaussian convolution, and a floor
+ The floor is a circular Gaussian, with size FWHM
+ For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329
+ The Gaussian convolution calculation is a trivial generalization of Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf
+
+ Args:
+ F0 (float): The total flux of the model
+ d (float): The ring diameter (radians)
+ alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians)
+ FWHM (float): The Gaussian FWHM
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+ ff (float): The fraction of the total flux in the floor
+ beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...].
+ Negative indices are determined by the condition beta_{-m} = beta_m*.
+ Indices are all scaled by F0 = beta_0, so they are dimensionless.
+ Returns:
+ (Model): Updated Model
+ """
+ if beta_list is None: beta_list = []
+ if beta_list_pol is None: beta_list_pol = []
+ if beta_list_cpol is None: beta_list_cpol = []
+
+ out = self.copy()
+ if beta_list is None:
+ beta_list = [0.0]
+ out.models.append('thick_mring_Gfloor')
+ out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'alpha':alpha,'x0':x0,'y0':y0,'ff':ff,'FWHM':FWHM})
+ return out
+
+ def add_stretched_thick_mring(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None, stretch = 1.0, stretch_PA = 0.0):
+ """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion and thickness determined by circular Gaussian convolution.
+ For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329
+ The Gaussian convolution calculation is a trivial generalization of Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf
+
+ Args:
+ F0 (float): The total flux of the ring (Jy), which is also beta_0.
+ d (float): The ring diameter (radians)
+ alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+ beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...].
+ Negative indices are determined by the condition beta_{-m} = beta_m*.
+ Indices are all scaled by F0 = beta_0, so they are dimensionless.
+ stretch (float): The stretch to apply (1.0 = no stretch)
+ stretch_PA (float): Position angle of the stretch, east of north (radians)
+ Returns:
+ (Model): Updated Model
+ """
+ if beta_list is None: beta_list = []
+ if beta_list_pol is None: beta_list_pol = []
+ if beta_list_cpol is None: beta_list_cpol = []
+
+ out = self.copy()
+ if beta_list is None:
+ beta_list = [0.0]
+ out.models.append('stretched_thick_mring')
+ out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'alpha':alpha,'x0':x0,'y0':y0,'stretch':stretch,'stretch_PA':stretch_PA})
+ return out
+
+ def add_stretched_thick_mring_floor(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, ff=0.0, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None, stretch = 1.0, stretch_PA = 0.0):
+ """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion and thickness determined by circular Gaussian convolution.
+ For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329
+ The Gaussian convolution calculation is a trivial generalization of Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf
+
+ Args:
+ F0 (float): The total flux of the ring (Jy), which is also beta_0.
+ d (float): The ring diameter (radians)
+ alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians)
+ x0 (float): The x-coordinate (radians)
+ y0 (float): The y-coordinate (radians)
+ beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...].
+ Negative indices are determined by the condition beta_{-m} = beta_m*.
+ Indices are all scaled by F0 = beta_0, so they are dimensionless.
+ stretch (float): The stretch to apply (1.0 = no stretch)
+ stretch_PA (float): Position angle of the stretch, east of north (radians)
+ Returns:
+ (Model): Updated Model
+ """
+ if beta_list is None: beta_list = []
+ if beta_list_pol is None: beta_list_pol = []
+ if beta_list_cpol is None: beta_list_cpol = []
+
+ out = self.copy()
+ if beta_list is None:
+ beta_list = [0.0]
+ out.models.append('stretched_thick_mring_floor')
+ out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'alpha':alpha,'x0':x0,'y0':y0,'stretch':stretch,'stretch_PA':stretch_PA,'ff':ff})
+ return out
+
+ def sample_xy(self, x, y, psize=1.*RADPERUAS, pol='I'):
+ """Sample model image on the specified x and y coordinates
+
+ Args:
+ x (float): x coordinate (dimensionless)
+ y (float): y coordinate (dimensionless)
+
+ Returns:
+ (float): Image brightness (Jy/radian^2)
+ """
+ return sample_model_xy(self.models, self.params, x, y, psize=psize, pol=pol)
+
+ def sample_uv(self, u, v, polrep_obs='Stokes', pol='I', jonesdict=None):
+ """Sample model visibility on the specified u and v coordinates
+
+ Args:
+ u (float): u coordinate (dimensionless)
+ v (float): v coordinate (dimensionless)
+
+ Returns:
+ (complex): complex visibility (Jy)
+ """
+ return sample_model_uv(self.models, self.params, u, v, pol=pol, jonesdict=jonesdict)
+
+ def sample_graduv_uv(self, u, v, pol='I', jonesdict=None):
+ """Sample model visibility gradient on the specified u and v coordinates wrt (u,v)
+
+ Args:
+ u (float): u coordinate (dimensionless)
+ v (float): v coordinate (dimensionless)
+
+ Returns:
+ (complex): complex visibility (Jy)
+ """
+ return sample_model_graduv_uv(self.models, self.params, u, v, pol=pol, jonesdict=jonesdict)
+
+ def sample_grad_uv(self, u, v, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """Sample model visibility gradient on the specified u and v coordinates wrt all model parameters
+
+ Args:
+ u (float): u coordinate (dimensionless)
+ v (float): v coordinate (dimensionless)
+
+ Returns:
+ (complex): complex visibility (Jy)
+ """
+ return sample_model_grad_uv(self.models, self.params, u, v, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+
+ def centroid(self, pol=None):
+ """Compute the location of the image centroid (corresponding to the polarization pol)
+ Note: This quantity is only well defined for total intensity
+
+ Args:
+ pol (str): The polarization for which to find the image centroid
+
+ Returns:
+ (np.array): centroid positions (x0,y0) in radians
+ """
+
+ if pol is None: pol=self.pol_prim
+ if not (pol in list(self._imdict.keys())):
+ raise Exception("for polrep==%s, pol must be in " %
+ self.polrep + ",".join(list(self._imdict.keys())))
+
+ return np.real(self.sample_graduv_uv(0,0)/(2.*np.pi*1j))/self.total_flux()
+
+ def default_prior(self,fit_pol=False,fit_cpol=False):
+ return [default_prior(self.models[j],self.params[j],fit_pol=fit_pol,fit_cpol=fit_cpol) for j in range(self.N_models())]
+
+ def display(self, fov=FOV_DEFAULT, npix=NPIX_DEFAULT, polrep='stokes', pol_prim=None, pulse=PULSE_DEFAULT, time=0., **kwargs):
+ return self.make_image(fov, npix, polrep, pol_prim, pulse, time).display(**kwargs)
+
+ def make_image(self, fov, npix, polrep='stokes', pol_prim=None, pulse=PULSE_DEFAULT, time=0.):
+ """Sample the model onto a square image.
+
+ Args:
+ fov (float): the field of view of each axis in radians
+ npix (int): the number of pixels on each axis
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ rf (float): The image frequency in Hz
+
+ source (str): The source name
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+ pulse (function): The function convolved with the pixel values for continuous image.
+
+ mjd (int): The integer MJD of the image
+ time (float): The observing time of the image (UTC hours)
+
+ Returns:
+ (Image): an image object
+ """
+
+ pdim = fov/float(npix)
+ npix = int(npix)
+ imarr = np.zeros((npix,npix))
+ outim = image.Image(imarr, pdim, self.ra, self.dec,
+ polrep=polrep, pol_prim=pol_prim,
+ rf=self.rf, source=self.source, mjd=self.mjd, time=time, pulse=pulse)
+
+ return self.image_same(outim)
+
+ def image_same(self, im):
+ """Create an image of the model with parameters equal to a reference image.
+
+ Args:
+ im (Image): the reference image
+
+ Returns:
+ (Image): image of the model
+ """
+ out = im.copy()
+ xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0
+ ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0
+
+ x_grid, y_grid = np.meshgrid(xlist, ylist)
+ imarr = self.sample_xy(x_grid, y_grid, im.psize)
+ out.imvec = imarr.flatten() # Change this to init with image_args
+
+ # Add the remaining polarizations
+ for pol in ['Q','U','V']:
+ out.add_pol_image(self.sample_xy(x_grid, y_grid, im.psize, pol=pol), pol)
+
+ return out
+
+ def observe_same_nonoise(self, obs, **kwargs):
+ """Observe the model on the same baselines as an existing observation, without noise.
+
+ Args:
+ obs (Obsdata): the existing observation
+
+ Returns:
+ (Obsdata): an observation object with no noise
+ """
+
+ # Copy data to be safe
+ obsdata = copy.deepcopy(obs.data)
+
+ # Load optional parameters
+ jonesdict = kwargs.get('jonesdict',None)
+
+ # Compute visibilities and put them into the obsdata
+ if obs.polrep=='stokes':
+ obsdata['vis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='I', jonesdict=jonesdict)
+ obsdata['qvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='Q', jonesdict=jonesdict)
+ obsdata['uvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='U', jonesdict=jonesdict)
+ obsdata['vvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='V', jonesdict=jonesdict)
+ elif obs.polrep=='circ':
+ obsdata['rrvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='RR', jonesdict=jonesdict)
+ obsdata['rlvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='RL', jonesdict=jonesdict)
+ obsdata['lrvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='LR', jonesdict=jonesdict)
+ obsdata['llvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='LL', jonesdict=jonesdict)
+
+ obs_no_noise = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=obs.source, mjd=obs.mjd, polrep=obs.polrep,
+ ampcal=True, phasecal=True, opacitycal=True,
+ dcal=True, frcal=True,
+ timetype=obs.timetype, scantable=obs.scans)
+
+ return obs_no_noise
+
+ def observe_same(self, obs_in, add_th_noise=True, sgrscat=False, ttype=False, # Note: sgrscat and ttype are kept for consistency with comp_plots
+ opacitycal=True, ampcal=True, phasecal=True,
+ dcal=True, frcal=True, rlgaincal=True,
+ stabilize_scan_phase=False, stabilize_scan_amp=False, neggains=False,
+ jones=False, inv_jones=False,
+ tau=TAUDEF, taup=GAINPDEF,
+ gain_offset=GAINPDEF, gainp=GAINPDEF,
+ dterm_offset=DTERMPDEF,
+ rlratio_std=0.,rlphase_std=0.,
+ caltable_path=None, seed=False, **kwargs):
+
+ """Observe the image on the same baselines as an existing observation object and add noise.
+
+ Args:
+ obs_in (Obsdata): the existing observation
+
+ add_th_noise (bool): if True, baseline-dependent thermal noise is added
+ opacitycal (bool): if False, time-dependent gaussian errors are added to opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to station gains
+ phasecal (bool): if False, time-dependent station-based random phases are added
+ frcal (bool): if False, feed rotation angle terms are added to Jones matrices.
+ dcal (bool): if False, time-dependent gaussian errors added to D-terms.
+
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ neggains (bool): if True, force the applied gains to be <1
+ meaning that you have overestimated your telescope's performance
+
+
+ jones (bool): if True, uses Jones matrix to apply mis-calibration effects
+ inv_jones (bool): if True, applies estimated inverse Jones matrix
+ (not including random terms) to a priori calibrate data
+
+ tau (float): the base opacity at all sites,
+ or a dict giving one opacity per site
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gainp (float): the fractional std. dev. of the random error on the gains
+ or a dict giving one std. dev. per site
+
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ dterm_offset (float): the base std. dev. of random additive error at all sites,
+ or a dict giving one std. dev. per site
+
+ rlratio_std (float): the fractional std. dev. of the R/L gain offset
+ or a dict giving one std. dev. per site
+ rlphase_std (float): std. dev. of R/L phase offset,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+
+ caltable_path (string): The path and prefix of a saved caltable
+
+ seed (int): seeds the random component of the noise terms. DO NOT set to 0!
+
+ Returns:
+ (Obsdata): an observation object
+ """
+
+ if seed!=False:
+ np.random.seed(seed=seed)
+
+ obs = self.observe_same_nonoise(obs_in, **kwargs)
+
+ # Jones Matrix Corruption & Calibration
+ if jones:
+ obsdata = simobs.add_jones_and_noise(obs, add_th_noise=add_th_noise,
+ opacitycal=opacitycal, ampcal=ampcal,
+ phasecal=phasecal, dcal=dcal, frcal=frcal,
+ rlgaincal=rlgaincal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ neggains=neggains,
+ gainp=gainp, taup=taup, gain_offset=gain_offset,
+ dterm_offset=dterm_offset,
+ rlratio_std=rlratio_std,rlphase_std=rlphase_std,
+ caltable_path=caltable_path, seed=seed)
+
+ obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
+ ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal,
+ dcal=dcal, frcal=frcal,
+ timetype=obs.timetype, scantable=obs.scans)
+
+ if inv_jones:
+ obsdata = simobs.apply_jones_inverse(obs,
+ opacitycal=opacitycal, dcal=dcal, frcal=frcal)
+
+ obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
+ ampcal=ampcal, phasecal=phasecal,
+ opacitycal=True, dcal=True, frcal=True,
+ timetype=obs.timetype, scantable=obs.scans)
+ #these are always set to True after inverse jones call
+
+
+ # No Jones Matrices, Add noise the old way
+ # NOTE There is an asymmetry here - in the old way, we don't offer the ability to *not*
+ # unscale estimated noise.
+ else:
+
+ if caltable_path:
+ print('WARNING: the caltable is only saved if you apply noise with a Jones Matrix')
+
+ obsdata = simobs.add_noise(obs, add_th_noise=add_th_noise,
+ ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ gainp=gainp, taup=taup, gain_offset=gain_offset, seed=seed)
+
+ obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
+ ampcal=ampcal, phasecal=phasecal,
+ opacitycal=True, dcal=True, frcal=True,
+ timetype=obs.timetype, scantable=obs.scans)
+ #these are always set to True after inverse jones call
+
+ return obs
+
+ def observe(self, array, tint, tadv, tstart, tstop, bw,
+ mjd=None, timetype='UTC', polrep_obs=None,
+ elevmin=ELEV_LOW, elevmax=ELEV_HIGH,
+ no_elevcut_space=False,
+ fix_theta_GMST=False, add_th_noise=True,
+ opacitycal=True, ampcal=True, phasecal=True,
+ dcal=True, frcal=True, rlgaincal=True,
+ stabilize_scan_phase=False, stabilize_scan_amp=False,
+ jones=False, inv_jones=False,
+ tau=TAUDEF, taup=GAINPDEF,
+ gainp=GAINPDEF, gain_offset=GAINPDEF,
+ dterm_offset=DTERMPDEF, rlratio_std=0.,rlphase_std=0.,
+ seed=False, **kwargs):
+
+ """Generate baselines from an array object and observe the image.
+
+ Args:
+ array (Array): an array object containing sites with which to generate baselines
+ tint (float): the scan integration time in seconds
+ tadv (float): the uniform cadence between scans in seconds
+ tstart (float): the start time of the observation in hours
+ tstop (float): the end time of the observation in hours
+ bw (float): the observing bandwidth in Hz
+
+ mjd (int): the mjd of the observation, if set as different from the image mjd
+ timetype (str): how to interpret tstart and tstop; either 'GMST' or 'UTC'
+ elevmin (float): station minimum elevation in degrees
+ elevmax (float): station maximum elevation in degrees
+ no_elevcut_space (bool): if True, do not apply elevation cut to orbiters
+
+ polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation
+
+ fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel
+ add_th_noise (bool): if True, baseline-dependent thermal noise is added
+ opacitycal (bool): if False, time-dependent gaussian errors are added to opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to station gains
+ phasecal (bool): if False, time-dependent station-based random phases are added
+ frcal (bool): if False, feed rotation angle terms are added to Jones matrices.
+ dcal (bool): if False, time-dependent gaussian errors added to Jones matrices D-terms.
+
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ jones (bool): if True, uses Jones matrix to apply mis-calibration effects
+ otherwise uses old formalism without D-terms
+ inv_jones (bool): if True, applies estimated inverse Jones matrix
+ (not including random terms) to calibrate data
+
+ tau (float): the base opacity at all sites,
+ or a dict giving one opacity per site
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gainp (float): the fractional std. dev. of the random error on the gains
+ or a dict giving one std. dev. per site
+
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ dterm_offset (float): the base std. dev. of random additive error at all sites,
+ or a dict giving one std. dev. per site
+
+ rlratio_std (float): the fractional std. dev. of the R/L gain offset
+ or a dict giving one std. dev. per site
+ rlphase_std (float): std. dev. of R/L phase offset,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+
+ seed (int): seeds the random component of noise added. DO NOT set to 0!
+
+ Returns:
+ (Obsdata): an observation object
+ """
+
+ # Generate empty observation
+ print("Generating empty observation file . . . ")
+
+ if mjd == None:
+ mjd = self.mjd
+ if polrep_obs is None:
+ polrep_obs=self.polrep
+
+ obs = array.obsdata(self.ra, self.dec, self.rf, bw, tint, tadv, tstart, tstop, mjd=mjd,
+ polrep=polrep_obs, tau=tau, timetype=timetype,
+ elevmin=elevmin, elevmax=elevmax,
+ no_elevcut_space=no_elevcut_space, fix_theta_GMST=fix_theta_GMST)
+
+ # Observe on the same baselines as the empty observation and add noise
+ obs = self.observe_same(obs, add_th_noise=add_th_noise,
+ opacitycal=opacitycal,ampcal=ampcal,
+ phasecal=phasecal,dcal=dcal,
+ frcal=frcal, rlgaincal=rlgaincal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ gainp=gainp,gain_offset=gain_offset,
+ tau=tau, taup=taup,
+ dterm_offset=dterm_offset,
+ rlratio_std=rlratio_std,rlphase_std=rlphase_std,
+ jones=jones, inv_jones=inv_jones, seed=seed, **kwargs)
+
+ obs.mjd = mjd
+
+ return obs
+
+ def save_txt(self,filename):
+ # Header
+ import ehtim.observing.obs_helpers as obshelp
+ mjd = float(self.mjd)
+ time = self.time
+ mjd += (time/24.)
+
+ head = ("SRC: %s \n" % self.source +
+ "RA: " + obshelp.rastring(self.ra) + "\n" +
+ "DEC: " + obshelp.decstring(self.dec) + "\n" +
+ "MJD: %.6f \n" % (float(mjd)) +
+ "RF: %.4f GHz" % (self.rf/1e9))
+ # Models
+ out = []
+ for j in range(self.N_models()):
+ out.append(self.models[j])
+ out.append(str(self.params[j]).replace('\n','').replace('complex128','np.complex128').replace('array','np.array'))
+ np.savetxt(filename, out, header=head, fmt="%s")
+
+ def load_txt(self,filename):
+ lines = open(filename).read().splitlines()
+
+ src = ' '.join(lines[0].split()[2:])
+ ra = lines[1].split()
+ self.ra = float(ra[2]) + float(ra[4])/60.0 + float(ra[6])/3600.0
+ dec = lines[2].split()
+ self.dec = np.sign(float(dec[2])) * (abs(float(dec[2])) + float(dec[4])/60.0 + float(dec[6])/3600.0)
+ mjd_float = float(lines[3].split()[2])
+ self.mjd = int(mjd_float)
+ self.time = (mjd_float - self.mjd) * 24
+ self.rf = float(lines[4].split()[2]) * 1e9
+
+ self.models = lines[5::2]
+ self.params = [eval(x) for x in lines[6::2]]
+
+def load_txt(filename):
+ out = Model()
+ out.load_txt(filename)
+ return out
diff --git a/modeling/__init__.py b/modeling/__init__.py
new file mode 100644
index 00000000..2d00f4e5
--- /dev/null
+++ b/modeling/__init__.py
@@ -0,0 +1,9 @@
+"""
+.. module:: ehtim.modeling
+ :platform: Unix
+ :synopsis: EHT Modeling Utilities: modeling functions
+
+.. moduleauthor:: Michael Johnson (mjohnson@cfa.harvard.edu)
+
+"""
+from ..const_def import *
diff --git a/modeling/modeling_utils.py b/modeling/modeling_utils.py
new file mode 100644
index 00000000..d58c637c
--- /dev/null
+++ b/modeling/modeling_utils.py
@@ -0,0 +1,2700 @@
+# modeling_utils.py
+# General modeling functions for total intensity VLBI data
+#
+# Copyright (C) 2020 Michael Johnson
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+## TODO ##
+# >> return jonesdict for all data types <- requires significant modification to eht-imaging
+# >> Deal with nans in fitting (mask chisqdata) <- mostly done
+# >> Add optional transform for leakage and gains
+
+from __future__ import division
+from __future__ import print_function
+from __future__ import absolute_import
+from builtins import range
+
+import string
+import time
+import numpy as np
+import scipy.optimize as opt
+import scipy.ndimage as nd
+import scipy.ndimage.filters as filt
+import matplotlib.pyplot as plt
+import scipy.special as sps
+import scipy.stats as stats
+import copy
+
+import ehtim.obsdata as obsdata
+import ehtim.image as image
+import ehtim.model as model
+import ehtim.caltable as caltable
+
+from ehtim.const_def import *
+from ehtim.observing.obs_helpers import *
+from ehtim.statistics.dataframes import *
+
+#from IPython import display
+
+##################################################################################################
+# Constants & Definitions
+##################################################################################################
+
+MAXLS = 100 # maximum number of line searches in L-BFGS-B
+NHIST = 100 # number of steps to store for hessian approx
+MAXIT = 100 # maximum number of iterations
+STOP = 1.e-8 # convergence criterion
+
+BOUNDS_MIN = -1e4
+BOUNDS_MAX = 1e4
+BOUNDS_GAUSS_NSIGMA = 10.
+BOUNDS_EXP_NSIGMA = 10.
+
+PRIOR_MIN = 1e-200 # to avoid problems with log-prior
+
+DATATERMS = ['vis', 'bs', 'amp', 'cphase', 'cphase_diag', 'camp', 'logcamp', 'logcamp_diag', 'logamp', 'pvis', 'm', 'rlrr', 'rlll', 'lrrr', 'lrll','rrll','llrr','polclosure']
+
+nit = 0 # global variable to track the iteration number in the plotting callback
+globdict = {} # global dictionary with all parameters related to the model fitting (mainly for efficient parallelization, but also very useful for debugging)
+
+# Details on each fitted parameter (convenience rescaling factor and associated unit)
+PARAM_DETAILS = {'F0':[1.,'Jy'], 'FWHM':[RADPERUAS,'uas'], 'FWHM_maj':[RADPERUAS,'uas'], 'FWHM_min':[RADPERUAS,'uas'],
+ 'd':[RADPERUAS,'uas'], 'PA':[np.pi/180.,'deg'], 'alpha':[RADPERUAS,'uas'], 'ff':[1.,''],
+ 'x0':[RADPERUAS,'uas'], 'y0':[RADPERUAS,'uas'], 'stretch':[1.,''], 'stretch_PA':[np.pi/180.,'deg'],
+ 'arg':[np.pi/180.,'deg'], 'evpa':[np.pi/180.,'deg'], 'phi':[np.pi/180.,'deg']}
+
+GAIN_PRIOR_DEFAULT = {'prior_type':'lognormal','sigma':0.1,'mu':0.0,'shift':-1.0}
+LEAKAGE_PRIOR_DEFAULT = {'prior_type':'flat','min':-0.5,'max':0.5}
+N_POSTERIOR_SAMPLES = 100
+
+##################################################################################################
+# Priors
+##################################################################################################
+
+def cdf(x, prior_params):
+ """Compute the cumulative distribution function CDF(x) of a given prior at a given point x
+
+ Args:
+ x (float): Value at which to compute the CDF
+ prior_params (dict): Dictionary with information about the prior
+
+ Returns:
+ float: CDF(x)
+ """
+ if prior_params['prior_type'] == 'flat':
+ return ( (x > prior_params['max']) * 1.0
+ + (x > prior_params['min']) * (x < prior_params['max']) * (x - prior_params['min'])/(prior_params['max'] - prior_params['min']))
+ elif prior_params['prior_type'] == 'gauss':
+ return 0.5 * (1.0 + sps.erf( (x - prior_params['mean'])/(prior_params['std'] * np.sqrt(2.0)) ))
+ elif prior_params['prior_type'] == 'exponential':
+ return (1.0 - np.exp(-x/prior_params['std'])) * (x >= 0.0)
+ elif prior_params['prior_type'] == 'lognormal':
+ return (x > prior_params['shift']) * (0.5 * sps.erfc( (prior_params['mu'] - np.log(x - prior_params['shift']))/(np.sqrt(2.0) * prior_params['sigma'])))
+ elif prior_params['prior_type'] == 'positive':
+ raise Exception('CDF is not defined for prior type "positive"')
+ elif prior_params['prior_type'] == 'none':
+ raise Exception('CDF is not defined for prior type "none"')
+ elif prior_params['prior_type'] == 'fixed':
+ raise Exception('CDF is not defined for prior type "fixed"')
+ else:
+ raise Exception('Prior type ' + prior_params['prior_type'] + ' not recognized!')
+
+def cdf_inverse(x, prior_params):
+ """Compute the inverse cumulative distribution function of a given prior at a given point 0 <= x <= 1
+
+ Args:
+ x (float): Value at which to compute the inverse CDF
+ prior_params (dict): Dictionary with information about the prior
+
+ Returns:
+ float: Inverse CDF at x
+ """
+ if prior_params['prior_type'] == 'flat':
+ return prior_params['min'] * (1.0 - x) + prior_params['max'] * x
+ elif prior_params['prior_type'] == 'gauss':
+ return prior_params['mean'] - np.sqrt(2.0) * prior_params['std'] * sps.erfcinv(2.0 * x)
+ elif prior_params['prior_type'] == 'exponential':
+ return prior_params['std'] * np.log(1.0/(1.0 - x))
+ elif prior_params['prior_type'] == 'lognormal':
+ return np.exp( prior_params['mu'] - np.sqrt(2.0) * prior_params['sigma'] * sps.erfcinv(2.0 * x)) + prior_params['shift']
+ elif prior_params['prior_type'] == 'positive':
+ raise Exception('CDF is not defined for prior type "positive"')
+ elif prior_params['prior_type'] == 'none':
+ raise Exception('CDF is not defined for prior type "none"')
+ elif prior_params['prior_type'] == 'fixed':
+ raise Exception('CDF is not defined for prior type "fixed"')
+ else:
+ raise Exception('Prior type ' + prior_params['prior_type'] + ' not recognized!')
+
+def param_bounds(prior_params):
+ """Compute the parameter boundaries associated with a given prior
+
+ Args:
+ prior_params (dict): Dictionary with information about the prior
+
+ Returns:
+ list: 2-element list specifying the allowed parameter range: [min,max]
+ """
+ if prior_params.get('transform','') == 'cdf':
+ bounds = [0.0, 1.0]
+ elif prior_params['prior_type'] == 'flat':
+ bounds = [prior_params['min'],prior_params['max']]
+ elif prior_params['prior_type'] == 'gauss':
+ bounds = [prior_params['mean'] - prior_params['std'] * BOUNDS_GAUSS_NSIGMA, prior_params['mean'] + prior_params['std'] * BOUNDS_GAUSS_NSIGMA]
+ elif prior_params['prior_type'] == 'exponential':
+ bounds = [PRIOR_MIN, BOUNDS_EXP_NSIGMA * prior_params['std']]
+ elif prior_params['prior_type'] == 'lognormal':
+ bounds = [prior_params['shift'], prior_params['shift'] + np.exp(prior_params['mu'] + BOUNDS_GAUSS_NSIGMA * prior_params['sigma'])]
+ elif prior_params['prior_type'] == 'positive':
+ bounds = [PRIOR_MIN, BOUNDS_MAX]
+ elif prior_params['prior_type'] == 'none':
+ bounds = [BOUNDS_MIN,BOUNDS_MAX]
+ elif prior_params['prior_type'] == 'fixed':
+ bounds = [1.0, 1.0]
+ else:
+ print('Prior type not recognized!')
+ bounds = [BOUNDS_MIN,BOUNDS_MAX]
+
+ return bounds
+
+def prior_func(x, prior_params):
+ """Compute the value of a 1-D prior P(x) at a specified value x.
+
+ Args:
+ x (float): Value at which to compute the prior
+ prior_params (dict): Dictionary with information about the prior
+
+ Returns:
+ float: Prior value P(x)
+ """
+
+ if prior_params['prior_type'] == 'flat':
+ return (x >= prior_params['min']) * (x <= prior_params['max']) * 1.0/(prior_params['max'] - prior_params['min']) + PRIOR_MIN
+ elif prior_params['prior_type'] == 'gauss':
+ return 1./((2.*np.pi)**0.5 * prior_params['std']) * np.exp(-(x - prior_params['mean'])**2/(2.*prior_params['std']**2))
+ elif prior_params['prior_type'] == 'exponential':
+ return (1./prior_params['std'] * np.exp(-x/prior_params['std'])) * (x >= 0.0) + PRIOR_MIN
+ elif prior_params['prior_type'] == 'lognormal':
+ return (x > prior_params['shift']) * (
+ 1.0/((2.0*np.pi)**0.5 * prior_params['sigma'] * (x - prior_params['shift']))
+ * np.exp( -(np.log(x - prior_params['shift']) - prior_params['mu'])**2/(2.0 * prior_params['sigma']**2) ) )
+ elif prior_params['prior_type'] == 'positive':
+ return (x >= 0.0) * 1.0 + PRIOR_MIN
+ elif prior_params['prior_type'] == 'none':
+ return 1.0
+ elif prior_params['prior_type'] == 'fixed':
+ return 1.0
+ else:
+ print('Prior not recognized!')
+ return 1.0
+
+def prior_grad_func(x, prior_params):
+ """Compute the value of the derivative of a 1-D prior, dP/dx at a specified value x.
+
+ Args:
+ x (float): Value at which to compute the prior derivative
+ prior_params (dict): Dictionary with information about the prior
+
+ Returns:
+ float: Prior derivative value dP/dx(x)
+ """
+
+ if prior_params['prior_type'] == 'flat':
+ return 0.0
+ elif prior_params['prior_type'] == 'gauss':
+ return -(x - prior_params['mean'])/((2.*np.pi)**0.5 * prior_params['std']**3) * np.exp(-(x - prior_params['mean'])**2/(2.*prior_params['std']**2))
+ elif prior_params['prior_type'] == 'exponential':
+ return (-1./prior_params['std']**2 * np.exp(-x/prior_params['std'])) * (x >= 0.0)
+ elif prior_params['prior_type'] == 'lognormal':
+ return (x > prior_params['shift']) * (
+ (prior_params['mu'] - prior_params['sigma']**2 - np.log(x - prior_params['shift']))
+ / ((2.0*np.pi)**0.5 * prior_params['sigma']**3 * (x - prior_params['shift'])**2)
+ * np.exp( -(np.log(x - prior_params['shift']) - prior_params['mu'])**2/(2.0 * prior_params['sigma']**2) ) )
+ elif prior_params['prior_type'] == 'positive':
+ return 0.0
+ elif prior_params['prior_type'] == 'none':
+ return 0.0
+ elif prior_params['prior_type'] == 'fixed':
+ return 0.0
+ else:
+ print('Prior not recognized!')
+ return 0.0
+
+def transform_param(x, x_prior, inverse=True):
+ """Compute a specified coordinate transformation T(x) of a parameter value x
+
+ Args:
+ x (float): Untransformed value
+ x_prior (dict): Dictionary with information about the transformation
+ inverse (bool): Whether to compute the forward or inverse transform.
+
+ Returns:
+ float: Transformed parameter value
+ """
+
+ try:
+ transform = x_prior['transform']
+ except:
+ transform = 'none'
+ pass
+
+ if transform == 'log':
+ if inverse:
+ return np.exp(x)
+ else:
+ return np.log(x)
+ elif transform == 'cdf':
+ if inverse:
+ return cdf_inverse(x, x_prior)
+ else:
+ return cdf(x, x_prior)
+ else:
+ return x
+
+def transform_grad_param(x, x_prior):
+ """Compute the gradient of a specified coordinate transformation T(x) of a parameter value x
+
+ Args:
+ x (float): Untransformed value
+ x_prior (dict): Dictionary with information about the transformation
+
+ Returns:
+ float: Gradient of transformation, dT/dx(x)
+ """
+
+ try:
+ transform = x_prior['transform']
+ except:
+ transform = 'none'
+ pass
+
+ if transform == 'log':
+ return np.exp(x)
+ elif transform == 'cdf':
+ return 1.0/prior_func(transform_param(x,x_prior),x_prior)
+ else:
+ return 1.0
+
+##################################################################################################
+# Helper functions
+##################################################################################################
+def shrink_prior(prior, model, shrink=0.1):
+ """Shrink a specified prior volume by centering on a specified fitted model
+
+ Args:
+ prior (list): Model prior (list of dictionaries, one per model component)
+ model (Model): Model to draw central values from
+ shrink (float): Factor to shrink each prior width by
+
+ Returns:
+ prior (list): Model prior with restricted volume
+ """
+
+ prior_shrunk = copy.deepcopy(prior)
+ f = 1.0
+
+ #TODO: this doesn't work for beta lists yet!
+
+ for j in range(len(prior_shrunk)):
+ for key in prior_shrunk[j].keys():
+ if prior_shrunk[j][key]['prior_type'] == 'flat':
+ x = model.params[j][key]
+ w = prior_shrunk[j][key]['max'] - prior_shrunk[j][key]['min']
+ prior_shrunk[j][key]['min'] = x - w/2
+ prior_shrunk[j][key]['max'] = x + w/2
+ if prior_shrunk[j][key]['min'] < prior[j][key]['min']: prior_shrunk[j][key]['min'] = prior[j][key]['min']
+ if prior_shrunk[j][key]['max'] > prior[j][key]['max']: prior_shrunk[j][key]['max'] = prior[j][key]['max']
+ f *= (prior_shrunk[j][key]['max'] - prior_shrunk[j][key]['min'])/w
+ else:
+ pass
+
+ print('(New Prior Volume)/(Original Prior Volume:',f)
+
+ return prior_shrunk
+
+def selfcal(Obsdata, model,
+ gain_init=None, gain_prior=None,
+ minimizer_func='scipy.optimize.minimize', minimizer_kwargs=None,
+ bounds=None, use_bounds=True,
+ processes=-1, msgtype='bar', quiet=True, **kwargs):
+ """Self-calibrate a specified observation to a given model, accounting for gain priors
+
+ Args:
+
+ Returns:
+ """
+
+ # This is just a convenience function. It will call modeler_func() scan-by-scan fitting only gains.
+ # This function differs from ehtim.calibrating.self_cal in the inclusion of gain priors
+ tlist = Obsdata.tlist()
+ res_list = []
+ for j in range(len(tlist)):
+ if msgtype not in ['none','']:
+ prog_msg(j, len(tlist), msgtype, j-1)
+ obs = Obsdata.copy()
+ obs.data = tlist[j]
+ res_list.append(modeler_func(obs, model, model_prior=None, d1='amp',
+ fit_model=False, fit_gains=True,gain_init=gain_init,gain_prior=gain_prior,
+ minimizer_func=minimizer_func, minimizer_kwargs=minimizer_kwargs,
+ bounds=bounds, use_bounds=use_bounds, processes=-1, quiet=quiet, **kwargs))
+
+ # Assemble a single caltable to return
+ allsites = Obsdata.tarr['site']
+ caldict = res_list[0]['caltable'].data
+ for j in range(1,len(tlist)):
+ row = res_list[j]['caltable'].data
+ for site in allsites:
+ try: dat = row[site]
+ except KeyError: continue
+
+ try: caldict[site] = np.append(caldict[site], row[site])
+ except KeyError: caldict[site] = dat
+
+ ct = caltable.Caltable(obs.ra, obs.dec, obs.rf, obs.bw, caldict, obs.tarr,
+ source=obs.source, mjd=obs.mjd, timetype=obs.timetype)
+
+ return ct
+
+def make_param_map(model_init, model_prior, minimizer_func, fit_model, fit_pol=False, fit_cpol=False):
+ # Define the mapping between solved parameters and the model
+ # Each fitted model parameter can be rescaled to give values closer to order unity
+
+ param_map = [] # Define mapping for every fitted parameter: model component #, parameter name, rescale multiplier internal, unit, rescale multiplier external
+ param_mask = [] # True or False for whether to fit each model parameter (because the gradient is computed for all model parameters)
+ for j in range(model_init.N_models()):
+ params = model.model_params(model_init.models[j],model_init.params[j], fit_pol=fit_pol, fit_cpol=fit_cpol)
+ for param in params:
+ if fit_model == False:
+ param_mask.append(False)
+ elif model_prior[j][param]['prior_type'] != 'fixed':
+ param_mask.append(True)
+ param_type = param
+ if len(param_type.split('_')) == 2 and param_type not in PARAM_DETAILS:
+ param_type = param_type.split('_')[1]
+ try:
+ if model_prior[j][param].get('transform','') == 'cdf' or minimizer_func in ['dynesty_static','dynesty_dynamic','pymc3']:
+ param_map.append([j,param,1,PARAM_DETAILS[param_type][1],PARAM_DETAILS[param_type][0]])
+ else:
+ param_map.append([j,param,PARAM_DETAILS[param_type][0],PARAM_DETAILS[param_type][1],PARAM_DETAILS[param_type][0]])
+ except:
+ param_map.append([j,param,1,'',1])
+ pass
+ else:
+ param_mask.append(False)
+ return (param_map, param_mask)
+
+def compute_likelihood_constants(d1, d2, d3, d4, sigma1, sigma2, sigma3, sigma4):
+ # Compute the correct data weights (hyperparameters) and the correct extra constant for the log-likelihood
+ alpha_d1 = alpha_d2 = alpha_d3 = alpha_d4 = ln_norm1 = ln_norm2 = ln_norm3 = ln_norm4 = 0.0
+
+ try:
+ alpha_d1 = 0.5 * len(sigma1)
+ ln_norm1 = -np.sum(np.log((2.0*np.pi)**0.5 * sigma1))
+ except: pass
+ try:
+ alpha_d2 = 0.5 * len(sigma2)
+ ln_norm2 = -np.sum(np.log((2.0*np.pi)**0.5 * sigma2))
+ except: pass
+ try:
+ alpha_d3 = 0.5 * len(sigma3)
+ ln_norm3 = -np.sum(np.log((2.0*np.pi)**0.5 * sigma3))
+ except: pass
+ try:
+ alpha_d4 = 0.5 * len(sigma4)
+ ln_norm4 = -np.sum(np.log((2.0*np.pi)**0.5 * sigma4))
+ except: pass
+
+ # If using closure phase, the sigma is given in degrees, not radians!
+ # Use the correct von Mises normalization if using closure phase
+ if d1 in ['cphase','cphase_diag']:
+ ln_norm1 = -np.sum(np.log(2.0*np.pi*sps.ive(0, 1.0/(sigma1 * DEGREE)**2)))
+ if d2 in ['cphase','cphase_diag']:
+ ln_norm2 = -np.sum(np.log(2.0*np.pi*sps.ive(0, 1.0/(sigma2 * DEGREE)**2)))
+ if d3 in ['cphase','cphase_diag']:
+ ln_norm3 = -np.sum(np.log(2.0*np.pi*sps.ive(0, 1.0/(sigma3 * DEGREE)**2)))
+ if d4 in ['cphase','cphase_diag']:
+ ln_norm4 = -np.sum(np.log(2.0*np.pi*sps.ive(0, 1.0/(sigma4 * DEGREE)**2)))
+
+ if d1 in ['vis','bs','m','pvis','rrll','llrr','lrll','rlll','lrrr','rlrr','polclosure']:
+ alpha_d1 *= 2
+ ln_norm1 *= 2
+ if d2 in ['vis','bs','m','pvis','rrll','llrr','lrll','rlll','lrrr','rlrr','polclosure']:
+ alpha_d2 *= 2
+ ln_norm2 *= 2
+ if d3 in ['vis','bs','m','pvis','rrll','llrr','lrll','rlll','lrrr','rlrr','polclosure']:
+ alpha_d3 *= 2
+ ln_norm3 *= 2
+ if d4 in ['vis','bs','m','pvis','rrll','llrr','lrll','rlll','lrrr','rlrr','polclosure']:
+ alpha_d4 *= 2
+ ln_norm4 *= 2
+ ln_norm = ln_norm1 + ln_norm2 + ln_norm3 + ln_norm4
+
+ return (alpha_d1, alpha_d2, alpha_d3, alpha_d4, ln_norm)
+
+def default_gain_prior(sites):
+ print('No gain prior specified. Defaulting to ' + str(GAIN_PRIOR_DEFAULT) + ' for all sites.')
+ gain_prior = {}
+ for site in sites:
+ gain_prior[site] = GAIN_PRIOR_DEFAULT
+ return gain_prior
+
+def caltable_to_gains(caltab, gain_list):
+ # Generate an ordered list of gains from a caltable
+ # gain_list is a set of tuples (time, site)
+ gains = [np.abs(caltab.data[site]['rscale'][caltab.data[site]['time'] == time][0]) - 1.0 for (time, site) in gain_list]
+ return gains
+
+def make_gain_map(Obsdata, gain_prior):
+ # gain_list gives all unique (time,site) pairs
+ # gains_t1 gives the gain index for the first site in each measurement
+ # gains_t2 gives the gain index for the second site in each measurement
+ gain_list = []
+ for j in range(len(Obsdata.data)):
+ if ([Obsdata.data[j]['time'],Obsdata.data[j]['t1']] not in gain_list) and (gain_prior[Obsdata.data[j]['t1']]['prior_type'] != 'fixed'):
+ gain_list.append([Obsdata.data[j]['time'],Obsdata.data[j]['t1']])
+ if ([Obsdata.data[j]['time'],Obsdata.data[j]['t2']] not in gain_list) and (gain_prior[Obsdata.data[j]['t2']]['prior_type'] != 'fixed'):
+ gain_list.append([Obsdata.data[j]['time'],Obsdata.data[j]['t2']])
+
+ # Now determine the appropriate mapping; use the final index for all ignored gains, which default to 1
+ def gain_index(j, tnum):
+ try:
+ return gain_list.index([Obsdata.data[j]['time'],Obsdata.data[j][tnum]])
+ except:
+ return len(gain_list)
+
+ gains_t1 = [gain_index(j, 't1') for j in range(len(Obsdata.data))]
+ gains_t2 = [gain_index(j, 't2') for j in range(len(Obsdata.data))]
+
+ return (gain_list, gains_t1, gains_t2)
+
+def make_bounds(model_prior, param_map, gain_prior, gain_list, n_gains, leakage_fit, leakage_prior):
+ bounds = []
+ for j in range(len(param_map)):
+ pm = param_map[j]
+ pb = param_bounds(model_prior[pm[0]][pm[1]])
+ if (model_prior[pm[0]][pm[1]]['prior_type'] not in ['positive','none','fixed']) and (model_prior[pm[0]][pm[1]].get('transform','') != 'cdf'):
+ pb[0] = transform_param(pb[0]/pm[2], model_prior[pm[0]][pm[1]], inverse=False)
+ pb[1] = transform_param(pb[1]/pm[2], model_prior[pm[0]][pm[1]], inverse=False)
+ bounds.append(pb)
+ for j in range(n_gains):
+ pb = param_bounds(gain_prior[gain_list[j][1]])
+ if (gain_prior[gain_list[j][1]]['prior_type'] not in ['positive','none','fixed']) and (gain_prior[gain_list[j][1]].get('transform','') != 'cdf'):
+ pb[0] = transform_param(pb[0], gain_prior[gain_list[j][1]], inverse=False)
+ pb[1] = transform_param(pb[1], gain_prior[gain_list[j][1]], inverse=False)
+ bounds.append(pb)
+ for j in range(len(leakage_fit)):
+ for cpart in ['re','im']:
+ prior = leakage_prior[leakage_fit[j][0]][leakage_fit[j][1]][cpart]
+ pb = param_bounds(prior)
+ if (prior['prior_type'] not in ['positive','none','fixed']) and (prior.get('transform','') != 'cdf'):
+ pb[0] = transform_param(pb[0], prior, inverse=False)
+ pb[1] = transform_param(pb[1], prior, inverse=False)
+ bounds.append(pb)
+
+ return np.array(bounds)
+
+# Determine multiplicative factor for the gains (amplitude only)
+def gain_factor(dtype,gains,gains_t1,gains_t2, fit_or_marginalize_gains):
+ global globdict
+
+ if not fit_or_marginalize_gains:
+ if globdict['gain_init'] == None:
+ return 1
+ else:
+ gains = globdict['gain_init']
+
+ if globdict['marginalize_gains']:
+ gains = globdict['gain_init']
+
+ if dtype in ['amp','vis']:
+ gains_wzero = np.append(gains,0.0)
+ return (1.0 + gains_wzero[gains_t1])*(1.0 + gains_wzero[gains_t2])
+ else:
+ return 1
+
+def gain_factor_separate(dtype,gains,gains_t1,gains_t2, fit_or_marginalize_gains):
+ # Determine the pair of multiplicative factors for the gains (amplitude only)
+ # Note: these are not displaced by unity!
+ global globdict
+
+ if not fit_or_marginalize_gains:
+ if globdict['gain_init'] == None:
+ return (0., 0.)
+ else:
+ gains = globdict['gain_init']
+
+ if globdict['marginalize_gains']:
+ gains = globdict['gain_init']
+
+ if dtype in ['amp','vis']:
+ gains_wzero = np.append(gains,0.0)
+ return (gains_wzero[gains_t1], gains_wzero[gains_t2])
+ else:
+ return (0, 0)
+
+def prior_leakage(leakage, leakage_fit, leakage_prior, fit_leakage):
+ # Compute the log-prior contribution from the fitted leakage terms
+ if fit_leakage:
+ cparts = ['re','im']
+ return np.sum([np.log(prior_func(leakage[j], leakage_prior[leakage_fit[j//2][0]][leakage_fit[j//2][1]][cparts[j%2]])) for j in range(len(leakage))])
+ else:
+ return 0.0
+
+def prior_leakage_grad(leakage, leakage_fit, leakage_prior, fit_leakage):
+ # Compute the log-prior contribution to the gradient from the leakages
+ if fit_leakage:
+ cparts = ['re','im']
+ f = np.array([prior_func(leakage[j], leakage_prior[leakage_fit[j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage))])
+ df = np.array([prior_grad_func(leakage[j], leakage_prior[leakage_fit[j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage))])
+ return df/f
+ else:
+ return []
+
+def prior_gain(gains, gain_list, gain_prior, fit_gains):
+ # Compute the log-prior contribution from the gains
+ if fit_gains:
+ return np.sum([np.log(prior_func(gains[j], gain_prior[gain_list[j][1]])) for j in range(len(gains))])
+ else:
+ return 0.0
+
+def prior_gain_grad(gains, gain_list, gain_prior, fit_gains):
+ # Compute the log-prior contribution to the gradient from the gains
+ if fit_gains:
+ f = np.array([prior_func(gains[j], gain_prior[gain_list[j][1]]) for j in range(len(gains))])
+ df = np.array([prior_grad_func(gains[j], gain_prior[gain_list[j][1]]) for j in range(len(gains))])
+ return df/f
+ else:
+ return []
+
+def transform_params(params, param_map, minimizer_func, model_prior, inverse=True):
+ if minimizer_func not in ['dynesty_static','dynesty_dynamic','pymc3']:
+ return [transform_param(params[j], model_prior[param_map[j][0]][param_map[j][1]], inverse=inverse) for j in range(len(params))]
+ else:
+ # For dynesty or pymc3, over-ride all specified parameter transformations to assume CDF mapping to the hypercube
+ # However, the passed parameters to the objective function and gradient are *not* transformed (i.e., they are not in the hypercube), thus the transformation does not need to be inverted
+ return params
+
+def set_params(params, trial_model, param_map, minimizer_func, model_prior):
+ tparams = transform_params(params, param_map, minimizer_func, model_prior)
+
+ for j in range(len(params)):
+ if param_map[j][1] in trial_model.params[param_map[j][0]].keys():
+ trial_model.params[param_map[j][0]][param_map[j][1]] = tparams[j] * param_map[j][2]
+ else: # In this case, the parameter is a list of complex numbers, so the real/imaginary or abs/arg components need to be assigned
+ if param_map[j][1].find('cpol') != -1:
+ param_type = 'beta_list_cpol'
+ idx = int(param_map[j][1].split('_')[0][8:])
+ elif param_map[j][1].find('pol') != -1:
+ param_type = 'beta_list_pol'
+ idx = int(param_map[j][1].split('_')[0][7:]) + (len(trial_model.params[param_map[j][0]][param_type])-1)//2
+ elif param_map[j][1].find('beta') != -1:
+ param_type = 'beta_list'
+ idx = int(param_map[j][1].split('_')[0][4:]) - 1
+ else:
+ raise Exception('Unsure how to interpret ' + param_map[j][1])
+
+ curval = trial_model.params[param_map[j][0]][param_type][idx]
+ if '_' not in param_map[j][1]: # This is for beta0 of cpol
+ trial_model.params[param_map[j][0]][param_type][idx] = tparams[j] * param_map[j][2]
+ elif param_map[j][1][-2:] == 're':
+ trial_model.params[param_map[j][0]][param_type][idx] = tparams[j] * param_map[j][2] + np.imag(curval)*1j
+ elif param_map[j][1][-2:] == 'im':
+ trial_model.params[param_map[j][0]][param_type][idx] = tparams[j] * param_map[j][2] * 1j + np.real(curval)
+ elif param_map[j][1][-3:] == 'abs':
+ trial_model.params[param_map[j][0]][param_type][idx] = tparams[j] * param_map[j][2] * np.exp(1j * np.angle(curval))
+ elif param_map[j][1][-3:] == 'arg':
+ trial_model.params[param_map[j][0]][param_type][idx] = np.abs(curval) * np.exp(1j * tparams[j] * param_map[j][2])
+ else:
+ print('Parameter ' + param_map[j][1] + ' not understood!')
+
+# Define prior
+def prior(params, param_map, model_prior, minimizer_func):
+ tparams = transform_params(params, param_map, minimizer_func, model_prior)
+ return np.sum([np.log(prior_func(tparams[j]*param_map[j][2], model_prior[param_map[j][0]][param_map[j][1]])) for j in range(len(params))])
+
+def prior_grad(params, param_map, model_prior, minimizer_func):
+ tparams = transform_params(params, param_map, minimizer_func, model_prior)
+ f = np.array([prior_func(tparams[j]*param_map[j][2], model_prior[param_map[j][0]][param_map[j][1]]) for j in range(len(params))])
+ df = np.array([prior_grad_func(tparams[j]*param_map[j][2], model_prior[param_map[j][0]][param_map[j][1]]) for j in range(len(params))])
+ return df/f
+
+# Define constraint functions
+def flux_constraint(trial_model, alpha_flux, flux):
+ if alpha_flux == 0.0:
+ return 0.0
+
+ return ((trial_model.total_flux() - flux)/flux)**2
+
+def flux_constraint_grad(trial_model, alpha_flux, flux, params, param_map):
+ if alpha_flux == 0.0:
+ return 0.0
+
+ fluxmask = np.zeros_like(params)
+ for j in range(len(param_map)):
+ if param_map[j][1] == 'F0':
+ fluxmask[j] = 1.0
+
+ return 2.0 * (trial_model.total_flux() - flux)/flux * fluxmask
+
+##################################################################################################
+# Define the chi^2 and chi^2 gradient functions
+##################################################################################################
+def laplace_approximation(trial_model, dtype, data, uv, sigma, gains_t1, gains_t2):
+ # Compute the approximate contribution to the log-likelihood by marginalizing over gains
+ global globdict
+
+ if globdict['marginalize_gains'] == True and dtype == 'amp':
+ # Add the log-likelihood term from analytic gain marginalization
+ # Create the Hessian matrix for the argument of the exponential
+ gain_hess = np.zeros((len(globdict['gain_list']), len(globdict['gain_list'])))
+
+ # Add the terms from the likelihood
+ gain = gain_factor(dtype,None,gains_t1,gains_t2,True)
+ amp_model = np.abs(trial_model.sample_uv(uv[:,0],uv[:,1])) # TODO: Add polarization!
+ amp_bar = gain*data
+ sigma_bar = gain*sigma
+ (g1, g2) = gain_factor_separate(dtype,None,gains_t1,gains_t2,True)
+
+ # Each amplitude *measurement* (not fitted gain parameter!) contributes to the hessian in four places; two diagonal and two off-diagonal
+ for j in range(len(gain)):
+ gain_hess[gains_t1[j],gains_t1[j]] += amp_model[j] * (3.0 * amp_model[j] - 2.0 * amp_bar[j])/((1.0 + g1[j])**2 * sigma_bar[j]**2)
+ gain_hess[gains_t2[j],gains_t2[j]] += amp_model[j] * (3.0 * amp_model[j] - 2.0 * amp_bar[j])/((1.0 + g2[j])**2 * sigma_bar[j]**2)
+ gain_hess[gains_t1[j],gains_t2[j]] += amp_model[j] * (2.0 * amp_model[j] - amp_bar[j])/((1.0 + g1[j])*(1.0 + g2[j]) * sigma_bar[j]**2)
+ gain_hess[gains_t2[j],gains_t1[j]] += amp_model[j] * (2.0 * amp_model[j] - amp_bar[j])/((1.0 + g1[j])*(1.0 + g2[j]) * sigma_bar[j]**2)
+
+ # Add contributions from the prior to the diagonal. This ranges over the fitted gain parameters.
+ # Note: for the Laplace approximation, only Gaussian gain priors have any effect!
+ for j in range(len(globdict['gain_list'])):
+ t = globdict['gain_list'][j][1]
+ if globdict['gain_prior'][t]['prior_type'] == 'gauss':
+ gain_hess[j,j] += 1.0/globdict['gain_prior'][t]['std']
+ elif globdict['gain_prior'][t]['prior_type'] == 'flat':
+ gain_hess[j,j] += 0.0
+ elif globdict['gain_prior'][t]['prior_type'] == 'exponential':
+ gain_hess[j,j] += 0.0
+ elif globdict['gain_prior'][t]['prior_type'] == 'fixed':
+ gain_hess[j,j] += 0.0
+ else:
+ raise Exception('Gain prior not implemented!')
+ return np.log((2.0 * np.pi)**(len(gain)/2.0) * np.abs(np.linalg.det(gain_hess))**-0.5)
+ else:
+ return 0.0
+
+def laplace_list():
+ global globdict
+ l1 = laplace_approximation(globdict['trial_model'], globdict['d1'], globdict['data1'], globdict['uv1'], globdict['sigma1'], globdict['gains_t1'], globdict['gains_t2'])
+ l2 = laplace_approximation(globdict['trial_model'], globdict['d2'], globdict['data2'], globdict['uv2'], globdict['sigma2'], globdict['gains_t1'], globdict['gains_t2'])
+ l3 = laplace_approximation(globdict['trial_model'], globdict['d3'], globdict['data3'], globdict['uv3'], globdict['sigma3'], globdict['gains_t1'], globdict['gains_t2'])
+ l4 = laplace_approximation(globdict['trial_model'], globdict['d4'], globdict['data4'], globdict['uv4'], globdict['sigma4'], globdict['gains_t1'], globdict['gains_t2'])
+ return (l1, l2, l3, l4)
+
+def chisq_wgain(trial_model, dtype, data, uv, sigma, pol, jonesdict, gains, gains_t1, gains_t2, fit_or_marginalize_gains):
+ global globdict
+ gain = gain_factor(dtype,gains,gains_t1,gains_t2,fit_or_marginalize_gains)
+ log_likelihood = chisq(trial_model, uv, gain*data, gain*sigma, dtype, pol, jonesdict)
+ return log_likelihood
+
+def chisqgrad_wgain(trial_model, dtype, data, uv, sigma, jonesdict, gains, gains_t1, gains_t2, fit_or_marginalize_gains, param_mask, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False):
+ gain = gain_factor(dtype,gains,gains_t1,gains_t2,fit_or_marginalize_gains)
+ return chisqgrad(trial_model, uv, gain*data, gain*sigma, jonesdict, dtype, param_mask, pol, fit_or_marginalize_gains, gains, gains_t1, gains_t2, fit_pol, fit_cpol, fit_leakage)
+
+def chisq_list(gains):
+ global globdict
+ chi2_1 = chisq_wgain(globdict['trial_model'], globdict['d1'], globdict['data1'], globdict['uv1'], globdict['sigma1'], globdict['pol1'], globdict['jonesdict1'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'])
+ chi2_2 = chisq_wgain(globdict['trial_model'], globdict['d2'], globdict['data2'], globdict['uv2'], globdict['sigma2'], globdict['pol2'], globdict['jonesdict2'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'])
+ chi2_3 = chisq_wgain(globdict['trial_model'], globdict['d3'], globdict['data3'], globdict['uv3'], globdict['sigma3'], globdict['pol3'], globdict['jonesdict3'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'])
+ chi2_4 = chisq_wgain(globdict['trial_model'], globdict['d4'], globdict['data4'], globdict['uv4'], globdict['sigma4'], globdict['pol4'], globdict['jonesdict4'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'])
+
+ return (chi2_1, chi2_2, chi2_3, chi2_4)
+
+def update_leakage(leakage):
+ # This function updates the 'jonesdict' entries based on current leakage estimates
+ # leakage is list of the fitted parameters (re and im are separate)
+ # station_leakages is the dictionary containing all station leakages, some of which may be fixed
+ global globdict
+ if len(leakage) == 0: return
+
+ station_leakages = globdict['station_leakages']
+ leakage_fit = globdict['leakage_fit']
+ # First, update the entries in the leakage dictionary
+ for j in range(len(leakage)//2):
+ station_leakages[leakage_fit[j][0]][leakage_fit[j][1]] = leakage[2*j] + 1j * leakage[2*j + 1]
+
+ # Now, recompute the jonesdict objects
+ for j in range(1,4):
+ jonesdict = globdict['jonesdict' + str(j)]
+ if jonesdict is not None:
+ if type(jonesdict) is dict:
+ jonesdict['DR1'] = np.array([station_leakages[jonesdict['t1'][_]]['R'] for _ in range(len(jonesdict['t1']))])
+ jonesdict['DR2'] = np.array([station_leakages[jonesdict['t2'][_]]['R'] for _ in range(len(jonesdict['t1']))])
+ jonesdict['DL1'] = np.array([station_leakages[jonesdict['t1'][_]]['L'] for _ in range(len(jonesdict['t1']))])
+ jonesdict['DL2'] = np.array([station_leakages[jonesdict['t2'][_]]['L'] for _ in range(len(jonesdict['t1']))])
+ jonesdict['leakage_fit'] = globdict['leakage_fit']
+ else:
+ # In this case, the data product requires a list of jonesdicts
+ for jonesdict2 in jonesdict:
+ jonesdict2['DR1'] = np.array([station_leakages[jonesdict2['t1'][_]]['R'] for _ in range(len(jonesdict2['t1']))])
+ jonesdict2['DR2'] = np.array([station_leakages[jonesdict2['t2'][_]]['R'] for _ in range(len(jonesdict2['t1']))])
+ jonesdict2['DL1'] = np.array([station_leakages[jonesdict2['t1'][_]]['L'] for _ in range(len(jonesdict2['t1']))])
+ jonesdict2['DL2'] = np.array([station_leakages[jonesdict2['t2'][_]]['L'] for _ in range(len(jonesdict2['t1']))])
+ jonesdict2['leakage_fit'] = globdict['leakage_fit']
+
+##################################################################################################
+# Define the objective function and gradient
+##################################################################################################
+def objfunc(params, force_posterior=False):
+ global globdict
+ # Note: model parameters can have transformations applied; gains and leakage do not
+ set_params(params[:globdict['n_params']], globdict['trial_model'], globdict['param_map'], globdict['minimizer_func'], globdict['model_prior'])
+ gains = params[globdict['n_params']:(globdict['n_params'] + globdict['n_gains'])]
+ leakage = params[(globdict['n_params'] + globdict['n_gains']):]
+ update_leakage(leakage)
+
+ if globdict['marginalize_gains']:
+ # Ugh, the use of global variables totally messes this up
+ _globdict = globdict
+ # This doesn't handle the passed gain_init properly because the dimensions are incorrect
+ _globdict['gain_init'] = caltable_to_gains(selfcal(globdict['Obsdata'], globdict['trial_model'], gain_init=None, gain_prior=globdict['gain_prior'], msgtype='none'),globdict['gain_list'])
+ globdict = _globdict
+
+ (chi2_1, chi2_2, chi2_3, chi2_4) = chisq_list(gains)
+ datterm = ( globdict['alpha_d1'] * chi2_1
+ + globdict['alpha_d2'] * chi2_2
+ + globdict['alpha_d3'] * chi2_3
+ + globdict['alpha_d4'] * chi2_4)
+
+ if globdict['marginalize_gains']:
+ (l1, l2, l3, l4) = laplace_list()
+ datterm += l1 + l2 + l3 + l4
+
+ if (globdict['minimizer_func'] not in ['dynesty_static','dynesty_dynamic','pymc3']) or force_posterior:
+ priterm = prior(params[:globdict['n_params']], globdict['param_map'], globdict['model_prior'], globdict['minimizer_func'])
+ priterm += prior_gain(params[globdict['n_params']:(globdict['n_params'] + globdict['n_gains'])], globdict['gain_list'], globdict['gain_prior'], globdict['fit_gains'])
+ priterm += prior_leakage(params[(globdict['n_params'] + globdict['n_gains']):], globdict['leakage_fit'], globdict['leakage_prior'], globdict['fit_leakage'])
+ else:
+ priterm = 0.0
+ fluxterm = globdict['alpha_flux'] * flux_constraint(globdict['trial_model'], globdict['alpha_flux'], globdict['flux'])
+
+ return datterm - priterm + fluxterm - globdict['ln_norm']
+
+def objgrad(params):
+ global globdict
+ set_params(params[:globdict['n_params']], globdict['trial_model'], globdict['param_map'], globdict['minimizer_func'], globdict['model_prior'])
+ gains = params[globdict['n_params']:(globdict['n_params'] + globdict['n_gains'])]
+ leakage = params[(globdict['n_params'] + globdict['n_gains']):]
+ update_leakage(leakage)
+
+ datterm = ( globdict['alpha_d1'] * chisqgrad_wgain(globdict['trial_model'], globdict['d1'], globdict['data1'], globdict['uv1'], globdict['sigma1'], globdict['jonesdict1'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'], globdict['param_mask'], globdict['pol1'], globdict['fit_pol'], globdict['fit_cpol'], globdict['fit_leakage'])
+ + globdict['alpha_d2'] * chisqgrad_wgain(globdict['trial_model'], globdict['d2'], globdict['data2'], globdict['uv2'], globdict['sigma2'], globdict['jonesdict2'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'], globdict['param_mask'], globdict['pol2'], globdict['fit_pol'], globdict['fit_cpol'], globdict['fit_leakage'])
+ + globdict['alpha_d3'] * chisqgrad_wgain(globdict['trial_model'], globdict['d3'], globdict['data3'], globdict['uv3'], globdict['sigma3'], globdict['jonesdict3'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'], globdict['param_mask'], globdict['pol3'], globdict['fit_pol'], globdict['fit_cpol'], globdict['fit_leakage'])
+ + globdict['alpha_d4'] * chisqgrad_wgain(globdict['trial_model'], globdict['d4'], globdict['data4'], globdict['uv4'], globdict['sigma4'], globdict['jonesdict4'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'], globdict['param_mask'], globdict['pol4'], globdict['fit_pol'], globdict['fit_cpol'], globdict['fit_leakage']))
+
+ if globdict['minimizer_func'] not in ['dynesty_static','dynesty_dynamic','pymc3']:
+ priterm = np.concatenate([prior_grad(params[:globdict['n_params']], globdict['param_map'],
+ globdict['model_prior'], globdict['minimizer_func']),
+ prior_gain_grad(params[globdict['n_params']:(globdict['n_params'] + globdict['n_gains'])],
+ globdict['gain_list'], globdict['gain_prior'], globdict['fit_gains']),
+ prior_leakage_grad(params[(globdict['n_params'] + globdict['n_gains']):], globdict['leakage_fit'],
+ globdict['leakage_prior'], globdict['fit_leakage'])])
+ else:
+ priterm = 0.0
+ fluxterm = globdict['alpha_flux'] * flux_constraint_grad(params, globdict['alpha_flux'], globdict['flux'], params, globdict['param_map'])
+
+ grad = datterm - priterm + fluxterm
+
+ if globdict['minimizer_func'] not in ['dynesty_static','dynesty_dynamic','pymc3']:
+ for j in range(globdict['n_params']):
+ grad[j] *= globdict['param_map'][j][2] * transform_grad_param(params[j], globdict['model_prior'][globdict['param_map'][j][0]][globdict['param_map'][j][1]])
+ else:
+ # For dynesty or pymc3, over-ride all specified parameter transformations to assume CDF
+ # However, the passed parameters are *not* transformed (i.e., they are not in the hypercube)
+ # The Jacobian still needs to account for the parameter transformation
+ for j in range(len(params)):
+ if j < globdict['n_params']:
+ j2 = j
+ x = params[j2]
+ prior_params = globdict['model_prior'][globdict['param_map'][j2][0]][globdict['param_map'][j2][1]]
+ grad[j] /= prior_func(x,prior_params)
+ elif j < globdict['n_params'] + globdict['n_gains']:
+ j2 = j-globdict['n_params']
+ x = gains[j2]
+ prior_params = globdict['gain_prior'][globdict['gain_list'][j2][1]]
+ grad[j] /= prior_func(x, prior_params)
+ else:
+ cparts = ['re','im']
+ j2 = j-globdict['n_params']-globdict['n_gains']
+ x = leakage[j2]
+ prior_params = globdict['leakage_prior'][globdict['leakage_fit'][j2//2][0]][globdict['leakage_fit'][j2//2][1]][cparts[j2%2]]
+ grad[j] /= prior_func(x, prior_params)
+
+ if globdict['test_gradient']:
+ print('Testing the gradient at ',params)
+ import copy
+ dx = 1e-5
+ grad_numeric = np.zeros(len(grad))
+ f1 = objfunc(params)
+ print('Objective Function:',f1)
+ print('\nNumeric Gradient Check: Analytic Numeric')
+ for j in range(len(grad)):
+ if globdict['minimizer_func'] in ['dynesty_static','dynesty_dynamic','pymc3']:
+ dx = np.abs(params[j]) * 1e-6
+
+ params2 = copy.deepcopy(params)
+ params2[j] += dx
+ f2 = objfunc(params2)
+ grad_numeric[j] = (f2 - f1)/dx
+
+ if globdict['minimizer_func'] in ['dynesty_static','dynesty_dynamic','pymc3']:
+ if j < globdict['n_params']:
+ j2 = j
+ x = params[j2]
+ prior_params = globdict['model_prior'][globdict['param_map'][j2][0]][globdict['param_map'][j2][1]]
+ grad_numeric[j] /= prior_func(x,prior_params)
+ elif j < globdict['n_params'] + globdict['n_gains']:
+ j2 = j-globdict['n_params']
+ x = gains[j2]
+ prior_params = globdict['gain_prior'][globdict['gain_list'][j2][1]]
+ grad_numeric[j] /= prior_func(x, prior_params)
+ else:
+ cparts = ['re','im']
+ j2 = j-globdict['n_params']-globdict['n_gains']
+ x = leakage[j2]
+ prior_params = globdict['leakage_prior'][globdict['leakage_fit'][j2//2][0]][globdict['leakage_fit'][j2//2][1]][cparts[j2%2]]
+ grad_numeric[j] /= prior_func(x, prior_params)
+
+ if j < globdict['n_params']:
+ print('\nNumeric Gradient Check:',globdict['param_map'][j][0],globdict['param_map'][j][1],grad[j],grad_numeric[j])
+ else:
+ print('\nNumeric Gradient Check:',grad[j],grad_numeric[j])
+
+ return grad
+
+##################################################################################################
+# Modeler
+##################################################################################################
+def modeler_func(Obsdata, model_init, model_prior,
+ d1='vis', d2=False, d3=False, d4=False,
+ pol1='I', pol2='I', pol3='I', pol4='I',
+ normchisq = False, alpha_d1=0, alpha_d2=0, alpha_d3=0, alpha_d4=0,
+ flux=1.0, alpha_flux=0,
+ fit_model=True, fit_pol=False, fit_cpol=False,
+ fit_gains=False,marginalize_gains=False,gain_init=None,gain_prior=None,
+ fit_leakage=False, leakage_init=None, leakage_prior=None,
+ fit_noise_model=False,
+ minimizer_func='scipy.optimize.minimize',
+ minimizer_kwargs=None,
+ bounds=None, use_bounds=False,
+ processes=-1,
+ test_gradient=False, quiet=False, **kwargs):
+
+ """Fit a specified model.
+
+ Args:
+ Obsdata (Obsdata): The Obsdata object with VLBI data
+ model_init (Model): The Model object to fit
+ model_prior (dict): Priors for each model parameter
+
+ d1 (str): The first data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag', 'm'
+ d2 (str): The second data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag', 'm'
+ d3 (str): The third data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag', 'm'
+ d4 (str): The fourth data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag', 'm'
+
+ normchisq (bool): If False (default), automatically assign weights alpha_d1-3 to match the true log-likelihood.
+ alpha_d1 (float): The first data term weighting.
+ alpha_d2 (float): The second data term weighting. Default value of zero will automatically assign weights to match the true log-likelihood.
+ alpha_d3 (float): The third data term weighting. Default value of zero will automatically assign weights to match the true log-likelihood.
+ alpha_d4 (float): The fourth data term weighting. Default value of zero will automatically assign weights to match the true log-likelihood.
+
+ flux (float): Total flux of the fitted model
+ alpha_flux (float): Hyperparameter controlling how strongly to constrain that the total flux matches the specified flux.
+
+ fit_model (bool): Whether or not to fit the model parameters
+ fit_pol (bool): Whether or not to fit linear polarization parameters
+ fit_cpol (bool): Whether or not to fit circular polarization parameters
+ fit_gains (bool): Whether or not to fit time-dependent amplitude gains for each station
+ marginalize_gains (bool): Whether or not to perform analytic gain marginalization (via the Laplace approximation to the posterior)
+
+ gain_init (list or caltable): Initial gain amplitudes to apply; these can be specified even if gains aren't fitted
+ gain_prior (dict): Dictionary with the gain prior for each site.
+
+ minimizer_func (str): Minimizer function to use. Current options are:
+ 'scipy.optimize.minimize'
+ 'scipy.optimize.dual_annealing'
+ 'scipy.optimize.basinhopping'
+ 'dynesty_static'
+ 'dynesty_dynamic'
+ 'pymc3'
+ minimizer_kwargs (dict): kwargs passed to the minimizer.
+
+ bounds (list): List of parameter bounds for the fitted parameters (will automatically compute if needed)
+ use_bounds (bool): Whether or not to use bounds when fitting (required for some minimizers)
+
+ processes (int): Number of processes to use for a multiprocessing pool. -1 disables multiprocessing; 0 uses all that are available. Only used for dynesty.
+
+ Returns:
+ dict: Dictionary with fitted model ('model') and other diagnostics that are minimizer-dependent
+ """
+
+ global nit, globdict
+ nit = n_params = 0
+ ln_norm = 0.0
+
+ if fit_model == False and fit_gains == False and fit_leakage == False:
+ raise Exception('Both fit_model, fit_gains, and fit_leakage are False. Must fit something!')
+
+ if fit_gains == True and marginalize_gains == True:
+ raise Exception('Both fit_gains and marginalize_gains are True. Cannot do both!')
+
+ if fit_gains == False and marginalize_gains == False and gain_init is not None:
+ if not quiet: print('Both fit_gains and marginalize_gains are False but gain_init was passed. Applying these gains as a fixed correction!')
+
+ if minimizer_kwargs is None:
+ minimizer_kwargs = {}
+
+ # Specifications for verbosity during fits
+ show_updates = kwargs.get('show_updates',True)
+ update_interval = kwargs.get('update_interval',1)
+ run_nested_kwargs = kwargs.get('run_nested_kwargs',{})
+
+ # Make sure data and regularizer options are ok
+ if not d1 and not d2 and not d3 and not d4:
+ raise Exception("Must have at least one data term!")
+ if (not ((d1 in DATATERMS) or d1==False)) or (not ((d2 in DATATERMS) or d2==False)) or (not ((d3 in DATATERMS) or d3==False)) or (not ((d4 in DATATERMS) or d4==False)):
+ raise Exception("Invalid data term: valid data terms are: " + ' '.join(DATATERMS))
+
+ # Create the trial model
+ trial_model = model_init.copy()
+
+ # Define mapping for every fitted parameter: model component index, parameter name, rescale multiplier, unit
+ (param_map, param_mask) = make_param_map(model_init, model_prior, minimizer_func, fit_model, fit_pol, fit_cpol)
+
+ # Get data and info for the data terms
+ if type(Obsdata) is obsdata.Obsdata:
+ (data1, sigma1, uv1, jonesdict1) = chisqdata(Obsdata, d1, pol=pol1)
+ (data2, sigma2, uv2, jonesdict2) = chisqdata(Obsdata, d2, pol=pol2)
+ (data3, sigma3, uv3, jonesdict3) = chisqdata(Obsdata, d3, pol=pol3)
+ (data4, sigma4, uv4, jonesdict4) = chisqdata(Obsdata, d4, pol=pol4)
+ elif type(Obsdata) is list:
+ # Combine a list of observations into one.
+ # Allow these to be from multiple sources for polarimetric zero-baseline purposes.
+ # Main thing for different sources is to compute field rotation before combining
+ def combine_data(d1,s1,u1,j1,d2,s2,u2,j2):
+ d = np.concatenate([d1,d2])
+ s = np.concatenate([s1,s2])
+ u = np.concatenate([u1,u2])
+ j = j1.copy()
+ for key in ['fr1', 'fr2', 't1', 't2', 'DR1', 'DR2', 'DL1', 'DL2']:
+ j[key] = np.concatenate([j1[key],j2[key]])
+ return (d, s, u, j)
+
+ (data1, sigma1, uv1, jonesdict1) = chisqdata(Obsdata[0], d1, pol=pol1)
+ (data2, sigma2, uv2, jonesdict2) = chisqdata(Obsdata[0], d2, pol=pol2)
+ (data3, sigma3, uv3, jonesdict3) = chisqdata(Obsdata[0], d3, pol=pol3)
+ (data4, sigma4, uv4, jonesdict4) = chisqdata(Obsdata[0], d4, pol=pol4)
+ for j in range(1,len(Obsdata)):
+ (data1b, sigma1b, uv1b, jonesdict1b) = chisqdata(Obsdata[j], d1, pol=pol1)
+ (data2b, sigma2b, uv2b, jonesdict2b) = chisqdata(Obsdata[j], d2, pol=pol2)
+ (data3b, sigma3b, uv3b, jonesdict3b) = chisqdata(Obsdata[j], d3, pol=pol3)
+ (data4b, sigma4b, uv4b, jonesdict4b) = chisqdata(Obsdata[j], d4, pol=pol4)
+
+ if data1b is not False:
+ (data1, sigma1, uv1, jonesdict1) = combine_data(data1,sigma1,uv1,jonesdict1,data1b,sigma1b,uv1b,jonesdict1b)
+ if data2b is not False:
+ (data2, sigma2, uv2, jonesdict2) = combine_data(data2,sigma2,uv2,jonesdict2,data2b,sigma2b,uv2b,jonesdict2b)
+ if data3b is not False:
+ (data3, sigma3, uv3, jonesdict3) = combine_data(data3,sigma3,uv3,jonesdict3,data3b,sigma3b,uv3b,jonesdict3b)
+ if data4b is not False:
+ (data4, sigma4, uv4, jonesdict4) = combine_data(data4,sigma4,uv4,jonesdict4,data4b,sigma4b,uv4b,jonesdict4b)
+
+ alldata = np.concatenate([_.data for _ in Obsdata])
+ Obsdata = Obsdata[0]
+ Obsdata.data = alldata
+ else:
+ raise Exception("Observation format not recognized!")
+
+ if fit_leakage or leakage_init is not None:
+ # Determine what leakage terms must be fitted. At most, this would be L & R complex leakages terms for every site
+ # leakage_fit is a list of tuples [site, hand] that will be fitted
+ leakage_fit = []
+ if fit_leakage:
+ import copy # Error on the next line if this isn't done again. Why python, why?!?
+ # Start with the list of all sites
+ sites = list(set(np.concatenate(Obsdata.unpack(['t1','t2']).tolist())))
+
+ # Add missing entries to leakage_prior
+ # leakage_prior is a nested dictionary with keys of station, hand, re/im
+ leakage_prior_init = copy.deepcopy(leakage_prior)
+ if leakage_prior_init is None: leakage_prior_init = {}
+ leakage_prior = {}
+ for s in sites:
+ leakage_prior[s] = {}
+ for pol in ['R','L']:
+ leakage_prior[s][pol] = {}
+ for cpart in ['re','im']:
+ # check to see if a prior is specified for the complex part, the pol, or the site (in that order)
+ if leakage_prior_init.get(s,{}).get(pol,{}).get(cpart,{}).get('prior_type','') != '':
+ leakage_prior[s][pol][cpart] = leakage_prior_init[s][pol][cpart]
+ elif leakage_prior_init.get(s,{}).get(pol,{}).get('prior_type','') != '':
+ leakage_prior[s][pol][cpart] = copy.deepcopy(leakage_prior_init[s][pol])
+ elif leakage_prior_init.get(s,{}).get('prior_type','') != '':
+ leakage_prior[s][pol][cpart] = copy.deepcopy(leakage_prior_init[s])
+ else:
+ leakage_prior[s][pol][cpart] = copy.deepcopy(LEAKAGE_PRIOR_DEFAULT)
+
+ if Obsdata.polrep == 'stokes':
+ for s in sites:
+ for pol in ['R','L']:
+ if leakage_prior[s][pol]['re']['prior_type'] == 'fixed': continue
+ leakage_fit.append([s,pol])
+ else:
+ vislist = Obsdata.unpack(['t1','t2','rlvis','lrvis'])
+ # Only fit leakage for sites that include cross hand visibilities
+ DR = list(set(np.concatenate([vislist[~np.isnan(vislist['rlvis'])]['t1'], vislist[~np.isnan(vislist['lrvis'])]['t2']])))
+ DL = list(set(np.concatenate([vislist[~np.isnan(vislist['lrvis'])]['t1'], vislist[~np.isnan(vislist['rlvis'])]['t2']])))
+ [leakage_fit.append([s,'R']) for s in DR if leakage_prior[s]['R']['re']['prior_type'] != 'fixed']
+ [leakage_fit.append([s,'L']) for s in DL if leakage_prior[s]['L']['re']['prior_type'] != 'fixed']
+ sites = list(set(np.concatenate([DR,DL])))
+
+ if type(leakage_init) is dict:
+ station_leakages = copy.deepcopy(leakage_init)
+ else:
+ station_leakages = {}
+
+ # Add missing entries to station_leakages
+ for s in sites:
+ for pol in ['R','L']:
+ if s not in station_leakages.keys():
+ station_leakages[s] = {}
+ if 'R' not in station_leakages[s].keys():
+ station_leakages[s]['R'] = 0.0
+ if 'L' not in station_leakages[s].keys():
+ station_leakages[s]['L'] = 0.0
+ else:
+ # Disable leakage computations
+ jonesdict1 = jonesdict2 = jonesdict3 = None
+ leakage_fit = []
+ station_leakages = None
+
+ if normchisq == False:
+ if not quiet: print('Assigning data weights to give the correct log-likelihood...')
+ (alpha_d1, alpha_d2, alpha_d3, alpha_d4, ln_norm) = compute_likelihood_constants(d1, d2, d3, d4, sigma1, sigma2, sigma3, sigma4)
+ else:
+ ln_norm = 0.0
+
+ # Determine the mapping between solution gains and the input visibilities
+ # Use passed gains even if fit_gains=False and marginalize_gains=False
+ # NOTE: THERE IS A PROBLEM IN THIS IMPLEMENTATION. A fixed gain prior is ignored. However, gain_init may still want to apply a constant correction, especially when passing a caltable.
+ # We should maybe have two gain lists: one for constant gains and one for fitted gains
+ mean_g1 = mean_g2 = 0.0
+ if fit_gains or marginalize_gains:
+ if gain_prior is None:
+ gain_prior = default_gain_prior(Obsdata.tarr['site'])
+ (gain_list, gains_t1, gains_t2) = make_gain_map(Obsdata, gain_prior)
+ if type(gain_init) == caltable.Caltable:
+ if not quiet: print('Converting gain_init from caltable to a list')
+ gain_init = caltable_to_gains(gain_init, gain_list)
+ if gain_init is None:
+ if not quiet: print('Initializing all gain corrections to be zero')
+ gain_init = np.zeros(len(gain_list))
+ else:
+ if len(gain_init) != len(gain_list):
+ raise Exception('Gain initialization has incorrect dimensions! %d %d' % (len(gain_init), len(gain_list)))
+ if fit_gains:
+ n_gains = len(gain_list)
+ elif marginalize_gains:
+ n_gains = 0
+ else:
+ if gain_init is None:
+ n_gains = 0
+ gain_list = []
+ gains_t1 = gains_t2 = None
+ else:
+ if gain_prior is None:
+ gain_prior = default_gain_prior(Obsdata.tarr['site'])
+ (gain_list, gains_t1, gains_t2) = make_gain_map(Obsdata, gain_prior)
+ if type(gain_init) == caltable.Caltable:
+ if not quiet: print('Converting gain_init from caltable to a list')
+ gain_init = caltable_to_gains(gain_init, gain_list)
+
+ if fit_leakage:
+ leakage_init = np.zeros(len(leakage_fit) * 2)
+ for j in range(len(leakage_init)//2):
+ leakage_init[2*j] = np.real(station_leakages[leakage_fit[j][0]][leakage_fit[j][1]])
+ leakage_init[2*j + 1] = np.imag(station_leakages[leakage_fit[j][0]][leakage_fit[j][1]])
+ else:
+ leakage_init = []
+
+ # Initial parameters
+ param_init = []
+ for j in range(len(param_map)):
+ pm = param_map[j]
+ if param_map[j][1] in trial_model.params[param_map[j][0]].keys():
+ param_init.append(transform_param(model_init.params[pm[0]][pm[1]]/pm[2], model_prior[pm[0]][pm[1]],inverse=False))
+ else: # In this case, the parameter is a list of complex numbers, so the real/imaginary or abs/arg components need to be assigned
+ if param_map[j][1].find('cpol') != -1:
+ param_type = 'beta_list_cpol'
+ idx = int(param_map[j][1].split('_')[0][8:])
+ elif param_map[j][1].find('pol') != -1:
+ param_type = 'beta_list_pol'
+ idx = int(param_map[j][1].split('_')[0][7:]) + (len(trial_model.params[param_map[j][0]][param_type])-1)//2
+ elif param_map[j][1].find('beta') != -1:
+ param_type = 'beta_list'
+ idx = int(param_map[j][1].split('_')[0][4:]) - 1
+ else:
+ raise Exception('Unsure how to interpret ' + param_map[j][1])
+
+ curval = model_init.params[param_map[j][0]][param_type][idx]
+ if '_' not in param_map[j][1]:
+ param_init.append(transform_param(np.real( model_init.params[pm[0]][param_type][idx]/pm[2]), model_prior[pm[0]][pm[1]],inverse=False))
+ elif param_map[j][1][-2:] == 're':
+ param_init.append(transform_param(np.real( model_init.params[pm[0]][param_type][idx]/pm[2]), model_prior[pm[0]][pm[1]],inverse=False))
+ elif param_map[j][1][-2:] == 'im':
+ param_init.append(transform_param(np.imag( model_init.params[pm[0]][param_type][idx]/pm[2]), model_prior[pm[0]][pm[1]],inverse=False))
+ elif param_map[j][1][-3:] == 'abs':
+ param_init.append(transform_param(np.abs( model_init.params[pm[0]][param_type][idx]/pm[2]), model_prior[pm[0]][pm[1]],inverse=False))
+ elif param_map[j][1][-3:] == 'arg':
+ param_init.append(transform_param(np.angle(model_init.params[pm[0]][param_type][idx])/pm[2], model_prior[pm[0]][pm[1]],inverse=False))
+ else:
+ if not quiet: print('Parameter ' + param_map[j][1] + ' not understood!')
+ n_params = len(param_init)
+
+ # Note: model parameters can have transformations applied; gains and leakage do not
+ if fit_gains: # Do not add these if marginalize_gains == True
+ param_init += list(gain_init)
+ if fit_leakage:
+ param_init += list(leakage_init)
+
+ if minimizer_func not in ['dynesty_static','dynesty_dynamic','pymc3']:
+ # Define bounds (irrelevant for dynesty or pymc3)
+ if use_bounds == False and minimizer_func in ['scipy.optimize.dual_annealing']:
+ if not quiet: print('Bounds are required for ' + minimizer_func + '! Setting use_bounds=True.')
+ use_bounds = True
+ if use_bounds == False and bounds is not None:
+ if not quiet: print('Bounds passed but use_bounds=False; setting use_bounds=True.')
+ use_bounds = True
+ if bounds is None and use_bounds:
+ if not quiet: print('No bounds passed. Setting nominal bounds.')
+ bounds = make_bounds(model_prior, param_map, gain_prior, gain_list, n_gains, leakage_fit, leakage_prior)
+ if use_bounds == False:
+ bounds = None
+
+ # Gather global variables into a dictionary
+ globdict = {'trial_model':trial_model,
+ 'd1':d1, 'd2':d2, 'd3':d3, 'd4':d4,
+ 'pol1':pol1, 'pol2':pol2, 'pol3':pol3, 'pol4':pol4,
+ 'data1':data1, 'sigma1':sigma1, 'uv1':uv1, 'jonesdict1':jonesdict1,
+ 'data2':data2, 'sigma2':sigma2, 'uv2':uv2, 'jonesdict2':jonesdict2,
+ 'data3':data3, 'sigma3':sigma3, 'uv3':uv3, 'jonesdict3':jonesdict3,
+ 'data4':data4, 'sigma4':sigma4, 'uv4':uv4, 'jonesdict4':jonesdict4,
+ 'alpha_d1':alpha_d1, 'alpha_d2':alpha_d2, 'alpha_d3':alpha_d3, 'alpha_d4':alpha_d4,
+ 'n_params': n_params, 'n_gains':n_gains, 'n_leakage':len(leakage_init),
+ 'model_prior':model_prior, 'param_map':param_map, 'param_mask':param_mask,
+ 'gain_prior':gain_prior, 'gain_list':gain_list, 'gain_init':gain_init,
+ 'fit_leakage':fit_leakage, 'leakage_init':leakage_init, 'leakage_fit':leakage_fit, 'station_leakages':station_leakages, 'leakage_prior':leakage_prior,
+ 'show_updates':show_updates, 'update_interval':update_interval, 'gains_t1':gains_t1, 'gains_t2':gains_t2,
+ 'minimizer_func':minimizer_func,'Obsdata':Obsdata,
+ 'fit_pol':fit_pol, 'fit_cpol':fit_cpol,
+ 'flux':flux, 'alpha_flux':alpha_flux, 'fit_gains':fit_gains, 'marginalize_gains':marginalize_gains, 'ln_norm':ln_norm, 'param_init':param_init, 'test_gradient':test_gradient}
+ if fit_leakage:
+ update_leakage(leakage_init)
+
+
+ # Define the function that reports progress
+ def plotcur(params_step, *args):
+ global nit, globdict
+ if globdict['show_updates'] and (nit % globdict['update_interval'] == 0) and (quiet == False):
+ if globdict['n_params'] > 0:
+ print('Params:',params_step[:globdict['n_params']])
+ print('Transformed Params:',transform_params(params_step[:globdict['n_params']], globdict['param_map'], globdict['minimizer_func'], globdict['model_prior']))
+ gains = params_step[globdict['n_params']:(globdict['n_params'] + globdict['n_gains'])]
+ leakage = params_step[(globdict['n_params'] + globdict['n_gains']):]
+ if len(leakage):
+ print('leakage:',leakage)
+ update_leakage(leakage)
+ (chi2_1, chi2_2, chi2_3, chi2_4) = chisq_list(gains)
+ print("i: %d chi2_1: %0.2f chi2_2: %0.2f chi2_3: %0.2f chi2_4: %0.2f prior: %0.2f" % (nit, chi2_1, chi2_2, chi2_3, chi2_4, prior(params_step[:globdict['n_params']], globdict['param_map'], globdict['model_prior'], globdict['minimizer_func'])))
+ nit += 1
+
+ # Print initial statistics
+ if not quiet:
+ print("Initial Objective Function: %f" % (objfunc(param_init)))
+ if d1 in DATATERMS:
+ print("Total Data 1: ", (len(data1)))
+ if d2 in DATATERMS:
+ print("Total Data 2: ", (len(data2)))
+ if d3 in DATATERMS:
+ print("Total Data 3: ", (len(data3)))
+ if d4 in DATATERMS:
+ print("Total Data 4: ", (len(data4)))
+ print("Total Fitted Real Parameters #: ",(len(param_init)))
+ print("Fitted Model Parameters: ",[_[1] for _ in param_map])
+ print('Fitting Leakage Terms for:',leakage_fit)
+ plotcur(param_init)
+
+ # Run the minimization
+ tstart = time.time()
+ ret = {}
+ if minimizer_func == 'scipy.optimize.minimize':
+ min_kwargs = {'method':minimizer_kwargs.get('method','L-BFGS-B'),
+ 'options':{'maxiter':MAXIT, 'ftol':STOP, 'maxcor':NHIST,'gtol':STOP,'maxls':MAXLS}}
+
+ if 'options' in minimizer_kwargs.keys():
+ for key in minimizer_kwargs['options'].keys():
+ min_kwargs['options'][key] = minimizer_kwargs['options'][key]
+
+ for key in minimizer_kwargs.keys():
+ if key in ['options','method']:
+ continue
+ else:
+ min_kwargs[key] = minimizer_kwargs[key]
+
+ res = opt.minimize(objfunc, param_init, jac=objgrad, callback=plotcur, bounds=bounds, **min_kwargs)
+ elif minimizer_func == 'scipy.optimize.dual_annealing':
+ min_kwargs = {}
+ min_kwargs['local_search_options'] = {'jac':objgrad,
+ 'method':'L-BFGS-B','options':{'maxiter':MAXIT, 'ftol':STOP, 'maxcor':NHIST,'gtol':STOP,'maxls':MAXLS}}
+ if 'local_search_options' in minimizer_kwargs.keys():
+ for key in minimizer_kwargs['local_search_options'].keys():
+ min_kwargs['local_search_options'][key] = minimizer_kwargs['local_search_options'][key]
+
+ for key in minimizer_kwargs.keys():
+ if key in ['local_search_options']:
+ continue
+ min_kwargs[key] = minimizer_kwargs[key]
+
+ res = opt.dual_annealing(objfunc, x0=param_init, bounds=bounds, callback=plotcur, **min_kwargs)
+ elif minimizer_func == 'scipy.optimize.basinhopping':
+ min_kwargs = {}
+ for key in minimizer_kwargs.keys():
+ min_kwargs[key] = minimizer_kwargs[key]
+
+ res = opt.basinhopping(objfunc, param_init, **min_kwargs)
+ elif minimizer_func == 'pymc3':
+ ########################
+ ## Sample using pymc3 ##
+ ########################
+ import pymc3 as pm
+ import theano
+ import theano.tensor as tt
+
+ # To simplfy things, we'll use cdf transforms to map everything to a hypercube, as in dynesty
+
+ # First, define a theano Op for our likelihood function
+ # This is based on the example here: https://docs.pymc.io/notebooks/blackbox_external_likelihood.html
+ class LogLike(tt.Op):
+ itypes = [tt.dvector] # expects a vector of parameter values when called
+ otypes = [tt.dscalar] # outputs a single scalar value (the log likelihood)
+
+ def __init__(self, objfunc, objgrad):
+ # add inputs as class attributes
+ self.objfunc = objfunc
+ self.objgrad = objgrad
+ self.logpgrad = LogLikeGrad(objfunc, objgrad)
+
+ def prior_transform(self, u):
+ # This function transforms samples from the unit hypercube (u) to the target prior (x)
+ global globdict
+ cparts = ['re','im']
+ model_params_u = u[:n_params]
+ gain_params_u = u[n_params:(n_params+n_gains)]
+ leakage_params_u = u[(n_params+n_gains):]
+ model_params_x = [cdf_inverse( model_params_u[j], globdict['model_prior'][globdict['param_map'][j][0]][globdict['param_map'][j][1]]) for j in range(len(model_params_u))]
+ gain_params_x = [cdf_inverse( gain_params_u[j], globdict['gain_prior'][globdict['gain_list'][j][1]]) for j in range(len(gain_params_u))]
+ leakage_params_x = [cdf_inverse(leakage_params_u[j], globdict['leakage_prior'][globdict['leakage_fit'][j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage_params_u))]
+ return np.concatenate([model_params_x, gain_params_x, leakage_params_x])
+
+ def perform(self, node, inputs, outputs):
+ # the method that is used when calling the Op
+ theta, = inputs # this will contain my variables
+ # Transform from the hypercube to the prior
+ x = self.prior_transform(theta)
+
+ # call the log-likelihood function
+ logl = -self.objfunc(x)
+
+ outputs[0][0] = np.array(logl) # output the log-likelihood
+
+ def grad(self, inputs, g):
+ # the method that calculates the vector-Jacobian product
+ # http://deeplearning.net/software/theano_versions/dev/extending/op.html#grad
+ theta, = inputs
+ return [g[0]*self.logpgrad(theta)]
+
+ class LogLikeGrad(tt.Op):
+ """
+ This Op will be called with a vector of values and also return a vector of
+ values - the gradients in each dimension.
+ """
+ itypes = [tt.dvector]
+ otypes = [tt.dvector]
+
+ def __init__(self, objfunc, objgrad):
+ self.objfunc = objfunc
+ self.objgrad = objgrad
+
+ def prior_transform(self, u):
+ # This function transforms samples from the unit hypercube (u) to the target prior (x)
+ global globdict
+ cparts = ['re','im']
+ model_params_u = u[:n_params]
+ gain_params_u = u[n_params:(n_params+n_gains)]
+ leakage_params_u = u[(n_params+n_gains):]
+ model_params_x = [cdf_inverse( model_params_u[j], globdict['model_prior'][globdict['param_map'][j][0]][globdict['param_map'][j][1]]) for j in range(len(model_params_u))]
+ gain_params_x = [cdf_inverse( gain_params_u[j], globdict['gain_prior'][globdict['gain_list'][j][1]]) for j in range(len(gain_params_u))]
+ leakage_params_x = [cdf_inverse(leakage_params_u[j], globdict['leakage_prior'][globdict['leakage_fit'][j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage_params_u))]
+ return np.concatenate([model_params_x, gain_params_x, leakage_params_x])
+
+ def perform(self, node, inputs, outputs):
+ theta, = inputs
+ x = self.prior_transform(theta)
+ outputs[0][0] = -self.objgrad(x)
+
+ # create the log-likelihood Op
+ logl = LogLike(objfunc, objgrad)
+
+ # Define the sampler keywords
+ min_kwargs = {}
+ for key in minimizer_kwargs.keys():
+ min_kwargs[key] = minimizer_kwargs[key]
+
+ # Define the initial value if not passed
+ if 'start' not in min_kwargs.keys():
+ cparts = ['re','im']
+ model_params_x = param_init[:n_params]
+ gain_params_x = param_init[n_params:(n_params+n_gains)]
+ leakage_params_x = param_init[(n_params+n_gains):]
+ model_params_u = [cdf( model_params_x[j], globdict['model_prior'][globdict['param_map'][j][0]][globdict['param_map'][j][1]]) for j in range(len(model_params_x))]
+ gain_params_u = [cdf( gain_params_x[j], globdict['gain_prior'][globdict['gain_list'][j][1]]) for j in range(len(gain_params_x))]
+ leakage_params_u = [cdf(leakage_params_x[j], globdict['leakage_prior'][globdict['leakage_fit'][j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage_params_x))]
+ param_init_u = np.concatenate([model_params_u, gain_params_u, leakage_params_u])
+ min_kwargs['start'] = {}
+ for j in range(len(param_init)):
+ min_kwargs['start']['var' + str(j)] = param_init_u[j]
+
+ # Setup the sampler
+ with pm.Model() as model:
+ theta = tt.as_tensor_variable([ pm.Uniform('var' + str(j), lower=0., upper=1.) for j in range(len(param_init)) ])
+ pm.DensityDist('likelihood', lambda v: logl(v), observed={'v': theta})
+ trace = pm.sample(**min_kwargs)
+
+ # Extract useful sampling diagnostics.
+ samples_u = np.vstack([trace['var' + str(j)] for j in range(len(param_init))]).T # samples in the hypercube
+ samples = np.array([logl.prior_transform(u) for u in samples_u]) # samples
+ mean = np.mean(samples,axis=0)
+ var = np.var(samples,axis=0)
+
+ # Compute the log-posterior
+ if not quiet: print('Calculating the posterior values for the samples...')
+ logposterior = np.array([-objfunc(x, force_posterior=True) for x in samples])
+
+ # Select the MAP
+ j_MAP = np.argmax(logposterior)
+ MAP = samples[j_MAP]
+
+ # Return a model determined by the MAP
+ set_params(MAP[:n_params], trial_model, param_map, minimizer_func, model_prior)
+ gains = MAP[n_params:(n_params+n_gains)]
+ leakage = MAP[(n_params+n_gains):]
+ update_leakage(leakage)
+
+ # Return the sampler
+ ret['trace'] = trace
+ ret['mean'] = mean
+ ret['map'] = MAP
+ ret['std'] = var**0.5
+ ret['samples'] = samples
+ ret['logposterior'] = logposterior
+
+ # Return a set of models from the posterior
+ posterior_models = []
+ for j in range(N_POSTERIOR_SAMPLES):
+ posterior_model = trial_model.copy()
+ set_params(samples[-j][:n_params], posterior_model, param_map, minimizer_func, model_prior)
+ posterior_models.append(posterior_model)
+ ret['posterior_models'] = posterior_models
+
+ # Return data that has been rescaled based on 'natural' units for each parameter
+ import copy
+ samples_natural = copy.deepcopy(samples)
+ samples_natural[:,:n_params] /= np.array([_[4] for _ in param_map])
+ ret['samples_natural'] = samples_natural
+
+ # Return the names of the fitted parameters
+ labels = []
+ labels_natural = []
+ for _ in param_map:
+ labels.append(_[1].replace('_','-'))
+ labels_natural.append(_[1].replace('_','-'))
+ if _[3] != '':
+ labels_natural[-1] += ' (' + _[3] + ')'
+ for _ in gain_list:
+ labels.append(str(_[0]) + ' ' + _[1])
+ labels_natural.append(str(_[0]) + ' ' + _[1])
+ for _ in leakage_fit:
+ for coord in ['re','im']:
+ labels.append(_[0] + ',' + _[1] + ',' + coord)
+ labels_natural.append(_[0] + ',' + _[1] + ',' + coord)
+
+ ret['labels'] = labels
+ ret['labels_natural'] = labels_natural
+ elif minimizer_func in ['dynesty_static','dynesty_dynamic']:
+ ##########################
+ ## Sample using dynesty ##
+ ##########################
+ import dynesty
+ from dynesty import utils as dyfunc
+ # Define the functions that dynesty requires
+ def prior_transform(u):
+ # This function transforms samples from the unit hypercube (u) to the target prior (x)
+ global globdict
+ cparts = ['re','im']
+ model_params_u = u[:n_params]
+ gain_params_u = u[n_params:(n_params+n_gains)]
+ leakage_params_u = u[(n_params+n_gains):]
+ model_params_x = [cdf_inverse( model_params_u[j], globdict['model_prior'][globdict['param_map'][j][0]][globdict['param_map'][j][1]]) for j in range(len(model_params_u))]
+ gain_params_x = [cdf_inverse( gain_params_u[j], globdict['gain_prior'][globdict['gain_list'][j][1]]) for j in range(len(gain_params_u))]
+ leakage_params_x = [cdf_inverse(leakage_params_u[j], globdict['leakage_prior'][globdict['leakage_fit'][j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage_params_u))]
+ return np.concatenate([model_params_x, gain_params_x, leakage_params_x])
+
+ def loglike(x):
+ return -objfunc(x)
+
+ def grad(x):
+ return -objgrad(x)
+
+ # Setup a multiprocessing pool if needed
+ if processes >= 0:
+ import pathos.multiprocessing as mp
+ from multiprocessing import cpu_count
+ if processes == 0: processes = int(cpu_count())
+
+ # Ensure efficient memory allocation among the processes and separate trial models for each
+ def init(_globdict):
+ global globdict
+ globdict = _globdict
+ if processes >= 0:
+ globdict['trial_model'] = globdict['trial_model'].copy()
+
+ return
+
+ pool = mp.Pool(processes=processes, initializer=init, initargs=(globdict,))
+ if not quiet: print('Using a pool with %d processes' % processes)
+ else:
+ pool = processes = None
+
+ # Setup the sampler
+ if minimizer_func == 'dynesty_static':
+ sampler = dynesty.NestedSampler(loglike, prior_transform, ndim=len(param_init), gradient=grad, pool=pool, queue_size=processes, **minimizer_kwargs)
+ else:
+ sampler = dynesty.DynamicNestedSampler(loglike, prior_transform, ndim=len(param_init), gradient=grad, pool=pool, queue_size=processes, **minimizer_kwargs)
+
+ # Run the sampler
+ sampler.run_nested(**run_nested_kwargs)
+
+ # Print the sampler summary
+ res = sampler.results
+ if not quiet:
+ try: res.summary()
+ except: pass
+
+ # Extract useful sampling diagnostics.
+ samples = res.samples # samples
+ weights = np.exp(res.logwt - res.logz[-1]) # normalized weights
+ mean, cov = dyfunc.mean_and_cov(samples, weights)
+
+ # Compute the log-posterior
+ if not quiet: print('Calculating the posterior values for the samples...')
+ if pool is not None:
+ from functools import partial
+ def logpost(j):
+ return -objfunc(samples[j], force_posterior=True)
+
+ logposterior = pool.map(logpost, range(len(samples)))
+ else:
+ logposterior = np.array([-objfunc(x, force_posterior=True) for x in samples])
+
+ # Close the pool (this may not be the desired behavior if the sampling is to be iterative!)
+ if pool is not None:
+ pool.close()
+
+ # Select the MAP
+ j_MAP = np.argmax(logposterior)
+ MAP = samples[j_MAP]
+
+ # Resample from the posterior
+ samples = dyfunc.resample_equal(samples, weights)
+
+ # Return a model determined by the MAP
+ set_params(MAP[:n_params], trial_model, param_map, minimizer_func, model_prior)
+ gains = MAP[n_params:(n_params+n_gains)]
+ leakage = MAP[(n_params+n_gains):]
+ update_leakage(leakage)
+
+ # Return the sampler
+ ret['sampler'] = sampler
+ ret['mean'] = mean
+ ret['map'] = MAP
+ ret['std'] = cov.diagonal()**0.5
+ ret['samples'] = samples
+ ret['logposterior'] = logposterior
+
+ # Return a set of models from the posterior
+ posterior_models = []
+ for j in range(N_POSTERIOR_SAMPLES):
+ posterior_model = trial_model.copy()
+ set_params(samples[j][:n_params], posterior_model, param_map, minimizer_func, model_prior)
+ posterior_models.append(posterior_model)
+ ret['posterior_models'] = posterior_models
+
+ # Return data that has been rescaled based on 'natural' units for each parameter
+ import copy
+ res_natural = copy.deepcopy(res)
+ res_natural.samples[:,:n_params] /= np.array([_[4] for _ in param_map])
+ samples_natural = samples[:,:n_params]/np.array([_[4] for _ in param_map])
+ ret['res_natural'] = res_natural
+ ret['samples_natural'] = samples_natural
+
+ # Return the names of the fitted parameters
+ labels = []
+ labels_natural = []
+ for _ in param_map:
+ labels.append(_[1].replace('_','-'))
+ labels_natural.append(_[1].replace('_','-'))
+ if _[3] != '':
+ labels_natural[-1] += ' (' + _[3] + ')'
+ for _ in gain_list:
+ labels.append(str(_[0]) + ' ' + _[1])
+ labels_natural.append(str(_[0]) + ' ' + _[1])
+ for _ in leakage_fit:
+ for coord in ['re','im']:
+ labels.append(_[0] + ',' + _[1] + ',' + coord)
+ labels_natural.append(_[0] + ',' + _[1] + ',' + coord)
+
+ ret['labels'] = labels
+ ret['labels_natural'] = labels_natural
+ else:
+ raise Exception('Minimizer function ' + minimizer_func + ' is not recognized!')
+
+ # Format and print summary and fitted parameters
+ tstop = time.time()
+ trial_model = globdict['trial_model']
+
+ if not quiet:
+ print("\ntime: %f s" % (tstop - tstart))
+ print("\nFitted Parameters:")
+ if minimizer_func not in ['dynesty_static','dynesty_dynamic','pymc3']:
+ out = res.x
+ set_params(out[:n_params], trial_model, param_map, minimizer_func, model_prior)
+ gains = out[n_params:(n_params + n_gains)]
+ leakage = out[(n_params + n_gains):]
+ update_leakage(leakage)
+ tparams = transform_params(out[:n_params], param_map, minimizer_func, model_prior)
+ if not quiet:
+ cur_idx = -1
+ if len(param_map):
+ print('Model Parameters:')
+ for j in range(len(param_map)):
+ if param_map[j][0] != cur_idx:
+ cur_idx = param_map[j][0]
+ print(model_init.models[cur_idx] + ' (component %d/%d):' % (cur_idx+1,model_init.N_models()))
+ print(('\t' + param_map[j][1] + ': %f ' + param_map[j][3]) % (tparams[j] * param_map[j][2]/param_map[j][4]))
+ print('\n')
+
+ if len(leakage_fit):
+ print('Leakage (%; re, im):')
+ for j in range(len(leakage_fit)):
+ print('\t' + leakage_fit[j][0] + ', ' + leakage_fit[j][1] + ': %2.2f %2.2f' % (leakage[2*j]*100,leakage[2*j + 1]*100))
+ print('\n')
+
+ print("Final Chi^2_1: %f Chi^2_2: %f Chi^2_3: %f Chi^2_4: %f" % chisq_list(gains))
+ print("J: %f" % res.fun)
+ print(res.message)
+ else:
+ if not quiet:
+ cur_idx = -1
+ if len(param_map):
+ print('Model Parameters (mean and std):')
+ for j in range(len(param_map)):
+ if param_map[j][0] != cur_idx:
+ cur_idx = param_map[j][0]
+ print(model_init.models[cur_idx] + ' (component %d/%d):' % (cur_idx+1,model_init.N_models()))
+ print(('\t' + param_map[j][1] + ': %f +/- %f ' + param_map[j][3]) % (ret['mean'][j] * param_map[j][2]/param_map[j][4], ret['std'][j] * param_map[j][2]/param_map[j][4]))
+ print('\n')
+
+ if len(leakage_fit):
+ print('Leakage (%; re, im):')
+ for j in range(len(leakage_fit)):
+ j2 = 2*j + n_params + n_gains
+ print(('\t' + leakage_fit[j][0] + ', ' + leakage_fit[j][1]
+ + ': %2.2f +/- %2.2f, %2.2f +/- %2.2f')
+ % (mean[j2]*100,cov[j2,j2]**0.5 * 100,mean[j2+1]*100,cov[j2+1,j2+1]**0.5 * 100))
+ print('\n')
+
+ # Return fitted model
+ ret['model'] = trial_model
+ ret['param_map'] = param_map
+ ret['chisq_list'] = chisq_list(gains)
+ try: ret['res'] = res
+ except: pass
+
+ if fit_gains:
+ ret['gains'] = gains
+
+ # Create and return a caltable
+ caldict = {}
+ for site in set(np.array(gain_list)[:,1]):
+ caldict[site] = []
+
+ for j in range(len(gains)):
+ caldict[gain_list[j][1]].append((gain_list[j][0], (1.0 + gains[j]), (1.0 + gains[j])))
+
+ for site in caldict.keys():
+ caldict[site] = np.array(caldict[site], dtype=DTCAL)
+
+ ct = caltable.Caltable(Obsdata.ra, Obsdata.dec, Obsdata.rf, Obsdata.bw, caldict, Obsdata.tarr,
+ source=Obsdata.source, mjd=Obsdata.mjd, timetype=Obsdata.timetype)
+ ret['caltable'] = ct
+
+ # If relevant, return useful quantities associated with the leakage
+ if station_leakages is not None:
+ ret['station_leakages'] = station_leakages
+ tarr = Obsdata.tarr.copy()
+ for s in station_leakages.keys():
+ if 'R' in station_leakages[s].keys(): tarr[Obsdata.tkey[s]]['dr'] = station_leakages[s]['R']
+ if 'L' in station_leakages[s].keys(): tarr[Obsdata.tkey[s]]['dl'] = station_leakages[s]['L']
+ ret['tarr'] = tarr
+
+ return ret
+
+##################################################################################################
+# Wrapper Functions
+##################################################################################################
+
+def chisq(model, uv, data, sigma, dtype, pol='I', jonesdict=None):
+ """return the chi^2 for the appropriate dtype
+ """
+
+ chisq = 1
+ if not dtype in DATATERMS:
+ return chisq
+
+ if dtype == 'vis':
+ chisq = chisq_vis(model, uv, data, sigma, pol=pol, jonesdict=jonesdict)
+ elif dtype == 'amp':
+ chisq = chisq_amp(model, uv, data, sigma, pol=pol, jonesdict=jonesdict)
+ elif dtype == 'logamp':
+ chisq = chisq_logamp(model, uv, data, sigma, pol=pol, jonesdict=jonesdict)
+ elif dtype == 'bs':
+ chisq = chisq_bs(model, uv, data, sigma, pol=pol, jonesdict=jonesdict)
+ elif dtype == 'cphase':
+ chisq = chisq_cphase(model, uv, data, sigma, pol=pol, jonesdict=jonesdict)
+ elif dtype == 'cphase_diag':
+ chisq = chisq_cphase_diag(model, uv, data, sigma, pol=pol, jonesdict=jonesdict)
+ elif dtype == 'camp':
+ chisq = chisq_camp(model, uv, data, sigma, pol=pol, jonesdict=jonesdict)
+ elif dtype == 'logcamp':
+ chisq = chisq_logcamp(model, uv, data, sigma, pol=pol, jonesdict=jonesdict)
+ elif dtype == 'logcamp_diag':
+ chisq = chisq_logcamp_diag(model, uv, data, sigma, pol=pol, jonesdict=jonesdict)
+ elif dtype == 'pvis':
+ chisq = chisq_pvis(model, uv, data, sigma, jonesdict=jonesdict)
+ elif dtype == 'm':
+ chisq = chisq_m(model, uv, data, sigma, jonesdict=jonesdict)
+ elif dtype in ['rrll','llrr','rlrr','rlll','lrrr','lrll']:
+ chisq = chisq_fracpol(dtype[:2],dtype[2:],model, uv, data, sigma, jonesdict=jonesdict)
+ elif dtype == 'polclosure':
+ chisq = chisq_polclosure(model, uv, data, sigma, jonesdict=jonesdict)
+
+ return chisq
+
+def chisqgrad(model, uv, data, sigma, jonesdict, dtype, param_mask, pol='I', fit_gains=False, gains=None, gains_t1=None, gains_t2=None, fit_pol=False, fit_cpol=False, fit_leakage=False):
+ """return the chi^2 gradient for the appropriate dtype
+ """
+ global globdict
+
+ n_chisqgrad = len(param_mask)
+ if fit_leakage:
+ n_chisqgrad += 2*len(globdict['leakage_fit'])
+
+ chisqgrad = np.zeros(n_chisqgrad)
+ if fit_gains:
+ gaingrad = np.zeros_like(gains)
+ else:
+ gaingrad = np.array([])
+
+ # Now we need to be sure to put the gradient in the correct order: model parameters, then gains, then leakage
+ param_mask_full = np.zeros(len(chisqgrad), dtype=bool)
+ leakage_mask_full = np.zeros(len(chisqgrad), dtype=bool)
+ param_mask_full[:len(param_mask)] = param_mask
+ leakage_mask_full[len(param_mask):] = ~leakage_mask_full[len(param_mask):]
+
+ if not dtype in DATATERMS:
+ return np.concatenate([chisqgrad[param_mask_full],gaingrad,chisqgrad[leakage_mask_full]])
+
+ if dtype == 'vis':
+ chisqgrad = chisqgrad_vis(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ elif dtype == 'amp':
+ chisqgrad = chisqgrad_amp(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+
+ if fit_gains:
+ i1 = model.sample_uv(uv[:,0],uv[:,1], pol=pol, jonesdict=jonesdict)
+ amp_samples = np.abs(i1)
+ amp = data
+ pp = ((amp - amp_samples) * amp_samples) / (sigma**2)
+ gaingrad = 2.0/(1.0 + np.array(gains)) * np.array([np.sum(pp[(np.array(gains_t1) == j) + (np.array(gains_t2) == j)]) for j in range(len(gains))])/len(data)
+ elif dtype == 'logamp':
+ chisqgrad = chisqgrad_logamp(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ elif dtype == 'bs':
+ chisqgrad = chisqgrad_bs(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ elif dtype == 'cphase':
+ chisqgrad = chisqgrad_cphase(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ elif dtype == 'cphase_diag':
+ chisqgrad = chisqgrad_cphase_diag(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ elif dtype == 'camp':
+ chisqgrad = chisqgrad_camp(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ elif dtype == 'logcamp':
+ chisqgrad = chisqgrad_logcamp(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ elif dtype == 'logcamp_diag':
+ chisqgrad = chisqgrad_logcamp_diag(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ elif dtype == 'pvis':
+ chisqgrad = chisqgrad_pvis(model, uv, data, sigma, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ elif dtype == 'm':
+ chisqgrad = chisqgrad_m(model, uv, data, sigma, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+ elif dtype in ['rrll','llrr','rlrr','rlll','lrrr','lrll']:
+ chisqgrad = chisqgrad_fracpol(dtype[:2],dtype[2:],model, uv, data, sigma, jonesdict=jonesdict, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage)
+ elif dtype == 'polclosure':
+ chisqgrad = chisqgrad_polclosure(model, uv, data, sigma, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict)
+
+ return np.concatenate([chisqgrad[param_mask_full],gaingrad,chisqgrad[leakage_mask_full]])
+
+def chisqdata(Obsdata, dtype, pol='I', **kwargs):
+
+ """Return the data, sigma, and matrices for the appropriate dtype
+ """
+
+ (data, sigma, uv, jonesdict) = (False, False, False, None)
+
+ if dtype == 'vis':
+ (data, sigma, uv, jonesdict) = chisqdata_vis(Obsdata, pol=pol, **kwargs)
+ elif dtype == 'amp' or dtype == 'logamp':
+ (data, sigma, uv, jonesdict) = chisqdata_amp(Obsdata, pol=pol,**kwargs)
+ elif dtype == 'bs':
+ (data, sigma, uv, jonesdict) = chisqdata_bs(Obsdata, pol=pol,**kwargs)
+ elif dtype == 'cphase':
+ (data, sigma, uv, jonesdict) = chisqdata_cphase(Obsdata, pol=pol,**kwargs)
+ elif dtype == 'cphase_diag':
+ (data, sigma, uv, jonesdict) = chisqdata_cphase_diag(Obsdata, pol=pol,**kwargs)
+ elif dtype == 'camp':
+ (data, sigma, uv, jonesdict) = chisqdata_camp(Obsdata, pol=pol,**kwargs)
+ elif dtype == 'logcamp':
+ (data, sigma, uv, jonesdict) = chisqdata_logcamp(Obsdata, pol=pol,**kwargs)
+ elif dtype == 'logcamp_diag':
+ (data, sigma, uv, jonesdict) = chisqdata_logcamp_diag(Obsdata, pol=pol,**kwargs)
+ elif dtype == 'pvis':
+ (data, sigma, uv, jonesdict) = chisqdata_pvis(Obsdata, pol=pol,**kwargs)
+ elif dtype == 'm':
+ (data, sigma, uv, jonesdict) = chisqdata_m(Obsdata, pol=pol,**kwargs)
+ elif dtype in ['rrll','llrr','rlrr','rlll','lrrr','lrll']:
+ (data, sigma, uv, jonesdict) = chisqdata_fracpol(Obsdata,dtype[:2],dtype[2:],jonesdict=jonesdict)
+ elif dtype == 'polclosure':
+ (data, sigma, uv, jonesdict) = chisqdata_polclosure(Obsdata,jonesdict=jonesdict)
+
+ return (data, sigma, uv, jonesdict)
+
+
+##################################################################################################
+# Chi-squared and Gradient Functions
+##################################################################################################
+
+def chisq_vis(model, uv, vis, sigma, pol='I', jonesdict=None):
+ """Visibility chi-squared"""
+
+ samples = model.sample_uv(uv[:,0],uv[:,1], pol=pol, jonesdict=jonesdict)
+ return np.sum(np.abs((samples-vis)/sigma)**2)/(2*len(vis))
+
+def chisqgrad_vis(model, uv, vis, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the visibility chi-squared"""
+ samples = model.sample_uv(uv[:,0],uv[:,1], pol=pol, jonesdict=jonesdict)
+ wdiff = (vis - samples)/(sigma**2)
+ grad = model.sample_grad_uv(uv[:,0],uv[:,1],fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+
+ out = -np.real(np.dot(grad.conj(), wdiff))/len(vis)
+ return out
+
+def chisq_amp(model, uv, amp, sigma, pol='I', jonesdict=None):
+ """Visibility Amplitudes (normalized) chi-squared"""
+
+ amp_samples = np.abs(model.sample_uv(uv[:,0],uv[:,1], pol=pol, jonesdict=jonesdict))
+ return np.sum(np.abs((amp - amp_samples)/sigma)**2)/len(amp)
+
+def chisqgrad_amp(model, uv, amp, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the amplitude chi-squared"""
+
+ i1 = model.sample_uv(uv[:,0],uv[:,1], jonesdict=jonesdict)
+ amp_samples = np.abs(i1)
+
+ pp = ((amp - amp_samples) * amp_samples) / (sigma**2) / i1
+ grad = model.sample_grad_uv(uv[:,0],uv[:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+ out = (-2.0/len(amp)) * np.real(np.dot(grad, pp))
+ return out
+
+def chisq_bs(model, uv, bis, sigma, pol='I', jonesdict=None):
+ """Bispectrum chi-squared"""
+
+ bisamples = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict) * model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict) * model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict)
+ chisq= np.sum(np.abs(((bis - bisamples)/sigma))**2)/(2.*len(bis))
+ return chisq
+
+def chisqgrad_bs(model, uv, bis, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the bispectrum chi-squared"""
+
+ V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict)
+ V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict)
+ V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict)
+ V1_grad = model.sample_grad_uv(uv[0][:,0],uv[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ V2_grad = model.sample_grad_uv(uv[1][:,0],uv[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ V3_grad = model.sample_grad_uv(uv[2][:,0],uv[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ bisamples = V1 * V2 * V3
+ wdiff = ((bis - bisamples).conj())/(sigma**2)
+ pt1 = wdiff * V2 * V3
+ pt2 = wdiff * V1 * V3
+ pt3 = wdiff * V1 * V2
+ out = -np.real(np.dot(pt1, V1_grad.T) + np.dot(pt2, V2_grad.T) + np.dot(pt3, V3_grad.T))/len(bis)
+ return out
+
+def chisq_cphase(model, uv, clphase, sigma, pol='I', jonesdict=None):
+ """Closure Phases (normalized) chi-squared"""
+ clphase = clphase * DEGREE
+ sigma = sigma * DEGREE
+
+ V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict)
+ V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict)
+ V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict)
+
+ clphase_samples = np.angle(V1 * V2 * V3)
+ chisq= (2.0/len(clphase)) * np.sum((1.0 - np.cos(clphase-clphase_samples))/(sigma**2))
+ return chisq
+
+def chisqgrad_cphase(model, uv, clphase, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the closure phase chi-squared"""
+ clphase = clphase * DEGREE
+ sigma = sigma * DEGREE
+
+ V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict)
+ V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict)
+ V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict)
+ V1_grad = model.sample_grad_uv(uv[0][:,0],uv[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ V2_grad = model.sample_grad_uv(uv[1][:,0],uv[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ V3_grad = model.sample_grad_uv(uv[2][:,0],uv[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+
+ clphase_samples = np.angle(V1 * V2 * V3)
+
+ pref = np.sin(clphase - clphase_samples)/(sigma**2)
+ pt1 = pref/V1
+ pt2 = pref/V2
+ pt3 = pref/V3
+ out = -(2.0/len(clphase)) * np.imag(np.dot(pt1, V1_grad.T) + np.dot(pt2, V2_grad.T) + np.dot(pt3, V3_grad.T))
+ return out
+
+def chisq_cphase_diag(model, uv, clphase_diag, sigma, pol='I', jonesdict=None):
+ """Diagonalized closure phases (normalized) chi-squared"""
+ clphase_diag = np.concatenate(clphase_diag) * DEGREE
+ sigma = np.concatenate(sigma) * DEGREE
+
+ uv_diag = uv[0]
+ tform_mats = uv[1]
+
+ clphase_diag_samples = []
+ for iA, uv3 in enumerate(uv_diag):
+ i1 = model.sample_uv(uv3[0][:,0],uv3[0][:,1],pol=pol,jonesdict=jonesdict)
+ i2 = model.sample_uv(uv3[1][:,0],uv3[1][:,1],pol=pol,jonesdict=jonesdict)
+ i3 = model.sample_uv(uv3[2][:,0],uv3[2][:,1],pol=pol,jonesdict=jonesdict)
+
+ clphase_samples = np.angle(i1 * i2 * i3)
+ clphase_diag_samples.append(np.dot(tform_mats[iA],clphase_samples))
+ clphase_diag_samples = np.concatenate(clphase_diag_samples)
+
+ chisq = (2.0/len(clphase_diag)) * np.sum((1.0 - np.cos(clphase_diag-clphase_diag_samples))/(sigma**2))
+ return chisq
+
+def chisqgrad_cphase_diag(model, uv, clphase_diag, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the diagonalized closure phase chi-squared"""
+ clphase_diag = clphase_diag * DEGREE
+ sigma = sigma * DEGREE
+
+ uv_diag = uv[0]
+ tform_mats = uv[1]
+
+ deriv = np.zeros(len(model.sample_grad_uv(0,0,pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)))
+ for iA, uv3 in enumerate(uv_diag):
+
+ i1 = model.sample_uv(uv3[0][:,0],uv3[0][:,1],pol=pol,jonesdict=jonesdict)
+ i2 = model.sample_uv(uv3[1][:,0],uv3[1][:,1],pol=pol,jonesdict=jonesdict)
+ i3 = model.sample_uv(uv3[2][:,0],uv3[2][:,1],pol=pol,jonesdict=jonesdict)
+
+ i1_grad = model.sample_grad_uv(uv3[0][:,0],uv3[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ i2_grad = model.sample_grad_uv(uv3[1][:,0],uv3[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ i3_grad = model.sample_grad_uv(uv3[2][:,0],uv3[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+
+ clphase_samples = np.angle(i1 * i2 * i3)
+ clphase_diag_samples = np.dot(tform_mats[iA],clphase_samples)
+
+ clphase_diag_measured = clphase_diag[iA]
+ clphase_diag_sigma = sigma[iA]
+
+ term1 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples)/(clphase_diag_sigma**2.0)),(tform_mats[iA]/i1)),i1_grad.T)
+ term2 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples)/(clphase_diag_sigma**2.0)),(tform_mats[iA]/i2)),i2_grad.T)
+ term3 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples)/(clphase_diag_sigma**2.0)),(tform_mats[iA]/i3)),i3_grad.T)
+ deriv += -2.0*np.imag(term1 + term2 + term3)
+
+ deriv *= 1.0/np.float(len(np.concatenate(clphase_diag)))
+
+ return deriv
+
+def chisq_camp(model, uv, clamp, sigma, pol='I', jonesdict=None):
+ """Closure Amplitudes (normalized) chi-squared"""
+
+ V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict)
+ V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict)
+ V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict)
+ V4 = model.sample_uv(uv[3][:,0],uv[3][:,1],pol=pol,jonesdict=jonesdict)
+
+ clamp_samples = np.abs(V1 * V2 / (V3 * V4))
+ chisq = np.sum(np.abs((clamp - clamp_samples)/sigma)**2)/len(clamp)
+ return chisq
+
+def chisqgrad_camp(model, uv, clamp, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the closure amplitude chi-squared"""
+
+ V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict)
+ V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict)
+ V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict)
+ V4 = model.sample_uv(uv[3][:,0],uv[3][:,1],pol=pol,jonesdict=jonesdict)
+ V1_grad = model.sample_grad_uv(uv[0][:,0],uv[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ V2_grad = model.sample_grad_uv(uv[1][:,0],uv[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ V3_grad = model.sample_grad_uv(uv[2][:,0],uv[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ V4_grad = model.sample_grad_uv(uv[3][:,0],uv[3][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+
+ clamp_samples = np.abs((V1 * V2)/(V3 * V4))
+
+ pp = ((clamp - clamp_samples) * clamp_samples)/(sigma**2)
+ pt1 = pp/V1
+ pt2 = pp/V2
+ pt3 = -pp/V3
+ pt4 = -pp/V4
+ out = (-2.0/len(clamp)) * np.real(np.dot(pt1, V1_grad.T) + np.dot(pt2, V2_grad.T) + np.dot(pt3, V3_grad.T) + np.dot(pt4, V4_grad.T))
+ return out
+
+def chisq_logcamp(model, uv, log_clamp, sigma, pol='I', jonesdict=None):
+ """Log Closure Amplitudes (normalized) chi-squared"""
+
+ a1 = np.abs(model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict))
+ a2 = np.abs(model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict))
+ a3 = np.abs(model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict))
+ a4 = np.abs(model.sample_uv(uv[3][:,0],uv[3][:,1],pol=pol,jonesdict=jonesdict))
+
+ samples = np.log(a1) + np.log(a2) - np.log(a3) - np.log(a4)
+ chisq = np.sum(np.abs((log_clamp - samples)/sigma)**2) / (len(log_clamp))
+ return chisq
+
+def chisqgrad_logcamp(model, uv, log_clamp, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the Log closure amplitude chi-squared"""
+
+ V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict)
+ V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict)
+ V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict)
+ V4 = model.sample_uv(uv[3][:,0],uv[3][:,1],pol=pol,jonesdict=jonesdict)
+ V1_grad = model.sample_grad_uv(uv[0][:,0],uv[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ V2_grad = model.sample_grad_uv(uv[1][:,0],uv[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ V3_grad = model.sample_grad_uv(uv[2][:,0],uv[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ V4_grad = model.sample_grad_uv(uv[3][:,0],uv[3][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+
+ log_clamp_samples = np.log(np.abs(V1)) + np.log(np.abs(V2)) - np.log(np.abs(V3)) - np.log(np.abs(V4))
+
+ pp = (log_clamp - log_clamp_samples) / (sigma**2)
+ pt1 = pp / V1
+ pt2 = pp / V2
+ pt3 = -pp / V3
+ pt4 = -pp / V4
+ out = (-2.0/len(log_clamp)) * np.real(np.dot(pt1, V1_grad.T) + np.dot(pt2, V2_grad.T) + np.dot(pt3, V3_grad.T) + np.dot(pt4, V4_grad.T))
+ return out
+
+def chisq_logcamp_diag(model, uv, log_clamp_diag, sigma, pol='I', jonesdict=None):
+ """Diagonalized log closure amplitudes (normalized) chi-squared"""
+
+ log_clamp_diag = np.concatenate(log_clamp_diag)
+ sigma = np.concatenate(sigma)
+
+ uv_diag = uv[0]
+ tform_mats = uv[1]
+
+ log_clamp_diag_samples = []
+ for iA, uv4 in enumerate(uv_diag):
+
+ a1 = np.abs(model.sample_uv(uv4[0][:,0],uv4[0][:,1],pol=pol,jonesdict=jonesdict))
+ a2 = np.abs(model.sample_uv(uv4[1][:,0],uv4[1][:,1],pol=pol,jonesdict=jonesdict))
+ a3 = np.abs(model.sample_uv(uv4[2][:,0],uv4[2][:,1],pol=pol,jonesdict=jonesdict))
+ a4 = np.abs(model.sample_uv(uv4[3][:,0],uv4[3][:,1],pol=pol,jonesdict=jonesdict))
+
+ log_clamp_samples = np.log(a1) + np.log(a2) - np.log(a3) - np.log(a4)
+ log_clamp_diag_samples.append(np.dot(tform_mats[iA],log_clamp_samples))
+
+ log_clamp_diag_samples = np.concatenate(log_clamp_diag_samples)
+
+ chisq = np.sum(np.abs((log_clamp_diag - log_clamp_diag_samples)/sigma)**2) / (len(log_clamp_diag))
+ return chisq
+
+def chisqgrad_logcamp_diag(model, uv, log_clamp_diag, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the diagonalized log closure amplitude chi-squared"""
+
+ uv_diag = uv[0]
+ tform_mats = uv[1]
+
+ deriv = np.zeros(len(model.sample_grad_uv(0,0,pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)))
+ for iA, uv4 in enumerate(uv_diag):
+
+ i1 = model.sample_uv(uv4[0][:,0],uv4[0][:,1],pol=pol,jonesdict=jonesdict)
+ i2 = model.sample_uv(uv4[1][:,0],uv4[1][:,1],pol=pol,jonesdict=jonesdict)
+ i3 = model.sample_uv(uv4[2][:,0],uv4[2][:,1],pol=pol,jonesdict=jonesdict)
+ i4 = model.sample_uv(uv4[3][:,0],uv4[3][:,1],pol=pol,jonesdict=jonesdict)
+
+ i1_grad = model.sample_grad_uv(uv4[0][:,0],uv4[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ i2_grad = model.sample_grad_uv(uv4[1][:,0],uv4[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ i3_grad = model.sample_grad_uv(uv4[2][:,0],uv4[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+ i4_grad = model.sample_grad_uv(uv4[3][:,0],uv4[3][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+
+ log_clamp_samples = np.log(np.abs(i1)) + np.log(np.abs(i2)) - np.log(np.abs(i3)) - np.log(np.abs(i4))
+ log_clamp_diag_samples = np.dot(tform_mats[iA],log_clamp_samples)
+
+ log_clamp_diag_measured = log_clamp_diag[iA]
+ log_clamp_diag_sigma = sigma[iA]
+
+ term1 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples)/(log_clamp_diag_sigma**2.0)),(tform_mats[iA]/i1)),i1_grad.T)
+ term2 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples)/(log_clamp_diag_sigma**2.0)),(tform_mats[iA]/i2)),i2_grad.T)
+ term3 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples)/(log_clamp_diag_sigma**2.0)),(tform_mats[iA]/i3)),i3_grad.T)
+ term4 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples)/(log_clamp_diag_sigma**2.0)),(tform_mats[iA]/i4)),i4_grad.T)
+ deriv += -2.0*np.real(term1 + term2 - term3 - term4)
+
+ deriv *= 1.0/np.float(len(np.concatenate(log_clamp_diag)))
+
+ return deriv
+
+def chisq_logamp(model, uv, amp, sigma, pol='I', jonesdict=None):
+ """Log Visibility Amplitudes (normalized) chi-squared"""
+
+ # to lowest order the variance on the logarithm of a quantity x is
+ # sigma^2_log(x) = sigma^2/x^2
+ logsigma = sigma / amp
+
+ amp_samples = np.abs(model.sample_uv(uv[:,0],uv[:,1],pol=pol,jonesdict=jonesdict))
+ return np.sum(np.abs((np.log(amp) - np.log(amp_samples))/logsigma)**2)/len(amp)
+
+def chisqgrad_logamp(model, uv, amp, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the Log amplitude chi-squared"""
+
+ # to lowest order the variance on the logarithm of a quantity x is
+ # sigma^2_log(x) = sigma^2/x^2
+ logsigma = sigma / amp
+
+ i1 = model.sample_uv(uv[:,0],uv[:,1],pol=pol,jonesdict=jonesdict)
+ amp_samples = np.abs(i1)
+
+ V_grad = model.sample_grad_uv(uv[:,0],uv[:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict)
+
+ pp = ((np.log(amp) - np.log(amp_samples))) / (logsigma**2) / i1
+ out = (-2.0/len(amp)) * np.real(np.dot(pp, V_grad.T))
+ return out
+
+
+def chisq_pvis(model, uv, pvis, psigma, jonesdict=None):
+ """Polarimetric visibility chi-squared
+ """
+
+ psamples = model.sample_uv(uv[:,0],uv[:,1],pol='P',jonesdict=jonesdict)
+ return np.sum(np.abs((psamples-pvis)/psigma)**2)/(2*len(pvis))
+
+def chisqgrad_pvis(model, uv, pvis, psigma, fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """Polarimetric visibility chi-squared gradient
+ """
+ samples = model.sample_uv(uv[:,0],uv[:,1],pol='P',jonesdict=jonesdict)
+ wdiff = (pvis - samples)/(psigma**2)
+ grad = model.sample_grad_uv(uv[:,0],uv[:,1],pol='P',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+
+ out = -np.real(np.dot(grad.conj(), wdiff))/len(pvis)
+ return out
+
+def chisq_m(model, uv, m, msigma, jonesdict=None):
+ """Polarimetric ratio chi-squared
+ """
+
+ msamples = model.sample_uv(uv[:,0],uv[:,1],pol='P',jonesdict=jonesdict)/model.sample_uv(uv[:,0],uv[:,1],pol='I',jonesdict=jonesdict)
+
+ return np.sum(np.abs((m - msamples))**2/(msigma**2)) / (2*len(m))
+
+def chisqgrad_m(model, uv, mvis, msigma, fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the polarimetric ratio chisq
+ """
+
+ samp_P = model.sample_uv(uv[:,0],uv[:,1],pol='P',jonesdict=jonesdict)
+ samp_I = model.sample_uv(uv[:,0],uv[:,1],pol='I',jonesdict=jonesdict)
+ grad_P = model.sample_grad_uv(uv[:,0],uv[:,1],pol='P',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+ grad_I = model.sample_grad_uv(uv[:,0],uv[:,1],pol='I',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+
+ msamples = samp_P/samp_I
+ wdiff = (mvis - msamples)/(msigma**2)
+ # Get the gradient from the quotient rule
+ grad = ( grad_P * samp_I - grad_I * samp_P)/samp_I**2
+
+ return -np.real(np.dot(grad.conj(), wdiff))/len(mvis)
+
+def chisq_fracpol(upper, lower, model, uv, m, msigma, jonesdict=None):
+ """Polarimetric ratio chi-squared
+ """
+
+ msamples = model.sample_uv(uv[:,0],uv[:,1],pol=upper.upper(),jonesdict=jonesdict)/model.sample_uv(uv[:,0],uv[:,1],pol=lower.upper(),jonesdict=jonesdict)
+
+ return np.sum(np.abs((m - msamples))**2/(msigma**2)) / (2*len(m))
+
+def chisqgrad_fracpol(upper, lower, model, uv, mvis, msigma, fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the polarimetric ratio chisq
+ """
+
+ samp_upper = model.sample_uv(uv[:,0],uv[:,1],pol=upper.upper(),jonesdict=jonesdict)
+ samp_lower = model.sample_uv(uv[:,0],uv[:,1],pol=lower.upper(),jonesdict=jonesdict)
+ grad_upper = model.sample_grad_uv(uv[:,0],uv[:,1],pol=upper.upper(),fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+ grad_lower = model.sample_grad_uv(uv[:,0],uv[:,1],pol=lower.upper(),fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+
+ msamples = samp_upper/samp_lower
+ wdiff = (mvis - msamples)/(msigma**2)
+ # Get the gradient from the quotient rule
+ grad = ( grad_upper * samp_lower - grad_lower * samp_upper)/samp_lower**2
+
+ return -np.real(np.dot(grad.conj(), wdiff))/len(mvis)
+
+def chisq_polclosure(model, uv, vis, sigma, jonesdict=None):
+ """Polarimetric ratio chi-squared
+ """
+
+ RL = model.sample_uv(uv[:,0],uv[:,1],pol='RL',jonesdict=jonesdict)
+ LR = model.sample_uv(uv[:,0],uv[:,1],pol='LR',jonesdict=jonesdict)
+ RR = model.sample_uv(uv[:,0],uv[:,1],pol='RR',jonesdict=jonesdict)
+ LL = model.sample_uv(uv[:,0],uv[:,1],pol='LL',jonesdict=jonesdict)
+ samples = (RL * LR)/(RR * LL)
+
+ return np.sum(np.abs((vis - samples))**2/(sigma**2)) / (2*len(vis))
+
+def chisqgrad_polclosure(model, uv, vis, sigma, fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None):
+ """The gradient of the polarimetric ratio chisq
+ """
+
+ RL = model.sample_uv(uv[:,0],uv[:,1],pol='RL',jonesdict=jonesdict)
+ LR = model.sample_uv(uv[:,0],uv[:,1],pol='LR',jonesdict=jonesdict)
+ RR = model.sample_uv(uv[:,0],uv[:,1],pol='RR',jonesdict=jonesdict)
+ LL = model.sample_uv(uv[:,0],uv[:,1],pol='LL',jonesdict=jonesdict)
+
+ dRL = model.sample_grad_uv(uv[:,0],uv[:,1],pol='RL',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+ dLR = model.sample_grad_uv(uv[:,0],uv[:,1],pol='LR',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+ dRR = model.sample_grad_uv(uv[:,0],uv[:,1],pol='RR',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+ dLL = model.sample_grad_uv(uv[:,0],uv[:,1],pol='LL',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict)
+
+ samples = (RL * LR)/(RR * LL)
+ wdiff = (vis - samples)/(sigma**2)
+
+ # Get the gradient from the quotient rule
+ samp_upper = RL * LR
+ samp_lower = RR * LL
+ grad_upper = RL * dLR + dRL * LR
+ grad_lower = RR * dLL + dRR * LL
+ grad = ( grad_upper * samp_lower - grad_lower * samp_upper)/samp_lower**2
+
+ return -np.real(np.dot(grad.conj(), wdiff))/len(vis)
+
+
+##################################################################################################
+# Chi^2 Data functions
+##################################################################################################
+def apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol):
+ """apply systematic noise to VISIBILITIES or AMPLITUDES
+ data_arr should have fields 't1','t2','u','v','vis','amp','sigma'
+
+ returns: (uv, vis, amp, sigma)
+ """
+
+ vtype=vis_poldict[pol]
+ atype=amp_poldict[pol]
+ etype=sig_poldict[pol]
+
+ t1 = data_arr['t1']
+ t2 = data_arr['t2']
+
+ sigma = data_arr[etype]
+ amp = data_arr[atype]
+ try:
+ vis = data_arr[vtype]
+ except ValueError:
+ vis = amp.astype('c16')
+
+ snrmask = np.abs(amp/sigma) >= snrcut
+
+ if type(systematic_noise) is dict:
+ sys_level = np.zeros(len(t1))
+ for i in range(len(t1)):
+ if t1[i] in systematic_noise.keys():
+ t1sys = systematic_noise[t1[i]]
+ else:
+ t1sys = 0.
+ if t2[i] in systematic_noise.keys():
+ t2sys = systematic_noise[t2[i]]
+ else:
+ t2sys = 0.
+
+ if t1sys<0 or t2sys<0:
+ sys_level[i] = -1
+ else:
+ sys_level[i] = np.sqrt(t1sys**2 + t2sys**2)
+ else:
+ sys_level = np.sqrt(2)*systematic_noise*np.ones(len(t1))
+
+ mask = sys_level>=0.
+ mask = snrmask * mask
+
+ sigma = np.linalg.norm([sigma, sys_level*np.abs(amp)], axis=0)[mask]
+ vis = vis[mask]
+ amp = amp[mask]
+ uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1)))[mask]
+ return (uv, vis, amp, sigma)
+
+def make_jonesdict(Obsdata, data_arr):
+ # Make a dictionary with entries needed to form the Jones matrices
+ # Currently, this only works for data types on a single baseline (e.g., closure quantities aren't supported yet)
+
+ # Get the names of each station for every measurement
+ t1 = data_arr['t1']
+ t2 = data_arr['t2']
+
+ # Get the elevation of each station
+ el1 = data_arr['el1']*np.pi/180.
+ el2 = data_arr['el2']*np.pi/180.
+
+ # Get the parallactic angle of each station
+ par1 = data_arr['par_ang1']*np.pi/180.
+ par2 = data_arr['par_ang2']*np.pi/180.
+
+ # Compute the full field rotation angle for each site, based information in the Obsdata Array
+ fr_elev1 = np.array([Obsdata.tarr[Obsdata.tkey[o['t1']]]['fr_elev'] for o in data_arr])
+ fr_elev2 = np.array([Obsdata.tarr[Obsdata.tkey[o['t2']]]['fr_elev'] for o in data_arr])
+ fr_par1 = np.array([Obsdata.tarr[Obsdata.tkey[o['t1']]]['fr_par'] for o in data_arr])
+ fr_par2 = np.array([Obsdata.tarr[Obsdata.tkey[o['t2']]]['fr_par'] for o in data_arr])
+ fr_off1 = np.array([Obsdata.tarr[Obsdata.tkey[o['t1']]]['fr_off'] for o in data_arr])
+ fr_off2 = np.array([Obsdata.tarr[Obsdata.tkey[o['t2']]]['fr_off'] for o in data_arr])
+ fr1 = fr_elev1*el1 + fr_par1*par1 + fr_off1*np.pi/180.
+ fr2 = fr_elev2*el2 + fr_par2*par2 + fr_off2*np.pi/180.
+
+ # Now populate the left and right D-term entries based on the Obsdata Array
+ DR1 = np.array([Obsdata.tarr[Obsdata.tkey[o['t1']]]['dr'] for o in data_arr])
+ DL1 = np.array([Obsdata.tarr[Obsdata.tkey[o['t1']]]['dl'] for o in data_arr])
+ DR2 = np.array([Obsdata.tarr[Obsdata.tkey[o['t2']]]['dr'] for o in data_arr])
+ DL2 = np.array([Obsdata.tarr[Obsdata.tkey[o['t2']]]['dl'] for o in data_arr])
+
+ return {'fr1':fr1,'fr2':fr2,'t1':t1,'t2':t2,
+ 'DR1':DR1, 'DR2':DR2, 'DL1':DL1, 'DL2':DL2}
+
+def chisqdata_vis(Obsdata, pol='I', **kwargs):
+ """Return the data, sigmas, and fourier matrix for visibilities
+ """
+
+ # unpack keyword args
+ systematic_noise = kwargs.get('systematic_noise',0.)
+ snrcut = kwargs.get('snrcut',0.)
+ debias = kwargs.get('debias',True)
+ weighting = kwargs.get('weighting','natural')
+
+ # unpack data
+ vtype=vis_poldict[pol]
+ atype=amp_poldict[pol]
+ etype=sig_poldict[pol]
+ data_arr = Obsdata.unpack(['t1','t2','u','v',vtype,atype,etype,'el1','el2','par_ang1','par_ang2'], debias=debias)
+ (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol)
+
+ jonesdict = make_jonesdict(Obsdata, data_arr)
+
+ return (vis, sigma, uv, jonesdict)
+
+def chisqdata_amp(Obsdata, pol='I',**kwargs):
+ """Return the data, sigmas, and fourier matrix for visibility amplitudes
+ """
+
+ # unpack keyword args
+ systematic_noise = kwargs.get('systematic_noise',0.)
+ snrcut = kwargs.get('snrcut',0.)
+ debias = kwargs.get('debias',True)
+ weighting = kwargs.get('weighting','natural')
+
+ # unpack data
+ vtype=vis_poldict[pol]
+ atype=amp_poldict[pol]
+ etype=sig_poldict[pol]
+ if (Obsdata.amp is None) or (len(Obsdata.amp)==0) or pol!='I':
+ data_arr = Obsdata.unpack(['time','t1','t2','u','v',vtype,atype,etype,'el1','el2','par_ang1','par_ang2'], debias=debias)
+
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed amplitude table in amplitude chi^2!")
+ if not type(Obsdata.amp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed amplitude table is not a numpy rec array!")
+ data_arr = Obsdata.amp
+
+
+ # apply systematic noise and SNR cut
+ # TODO -- after pre-computed??
+ (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol)
+
+ # data weighting
+ if weighting=='uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ jonesdict = make_jonesdict(Obsdata, data_arr)
+
+ return (amp, sigma, uv, jonesdict)
+
+def chisqdata_bs(Obsdata, pol='I',**kwargs):
+ """return the data, sigmas, and fourier matrices for bispectra
+ """
+
+ # unpack keyword args
+ #systematic_noise = kwargs.get('systematic_noise',0.) #this will break with a systematic noise dict
+ maxset = kwargs.get('maxset',False)
+ if maxset: count='max'
+ else: count='min'
+
+ snrcut = kwargs.get('snrcut',0.)
+ debias = kwargs.get('debias',True)
+ weighting = kwargs.get('weighting','natural')
+
+ # unpack data
+ vtype=vis_poldict[pol]
+ if (Obsdata.bispec is None) or (len(Obsdata.bispec)==0) or pol!='I':
+ biarr = Obsdata.bispectra(mode="all", vtype=vtype, count=count,snrcut=snrcut)
+
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed bispectrum table in cphase chi^2!")
+ if not type(Obsdata.bispec) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed bispectrum table is not a numpy rec array!")
+ biarr = Obsdata.bispec
+ # reduce to a minimal set
+ if count!='max':
+ biarr = reduce_tri_minimal(Obsdata, biarr)
+
+ uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1)))
+ uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1)))
+ uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1)))
+ bi = biarr['bispec']
+ sigma = biarr['sigmab']
+
+ #add systematic noise
+ #sigma = np.linalg.norm([biarr['sigmab'], systematic_noise*np.abs(biarr['bispec'])], axis=0)
+
+ # data weighting
+ if weighting=='uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ return (bi, sigma, (uv1, uv2, uv3), None)
+
+def chisqdata_cphase(Obsdata, pol='I',**kwargs):
+ """Return the data, sigmas, and fourier matrices for closure phases
+ """
+
+ # unpack keyword args
+ maxset = kwargs.get('maxset',False)
+ uv_min = kwargs.get('cp_uv_min', False)
+ if maxset: count='max'
+ else: count='min'
+
+ snrcut = kwargs.get('snrcut',0.)
+ systematic_cphase_noise = kwargs.get('systematic_cphase_noise',0.)
+ weighting = kwargs.get('weighting','natural')
+
+ # unpack data
+ vtype=vis_poldict[pol]
+ if (Obsdata.cphase is None) or (len(Obsdata.cphase)==0) or pol!='I':
+ clphasearr = Obsdata.c_phases(mode="all", vtype=vtype, count=count, uv_min=uv_min, snrcut=snrcut)
+ else: #TODO precomputed with not Stokes I
+ print("Using pre-computed cphase table in cphase chi^2!")
+ if not type(Obsdata.cphase) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed closure phase table is not a numpy rec array!")
+ clphasearr = Obsdata.cphase
+ # reduce to a minimal set
+ if count!='max':
+ clphasearr = reduce_tri_minimal(Obsdata, clphasearr)
+
+ uv1 = np.hstack((clphasearr['u1'].reshape(-1,1), clphasearr['v1'].reshape(-1,1)))
+ uv2 = np.hstack((clphasearr['u2'].reshape(-1,1), clphasearr['v2'].reshape(-1,1)))
+ uv3 = np.hstack((clphasearr['u3'].reshape(-1,1), clphasearr['v3'].reshape(-1,1)))
+ clphase = clphasearr['cphase']
+ sigma = clphasearr['sigmacp']
+
+ #add systematic cphase noise (in DEGREES)
+ sigma = np.linalg.norm([sigma, systematic_cphase_noise*np.ones(len(sigma))], axis=0)
+
+ # data weighting
+ if weighting=='uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ return (clphase, sigma, (uv1, uv2, uv3), None)
+
+def chisqdata_cphase_diag(Obsdata, pol='I',**kwargs):
+ """Return the data, sigmas, and fourier matrices for diagonalized closure phases
+ """
+
+ # unpack keyword args
+ maxset = kwargs.get('maxset',False)
+ uv_min = kwargs.get('cp_uv_min', False)
+ if maxset: count='max'
+ else: count='min'
+
+ snrcut = kwargs.get('snrcut',0.)
+
+ # unpack data
+ vtype=vis_poldict[pol]
+ clphasearr = Obsdata.c_phases_diag(vtype=vtype,count=count,snrcut=snrcut,uv_min=uv_min)
+
+ # loop over timestamps
+ clphase_diag = []
+ sigma_diag = []
+ uv_diag = []
+ tform_mats = []
+ for ic, cl in enumerate(clphasearr):
+
+ # get diagonalized closure phases and errors
+ clphase_diag.append(cl[0]['cphase'])
+ sigma_diag.append(cl[0]['sigmacp'])
+
+ # get uv arrays
+ u1 = cl[2][:,0].astype('float')
+ v1 = cl[3][:,0].astype('float')
+ uv1 = np.hstack((u1.reshape(-1,1), v1.reshape(-1,1)))
+
+ u2 = cl[2][:,1].astype('float')
+ v2 = cl[3][:,1].astype('float')
+ uv2 = np.hstack((u2.reshape(-1,1), v2.reshape(-1,1)))
+
+ u3 = cl[2][:,2].astype('float')
+ v3 = cl[3][:,2].astype('float')
+ uv3 = np.hstack((u3.reshape(-1,1), v3.reshape(-1,1)))
+
+ # compute Fourier matrices
+ uv = (uv1,
+ uv2,
+ uv3
+ )
+ uv_diag.append(uv)
+
+ # get transformation matrix for this timestamp
+ tform_mats.append(cl[4].astype('float'))
+
+ # combine Fourier and transformation matrices into tuple for outputting
+ uvmatrices = (np.array(uv_diag),np.array(tform_mats))
+
+ return (np.array(clphase_diag), np.array(sigma_diag), uvmatrices, None)
+
+def chisqdata_camp(Obsdata, pol='I',**kwargs):
+ """Return the data, sigmas, and fourier matrices for closure amplitudes
+ """
+ # unpack keyword args
+ maxset = kwargs.get('maxset',False)
+ if maxset: count='max'
+ else: count='min'
+
+ snrcut = kwargs.get('snrcut',0.)
+ debias = kwargs.get('debias',True)
+ weighting = kwargs.get('weighting','natural')
+
+ # unpack data & mask low snr points
+ vtype=vis_poldict[pol]
+ if (Obsdata.camp is None) or (len(Obsdata.camp)==0) or pol!='I':
+ clamparr = Obsdata.c_amplitudes(mode='all', count=count, ctype='camp', debias=debias, snrcut=snrcut)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed closure amplitude table in closure amplitude chi^2!")
+ if not type(Obsdata.camp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed closure amplitude table is not a numpy rec array!")
+ clamparr = Obsdata.camp
+ # reduce to a minimal set
+ if count!='max':
+ clamparr = reduce_quad_minimal(Obsdata, clamparr, ctype='camp')
+
+ uv1 = np.hstack((clamparr['u1'].reshape(-1,1), clamparr['v1'].reshape(-1,1)))
+ uv2 = np.hstack((clamparr['u2'].reshape(-1,1), clamparr['v2'].reshape(-1,1)))
+ uv3 = np.hstack((clamparr['u3'].reshape(-1,1), clamparr['v3'].reshape(-1,1)))
+ uv4 = np.hstack((clamparr['u4'].reshape(-1,1), clamparr['v4'].reshape(-1,1)))
+ clamp = clamparr['camp']
+ sigma = clamparr['sigmaca']
+
+ # data weighting
+ if weighting=='uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ return (clamp, sigma, (uv1, uv2, uv3, uv4), None)
+
+def chisqdata_logcamp(Obsdata, pol='I', **kwargs):
+ """Return the data, sigmas, and fourier matrices for log closure amplitudes
+ """
+ # unpack keyword args
+ maxset = kwargs.get('maxset',False)
+ if maxset: count='max'
+ else: count='min'
+
+ snrcut = kwargs.get('snrcut',0.)
+ debias = kwargs.get('debias',True)
+ weighting = kwargs.get('weighting','natural')
+
+ # unpack data & mask low snr points
+ vtype=vis_poldict[pol]
+ if (Obsdata.logcamp is None) or (len(Obsdata.logcamp)==0) or pol!='I':
+ clamparr = Obsdata.c_amplitudes(mode='all', count=count, vtype=vtype, ctype='logcamp', debias=debias, snrcut=snrcut)
+ else: # TODO -- pre-computed with not stokes I?
+ print("Using pre-computed log closure amplitude table in log closure amplitude chi^2!")
+ if not type(Obsdata.logcamp) in [np.ndarray, np.recarray]:
+ raise Exception("pre-computed log closure amplitude table is not a numpy rec array!")
+ clamparr = Obsdata.logcamp
+ # reduce to a minimal set
+ if count!='max':
+ clamparr = reduce_quad_minimal(Obsdata, clamparr, ctype='logcamp')
+
+ uv1 = np.hstack((clamparr['u1'].reshape(-1,1), clamparr['v1'].reshape(-1,1)))
+ uv2 = np.hstack((clamparr['u2'].reshape(-1,1), clamparr['v2'].reshape(-1,1)))
+ uv3 = np.hstack((clamparr['u3'].reshape(-1,1), clamparr['v3'].reshape(-1,1)))
+ uv4 = np.hstack((clamparr['u4'].reshape(-1,1), clamparr['v4'].reshape(-1,1)))
+ clamp = clamparr['camp']
+ sigma = clamparr['sigmaca']
+
+ # data weighting
+ if weighting=='uniform':
+ sigma = np.median(sigma) * np.ones(len(sigma))
+
+ return (clamp, sigma, (uv1, uv2, uv3, uv4), None)
+
+def chisqdata_logcamp_diag(Obsdata, pol='I', **kwargs):
+ """Return the data, sigmas, and fourier matrices for diagonalized log closure amplitudes
+ """
+ # unpack keyword args
+ maxset = kwargs.get('maxset',False)
+ if maxset: count='max'
+ else: count='min'
+
+ snrcut = kwargs.get('snrcut',0.)
+ debias = kwargs.get('debias',True)
+
+ # unpack data & mask low snr points
+ vtype=vis_poldict[pol]
+ clamparr = Obsdata.c_log_amplitudes_diag(vtype=vtype,count=count,debias=debias,snrcut=snrcut)
+
+ # loop over timestamps
+ clamp_diag = []
+ sigma_diag = []
+ uv_diag = []
+ tform_mats = []
+ for ic, cl in enumerate(clamparr):
+
+ # get diagonalized log closure amplitudes and errors
+ clamp_diag.append(cl[0]['camp'])
+ sigma_diag.append(cl[0]['sigmaca'])
+
+ # get uv arrays
+ u1 = cl[2][:,0].astype('float')
+ v1 = cl[3][:,0].astype('float')
+ uv1 = np.hstack((u1.reshape(-1,1), v1.reshape(-1,1)))
+
+ u2 = cl[2][:,1].astype('float')
+ v2 = cl[3][:,1].astype('float')
+ uv2 = np.hstack((u2.reshape(-1,1), v2.reshape(-1,1)))
+
+ u3 = cl[2][:,2].astype('float')
+ v3 = cl[3][:,2].astype('float')
+ uv3 = np.hstack((u3.reshape(-1,1), v3.reshape(-1,1)))
+
+ u4 = cl[2][:,3].astype('float')
+ v4 = cl[3][:,3].astype('float')
+ uv4 = np.hstack((u4.reshape(-1,1), v4.reshape(-1,1)))
+
+ # compute Fourier matrices
+ uv = (uv1,
+ uv2,
+ uv3,
+ uv4
+ )
+ uv_diag.append(uv)
+
+ # get transformation matrix for this timestamp
+ tform_mats.append(cl[4].astype('float'))
+
+ # combine Fourier and transformation matrices into tuple for outputting
+ uvmatrices = (np.array(uv_diag),np.array(tform_mats))
+
+ return (np.array(clamp_diag), np.array(sigma_diag), uvmatrices, None)
+
+def chisqdata_pvis(Obsdata, pol='I', **kwargs):
+ data_arr = Obsdata.unpack(['t1','t2','u','v','pvis','psigma','el1','el2','par_ang1','par_ang2'], conj=True)
+ uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1)))
+ mask = np.isfinite(data_arr['pvis'] + data_arr['psigma']) # don't include nan (missing data) or inf (division by zero)
+ jonesdict = make_jonesdict(Obsdata, data_arr[mask])
+ return (data_arr['pvis'][mask], data_arr['psigma'][mask], uv[mask], jonesdict)
+
+def chisqdata_m(Obsdata, pol='I',**kwargs):
+ debias = kwargs.get('debias',True)
+ data_arr = Obsdata.unpack(['t1','t2','u','v','m','msigma','el1','el2','par_ang1','par_ang2'], conj=True, debias=False)
+ uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1)))
+ mask = np.isfinite(data_arr['m'] + data_arr['msigma']) # don't include nan (missing data) or inf (division by zero)
+ jonesdict = make_jonesdict(Obsdata, data_arr[mask])
+ return (data_arr['m'][mask], data_arr['msigma'][mask], uv[mask], jonesdict)
+
+def chisqdata_fracpol(Obsdata, pol_upper,pol_lower,**kwargs):
+ debias = kwargs.get('debias',True)
+ data_arr = Obsdata.unpack(['t1','t2','u','v','m','msigma','el1','el2','par_ang1','par_ang2','rrvis','rlvis','lrvis','llvis','rramp','rlamp','lramp','llamp','rrsigma','rlsigma','lrsigma','llsigma'], conj=False, debias=True)
+ uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1)))
+
+ upper = data_arr[pol_upper + 'vis']
+ lower = data_arr[pol_lower + 'vis']
+ upper_amp = data_arr[pol_upper + 'amp']
+ lower_amp = data_arr[pol_lower + 'amp']
+ upper_sig = data_arr[pol_upper + 'sigma']
+ lower_sig = data_arr[pol_lower + 'sigma']
+
+ sig = ((upper_sig/lower_amp)**2 + (lower_sig*upper_amp/lower_amp**2)**2)**0.5
+
+ # Mask bad data
+ mask = np.isfinite(upper + lower + sig) # don't include nan (missing data) or inf (division by zero)
+ jonesdict = make_jonesdict(Obsdata, data_arr[mask])
+
+ return ((upper/lower)[mask], sig[mask], uv[mask], jonesdict)
+
+def chisqdata_polclosure(Obsdata, **kwargs):
+ debias = kwargs.get('debias',True)
+ data_arr = Obsdata.unpack(['t1','t2','u','v','m','msigma','el1','el2','par_ang1','par_ang2','rrvis','rlvis','lrvis','llvis','rramp','rlamp','lramp','llamp','rrsigma','rlsigma','lrsigma','llsigma'], conj=False, debias=True)
+ uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1)))
+
+ RL = data_arr['rlvis']
+ LR = data_arr['lrvis']
+ RR = data_arr['rrvis']
+ LL = data_arr['llvis']
+ vis = (RL * LR)/(RR * LL)
+ sig = (np.abs(LR/(LL*RR) * data_arr['rlsigma'])**2
+ +np.abs(RL/(LL*RR) * data_arr['lrsigma'])**2
+ +np.abs(LR*RL/(RR**2*LL) * data_arr['rrsigma'])**2
+ +np.abs(RL*LR/(LL**2*RR) * data_arr['llsigma'])**2)**0.5
+
+ # Mask bad data
+ mask = np.isfinite(vis + sig) # don't include nan (missing data) or inf (division by zero)
+ jonesdict = make_jonesdict(Obsdata, data_arr[mask])
+
+ return (vis[mask], sig[mask], uv[mask], jonesdict)
diff --git a/movie.py b/movie.py
new file mode 100644
index 00000000..d0855e15
--- /dev/null
+++ b/movie.py
@@ -0,0 +1,1935 @@
+# movie.py
+# a interferometric movie class
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import string
+import numpy as np
+import scipy.interpolate
+import scipy.ndimage.filters as filt
+
+import ehtim.image
+import ehtim.obsdata
+import ehtim.observing.obs_simulate as simobs
+import ehtim.io.save
+import ehtim.io.load
+import ehtim.const_def as ehc
+import ehtim.observing.obs_helpers as obsh
+
+INTERPOLATION_KINDS = ['linear', 'nearest', 'zero', 'slinear',
+ 'quadratic', 'cubic', 'previous', 'next']
+
+###################################################################################################
+# Movie object
+###################################################################################################
+
+
+class Movie(object):
+
+ """A polarimetric movie (in units of Jy/pixel).
+
+ Attributes:
+ pulse (function): The function convolved with the pixel values for continuous image
+ psize (float): The pixel dimension in radians
+ xdim (int): The number of pixels along the x dimension
+ ydim (int): The number of pixels along the y dimension
+ mjd (int): The integer MJD of the image
+ source (str): The astrophysical source name
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ rf (float): The image frequency in Hz
+
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, or RR,LL,LR,RL for Circular
+ interp (str): Interpolation method, for scipy.interpolate.interp1d 'kind' keyword
+ (e.g. 'linear', 'nearest', 'quadratic', 'cubic', 'previous', 'next'...)
+ bounds_error (bool): if False, return nearest frame when outside [start_hr, stop_hr]
+
+
+ times (list): The list of frame time stamps in hours
+
+ _movdict (dict): The dictionary with the lists of frames
+ """
+
+ def __init__(self, frames, times, psize, ra, dec,
+ rf=ehc.RF_DEFAULT, polrep='stokes', pol_prim=None,
+ pulse=ehc.PULSE_DEFAULT, source=ehc.SOURCE_DEFAULT,
+ mjd=ehc.MJD_DEFAULT,
+ bounds_error=ehc.BOUNDS_ERROR, interp=ehc.INTERP_DEFAULT):
+ """A polarimetric image (in units of Jy/pixel).
+
+ Args:
+ frames (list): The list of 2D frames; each is a Jy/pixel array
+ times (list): The list of frame time stamps in hours
+ psize (float): The pixel dimension in radians
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ rf (float): The image frequency in Hz
+
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+ pulse (function): The function convolved with the pixel values for continuous image.
+ source (str): The source name
+ mjd (int): The integer MJD of the image
+ interp (str): Interpolation method, for scipy.interpolate.interp1d 'kind' keyword
+ (e.g. 'linear', 'nearest', 'quadratic', 'cubic', 'previous', 'next'...)
+ bounds_error (bool): if False, return nearest frame when outside [start_hr, stop_hr]
+
+ Returns:
+ (Image): the Image object
+ """
+
+ if len(frames[0].shape) != 2:
+ raise Exception("frames must each be a 2D numpy array")
+
+ if len(frames) != len(times):
+ raise Exception("len(frames) != len(times) !")
+
+ if not (interp in INTERPOLATION_KINDS):
+ raise Exception(
+ "'interp' must be a valid argument for scipy.interpolate.interp1d: " +
+ string.join(INTERPOLATION_KINDS))
+
+ self.times = times
+ start_hr = np.min(self.times)
+ self.mjd = int(mjd)
+ if start_hr > 24:
+ self.mjd += int((start_hr - start_hr % 24)/24)
+ self.start_hr = float(start_hr % 24)
+ else:
+ self.start_hr = start_hr
+ self.stop_hr = np.max(self.times)
+ self.duration = self.stop_hr - self.start_hr
+
+ # frame shape parameters
+ self.nframes = len(frames)
+ self.polrep = polrep
+ self.pulse = pulse
+ self.psize = float(psize)
+ self.xdim = frames[0].shape[1]
+ self.ydim = frames[0].shape[0]
+
+ # the list of frames
+ frames = np.array([image.flatten() for image in frames])
+ self.interp = interp
+ self.bounds_error = bounds_error
+
+ fill_value = (frames[0], frames[-1])
+ fun = scipy.interpolate.interp1d(self.times, frames.T, kind=interp,
+ bounds_error=bounds_error, fill_value=fill_value)
+
+ if polrep == 'stokes':
+ if pol_prim is None:
+ pol_prim = 'I'
+ if pol_prim == 'I':
+ self._movdict = {'I': frames, 'Q': [], 'U': [], 'V': []}
+ self._fundict = {'I': fun, 'Q': None, 'U': None, 'V': None}
+ elif pol_prim == 'V':
+ self._movdict = {'I': [], 'Q': [], 'U': [], 'V': frames}
+ self._fundict = {'I': None, 'Q': None, 'U': None, 'V': fun}
+ elif pol_prim == 'Q':
+ self._movdict = {'I': [], 'Q': frames, 'U': [], 'V': []}
+ self._fundict = {'I': None, 'Q': fun, 'U': None, 'V': None}
+ elif pol_prim == 'U':
+ self._movdict = {'I': [], 'Q': [], 'U': frames, 'V': []}
+ self._fundict = {'I': None, 'Q': None, 'U': frames, 'V': None}
+ else:
+ raise Exception("for polrep=='stokes', pol_prim must be 'I','Q','U', or 'V'!")
+
+ elif polrep == 'circ':
+ if pol_prim is None:
+ print("polrep is 'circ' and no pol_prim specified! Setting pol_prim='RR'")
+ pol_prim = 'RR'
+ if pol_prim == 'RR':
+ self._movdict = {'RR': frames, 'LL': [], 'RL': [], 'LR': []}
+ self._fundict = {'RR': fun, 'LL': None, 'RL': None, 'LR': None}
+ elif pol_prim == 'LL':
+ self._movdict = {'RR': [], 'LL': frames, 'RL': [], 'LR': []}
+ self._fundict = {'RR': None, 'LL': fun, 'RL': None, 'LR': None}
+ else:
+ raise Exception("for polrep=='circ', pol_prim must be 'RR' or 'LL'!")
+
+ self.pol_prim = pol_prim
+
+ self.ra = float(ra)
+ self.dec = float(dec)
+ self.rf = float(rf)
+ self.source = str(source)
+ self.pa = 0.0 # TODO: The pa needs to be properly implemented in the movie object
+ # TODO: What is this doing??
+
+ @property
+ def frames(self):
+ frames = self._movdict[self.pol_prim]
+ return frames
+
+ @frames.setter
+ def frames(self, frames):
+ if len(frames[0]) != self.xdim*self.ydim:
+ raise Exception("imvec size is not consistent with xdim*ydim!")
+ # TODO -- more checks on consistency with the existing pol data???
+
+ frames = np.array(frames)
+ self._movdict[self.pol_prim] = frames
+
+ fill_value = (frames[0], frames[-1])
+ fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp,
+ bounds_error=self.bounds_error, fill_value=fill_value)
+ self._fundict[self.pol_prim] = fun
+
+ @property
+ def iframes(self):
+
+ if self.polrep != 'stokes':
+ raise Exception(
+ "iframes is not defined unless self.polrep=='stokes' -- try self.switch_polrep()")
+
+ frames = self._movdict['I']
+ return frames
+
+ @iframes.setter
+ def iframes(self, frames):
+
+ if len(frames[0]) != self.xdim*self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+
+ # TODO -- more checks on the consistency of the imvec with the existing pol data???
+ frames = np.array(frames)
+ self._movdict['I'] = frames
+ fill_value = (frames[0], frames[-1])
+ fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp,
+ bounds_error=self.bounds_error, fill_value=fill_value)
+ self._fundict['I'] = fun
+
+ @property
+ def qframes(self):
+
+ if self.polrep != 'stokes':
+ raise Exception(
+ "qframes is not defined unless self.polrep=='stokes' -- try self.switch_polrep()")
+
+ frames = self._movdict['Q']
+ return frames
+
+ @qframes.setter
+ def qframes(self, frames):
+
+ if len(frames[0]) != self.xdim*self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+
+ # TODO -- more checks on the consistency of the imvec with the existing pol data???
+ frames = np.array(frames)
+ self._movdict['Q'] = frames
+ fill_value = (frames[0], frames[-1])
+ fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp,
+ bounds_error=self.bounds_error, fill_value=fill_value)
+ self._fundict['Q'] = fun
+
+ @property
+ def uframes(self):
+
+ if self.polrep != 'stokes':
+ raise Exception(
+ "uframes is not defined unless self.polrep=='stokes' -- try self.switch_polrep()")
+
+ frames = self._movdict['U']
+
+ return frames
+
+ @uframes.setter
+ def uframes(self, frames):
+
+ if len(frames[0]) != self.xdim*self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+
+ # TODO -- more checks on the consistency of the imvec with the existing pol data???
+ frames = np.array(frames)
+ self._movdict['U'] = frames
+ fill_value = (frames[0], frames[-1])
+ fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp,
+ bounds_error=self.bounds_error, fill_value=fill_value)
+ self._fundict['U'] = fun
+
+ @property
+ def vframes(self):
+
+ if self.polrep != 'stokes':
+ raise Exception(
+ "vframes is not defined unless self.polrep=='stokes' -- try self.switch_polrep()")
+
+ frames = self._movdict['V']
+
+ return frames
+
+ @vframes.setter
+ def vframes(self, frames):
+
+ if len(frames[0]) != self.xdim*self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+
+ # TODO -- more checks on the consistency of the imvec with the existing pol data???
+ frames = np.array(frames)
+ self._movdict['V'] = frames
+ fill_value = (frames[0], frames[-1])
+ fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp,
+ bounds_error=self.bounds_error, fill_value=fill_value)
+ self._fundict['V'] = fun
+
+ @property
+ def rrframes(self):
+
+ if self.polrep != 'circ':
+ raise Exception(
+ "rrframes is not defined unless self.polrep=='circ' -- try self.switch_polrep()")
+
+ frames = self._movdict['RR']
+ return frames
+
+ @rrframes.setter
+ def rrframes(self, frames):
+
+ if len(frames[0]) != self.xdim*self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+
+ # TODO -- more checks on the consistency of the imvec with the existing pol data???
+ frames = np.array(frames)
+ self._movdict['RR'] = frames
+ fill_value = (frames[0], frames[-1])
+ fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp,
+ bounds_error=self.bounds_error, fill_value=fill_value)
+ self._fundict['RR'] = fun
+
+ @property
+ def llframes(self):
+
+ if self.polrep != 'circ':
+ raise Exception(
+ "llframes is not defined unless self.polrep=='circ' -- try self.switch_polrep()")
+
+ frames = self._movdict['LL']
+ return frames
+
+ @llframes.setter
+ def llframes(self, frames):
+
+ if len(frames[0]) != self.xdim*self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+
+ # TODO -- more checks on the consistency of the imvec with the existing pol data???
+ frames = np.array(frames)
+ self._movdict['LL'] = frames
+ fill_value = (frames[0], frames[-1])
+ fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp,
+ bounds_error=self.bounds_error, fill_value=fill_value)
+ self._fundict['LL'] = fun
+
+ @property
+ def rlvec(self):
+
+ if self.polrep != 'circ':
+ raise Exception(
+ "rlframes is not defined unless self.polrep=='circ' -- try self.switch_polrep()")
+
+ frames = self._movdict['RL']
+ return frames
+
+ @rlvec.setter
+ def rlvec(self, frames):
+
+ if len(frames[0]) != self.xdim*self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+
+ # TODO -- more checks on the consistency of the imvec with the existing pol data???
+ frames = np.array(frames)
+ self._movdict['RL'] = frames
+ fill_value = (frames[0], frames[-1])
+ fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp,
+ bounds_error=self.bounds_error, fill_value=fill_value)
+ self._fundict['RL'] = fun
+
+ @property
+ def lrvec(self):
+
+ if self.polrep != 'circ':
+ raise Exception(
+ "lrframes is not defined unless self.polrep=='circ' -- try self.switch_polrep()")
+
+ frames = self._movdict['LR']
+ return frames
+
+ @lrvec.setter
+ def lrvec(self, frames):
+
+ if len(frames[0]) != self.xdim*self.ydim:
+ raise Exception("vec size is not consistent with xdim*ydim!")
+
+ # TODO -- more checks on the consistency of the imvec with the existing pol data???
+ frames = np.array(frames)
+ self._movdict['LR'] = frames
+ fill_value = (frames[0], frames[-1])
+ fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp,
+ bounds_error=self.bounds_error, fill_value=fill_value)
+ self._fundict['LR'] = fun
+
+ def movie_args(self):
+ """"Copy arguments for making a new Movie into a list and dictonary
+ """
+
+ frames2D = self.frames.reshape((self.nframes, self.ydim, self.xdim))
+ arglist = [frames2D.copy(), self.times.copy(), self.psize, self.ra, self.dec]
+ #arglist = [frames2D, self.times, self.psize, self.ra, self.dec]
+ argdict = {'rf': self.rf, 'polrep': self.polrep, 'pol_prim': self.pol_prim,
+ 'pulse': self.pulse, 'source': self.source,
+ 'mjd': self.mjd, 'interp': self.interp, 'bounds_error': self.bounds_error}
+
+ return (arglist, argdict)
+
+ def copy(self):
+ """Return a copy of the Movie object.
+
+ Args:
+
+ Returns:
+ (Movie): copy of the Image.
+ """
+
+ arglist, argdict = self.movie_args()
+
+ # Make new movie with primary polarization
+ newmov = Movie(*arglist, **argdict)
+
+ # Copy over all polarization movies
+ for pol in list(self._movdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polframes = self._movdict[pol]
+ if len(polframes):
+ polframes = polframes.reshape((self.nframes, self.ydim, self.xdim))
+ newmov.add_pol_movie(polframes, pol)
+
+ return newmov
+
+ def reset_interp(self, interp=None, bounds_error=None):
+ """Reset the movie interpolation function to change the interp. type or change the frames
+
+ Args:
+ interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword
+ bounds_error (bool): if False, return nearest frame outside [start_hr, stop_hr]
+ """
+
+ if interp is None:
+ interp = self.interp
+ if bounds_error is None:
+ bounds_error = self.bounds_error
+
+ # Copy over all polarization movies
+ for pol in list(self._movdict.keys()):
+ polframes = self._movdict[pol]
+ if len(polframes):
+ fill_value = (polframes[0], polframes[-1])
+ fun = scipy.interpolate.interp1d(self.times, polframes.T, kind=interp,
+ fill_value=fill_value, bounds_error=bounds_error)
+ self._fundict[pol] = fun
+ else:
+ self._fundict[pol] = None
+
+ self.interp = interp
+ self.bounds_error = bounds_error
+ return
+
+ def offset_time(self, t_offset):
+ """Offset the movie in time by t_offset
+
+ Args:
+ t_offset (float): offset time in hours
+ Returns:
+
+ """
+ mov = self.copy()
+ mov.start_hr += t_offset
+ mov.stop_hr += t_offset
+ mov.times += t_offset
+ mov.reset_interp(interp=mov.interp, bounds_error=mov.bounds_error)
+ return mov
+
+ def add_pol_movie(self, movie, pol):
+ """Add another movie polarization.
+
+ Args:
+ movie (list): list of 2D frames (possibly complex) in a Jy/pixel array
+ pol (str): The image type: 'I','Q','U','V' for stokes, 'RR','LL','RL','LR' for circ
+ """
+ if not(len(movie) == self.nframes):
+ raise Exception("new pol movies must have same length as primary movie!")
+
+ if pol == self.pol_prim:
+ raise Exception("new pol in add_pol_movie is the same as pol_prim!")
+ if np.any(np.array([image.shape != (self.ydim, self.xdim) for image in movie])):
+ raise Exception("add_pol_movie image shapes incompatible with primary image!")
+ if not (pol in list(self._movdict.keys())):
+ raise Exception("for polrep==%s, pol in add_pol_movie must be in " %
+ self.polrep + ",".join(list(self._movdict.keys())))
+
+ if self.polrep == 'stokes':
+ if pol == 'I':
+ self.iframes = [image.flatten() for image in movie]
+ elif pol == 'Q':
+ self.qframes = [image.flatten() for image in movie]
+ elif pol == 'U':
+ self.uframes = [image.flatten() for image in movie]
+ elif pol == 'V':
+ self.vframes = [image.flatten() for image in movie]
+
+ if len(self.iframes) > 0:
+ fill_value = (self.iframes[0], self.iframes[-1])
+ ifun = scipy.interpolate.interp1d(self.times, self.iframes.T, kind=self.interp,
+ fill_value=fill_value,
+ bounds_error=self.bounds_error)
+ else:
+ ifun = None
+ if len(self.vframes) > 0:
+ fill_value = (self.vframes[0], self.vframes[-1])
+ vfun = scipy.interpolate.interp1d(self.times, self.vframes.T, kind=self.interp,
+ fill_value=fill_value,
+ bounds_error=self.bounds_error)
+ else:
+ vfun = None
+ if len(self.qframes) > 0:
+ fill_value = (self.qframes[0], self.qframes[-1])
+ qfun = scipy.interpolate.interp1d(self.times, self.qframes.T, kind=self.interp,
+ fill_value=fill_value,
+ bounds_error=self.bounds_error)
+ else:
+ qfun = None
+ if len(self.uframes) > 0:
+ fill_value = (self.uframes[0], self.uframes[-1])
+ ufun = scipy.interpolate.interp1d(self.times, self.uframes.T, kind=self.interp,
+ fill_value=fill_value,
+ bounds_error=self.bounds_error)
+ else:
+ ufun = None
+
+ self._movdict = {'I': self.iframes, 'Q': self.qframes,
+ 'U': self.uframes, 'V': self.vframes}
+ self._fundict = {'I': ifun, 'Q': qfun, 'U': ufun, 'V': vfun}
+
+ elif self.polrep == 'circ':
+ if pol == 'RR':
+ self.rrframes = [image.flatten() for image in movie]
+ elif pol == 'LL':
+ self.llframes = [image.flatten() for image in movie]
+ elif pol == 'RL':
+ self.rlframes = [image.flatten() for image in movie]
+ elif pol == 'LR':
+ self.lrframes = [image.flatten() for image in movie]
+
+ if len(self.rrframes) > 0:
+ fill_value = (self.rrframes[0], self.rrframes[-1])
+ rrfun = scipy.interpolate.interp1d(self.times, self.rrframes.T, kind=self.interp,
+ fill_value=fill_value,
+ bounds_error=self.bounds_error)
+ else:
+ rrfun = None
+ if len(self.llframes) > 0:
+ fill_value = (self.llframes[0], self.llframes[-1])
+ llfun = scipy.interpolate.interp1d(self.times, self.llframes.T, kind=self.interp,
+ fill_value=fill_value,
+ bounds_error=self.bounds_error)
+ else:
+ llfun = None
+ if len(self.rlframes) > 0:
+ fill_value = (self.rlframes[0], self.rlframes[-1])
+ rlfun = scipy.interpolate.interp1d(self.times, self.rlframes.T, kind=self.interp,
+ fill_value=fill_value,
+ bounds_error=self.bounds_error)
+ else:
+ rlfun = None
+ if len(self.lrframes) > 0:
+ fill_value = (self.lrframes[0], self.lrframes[-1])
+ lrfun = scipy.interpolate.interp1d(self.times, self.lrframes.T, kind=self.interp,
+ fill_value=fill_value,
+ bounds_error=self.bounds_error)
+ else:
+ lrfun = None
+
+ self._movdict = {'RR': self.rrframes, 'LL': self.llframes,
+ 'RL': self.rlframes, 'LR': self.lrframes}
+ self._fundict = {'RR': rrfun, 'LL': llfun, 'RL': rlfun, 'LR': lrfun}
+ return
+
+ # TODO deprecated -- replace with generic add_pol_movie
+ def add_qu(self, qmovie, umovie):
+ """Add Stokes Q and U movies. self.polrep must be 'stokes'
+
+ Args:
+ qmovie (list): list of 2D Stokes Q frames in Jy/pixel array
+ umovie (list): list of 2D Stokes U frames in Jy/pixel array
+
+ Returns:
+ """
+
+ if self.polrep != 'stokes':
+ raise Exception("polrep must be 'stokes' for add_qu() !")
+ self.add_pol_movie(qmovie, 'Q')
+ self.add_pol_movie(umovie, 'U')
+
+ return
+
+ # TODO deprecated -- replace with generic add_pol_movie
+ def add_v(self, vmovie):
+ """Add Stokes V movie. self.polrep must be 'stokes'
+
+ Args:
+ vmovie (list): list of 2D Stokes Q frames in Jy/pixel array
+
+ Returns:
+ """
+
+ if self.polrep != 'stokes':
+ raise Exception("polrep must be 'stokes' for add_v() !")
+ self.add_pol_movie(vmovie, 'V')
+
+ return
+
+ def switch_polrep(self, polrep_out='stokes', pol_prim_out=None):
+ """Return a new movie with the polarization representation changed
+
+ Args:
+ polrep_out (str): the polrep of the output data
+ pol_prim_out (str): The default movie: I,Q,U or V for Stokes,
+ RR,LL,LR,RL for Circular
+
+ Returns:
+ (Movie): new movie object with potentially different polrep
+ """
+
+ if polrep_out not in ['stokes', 'circ']:
+ raise Exception("polrep_out must be either 'stokes' or 'circ'")
+ if pol_prim_out is None:
+ if polrep_out == 'stokes':
+ pol_prim_out = 'I'
+ elif polrep_out == 'circ':
+ pol_prim_out = 'RR'
+
+ # Simply copy if the polrep is unchanged
+ if polrep_out == self.polrep and pol_prim_out == self.pol_prim:
+ return self.copy()
+
+ # Assemble a dictionary of new polarization vectors
+ framedim = (self.nframes, self.ydim, self.xdim)
+ if polrep_out == 'stokes':
+ if self.polrep == 'stokes':
+ movdict = {'I': self.iframes, 'Q': self.qframes,
+ 'U': self.uframes, 'V': self.vframes}
+ else:
+ if len(self.rrframes) == 0 or len(self.llframes) == 0:
+ iframes = []
+ vframes = []
+ else:
+ iframes = 0.5*(self.rrframes.reshape(framedim) +
+ self.llframes.reshape(framedim))
+ vframes = 0.5*(self.rrframes.reshape(framedim) -
+ self.llframes.reshape(framedim))
+
+ if len(self.rlframes) == 0 or len(self.lrframes) == 0:
+ qframes = []
+ uframes = []
+ else:
+ qframes = np.real(0.5*(self.lrframes.reshape(framedim) +
+ self.rlframes.reshape(framedim)))
+ uframes = np.real(0.5j*(self.lrframes.reshape(framedim) -
+ self.rlframes.reshape(framedim)))
+
+ movdict = {'I': iframes, 'Q': qframes, 'U': uframes, 'V': vframes}
+
+ elif polrep_out == 'circ':
+ if self.polrep == 'circ':
+ movdict = {'RR': self.rrframes, 'LL': self.llframes,
+ 'RL': self.rlframes, 'LR': self.lrframes}
+ else:
+ if len(self.iframes) == 0 or len(self.vframes) == 0:
+ rrframes = []
+ llframes = []
+ else:
+ rrframes = (self.iframes.reshape(framedim) + self.vframes.reshape(framedim))
+ llframes = (self.iframes.reshape(framedim) - self.vframes.reshape(framedim))
+
+ if len(self.qframes) == 0 or len(self.uframes) == 0:
+ rlframes = []
+ lrframes = []
+ else:
+ rlframes = (self.qframes.reshape(framedim) + 1j*self.uframes.reshape(framedim))
+ lrframes = (self.qframes.reshape(framedim) - 1j*self.uframes.reshape(framedim))
+
+ movdict = {'RR': rrframes, 'LL': llframes, 'RL': rlframes, 'LR': lrframes}
+
+ # Assemble the new movie
+ frames = movdict[pol_prim_out]
+ if len(frames) == 0:
+ raise Exception("switch_polrep to " +
+ "%s with pol_prim_out=%s, \n" % (polrep_out, pol_prim_out) +
+ "output movie is not defined")
+
+ # Make new movie with primary polarization
+ arglist, argdict = self.movie_args()
+ arglist[0] = frames
+ argdict['polrep'] = polrep_out
+ argdict['pol_prim'] = pol_prim_out
+ newmov = Movie(*arglist, **argdict)
+
+ # Add in any other polarizations
+ for pol in list(movdict.keys()):
+ if pol == pol_prim_out:
+ continue
+ polframes = movdict[pol]
+ if len(polframes):
+ polframes = polframes.reshape((self.nframes, self.ydim, self.xdim))
+ newmov.add_pol_movie(polframes, pol)
+
+ return newmov
+
+ def orth_chi(self):
+ """Rotate the EVPA 90 degrees
+
+ Args:
+
+ Returns:
+ (Image): movie with rotated EVPA
+ """
+ mov = self.copy()
+ if mov.polrep == 'stokes':
+ mov.qframes *= -1
+ mov.uframes *= -1
+ elif mov.polrep == 'circ':
+ mov.lrframes *= -1
+ mov.rlframes *= -1
+
+ return mov
+
+ def fovx(self):
+ """Return the movie fov in x direction in radians.
+
+ Args:
+
+ Returns:
+ (float) : movie fov in x direction (radian)
+ """
+
+ return self.psize * self.xdim
+
+ def fovy(self):
+ """Returns the movie fov in y direction in radians.
+
+ Args:
+
+ Returns:
+ (float) : movie fov in y direction (radian)
+ """
+
+ return self.psize * self.ydim
+
+ @property
+ def lightcurve(self):
+ """Return the total flux over time of the image in Jy.
+
+ Args:
+
+ Returns:
+ (numpy.Array) : image total flux (Jy) over time
+ """
+ if self.polrep == 'stokes':
+ flux = [np.sum(ivec) for ivec in self.iframes]
+ elif self.polrep == 'circ':
+ flux = [0.5*(np.sum(self.rrframes[i])+np.sum(self.llframes[i]))
+ for i in range(self.nframes)]
+
+ return np.array(flux)
+
+ def lin_polfrac_curve(self):
+ """Return the total fractional linear polarized flux over time
+
+ Args:
+
+ Returns:
+ (numpy.ndarray) : image fractional linear polarized flux per frame
+ """
+ if self.polrep == 'stokes':
+ frac = [np.abs(np.sum(self.qframes[i] + 1j*self.uframes[i])) /
+ np.abs(np.sum(self.iframes[i]))
+ for i in range(self.nframes)]
+ elif self.polrep == 'circ':
+ frac = [2*np.abs(np.sum(self.rlframes[i])) /
+ np.abs(np.sum(self.rrframes[i]+self.llframes[i]))
+ for i in range(self.nframes)]
+ return np.array(frac)
+
+ def circ_polfrac_curve(self):
+ """Return the (signed) total fractional circular polarized flux over time
+
+ Args:
+
+ Returns:
+ (numpy.ndarray) : image fractional circular polarized flux per frame
+ """
+ if self.polrep == 'stokes':
+ frac = [np.sum(self.vframes[i]) / np.abs(np.sum(self.iframes[i]))
+ for i in range(self.nframes)]
+ elif self.polrep == 'circ':
+ frac = [np.sum(self.rrframes[i]-self.llframes[i]) /
+ np.abs(np.sum(self.rrframes[i] + self.llframes[i]))
+ for i in range(self.nframes)]
+
+ return np.array(frac)
+
+ def get_image(self, time):
+ """Return an Image at time
+
+ Args:
+ time (float): the time in hours
+
+ Returns:
+ (Image): the Image object at the given time
+ """
+
+ if (time < self.start_hr):
+ if not(self.bounds_error):
+ pass
+ # print ("time %f before movie start time %f" % (time, self.start_hr))
+ # print ("returning constant frame 0! \n")
+ else:
+ raise Exception("time %f must be in the range %f - %f" %
+ (time, self.start_hr, self.stop_hr))
+
+ if (time > self.stop_hr):
+ if not(self.bounds_error):
+ pass
+ # print ("time %f after movie stop time %f" % (time, self.stop_hr))
+ # print ("returning constant frame -1! \n")
+ else:
+ raise Exception("time %f must be in the range %f - %f" %
+ (time, self.start_hr, self.stop_hr))
+
+ # interpolate the imvec to the given time
+ imvec = self._fundict[self.pol_prim](time)
+
+ # Make the primary image
+ imarr = imvec.reshape(self.ydim, self.xdim)
+ outim = ehtim.image.Image(imarr, self.psize, self.ra, self.dec, self.pa,
+ polrep=self.polrep, pol_prim=self.pol_prim, time=time,
+ rf=self.rf, source=self.source, mjd=self.mjd, pulse=self.pulse)
+
+ # Copy over the rest of the polarizations
+ for pol in list(self._movdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polframes = self._movdict[pol]
+ if len(polframes):
+ polvec = self._fundict[pol](time)
+ polarr = polvec.reshape(self.ydim, self.xdim).copy()
+ outim.add_pol_image(polarr, pol)
+
+ return outim
+
+ def get_frame(self, n):
+ """Return an Image of the nth frame
+
+ Args:
+ n (int): the frame number
+
+ Returns:
+ (Image): the Image object of the nth frame
+ """
+
+ if n < 0 or n >= len(self.frames):
+ raise Exception("n must be in the range 0 - %i" % self.nframes)
+
+ time = self.times[n]
+
+ # Make the primary image
+ imarr = self.frames[n].reshape(self.ydim, self.xdim)
+ outim = ehtim.image.Image(imarr, self.psize, self.ra, self.dec, self.pa,
+ polrep=self.polrep, pol_prim=self.pol_prim, time=time,
+ rf=self.rf, source=self.source, mjd=self.mjd, pulse=self.pulse)
+
+ # Copy over the rest of the polarizations
+ for pol in list(self._movdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polframes = self._movdict[pol]
+ if len(polframes):
+ polvec = polframes[n]
+ polarr = polvec.reshape(self.ydim, self.xdim).copy()
+ outim.add_pol_image(polarr, pol)
+
+ return outim
+
+ def im_list(self):
+ """Return a list of the movie frames
+
+ Args:
+
+ Returns:
+ (list): list of Image objects
+ """
+
+ return [self.get_frame(j) for j in range(self.nframes)]
+
+ def avg_frame(self):
+ """Coherently Average the movie frames into a single image.
+
+ Returns:
+ (Image) : averaged image of all frames
+ """
+
+ # Make the primary image
+ avg_imvec = np.mean(np.array(self.frames), axis=0)
+ avg_imarr = avg_imvec.reshape(self.ydim, self.xdim)
+ outim = ehtim.image.Image(avg_imarr, self.psize, self.ra, self.dec, self.pa,
+ polrep=self.polrep, pol_prim=self.pol_prim, time=self.start_hr,
+ rf=self.rf, source=self.source, mjd=self.mjd, pulse=self.pulse)
+
+ # Copy over the rest of the average polarizations
+ for pol in list(self._movdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polframes = self._movdict[pol]
+ if len(polframes):
+ avg_polvec = np.mean(np.array(polframes), axis=0)
+ avg_polarr = avg_polvec.reshape(self.ydim, self.xdim)
+ outim.add_pol_image(avg_polarr, pol)
+
+ return outim
+
+ def blur_circ(self, fwhm_x, fwhm_t, fwhm_x_pol=0):
+ """Apply a Gaussian filter to a list of images.
+
+ Args:
+ fwhm_x (float): circular beam size for spatial blurring in radians
+ fwhm_t (float): temporal blurring in frames
+ fwhm_x_pol (float): circular beam size for Stokes Q,U,V spatial blurring in radians
+ Returns:
+ (Image): output image list
+ """
+
+ # Unpack the frames
+ frames = self.im_list()
+
+ # Blur Stokes I
+ sigma_x = fwhm_x / self.psize / (2. * np.sqrt(2. * np.log(2.)))
+ sigma_t = fwhm_t / (2. * np.sqrt(2. * np.log(2.)))
+ sigma_x_pol = fwhm_x_pol / self.psize / (2. * np.sqrt(2. * np.log(2.)))
+
+ arr = np.array([im.imvec.reshape(self.ydim, self.xdim) for im in frames])
+ arr = filt.gaussian_filter(arr, (sigma_t, sigma_x, sigma_x))
+
+ # Make a new blurred movie
+ arglist, argdict = self.movie_args()
+ arglist[0] = arr
+ movie_blur = Movie(*arglist, **argdict)
+
+ # Process the remaining polarizations
+ for pol in list(self._movdict.keys()):
+ if pol == self.pol_prim:
+ continue
+ polframes = self._movdict[pol]
+
+ if len(polframes):
+ arr = np.array([imvec.reshape(self.ydim, self.xdim) for imvec in polframes])
+ arr = filt.gaussian_filter(arr, (sigma_t, sigma_x_pol, sigma_x_pol))
+ movie_blur.add_pol_movie(arr, pol)
+
+ return movie_blur
+
+ def observe_same_nonoise(self, obs, repeat=False, sgrscat=False,
+ ttype="nfft", fft_pad_factor=2,
+ zero_empty_pol=True, verbose=True):
+ """Observe the movie on the same baselines as an existing observation
+ without adding noise.
+
+ Args:
+ obs (Obsdata): existing observation with baselines where the FT will be sampled
+ repeat (bool): if True, repeat the movie to fill up the observation interval
+ sgrscat (bool): if True, the visibilites are blurred by the Sgr A* scattering kernel
+ ttype (str): if "fast", use FFT to produce visibilities. Else "direct" for DTFT
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+
+ zero_empty_pol (bool): if True, returns zero vec if the polarization doesn't exist.
+ Otherwise return None
+ verbose (bool): Boolean value controls output prints.
+
+ Returns:
+ (Obsdata): an observation object
+ """
+
+ # Check for agreement in coordinates and frequency
+ tolerance = 1e-8
+ if (np.abs(self.ra - obs.ra) > tolerance) or (np.abs(self.dec - obs.dec) > tolerance):
+ raise Exception("Movie coordinates are not the same as observtion coordinates!")
+ if (np.abs(self.rf - obs.rf)/obs.rf > tolerance):
+ raise Exception("Movie frequency is not the same as observation frequency!")
+
+ if ttype == 'direct' or ttype == 'fast' or ttype == 'nfft':
+ if verbose: print("Producing clean visibilities from movie with " + ttype + " FT . . . ")
+ else:
+ raise Exception("ttype=%s, options for ttype are 'direct', 'fast', 'nfft'" % ttype)
+
+ # Get data
+ obslist = obs.tlist()
+
+ obstimes = np.array([obsdata[0]['time'] for obsdata in obslist])
+
+ if (obstimes < self.start_hr).any():
+ if repeat:
+ print("Some observation times before movie start time %f" % self.start_hr)
+ print("Looping movie before start\n")
+ elif not(self.bounds_error):
+ print("Some observation times before movie start time %f" % self.start_hr)
+ print("bounds_error is False: using constant frame 0 before start_hr! \n")
+ else:
+ raise Exception("Some observation times before movie start time %f" % self.start_hr)
+ if (obstimes > self.stop_hr).any():
+ if repeat:
+ print("Some observation times after movie stop time %f" % self.stop_hr)
+ print("Looping movie after stop\n")
+ elif not(self.bounds_error):
+ print("Some observation times after movie stop time %f" % self.stop_hr)
+ print("bounds_error is False: using constant frame -1 after stop_hr! \n")
+ else:
+ raise Exception("Some observation times after movie stop time %f" % self.stop_hr)
+
+ # Observe nearest frame
+ obsdata_out = []
+
+ for i in range(len(obslist)):
+ obsdata = obslist[i]
+
+ # Frame number
+ time = obsdata[0]['time']
+
+ if self.bounds_error:
+ if (time < self.start_hr or time > self.stop_hr):
+ if repeat:
+ time = self.start_hr + np.mod(time - self.start_hr, self.duration)
+ else:
+ raise Exception("Obs time %f outside movie range %f--%f" %
+ (time, self.start_hr, self.stop_hr))
+
+ # Get the frame visibilities
+ uv = obsh.recarr_to_ndarr(obsdata[['u', 'v']], 'f8')
+
+ try:
+ im = self.get_image(time)
+ except ValueError:
+ raise Exception("Interpolation error for time %f: movie range %f--%f" %
+ (time, self.start_hr, self.stop_hr))
+
+ data = simobs.sample_vis(im, uv, sgrscat=sgrscat, polrep_obs=obs.polrep,
+ ttype=ttype, fft_pad_factor=fft_pad_factor,
+ zero_empty_pol=zero_empty_pol, verbose=verbose)
+ verbose = False # only print for one frame
+
+ # Put visibilities into the obsdata
+ if obs.polrep == 'stokes':
+ obsdata['vis'] = data[0]
+ if not(data[1] is None):
+ obsdata['qvis'] = data[1]
+ obsdata['uvis'] = data[2]
+ obsdata['vvis'] = data[3]
+
+ elif obs.polrep == 'circ':
+ obsdata['rrvis'] = data[0]
+ if not(data[1] is None):
+ obsdata['llvis'] = data[1]
+ if not(data[2] is None):
+ obsdata['rlvis'] = data[2]
+ obsdata['lrvis'] = data[3]
+
+ if len(obsdata_out):
+ obsdata_out = np.hstack((obsdata_out, obsdata))
+ else:
+ obsdata_out = obsdata
+
+ obsdata_out = np.array(obsdata_out, dtype=obs.poltype)
+ obs_no_noise = ehtim.obsdata.Obsdata(self.ra, self.dec, self.rf, obs.bw,
+ obsdata_out, obs.tarr,
+ source=self.source, mjd=np.floor(obs.mjd),
+ polrep=obs.polrep,
+ ampcal=True, phasecal=True, opacitycal=True,
+ dcal=True, frcal=True,
+ timetype=obs.timetype, scantable=obs.scans)
+
+ return obs_no_noise
+
+ def observe_same(self, obs_in, repeat=False,
+ ttype='nfft', fft_pad_factor=2,
+ sgrscat=False, add_th_noise=True,
+ jones=False, inv_jones=False,
+ opacitycal=True, ampcal=True, phasecal=True,
+ frcal=True, dcal=True, rlgaincal=True,
+ stabilize_scan_phase=False, stabilize_scan_amp=False,
+ neggains=False,
+ taup=ehc.GAINPDEF,
+ gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF,
+ dterm_offset=ehc.DTERMPDEF,
+ rlratio_std=0.,rlphase_std=0.,
+ caltable_path=None, seed=False, sigmat=None, verbose=True):
+ """Observe the image on the same baselines as an existing observation object and add noise.
+
+ Args:
+ obs_in (Obsdata): existing observation with baselines where the FT will be sampled
+ repeat (bool): if True, repeat the movie to fill up the observation interval
+ ttype (str): "fast" or "nfft" or "direct"
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel
+ add_th_noise (bool): if True, baseline-dependent thermal noise is added
+
+ jones (bool): if True, uses Jones matrix to apply mis-calibration effects
+ inv_jones (bool): if True, applies estimated inverse Jones matrix
+ (not including random terms) to a priori calibrate data
+ opacitycal (bool): if False, time-dependent gaussian errors are added to opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to station gains
+ phasecal (bool): if False, time-dependent station-based random phases are added
+ frcal (bool): if False, feed rotation angle terms are added to Jones matrices.
+ dcal (bool): if False, time-dependent gaussian errors added to D-terms.
+ rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ neggains (bool): if True, force the applied gains to be <1
+
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gainp (float): the fractional std. dev. of the random error on the gains
+ or a dict giving one std. dev. per site
+
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ dterm_offset (float): the base std. dev. of random additive error at all sites,
+ or a dict giving one std. dev. per site
+
+ rlratio_std (float): the fractional std. dev. of the R/L gain offset
+ or a dict giving one std. dev. per site
+ rlphase_std (float): std. dev. of R/L phase offset,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+
+ caltable_path (string): If not None, path and prefix for saving the applied caltable
+ seed (int): seeds the random component of the noise terms. DO NOT set to 0!
+ sigmat (float): temporal std for a Gaussian Process used to generate gain noise.
+ if sigmat=None then an iid gain noise is applied.
+ verbose (bool): print updates and warnings
+
+ Returns:
+ (Obsdata): an observation object
+
+ """
+
+ if seed:
+ np.random.seed(seed=seed)
+
+ # print("Producing clean visibilities from movie . . . ")
+ obs = self.observe_same_nonoise(obs_in, repeat=repeat, sgrscat=sgrscat,
+ ttype=ttype, fft_pad_factor=fft_pad_factor,
+ zero_empty_pol=True, verbose=verbose)
+
+ # Jones Matrix Corruption & Calibration
+ if jones:
+ obsdata = simobs.add_jones_and_noise(obs, add_th_noise=add_th_noise,
+ opacitycal=opacitycal, ampcal=ampcal,
+ phasecal=phasecal, frcal=frcal, dcal=dcal,
+ rlgaincal=rlgaincal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ neggains=neggains,
+ taup=taup,
+ gain_offset=gain_offset, gainp=gainp,
+ dterm_offset=dterm_offset,
+ rlratio_std=rlratio_std, rlphase_std=rlphase_std,
+ caltable_path=caltable_path,
+ seed=seed, sigmat=sigmat, verbose=verbose)
+
+ obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
+ ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal,
+ dcal=dcal, frcal=frcal,
+ timetype=obs.timetype, scantable=obs.scans)
+
+ if inv_jones:
+ obsdata = simobs.apply_jones_inverse(obs,
+ opacitycal=opacitycal, dcal=dcal, frcal=frcal,
+ verbose=verbose)
+
+ obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
+ ampcal=ampcal, phasecal=phasecal,
+ opacitycal=True, dcal=True, frcal=True,
+ timetype=obs.timetype, scantable=obs.scans)
+
+ # No Jones Matrices, Add noise the old way
+ # TODO There is an asymmetry here - in the old way, we don't offer the ability to
+ # *not* unscale estimated noise.
+ else:
+ if caltable_path:
+ print('WARNING: the caltable is only saved if you apply noise with a Jones Matrix')
+
+ obsdata = simobs.add_noise(obs, add_th_noise=add_th_noise,
+ opacitycal=opacitycal, ampcal=ampcal, phasecal=phasecal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ neggains=neggains,
+ taup=taup, gain_offset=gain_offset, gainp=gainp,
+ caltable_path=caltable_path, seed=seed, sigmat=sigmat,
+ verbose=verbose)
+
+ obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr,
+ source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep,
+ ampcal=ampcal, phasecal=phasecal,
+ opacitycal=True, dcal=True, frcal=True,
+ timetype=obs.timetype, scantable=obs.scans)
+
+ return obs
+
+ def observe(self, array, tint, tadv, tstart, tstop, bw, repeat=False,
+ mjd=None, timetype='UTC', polrep_obs=None,
+ elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH,
+ no_elevcut_space=False,
+ ttype='nfft', fft_pad_factor=2, fix_theta_GMST=False,
+ sgrscat=False, add_th_noise=True,
+ jones=False, inv_jones=False,
+ opacitycal=True, ampcal=True, phasecal=True,
+ frcal=True, dcal=True, rlgaincal=True,
+ stabilize_scan_phase=False, stabilize_scan_amp=False,
+ neggains=False,
+ tau=ehc.TAUDEF, taup=ehc.GAINPDEF,
+ gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF,
+ dterm_offset=ehc.DTERMPDEF,
+ rlratio_std=0.,rlphase_std=0.,
+ caltable_path=None, seed=False, sigmat=None, verbose=True):
+ """Generate baselines from an array object and observe the movie.
+
+ Args:
+ array (Array): an array object containing sites with which to generate baselines
+ tint (float): the scan integration time in seconds
+ tadv (float): the uniform cadence between scans in seconds
+ tstart (float): the start time of the observation in hours
+ tstop (float): the end time of the observation in hours
+ bw (float): the observing bandwidth in Hz
+ repeat (bool): if True, repeat the movie to fill up the observation interval
+
+ mjd (int): the mjd of the observation, if set as different from the image mjd
+ timetype (str): how to interpret tstart and tstop; either 'GMST' or 'UTC'
+ polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation
+ elevmin (float): station minimum elevation in degrees
+ elevmax (float): station maximum elevation in degrees
+ no_elevcut_space (bool): if True, do not apply elevation cut to orbiters
+
+ ttype (str): "fast", "nfft" or "dtft"
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in the FFT
+ fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v
+
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel
+ add_th_noise (bool): if True, baseline-dependent thermal noise is added
+
+ jones (bool): if True, uses Jones matrix to apply mis-calibration effects
+ otherwise uses old formalism without D-terms
+ inv_jones (bool): if True, applies estimated inverse Jones matrix
+ (not including random terms) to calibrate data
+
+ opacitycal (bool): if False, time-dependent gaussian errors are added to opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to station gains
+ phasecal (bool): if False, time-dependent station-based random phases are added
+ frcal (bool): if False, feed rotation angle terms are added to Jones matrix.
+ dcal (bool): if False, time-dependent gaussian errors added to Jones matrix D-terms.
+ rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ neggains (bool): if True, force the applied gains to be <1
+
+ tau (float): the base opacity at all sites, or a dict giving one opacity per site
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gainp (float): the fractional std. dev. of the random error on the gains
+ or a dict giving one std. dev. per site
+
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ dterm_offset (float): the base std. dev. of random additive error at all sites,
+ or a dict giving one std. dev. per site
+
+ rlratio_std (float): the fractional std. dev. of the R/L gain offset
+ or a dict giving one std. dev. per site
+ rlphase_std (float): std. dev. of R/L phase offset,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+
+
+ caltable_path (string): If not None, path and prefix for saving the applied caltable
+ seed (int): seeds the random component of the noise terms. DO NOT set to 0!
+ sigmat (float): temporal std for a Gaussian Process used to generate gain noise.
+ if sigmat=None then an iid gain noise is applied.
+ verbose (bool): print updates and warnings
+
+ Returns:
+ (Obsdata): an observation object
+
+ """
+
+ # Generate empty observation
+ if verbose: print("Generating empty observation file . . . ")
+ if mjd is None:
+ mjd = self.mjd
+ if polrep_obs is None:
+ polrep_obs = self.polrep
+
+ obs = array.obsdata(self.ra, self.dec, self.rf, bw, tint, tadv, tstart, tstop,
+ mjd=mjd, polrep=polrep_obs,
+ tau=tau, timetype=timetype,
+ elevmin=elevmin, elevmax=elevmax,
+ no_elevcut_space=no_elevcut_space,
+ fix_theta_GMST=fix_theta_GMST)
+
+ # Observe on the same baselines as the empty observation and add noise
+ obs = self.observe_same(obs, repeat=repeat,
+ ttype=ttype, fft_pad_factor=fft_pad_factor,
+ sgrscat=sgrscat,
+ add_th_noise=add_th_noise,
+ jones=jones, inv_jones=inv_jones,
+ opacitycal=opacitycal, ampcal=ampcal,
+ phasecal=phasecal, dcal=dcal,
+ frcal=frcal, rlgaincal=rlgaincal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ neggains=neggains,
+ taup=taup,
+ gain_offset=gain_offset, gainp=gainp,
+ dterm_offset=dterm_offset,
+ rlratio_std=rlratio_std,rlphase_std=rlphase_std,
+ caltable_path=caltable_path, seed=seed, sigmat=sigmat,
+ verbose=verbose)
+
+ return obs
+
+ def observe_vex(self, vex, source, synchronize_start=True, t_int=0.0,
+ polrep_obs=None, ttype='nfft', fft_pad_factor=2,
+ fix_theta_GMST=False,
+ sgrscat=False, add_th_noise=True,
+ jones=False, inv_jones=False,
+ opacitycal=True, ampcal=True, phasecal=True,
+ frcal=True, dcal=True, rlgaincal=True,
+ stabilize_scan_phase=False, stabilize_scan_amp=False,
+ neggains=False,
+ tau=ehc.TAUDEF, taup=ehc.GAINPDEF,
+ gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF,
+ dterm_offset=ehc.DTERMPDEF,
+ caltable_path=None, seed=False, sigmat=None, verbose=True):
+ """Generate baselines from a vex file and observe the movie.
+
+ Args:
+ vex (Vex): an vex object containing sites and scan information
+ source (str): the source to observe
+ synchronize_start (bool): if True, the start of the movie is defined
+ as the start of the observations
+ t_int (float): if not zero, overrides the vex scan lengths
+
+ polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation
+ ttype (str): "fast" or "nfft" or "dtft"
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+ fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v
+
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel
+ add_th_noise (bool): if True, baseline-dependent thermal noise is added
+
+ jones (bool): if True, uses Jones matrix to apply mis-calibration effects
+ otherwise uses old formalism without D-terms
+ inv_jones (bool): if True, applies estimated inverse Jones matrix
+ (not including random terms) to calibrate data
+ opacitycal (bool): if False, time-dependent gaussian errors are added to opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to station gains
+ phasecal (bool): if False, time-dependent station-based random phases are added
+ frcal (bool): if False, feed rotation angle terms are added to Jones matrix.
+ dcal (bool): if False, time-dependent gaussian errors added to Jones matrix D-terms.
+ rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ neggains (bool): if True, force the applied gains to be <1
+
+ tau (float): the base opacity at all sites,
+ or a dict giving one opacity per site
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ gainp (float): the fractional std. dev. of the random error on the gains
+ dterm_offset (float): the base dterm offset at all sites,
+ or a dict giving one dterm offset per site
+
+ caltable_path (string): If not None, path and prefix for saving the applied caltable
+ seed (int): seeds the random component of the noise terms. DO NOT set to 0!
+ sigmat (float): temporal std for a Gaussian Process used to generate gain noise.
+ if sigmat=None then an iid gain noise is applied.
+ verbose (bool): print updates and warnings
+
+ Returns:
+ (Obsdata): an observation object
+
+ """
+
+ if polrep_obs is None:
+ polrep_obs = self.polrep
+
+ obs_List = []
+ movie = self.copy()
+
+ if synchronize_start:
+ movie.mjd = vex.sched[0]['mjd_floor']
+ movie.start_hr = vex.sched[0]['start_hr']
+
+ movie_start = float(movie.mjd) + movie.start_hr/24.0
+ movie_end = float(movie.mjd) + movie.stop_hr/24.0
+
+ print("Movie MJD Range: ", movie_start, movie_end)
+
+ snapshot = 1.0
+ if t_int > 0.0:
+ snapshot = 0.0
+
+ for i_scan in range(len(vex.sched)):
+ if vex.sched[i_scan]['source'] != source:
+ continue
+ scankeys = list(vex.sched[i_scan]['scan'].keys())
+ subarray = vex.array.make_subarray([vex.sched[i_scan]['scan'][key]['site']
+ for key in scankeys])
+
+ if snapshot == 1.0:
+ t_int = np.max(np.array([vex.sched[i_scan]['scan'][site]
+ ['scan_sec'] for site in scankeys]))
+ print(t_int)
+
+ vex_scan_start_mjd = float(vex.sched[i_scan]['mjd_floor'])
+ vex_scan_start_mjd += vex.sched[i_scan]['start_hr']/24.0
+
+ vex_scan_length_mjd = vex.sched[i_scan]['scan'][0]['scan_sec']/3600.0/24.0
+ vex_scan_stop_mjd = vex_scan_start_mjd + vex_scan_length_mjd
+
+ print("Scan MJD Range: ", vex_scan_start_mjd, vex_scan_stop_mjd)
+
+ if vex_scan_start_mjd < movie_start or vex_scan_stop_mjd > movie_end:
+ continue
+
+ t_start = vex.sched[i_scan]['start_hr']
+ t_stop = t_start + vex.sched[i_scan]['scan'][0]['scan_sec']/3600.0 - ehc.EP
+
+ mjd = vex.sched[i_scan]['mjd_floor']
+ obs = subarray.obsdata(movie.ra, movie.dec, movie.rf, vex.bw_hz,
+ t_int, t_int, t_start, t_stop,
+ mjd=mjd, polrep=polrep_obs, tau=tau,
+ elevmin=.01, elevmax=89.99, timetype='UTC',
+ fix_theta_GMST=fix_theta_GMST)
+ obs_List.append(obs)
+
+ if len(obs_List) == 0:
+ raise Exception("Movie has no overlap with the vex file")
+
+ obs = ehtim.obsdata.merge_obs(obs_List)
+
+ obsout = movie.observe_same(obs, repeat=False,
+ ttype=ttype, fft_pad_factor=fft_pad_factor,
+ sgrscat=sgrscat, add_th_noise=add_th_noise,
+ jones=jones, inv_jones=inv_jones,
+ opacitycal=opacitycal, ampcal=ampcal, phasecal=phasecal,
+ frcal=frcal, dcal=dcal, rlgaincal=rlgaincal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp,
+ neggains=neggains,
+ taup=taup,
+ gain_offset=gain_offset, gainp=gainp,
+ dterm_offset=dterm_offset,
+ caltable_path=caltable_path, seed=seed,sigmat=sigmat,
+ verbose=verbose)
+
+ return obsout
+
+ def save_txt(self, fname):
+ """Save the Movie data to individual text files with filenames basename + 00001, etc.
+
+ Args:
+ fname (str): basename of output files
+
+ Returns:
+ """
+
+ ehtim.io.save.save_mov_txt(self, fname)
+
+ return
+
+ def save_fits(self, fname):
+ """Save the Movie data to individual fits files with filenames basename + 00001, etc.
+
+ Args:
+ fname (str): basename of output files
+
+ Returns:
+ """
+
+ ehtim.io.save.save_mov_fits(self, fname)
+ return
+
+ def save_hdf5(self, fname):
+ """Save the Movie data to a single hdf5 file.
+
+ Args:
+ fname (str): output file name
+
+ Returns:
+ """
+
+ ehtim.io.save.save_mov_hdf5(self, fname)
+ return
+
+ def export_mp4(self, out='movie.mp4', fps=10, dpi=120,
+ interp='gaussian', scale='lin', dynamic_range=1000.0, cfun='afmhot',
+ nvec=20, pcut=0.01, plotp=False, gamma=0.5, frame_pad_factor=1,
+ label_time=False, verbose=False):
+ """Save the Movie to an mp4 file
+ """
+
+ import matplotlib
+ matplotlib.use('agg')
+ import matplotlib.pyplot as plt
+ import matplotlib.animation as animation
+
+ if self.polrep != 'stokes':
+ raise Exception("export_mp4 requires self.polrep=='stokes' -- try self.switch_polrep()")
+
+ if (interp in ['gauss', 'gaussian', 'Gaussian', 'Gauss']):
+ interp = 'gaussian'
+ else:
+ interp = 'nearest'
+
+ if scale == 'lin':
+ unit = 'Jy/pixel'
+ elif scale == 'log':
+ unit = 'log(Jy/pixel)'
+ elif scale == 'gamma':
+ unit = '(Jy/pixel)^gamma'
+ else:
+ raise Exception("Scale not recognized!")
+
+ fig = plt.figure()
+ maxi = np.max(np.concatenate([im for im in self.frames]))
+
+ if len(self.qframes) and plotp:
+ thin = self.xdim//nvec
+ mask = (self.frames[0]).reshape(self.ydim, self.xdim) > pcut * np.max(self.frames[0])
+ mask2 = mask[::thin, ::thin]
+ x = (np.array([[i for i in range(self.xdim)]
+ for j in range(self.ydim)])[::thin, ::thin])[mask2]
+ y = (np.array([[j for i in range(self.xdim)]
+ for j in range(self.ydim)])[::thin, ::thin])[mask2]
+ a = (-np.sin(np.angle(self.qframes[0]+1j*self.uframes[0]) /
+ 2).reshape(self.ydim, self.xdim)[::thin, ::thin])[mask2]
+ b = (np.cos(np.angle(self.qframes[0]+1j*self.uframes[0]) /
+ 2).reshape(self.ydim, self.xdim)[::thin, ::thin])[mask2]
+
+ m = (np.abs(self.qframes[0] + 1j*self.uframes[0]) /
+ self.frames[0]).reshape(self.ydim, self.xdim)
+ m[np.logical_not(mask)] = 0
+
+ Q1 = plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.01*self.xdim, units='x', pivot='mid', color='k',
+ angles='uv', scale=1.0/thin)
+ Q2 = plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.005*self.xdim, units='x', pivot='mid', color='w',
+ angles='uv', scale=1.1/thin)
+
+ def im_data(n):
+
+ n_data = int((n-n % frame_pad_factor)/frame_pad_factor)
+
+ if len(self.qframes) and plotp:
+ a = (-np.sin(np.angle(self.qframes[n_data]+1j*self.uframes[n_data]
+ )/2).reshape(self.ydim, self.xdim)[::thin, ::thin])[mask2]
+ b = (np.cos(np.angle(self.qframes[n_data]+1j*self.uframes[n_data]
+ )/2).reshape(self.ydim, self.xdim)[::thin, ::thin])[mask2]
+
+ Q1.set_UVC(a, b)
+ Q2.set_UVC(a, b)
+
+ if scale == 'lin':
+ return self.frames[n_data].reshape((self.ydim, self.xdim))
+ elif scale == 'log':
+ return np.log(self.frames[n_data].reshape(
+ (self.ydim, self.xdim)) + maxi/dynamic_range)
+ elif scale == 'gamma':
+ return (self.frames[n_data]**(gamma)).reshape((self.ydim, self.xdim))
+
+ plt_im = plt.imshow(im_data(0), cmap=plt.get_cmap(cfun), interpolation=interp)
+ plt.colorbar(plt_im, fraction=0.046, pad=0.04, label=unit)
+
+ if scale == 'lin':
+
+ plt_im.set_clim([0, maxi])
+ else:
+ plt_im.set_clim([np.log(maxi/dynamic_range), np.log(maxi)])
+
+ xticks = obsh.ticks(self.xdim, self.psize/ehc.RADPERAS/1e-6)
+ yticks = obsh.ticks(self.ydim, self.psize/ehc.RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel(r'Relative RA ($\mu$as)')
+ plt.ylabel(r'Relative Dec ($\mu$as)')
+
+ fig.set_size_inches([5, 5])
+ plt.tight_layout()
+
+ def update_img(n):
+ if verbose:
+ print("processing frame {0} of {1}".format(n, len(self.frames)*frame_pad_factor))
+ plt_im.set_data(im_data(n))
+
+ if label_time:
+ time = self.times[n]
+ time_str = ("%02d:%02d:%02d" % (int(time), (time*60) % 60, (time*3600) % 60))
+ fig.suptitle(time_str)
+
+ return plt_im
+
+ ani = animation.FuncAnimation(fig, update_img, len(
+ self.frames)*frame_pad_factor, interval=1e3/fps)
+ writer = animation.writers['ffmpeg'](fps=fps, bitrate=1e6)
+ ani.save(out, writer=writer, dpi=dpi)
+
+##################################################################################################
+# Movie creation functions
+##################################################################################################
+
+
+def export_multipanel_mp4(input_list, out='movie.mp4', start_hr=None, stop_hr=None, nframes=100,
+ fov=None, npix=None,
+ nrows=1, fps=10, dpi=120, verbose=False, titles=None,
+ panel_size=4.0, common_scale=False, scale='linear', label_type='scale',
+ has_cbar=False, **kwargs):
+ """Export a movie comparing multiple movies in a grid.
+
+ Args:
+ input_list (list): The list of input Movies or Images
+ out (string): The output filename
+ start_hr (float): The start time in hours. If None, defaults to first start time
+ end_hr (float): The end time in hours. If None, defaults to last start time
+ nframes (int): The number of frames in the output movie
+ fov (float): If specified, use this field of view for all panels
+ npix (int): If specified, use this linear pixel dimension for all panels
+ nrows (int): Number of rows in movie
+ fps (int): Frames per second
+ titles (list): List of panel titles for input_list
+ panel_size (float): Size of individual panels (inches)
+
+ """
+ import matplotlib
+ matplotlib.use('agg')
+ import matplotlib.pyplot as plt
+ import matplotlib.animation as animation
+
+ if start_hr is None:
+ try:
+ start_hr = np.min([x.start_hr for x in input_list if hasattr(x, 'start_hr')])
+ except ValueError:
+ raise Exception("no movies in input_list!")
+
+ if stop_hr is None:
+ try:
+ stop_hr = np.max([x.stop_hr for x in input_list if hasattr(x, 'stop_hr')])
+ except ValueError:
+ raise Exception("no movies in input_list!")
+
+ print("%s will have %i frames in the range %f-%f hr" % (out, nframes, start_hr, stop_hr))
+
+ ncols = int(np.ceil(len(input_list)/nrows))
+ suptitle_space = 0.6 # inches
+ w = panel_size*ncols
+ h = panel_size*nrows + suptitle_space
+ tgap = suptitle_space / h
+ bgap = .1
+ rgap = .1
+ lgap = .1
+ subw = (1-lgap-rgap)/ncols
+ subh = (1-tgap-bgap)/nrows
+ print("Rows: " + str(nrows))
+ print("Cols: " + str(ncols))
+
+ fig = plt.figure(figsize=(w, h))
+ ax_all = [[] for j in range(nrows)]
+ for y in range(nrows):
+ for x in range(ncols):
+ ax = fig.add_axes([lgap+subw*x, bgap+subh*(nrows-y-1), subw, subh])
+ ax_all[y].append(ax)
+
+ times = np.linspace(start_hr, stop_hr, nframes)
+ hr_step = times[1]-times[0]
+ mjd_step = hr_step/24.
+
+ im_List_Set = [[x.get_image(time) if hasattr(x, 'get_image') else x.copy() for time in times]
+ for x in input_list]
+
+ if fov and npix:
+ im_List_Set = [[x.regrid_image(fov, npix) for x in y] for y in im_List_Set]
+ else:
+ print('not rescaling images to common fov and npix!')
+
+ maxi = [np.max([im.imvec for im in im_List_Set[j]]) for j in range(len(im_List_Set))]
+ if common_scale:
+ maxi = np.max(maxi) + 0.0*maxi
+
+ i = 0
+ for y in range(nrows):
+ for x in range(ncols):
+ if i >= len(im_List_Set):
+ ax_all[y][x].set_visible(False)
+ else:
+ kwargs.get('ttype', 'nfft')
+ if (y == nrows-1 and x == 0) or fov is None:
+ label_type_cur = label_type
+ else:
+ label_type_cur = 'none'
+
+ im_List_Set[i][0].display(axis=ax_all[y][x], scale=scale,
+ label_type=label_type_cur, has_cbar=has_cbar, **kwargs)
+ if y == nrows-1 and x == 0:
+ plt.xlabel(r'Relative RA ($\mu$as)')
+ plt.ylabel(r'Relative Dec ($\mu$as)')
+ else:
+ plt.xlabel('')
+ plt.ylabel('')
+ if not titles:
+ ax_all[y][x].set_title('')
+ else:
+ ax_all[y][x].set_title(titles[i])
+ i = i+1
+
+ def im_data(i, n):
+ if scale == 'linear':
+ return im_List_Set[i][n].imvec.reshape((im_List_Set[i][n].ydim, im_List_Set[i][n].xdim))
+ else:
+ return np.log(im_List_Set[i][n].imvec.reshape(
+ (im_List_Set[i][n].ydim, im_List_Set[i][n].xdim)) + 1e-20)
+
+ def update_img(n):
+ if verbose:
+ print("processing frame {0} of {1}".format(n, len(im_List_Set[0])))
+ i = 0
+ for y in range(nrows):
+ for x in range(ncols):
+ ax_all[y][x].images[0].set_data(im_data(i, n))
+ i = i+1
+ if i >= len(im_List_Set):
+ break
+
+ if mjd_step > 0.1:
+ # , verticalalignment=verticalalignment)
+ fig.suptitle('MJD: ' + str(im_List_Set[0][n].mjd))
+ else:
+ time = im_List_Set[0][n].time
+ time_str = ("%d:%02d.%02d" % (int(time), (time*60) % 60, (time*3600) % 60))
+ fig.suptitle(time_str)
+
+ return
+
+ ani = animation.FuncAnimation(fig, update_img, len(im_List_Set[0]), interval=1e3/fps)
+ writer = animation.writers['ffmpeg'](fps=fps, bitrate=1e6)
+ ani.save(out, writer=writer, dpi=dpi)
+
+
+def merge_im_list(imlist, framedur=-1, interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR):
+ """Merge a list of image objects into a movie object.
+
+ Args:
+ imlist (list): list of Image objects
+ framedur (float): duration of a movie frame in seconds
+ use to override times in the individual movies
+ interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword
+ bounds_error (bool): if False, return nearest frame outside interval [start_hr, stop_hr]
+
+ Returns:
+ (Movie): a Movie object assembled from the images
+ """
+ framelist = []
+ nframes = len(imlist)
+
+ print("\nMerging %i frames from MJD %i %.2f hr to MJD %i %.2f hr" % (
+ nframes, imlist[0].mjd, imlist[0].time, imlist[-1].mjd, imlist[-1].time))
+
+ for i in range(nframes):
+ im = imlist[i]
+ if i == 0:
+ polrep0 = im.polrep
+ pol_prim0 = im.pol_prim
+ movdict = {key: [] for key in list(im._imdict.keys())}
+ psize0 = im.psize
+ xdim0 = im.xdim
+ ydim0 = im.ydim
+ ra0 = im.ra
+ dec0 = im.dec
+ rf0 = im.rf
+ src0 = im.source
+ mjd0 = im.mjd
+ hour0 = im.time
+ pulse = im.pulse
+ times = [hour0]
+ else:
+ if (im.polrep != polrep0):
+ raise Exception("polrep of image %i != polrep of image 0!" % i)
+ if (im.psize != psize0):
+ raise Exception("psize of image %i != psize of image 0!" % i)
+ if (im.xdim != xdim0):
+ raise Exception("xdim of image %i != xdim of image 0!" % i)
+ if (im.ydim != ydim0):
+ raise Exception("ydim of image %i != ydim of image 0!" % i)
+ if (im.ra != ra0):
+ raise Exception("RA of image %i != RA of image 0!" % i)
+ if (im.dec != dec0):
+ raise Exception("DEC of image %i != DEC of image 0!" % i)
+ if (im.rf != rf0):
+ raise Exception("rf of image %i != rf of image 0!" % i)
+ if (im.source != src0):
+ raise Exception("source of image %i != src of image 0!" % i)
+ if (im.mjd < mjd0):
+ raise Exception("mjd of image %i < mjd of image 0!" % i)
+
+ hour = im.time
+ if im.mjd > mjd0:
+ hour += 24*(im.mjd - mjd0)
+ times.append(hour)
+
+ imarr = im.imvec.reshape(ydim0, xdim0)
+ framelist.append(imarr)
+
+ # Look for other polarizations
+ for pol in list(movdict.keys()):
+ polvec = im._imdict[pol]
+ if len(polvec):
+ polarr = polvec.reshape(ydim0, xdim0)
+ movdict[pol].append(polarr)
+ else:
+ if movdict[pol]:
+ raise Exception("all frames in merge_im_list must have the same pol layout: " +
+ "error in frame %i" % i)
+
+ # assume equispaced with a given framedur instead of reading the individual image times
+ if framedur != -1:
+ framedur_hr = framedur/3600.
+ tstart = hour0
+ tstop = hour0 + framedur_hr*nframes
+ times = np.linspace(tstart, tstop, nframes)
+
+ elif len(set(times)) < len(framelist):
+ raise Exception("image times have duplicates!")
+
+ # Make new movie with primary polarization
+ newmov = Movie(framelist, times,
+ psize0, ra0, dec0, interp=interp, bounds_error=bounds_error,
+ polrep=polrep0, pol_prim=pol_prim0,
+ rf=rf0, source=src0, mjd=mjd0, pulse=pulse)
+
+ # Copy over all polarization movies
+ for pol in list(movdict.keys()):
+ if pol == newmov.pol_prim:
+ continue
+ polframes = np.array(movdict[pol])
+ if len(polframes):
+ polframes = polframes.reshape((newmov.nframes, newmov.ydim, newmov.xdim))
+ newmov.add_pol_movie(polframes, pol)
+
+ return newmov
+
+
+def load_hdf5(file_name,
+ pulse=ehc.PULSE_DEFAULT, interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR):
+ """Read in a movie from an hdf5 file and create a Movie object.
+
+ Args:
+ file_name (str): The name of the hdf5 file.
+ pulse (function): The function convolved with the pixel values for continuous image
+ interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword
+ bounds_error (bool): if False, return nearest frame outside interval [start_hr, stop_hr]
+
+ Returns:
+ Movie: a Movie object
+ """
+
+ return ehtim.io.load.load_movie_hdf5(file_name, pulse=pulse, interp=interp,
+ bounds_error=bounds_error)
+
+
+def load_txt(basename, nframes,
+ framedur=-1, pulse=ehc.PULSE_DEFAULT,
+ polrep='stokes', pol_prim=None, zero_pol=True,
+ interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR):
+ """Read in a movie from text files and create a Movie object.
+
+ Args:
+ basename (str): The base name of individual movie frames.
+ Files should have names basename + 00001, etc.
+ nframes (int): The total number of frames
+ framedur (float): The frame duration in seconds
+ if famedur==-1, frame duration taken from file headers
+ pulse (function): The function convolved with the pixel values for continuous image
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+ zero_pol (bool): If True, loads any missing polarizations as zeros
+ interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword
+ bounds_error (bool): if False, return nearest frame outside interval [start_hr, stop_hr]
+
+ Returns:
+ Movie: a Movie object
+ """
+
+ return ehtim.io.load.load_movie_txt(basename, nframes, framedur=framedur, pulse=pulse,
+ polrep=polrep, pol_prim=pol_prim, zero_pol=zero_pol,
+ interp=interp, bounds_error=bounds_error)
+
+
+def load_fits(basename, nframes,
+ framedur=-1, pulse=ehc.PULSE_DEFAULT,
+ polrep='stokes', pol_prim=None, zero_pol=True,
+ interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR):
+ """Read in a movie from fits files and create a Movie object.
+
+ Args:
+ basename (str): The base name of individual movie frames.
+ Files should have names basename + 00001, etc.
+ nframes (int): The total number of frames
+ framedur (float): The frame duration in seconds.
+ if famedur==-1, frame duration taken from file headers
+ pulse (function): The function convolved with the pixel values for continuous image
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular
+ zero_pol (bool): If True, loads any missing polarizations as zeros
+ interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword
+ bounds_error (bool): if False, return nearest frame outside interval [start_hr, stop_hr]
+
+ Returns:
+ Movie: a Movie object
+ """
+
+ return ehtim.io.load.load_movie_fits(basename, nframes, framedur=framedur, pulse=pulse,
+ polrep=polrep, pol_prim=pol_prim, zero_pol=zero_pol,
+ interp=interp, bounds_error=bounds_error)
diff --git a/obsdata.py b/obsdata.py
new file mode 100644
index 00000000..532b9217
--- /dev/null
+++ b/obsdata.py
@@ -0,0 +1,5006 @@
+# obsdata.py
+# a interferometric observation class
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import string
+import copy
+import numpy as np
+import numpy.lib.recfunctions as rec
+import matplotlib.pyplot as plt
+import scipy.optimize as opt
+import scipy.spatial as spatial
+import itertools as it
+import sys
+
+try:
+ import pandas as pd
+except ImportError:
+ print("Warning: pandas not installed!")
+ print("Please install pandas to use statistics package!")
+
+
+import ehtim.image
+import ehtim.io.save
+import ehtim.io.load
+import ehtim.const_def as ehc
+import ehtim.observing.obs_helpers as obsh
+import ehtim.statistics.dataframes as ehdf
+
+import warnings
+warnings.filterwarnings("ignore",
+ message="Casting complex values to real discards the imaginary part")
+
+RAPOS = 0
+DECPOS = 1
+RFPOS = 2
+BWPOS = 3
+DATPOS = 4
+TARRPOS = 5
+
+##################################################################################################
+# Obsdata object
+##################################################################################################
+
+
+class Obsdata(object):
+
+ """A polarimetric VLBI observation of visibility amplitudes and phases (in Jy).
+
+ Attributes:
+ source (str): The source name
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ mjd (int): The integer MJD of the observation
+ tstart (float): The start time of the observation in hours
+ tstop (float): The end time of the observation in hours
+ rf (float): The observation frequency in Hz
+ bw (float): The observation bandwidth in Hz
+ timetype (str): How to interpret tstart and tstop; either 'GMST' or 'UTC'
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+
+ tarr (numpy.recarray): The array of telescope data with datatype DTARR
+ tkey (dict): A dictionary of rows in the tarr for each site name
+ data (numpy.recarray): the basic data with datatype DTPOL_STOKES or DTPOL_CIRC
+ scantable (numpy.recarray): The array of scan information
+
+ ampcal (bool): True if amplitudes calibrated
+ phasecal (bool): True if phases calibrated
+ opacitycal (bool): True if time-dependent opacities correctly accounted for in sigmas
+ frcal (bool): True if feed rotation calibrated out of visibilities
+ dcal (bool): True if D terms calibrated out of visibilities
+
+ amp (numpy.recarray): An array of (averaged) visibility amplitudes
+ bispec (numpy.recarray): An array of (averaged) bispectra
+ cphase (numpy.recarray): An array of (averaged) closure phases
+ cphase_diag (numpy.recarray): An array of (averaged) diagonalized closure phases
+ camp (numpy.recarray): An array of (averaged) closure amplitudes
+ logcamp (numpy.recarray): An array of (averaged) log closure amplitudes
+ logcamp_diag (numpy.recarray): An array of (averaged) diagonalized log closure amps
+ """
+
+ def __init__(self, ra, dec, rf, bw, datatable, tarr, scantable=None,
+ polrep='stokes', source=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, timetype='UTC',
+ ampcal=True, phasecal=True, opacitycal=True, dcal=True, frcal=True,
+ trial_speedups=False, reorder=True):
+ """A polarimetric VLBI observation of visibility amplitudes and phases (in Jy).
+
+ Args:
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ rf (float): The observation frequency in Hz
+ bw (float): The observation bandwidth in Hz
+
+ datatable (numpy.recarray): the basic data with datatype DTPOL_STOKES or DTPOL_CIRC
+ tarr (numpy.recarray): The array of telescope data with datatype DTARR
+ scantable (numpy.recarray): The array of scan information
+
+ polrep (str): polarization representation, either 'stokes' or 'circ'
+ source (str): The source name
+ mjd (int): The integer MJD of the observation
+ timetype (str): How to interpret tstart and tstop; either 'GMST' or 'UTC'
+
+ ampcal (bool): True if amplitudes calibrated
+ phasecal (bool): True if phases calibrated
+ opacitycal (bool): True if time-dependent opacities correctly accounted for in sigmas
+ frcal (bool): True if feed rotation calibrated out of visibilities
+ dcal (bool): True if D terms calibrated out of visibilities
+
+ Returns:
+ obsdata (Obsdata): an Obsdata object
+ """
+
+ if len(datatable) == 0:
+ raise Exception("No data in input table!")
+ if not (datatable.dtype in [ehc.DTPOL_STOKES, ehc.DTPOL_CIRC]):
+ raise Exception("Data table dtype should be DTPOL_STOKES or DTPOL_CIRC")
+
+ # Polarization Representation
+ if polrep == 'stokes':
+ self.polrep = 'stokes'
+ self.poldict = ehc.POLDICT_STOKES
+ self.poltype = ehc.DTPOL_STOKES
+ elif polrep == 'circ':
+ self.polrep = 'circ'
+ self.poldict = ehc.POLDICT_CIRC
+ self.poltype = ehc.DTPOL_CIRC
+ else:
+ raise Exception("only 'stokes' and 'circ' are supported polreps!")
+
+ # Set the various observation parameters
+ self.source = str(source)
+ self.ra = float(ra)
+ self.dec = float(dec)
+ self.rf = float(rf)
+ self.bw = float(bw)
+ self.ampcal = bool(ampcal)
+ self.phasecal = bool(phasecal)
+ self.opacitycal = bool(opacitycal)
+ self.dcal = bool(dcal)
+ self.frcal = bool(frcal)
+
+ if timetype not in ['GMST', 'UTC']:
+ raise Exception("timetype must be 'GMST' or 'UTC'")
+ self.timetype = timetype
+
+ # Save the data
+ self.data = datatable
+ self.scans = scantable
+
+ # Telescope array: default ordering is by sefd
+ self.tarr = tarr
+ self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))}
+ if np.any(self.tarr['sefdr'] != 0) or np.any(self.tarr['sefdl'] != 0):
+ self.reorder_tarr_sefd(reorder_baselines=False)
+
+ # reorder baselines to uvfits convention
+ if reorder:
+ self.reorder_baselines(trial_speedups=trial_speedups) # comment out for Closure Invariants
+ else:
+ self.data = np.array(sorted(self.data, key=lambda x: x['time']))
+
+
+ # Get tstart, mjd and tstop
+ times = self.unpack(['time'])['time']
+ self.tstart = times[0]
+ self.mjd = int(mjd)
+ self.tstop = times[-1]
+ if self.tstop < self.tstart:
+ self.tstop += 24.0
+
+ # Saved closure quantity arrays
+ self.amp = None
+ self.bispec = None
+ self.cphase = None
+ self.cphase_diag = None
+ self.camp = None
+ self.logcamp = None
+ self.logcamp_diag = None
+
+ @property
+ def tarr(self):
+ return self._tarr
+
+ @tarr.setter
+ def tarr(self, tarr):
+ self._tarr = tarr
+ self.tkey = {tarr[i]['site']: i for i in range(len(tarr))}
+
+ def obsdata_args(self):
+ """"Copy arguments for making a new Obsdata into a list and dictonary
+ """
+
+ arglist = [self.ra, self.dec, self.rf, self.bw, self.data, self.tarr]
+ argdict = {'scantable': self.scans, 'polrep': self.polrep, 'source': self.source,
+ 'mjd': self.mjd, 'timetype': self.timetype,
+ 'ampcal': self.ampcal, 'phasecal': self.phasecal, 'opacitycal': self.opacitycal,
+ 'dcal': self.dcal, 'frcal': self.frcal}
+ return (arglist, argdict)
+
+ def copy(self):
+ """Copy the observation object.
+
+ Args:
+
+ Returns:
+ (Obsdata): a copy of the Obsdata object.
+ """
+
+ # TODO: Do we want to copy over e.g. closure tables?
+ newobs = copy.deepcopy(self)
+
+ return newobs
+
+ def switch_timetype(self, timetype_out='UTC'):
+ """Return a new observation with the time type switched
+
+ Args:
+ timetype (str): "UTC" or "GMST"
+
+ Returns:
+ (Obsdata): new Obsdata object with potentially different timetype
+ """
+
+ if timetype_out not in ['GMST', 'UTC']:
+ raise Exception("timetype_out must be 'GMST' or 'UTC'")
+
+ out = self.copy()
+ if timetype_out == self.timetype:
+ return out
+
+ if timetype_out == 'UTC':
+ out.data['time'] = obsh.gmst_to_utc(out.data['time'], out.mjd)
+ if timetype_out == 'GMST':
+ out.data['time'] = obsh.utc_to_gmst(out.data['time'], out.mjd)
+
+ out.timetype = timetype_out
+ return out
+
+ def switch_polrep(self, polrep_out='stokes', allow_singlepol=True, singlepol_hand='R'):
+ """Return a new observation with the polarization representation changed
+
+ Args:
+ polrep_out (str): the polrep of the output data
+ allow_singlepol (bool): If True, treat single-polarization data as Stokes I
+ when converting from 'circ' polrep to 'stokes'
+ singlepol_hand (str): 'R' or 'L'; determines which parallel-hand is assumed
+ when converting 'stokes' to 'circ' if only I is present
+
+ Returns:
+ (Obsdata): new Obsdata object with potentially different polrep
+ """
+
+ if polrep_out not in ['stokes', 'circ']:
+ raise Exception("polrep_out must be either 'stokes' or 'circ'")
+ if polrep_out == self.polrep:
+ return self.copy()
+ elif polrep_out == 'stokes': # circ -> stokes
+ data = np.empty(len(self.data), dtype=ehc.DTPOL_STOKES)
+ rrmask = np.isnan(self.data['rrvis'])
+ llmask = np.isnan(self.data['llvis'])
+
+ for f in ehc.DTPOL_STOKES:
+ f = f[0]
+ if f in ['time', 'tint', 't1', 't2', 'tau1', 'tau2', 'u', 'v']:
+ data[f] = self.data[f]
+ elif f == 'vis':
+ data[f] = 0.5 * (self.data['rrvis'] + self.data['llvis'])
+ elif f == 'qvis':
+ data[f] = 0.5 * (self.data['lrvis'] + self.data['rlvis'])
+ elif f == 'uvis':
+ data[f] = 0.5j * (self.data['lrvis'] - self.data['rlvis'])
+ elif f == 'vvis':
+ data[f] = 0.5 * (self.data['rrvis'] - self.data['llvis'])
+ elif f in ['sigma', 'vsigma']:
+ data[f] = 0.5 * np.sqrt(self.data['rrsigma']**2 + self.data['llsigma']**2)
+ elif f in ['qsigma', 'usigma']:
+ data[f] = 0.5 * np.sqrt(self.data['rlsigma']**2 + self.data['lrsigma']**2)
+
+ if allow_singlepol:
+ # In cases where only one polarization is present
+ # use it as an estimator for Stokes I
+ data['vis'][rrmask] = self.data['llvis'][rrmask]
+ data['sigma'][rrmask] = self.data['llsigma'][rrmask]
+
+ data['vis'][llmask] = self.data['rrvis'][llmask]
+ data['sigma'][llmask] = self.data['rrsigma'][llmask]
+
+ elif polrep_out == 'circ': # stokes -> circ
+ data = np.empty(len(self.data), dtype=ehc.DTPOL_CIRC)
+ Vmask = np.isnan(self.data['vvis'])
+
+ for f in ehc.DTPOL_CIRC:
+ f = f[0]
+ if f in ['time', 'tint', 't1', 't2', 'tau1', 'tau2', 'u', 'v']:
+ data[f] = self.data[f]
+ elif f == 'rrvis':
+ data[f] = (self.data['vis'] + self.data['vvis'])
+ elif f == 'llvis':
+ data[f] = (self.data['vis'] - self.data['vvis'])
+ elif f == 'rlvis':
+ data[f] = (self.data['qvis'] + 1j * self.data['uvis'])
+ elif f == 'lrvis':
+ data[f] = (self.data['qvis'] - 1j * self.data['uvis'])
+ elif f in ['rrsigma', 'llsigma']:
+ data[f] = np.sqrt(self.data['sigma']**2 + self.data['vsigma']**2)
+ elif f in ['rlsigma', 'lrsigma']:
+ data[f] = np.sqrt(self.data['qsigma']**2 + self.data['usigma']**2)
+
+ if allow_singlepol:
+ # In cases where only Stokes I is present, copy it to a specified parallel-hand
+ prefix = singlepol_hand.lower() + singlepol_hand.lower() # rr or ll
+ if prefix not in ['rr', 'll']:
+ raise Exception('singlepol_hand must be R or L')
+
+ data[prefix + 'vis'][Vmask] = self.data['vis'][Vmask]
+ data[prefix + 'sigma'][Vmask] = self.data['sigma'][Vmask]
+
+ arglist, argdict = self.obsdata_args()
+ arglist[DATPOS] = data
+ argdict['polrep'] = polrep_out
+ newobs = Obsdata(*arglist, **argdict)
+
+ return newobs
+
+ def reorder_baselines(self, trial_speedups=False):
+ """Reorder baselines to match uvfits convention, based on the telescope array ordering
+ """
+ if trial_speedups:
+ self.reorder_baselines_trial_speedups()
+
+ else: # original code
+
+ # Time partition the datatable
+ datatable = self.data.copy()
+ datalist = []
+ for key, group in it.groupby(datatable, lambda x: x['time']):
+ #print(key*60*60,len(list(group)))
+ datalist.append(np.array([obs for obs in group]))
+
+ # loop through all data
+ obsdata = []
+ for tlist in datalist:
+ blpairs = []
+ for dat in tlist:
+ # Remove conjugate baselines
+ if not (set((dat['t1'], dat['t2']))) in blpairs:
+
+ # Reverse the baseline in the right order for uvfits:
+ if(self.tkey[dat['t2']] < self.tkey[dat['t1']]):
+
+ (dat['t1'], dat['t2']) = (dat['t2'], dat['t1'])
+ (dat['tau1'], dat['tau2']) = (dat['tau2'], dat['tau1'])
+ dat['u'] = -dat['u']
+ dat['v'] = -dat['v']
+
+ if self.polrep == 'stokes':
+ dat['vis'] = np.conj(dat['vis'])
+ dat['qvis'] = np.conj(dat['qvis'])
+ dat['uvis'] = np.conj(dat['uvis'])
+ dat['vvis'] = np.conj(dat['vvis'])
+ elif self.polrep == 'circ':
+ dat['rrvis'] = np.conj(dat['rrvis'])
+ dat['llvis'] = np.conj(dat['llvis'])
+ # must switch l & r !!
+ rl = dat['rlvis'].copy()
+ lr = dat['lrvis'].copy()
+ dat['rlvis'] = np.conj(lr)
+ dat['lrvis'] = np.conj(rl)
+
+ # You also have to switch the errors for the coherency!
+ rlerr = dat['rlsigma'].copy()
+ lrerr = dat['lrsigma'].copy()
+ dat["rlsigma"] = lrerr
+ dat["lrsigma"] = rlerr
+
+ else:
+ raise Exception("polrep must be either 'stokes' or 'circ'")
+
+ # Append the data point
+ blpairs.append(set((dat['t1'], dat['t2'])))
+ obsdata.append(dat)
+
+ obsdata = np.array(obsdata, dtype=self.poltype)
+
+ # Timesort data
+ obsdata = obsdata[np.argsort(obsdata, order=['time', 't1'])]
+
+ # Save the data
+ self.data = obsdata
+
+ return
+
+ def reorder_baselines_trial_speedups(self):
+ """Reorder baselines to match uvfits convention, based on the telescope array ordering
+ """
+
+ dat = self.data.copy()
+
+ ############ Ensure correct baseline order
+ # TODO can these be faster?
+ t1nums = np.fromiter([self.tkey[t] for t in dat['t1']],int)
+ t2nums = np.fromiter([self.tkey[t] for t in dat['t2']],int)
+
+ # which entries are in the wrong telescope order?
+ ordermask = t2nums < t1nums
+
+ # flip the order of these entries
+ t1 = dat['t1'].copy()
+ t2 = dat['t2'].copy()
+ tau1 = dat['tau1'].copy()
+ tau2 = dat['tau2'].copy()
+
+ dat['t1'][ordermask] = t2[ordermask]
+ dat['t2'][ordermask] = t1[ordermask]
+ dat['tau1'][ordermask] = tau2[ordermask]
+ dat['tau2'][ordermask] = tau1[ordermask]
+ dat['u'][ordermask] *= -1
+ dat['v'][ordermask] *= -1
+
+ if self.polrep=='stokes':
+ dat['vis'][ordermask] = np.conj(dat['vis'][ordermask])
+ dat['qvis'][ordermask] = np.conj(dat['qvis'][ordermask])
+ dat['uvis'][ordermask] = np.conj(dat['uvis'][ordermask])
+ dat['vvis'][ordermask] = np.conj(dat['vvis'][ordermask])
+
+ elif self.polrep == 'circ':
+ dat['rrvis'][ordermask] = np.conj(dat['rrvis'][ordermask])
+ dat['llvis'][ordermask] = np.conj(dat['llvis'][ordermask])
+ rl = dat['rlvis'].copy()
+ lr = dat['lrvis'].copy()
+ dat['rlvis'][ordermask] = np.conj(lr[ordermask])
+ dat['lrvis'][ordermask] = np.conj(rl[ordermask])
+
+ # Also need to switch error matrix
+ rle = dat['rlsigma'].copy()
+ lre = dat['lrsigma'].copy()
+ dat['rlsigma'][ordermask] = lre[ordermask]
+ dat['lrsigma'][ordermask] = rle[ordermask]
+
+ else:
+ raise Exception("polrep must be either 'stokes' or 'circ'")
+
+ # Remove duplicate or conjugate entries at any timestep
+ # Since telescope order has been sorted conjugates should appear as duplicates
+ timeblcombos = np.vstack((dat['time'],t1nums,t2nums)).T
+ uniqdat, uniqdatinv = np.unique(timeblcombos,axis=0, return_inverse=True)
+
+ if len(uniqdat) != len(dat):
+ print("WARNING: removing duplicate/conjuagte points in reorder_baselines!")
+ deletemask = np.ones(len(dat)).astype(bool)
+
+ for j in len(uniqdat):
+ idxs = np.argwhere(uniqdatinv==j)[:,0]
+ for idx in idxs[1:]: # delete all but first occurance
+ deletemask[idx] = False
+
+ # remove duplicates
+ dat_unique = dat[deletemask]
+
+ # sort data
+ dat = dat[np.argsort(dat, order=['time', 't1'])]
+
+ # save the data
+ self.data = dat
+
+ return
+
+ def reorder_tarr_sefd(self, reorder_baselines=True):
+ """Reorder the telescope array by SEFD minimal to maximum.
+ """
+
+ sorted_list = sorted(self.tarr, key=lambda x: np.sqrt(x['sefdr']**2 + x['sefdl']**2))
+ self.tarr = np.array(sorted_list, dtype=ehc.DTARR)
+ self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))}
+ if reorder_baselines:
+ self.reorder_baselines()
+
+ return
+
+ def reorder_tarr_snr(self, reorder_baselines=True):
+ """Reorder the telescope array by median SNR maximal to minimal.
+ """
+
+ snr = self.unpack(['t1', 't2', 'snr'])
+ snr_median = [np.median(snr[(snr['t1'] == site) + (snr['t2'] == site)]['snr'])
+ for site in self.tarr['site']]
+ idx = np.argsort(snr_median)[::-1]
+ self.tarr = self.tarr[idx]
+ self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))}
+ if reorder_baselines:
+ self.reorder_baselines()
+
+ return
+
+ def reorder_tarr_random(self, reorder_baselines=True):
+ """Randomly reorder the telescope array.
+ """
+
+ idx = np.arange(len(self.tarr))
+ np.random.shuffle(idx)
+ self.tarr = self.tarr[idx]
+ self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))}
+ if reorder_baselines:
+ self.reorder_baselines()
+
+ return
+
+ def data_conj(self):
+ """Make a data array including all conjugate baselines.
+
+ Args:
+
+ Returns:
+ (numpy.recarray): a copy of the Obsdata.data table including all conjugate baselines.
+ """
+
+ data = np.empty(2 * len(self.data), dtype=self.poltype)
+
+ # Add the conjugate baseline data
+ for f in self.poltype:
+ f = f[0]
+ if f in ['t1', 't2', 'tau1', 'tau2']:
+ if f[-1] == '1':
+ f2 = f[:-1] + '2'
+ else:
+ f2 = f[:-1] + '1'
+ data[f] = np.hstack((self.data[f], self.data[f2]))
+
+ elif f in ['u', 'v']:
+ data[f] = np.hstack((self.data[f], -self.data[f]))
+
+ elif f in [self.poldict['vis1'], self.poldict['vis2'],
+ self.poldict['vis3'], self.poldict['vis4']]:
+ if self.polrep == 'stokes':
+ data[f] = np.hstack((self.data[f], np.conj(self.data[f])))
+ elif self.polrep == 'circ':
+ if f in ['rrvis', 'llvis']:
+ data[f] = np.hstack((self.data[f], np.conj(self.data[f])))
+ elif f == 'rlvis':
+ data[f] = np.hstack((self.data['rlvis'], np.conj(self.data['lrvis'])))
+ elif f == 'lrvis':
+ data[f] = np.hstack((self.data['lrvis'], np.conj(self.data['rlvis'])))
+
+ # ALSO SWITCH THE ERRORS!
+ else:
+ raise Exception("polrep must be either 'stokes' or 'circ'")
+ # The conjugate baselines need the transpose error terms.
+ elif f == "rlsigma":
+ data[f] = np.hstack((self.data["rlsigma"], self.data["lrsigma"]))
+ elif f == "lrsigma":
+ data[f] = np.hstack((self.data["lrsigma"], self.data["rlsigma"]))
+
+ else:
+ data[f] = np.hstack((self.data[f], self.data[f]))
+
+ # Sort the data by time
+ data = data[np.argsort(data['time'])]
+
+ return data
+
+ def tlist(self, conj=False, t_gather=0., scan_gather=False):
+ """Group the data in a list of equal time observation datatables.
+
+ Args:
+ conj (bool): True if tlist_out includes conjugate baselines.
+ t_gather (float): Grouping timescale (in seconds). 0.0 indicates no grouping.
+ scan_gather (bool): If true, gather data into scans
+
+ Returns:
+ (list): a list of data tables containing time-partitioned data
+ """
+
+ if conj:
+ data = self.data_conj()
+ else:
+ data = self.data
+
+ # partition the data by time
+ datalist = []
+
+ if t_gather <= 0.0 and not scan_gather:
+ # Only group measurements at the same time
+ for key, group in it.groupby(data, lambda x: x['time']):
+ datalist.append(np.array([obs for obs in group]))
+ elif t_gather > 0.0 and not scan_gather:
+ # Group measurements in time
+ for key, group in it.groupby(data, lambda x: int(x['time'] / (t_gather / 3600.0))):
+ datalist.append(np.array([obs for obs in group]))
+ else:
+ # Group measurements by scan
+ if ((self.scans is None) or
+ np.any([scan is None for scan in self.scans]) or
+ len(self.scans) == 0):
+ print("No scan table in observation. Adding scan table before gathering...")
+ self.add_scans()
+
+ for key, group in it.groupby(
+ data, lambda x: np.searchsorted(self.scans[:, 0], x['time'])):
+ datalist.append(np.array([obs for obs in group]))
+
+ # return np.array(datalist, dtype=object)
+ return datalist
+
+
+ def split_obs(self, t_gather=0., scan_gather=False):
+ """Split single observation into multiple observation files, one per scan..
+
+ Args:
+ t_gather (float): Grouping timescale (in seconds). 0.0 indicates no grouping.
+ scan_gather (bool): If true, gather data into scans
+
+ Returns:
+ (list): list of single-scan Obsdata objects
+ """
+
+ tlist = self.tlist(t_gather=t_gather, scan_gather=scan_gather)
+
+ print("Splitting Observation File into " + str(len(tlist)) + " times")
+ arglist, argdict = self.obsdata_args()
+
+ # note that the tarr of the output includes all sites,
+ # even those that don't participate in the scan
+ splitlist = []
+ for tdata in tlist:
+ arglist[DATPOS] = tdata
+ splitlist.append(Obsdata(*arglist, **argdict))
+
+ return splitlist
+
+
+ def getClosestScan(self, time, splitObs=None):
+ """Split observation by scan and grab scan closest to timestamp
+
+ Args:
+ time (float): Time (GMST) you want to find the scan closest to
+ splitObs (bool): a list of Obsdata objects, output from split_obs, to save time
+
+ Returns:
+ (Obsdata): Obsdata object composed of scan closest to time
+ """
+
+ ## check if splitObs has been passed in alread ##
+ if splitObs is None:
+ splitObs = self.split_obs()
+
+ ## check for the scan with the closest start time to time arg ##
+ ## TODO: allow user to choose start time, end time, or mid-time
+ closest_index = 0
+ delta_t = 1e22
+ for s, s_obs in enumerate(splitObs):
+ dt = abs(s_obs.tstart - time)
+ if dt < delta_t:
+ delta_t = dt
+ closest_index = s
+
+ print(f"Using scan with time {splitObs[closest_index].tstart}.")
+ return splitObs[closest_index]
+
+
+ def bllist(self, conj=False):
+ """Group the data in a list of same baseline datatables.
+
+ Args:
+ conj (bool): True if tlist_out includes conjugate baselines.
+
+ Returns:
+ (list): a list of data tables containing baseline-partitioned data
+ """
+
+ if conj:
+ data = self.data_conj()
+ else:
+ data = self.data
+
+ # partition the data by baseline
+ datalist = []
+ idx = np.lexsort((data['t2'], data['t1']))
+ for key, group in it.groupby(data[idx], lambda x: set((x['t1'], x['t2']))):
+ datalist.append(np.array([obs for obs in group]))
+
+ return np.array(datalist, dtype=object)
+
+ def unpack_bl(self, site1, site2, fields, ang_unit='deg', debias=False, timetype=False):
+ """Unpack the data over time on the selected baseline site1-site2.
+
+ Args:
+ site1 (str): First site name
+ site2 (str): Second site name
+ fields (list): list of unpacked quantities from available quantities in FIELDS
+ ang_unit (str): 'deg' for degrees and 'rad' for radian phases
+ debias (bool): True to debias visibility amplitudes
+ timetype (str): 'GMST' or 'UTC' changes what is returned for 'time'
+
+ Returns:
+ (numpy.recarray): unpacked numpy array with data in fields requested
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+
+ # If we only specify one field
+ if timetype not in ['GMST', 'UTC', 'utc', 'gmst']:
+ raise Exception("timetype should be 'GMST' or 'UTC'!")
+ allfields = ['time']
+
+ if not isinstance(fields, list):
+ allfields.append(fields)
+ else:
+ for i in range(len(fields)):
+ allfields.append(fields[i])
+
+ # Get the data from data table on the selected baseline
+ allout = []
+ tlist = self.tlist(conj=True)
+ for scan in tlist:
+ for obs in scan:
+ if (obs['t1'], obs['t2']) == (site1, site2):
+ obs = np.array([obs])
+ out = self.unpack_dat(obs, allfields, ang_unit=ang_unit,
+ debias=debias, timetype=timetype)
+
+ allout.append(out)
+
+ return np.array(allout)
+
+ def unpack(self, fields, mode='all', ang_unit='deg', debias=False, conj=False, timetype=False):
+ """Unpack the data for the whole observation .
+
+ Args:
+ fields (list): list of unpacked quantities from availalbe quantities in FIELDS
+ mode (str): 'all' returns all data in single table,
+ 'time' groups output by equal time, 'bl' groups by baseline
+ ang_unit (str): 'deg' for degrees and 'rad' for radian phases
+ debias (bool): True to debias visibility amplitudes
+ conj (bool): True to include conjugate baselines
+ timetype (str): 'GMST' or 'UTC' changes what is returned for 'time'
+
+ Returns:
+ (numpy.recarray): unpacked numpy array with data in fields requested
+
+ """
+
+ if mode not in ('time', 'all', 'bl'):
+ raise Exception("possible options for mode are 'time', 'all' and 'bl'")
+
+ # If we only specify one field
+ if not isinstance(fields, list):
+ fields = [fields]
+
+ if mode == 'all':
+ if conj:
+ data = self.data_conj()
+ else:
+ data = self.data
+ allout = self.unpack_dat(data, fields, ang_unit=ang_unit,
+ debias=debias, timetype=timetype)
+
+ elif mode == 'time':
+ allout = []
+ tlist = self.tlist(conj=True)
+ for scan in tlist:
+ out = self.unpack_dat(scan, fields, ang_unit=ang_unit,
+ debias=debias, timetype=timetype)
+ allout.append(out)
+
+ elif mode == 'bl':
+ allout = []
+ bllist = self.bllist()
+ for bl in bllist:
+ out = self.unpack_dat(bl, fields, ang_unit=ang_unit,
+ debias=debias, timetype=timetype)
+ allout.append(out)
+
+ return allout
+
+ def unpack_dat(self, data, fields, conj=False, ang_unit='deg', debias=False, timetype=False):
+ """Unpack the data in a passed data recarray.
+
+ Args:
+ data (numpy.recarray): data recarray of format DTPOL_STOKES or DTPOL_CIRC
+ fields (list): list of unpacked quantities from availalbe quantities in FIELDS
+ conj (bool): True to include conjugate baselines
+ ang_unit (str): 'deg' for degrees and 'rad' for radian phases
+ debias (bool): True to debias visibility amplitudes
+ timetype (str): 'GMST' or 'UTC' changes what is returned for 'time'
+
+ Returns:
+ (numpy.recarray): unpacked numpy array with data in fields requested
+
+ """
+
+ if ang_unit == 'deg':
+ angle = ehc.DEGREE
+ else:
+ angle = 1.0
+
+ # If we only specify one field
+ if isinstance(fields, str):
+ fields = [fields]
+
+ if not timetype:
+ timetype = self.timetype
+ if timetype not in ['GMST', 'UTC', 'gmst', 'utc']:
+ raise Exception("timetype should be 'GMST' or 'UTC'!")
+
+ # Get field data
+ allout = []
+ for field in fields:
+ if field in ["time", "time_utc", "time_gmst"]:
+ out = data['time']
+ ty = 'f8'
+ elif field in ["u", "v", "tint", "tau1", "tau2"]:
+ out = data[field]
+ ty = 'f8'
+ elif field in ["uvdist"]:
+ out = np.abs(data['u'] + 1j * data['v'])
+ ty = 'f8'
+ elif field in ["t1", "el1", "par_ang1", "hr_ang1"]:
+ sites = data["t1"]
+ keys = [self.tkey[site] for site in sites]
+ tdata = self.tarr[keys]
+ out = sites
+ ty = 'U32'
+ elif field in ["t2", "el2", "par_ang2", "hr_ang2"]:
+ sites = data["t2"]
+ keys = [self.tkey[site] for site in sites]
+ tdata = self.tarr[keys]
+ out = sites
+ ty = 'U32'
+ elif field in ['vis', 'amp', 'phase', 'snr', 'sigma', 'sigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = data['vis']
+ sig = data['sigma']
+ elif self.polrep == 'circ':
+ out = 0.5 * (data['rrvis'] + data['llvis'])
+ sig = 0.5 * np.sqrt(data['rrsigma']**2 + data['llsigma']**2)
+ elif field in ['qvis', 'qamp', 'qphase', 'qsnr', 'qsigma', 'qsigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = data['qvis']
+ sig = data['qsigma']
+ elif self.polrep == 'circ':
+ out = 0.5 * (data['lrvis'] + data['rlvis'])
+ sig = 0.5 * np.sqrt(data['lrsigma']**2 + data['rlsigma']**2)
+ elif field in ['uvis', 'uamp', 'uphase', 'usnr', 'usigma', 'usigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = data['uvis']
+ sig = data['usigma']
+ elif self.polrep == 'circ':
+ out = 0.5j * (data['lrvis'] - data['rlvis'])
+ sig = 0.5 * np.sqrt(data['lrsigma']**2 + data['rlsigma']**2)
+ elif field in ['vvis', 'vamp', 'vphase', 'vsnr', 'vsigma', 'vsigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = data['vvis']
+ sig = data['vsigma']
+ elif self.polrep == 'circ':
+ out = 0.5 * (data['rrvis'] - data['llvis'])
+ sig = 0.5 * np.sqrt(data['rrsigma']**2 + data['llsigma']**2)
+ elif field in ['pvis', 'pamp', 'pphase', 'psnr', 'psigma', 'psigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = data['qvis'] + 1j * data['uvis']
+ sig = np.sqrt(data['qsigma']**2 + data['usigma']**2)
+ elif self.polrep == 'circ':
+ out = data['rlvis']
+ sig = data['rlsigma']
+ elif field in ['m', 'mamp', 'mphase', 'msnr', 'msigma', 'msigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = (data['qvis'] + 1j * data['uvis']) / data['vis']
+ sig = obsh.merr(data['sigma'], data['qsigma'], data['usigma'], data['vis'], out)
+ elif self.polrep == 'circ':
+ out = 2 * data['rlvis'] / (data['rrvis'] + data['llvis'])
+ sig = obsh.merr2(data['rlsigma'], data['rrsigma'], data['llsigma'],
+ 0.5 * (data['rrvis'] + data['llvis']), out)
+ elif field in ['evis', 'eamp', 'ephase', 'esnr', 'esigma', 'esigma_phase']:
+ ty = 'c16'
+ ang = np.arctan2(data['u'], data['v']) # TODO: correct convention EofN?
+ if self.polrep == 'stokes':
+ q = data['qvis']
+ u = data['uvis']
+ qsig = data['qsigma']
+ usig = data['usigma']
+ elif self.polrep == 'circ':
+ q = 0.5 * (data['lrvis'] + data['rlvis'])
+ u = 0.5j * (data['lrvis'] - data['rlvis'])
+ qsig = 0.5 * np.sqrt(data['lrsigma']**2 + data['rlsigma']**2)
+ usig = qsig
+ out = (np.cos(2 * ang) * q + np.sin(2 * ang) * u)
+ sig = np.sqrt(0.5 * ((np.cos(2 * ang) * qsig)**2 + (np.sin(2 * ang) * usig)**2))
+ elif field in ['bvis', 'bamp', 'bphase', 'bsnr', 'bsigma', 'bsigma_phase']:
+ ty = 'c16'
+ ang = np.arctan2(data['u'], data['v']) # TODO: correct convention EofN?
+ if self.polrep == 'stokes':
+ q = data['qvis']
+ u = data['uvis']
+ qsig = data['qsigma']
+ usig = data['usigma']
+ elif self.polrep == 'circ':
+ q = 0.5 * (data['lrvis'] + data['rlvis'])
+ u = 0.5j * (data['lrvis'] - data['rlvis'])
+ qsig = 0.5 * np.sqrt(data['lrsigma']**2 + data['rlsigma']**2)
+ usig = qsig
+ out = (-np.sin(2 * ang) * q + np.cos(2 * ang) * u)
+ sig = np.sqrt(0.5 * ((np.sin(2 * ang) * qsig)**2 + (np.cos(2 * ang) * usig)**2))
+ elif field in ['rrvis', 'rramp', 'rrphase', 'rrsnr', 'rrsigma', 'rrsigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = data['vis'] + data['vvis']
+ sig = np.sqrt(data['sigma']**2 + data['vsigma']**2)
+ elif self.polrep == 'circ':
+ out = data['rrvis']
+ sig = data['rrsigma']
+ elif field in ['llvis', 'llamp', 'llphase', 'llsnr', 'llsigma', 'llsigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = data['vis'] - data['vvis']
+ sig = np.sqrt(data['sigma']**2 + data['vsigma']**2)
+ elif self.polrep == 'circ':
+ out = data['llvis']
+ sig = data['llsigma']
+ elif field in ['rlvis', 'rlamp', 'rlphase', 'rlsnr', 'rlsigma', 'rlsigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = data['qvis'] + 1j * data['uvis']
+ sig = np.sqrt(data['qsigma']**2 + data['usigma']**2)
+ elif self.polrep == 'circ':
+ out = data['rlvis']
+ sig = data['rlsigma']
+ elif field in ['lrvis', 'lramp', 'lrphase', 'lrsnr', 'lrsigma', 'lrsigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = data['qvis'] - 1j * data['uvis']
+ sig = np.sqrt(data['qsigma']**2 + data['usigma']**2)
+ elif self.polrep == 'circ':
+ out = data['lrvis']
+ sig = data['lrsigma']
+ elif field in ['rrllvis', 'rrllamp', 'rrllphase', 'rrllsnr',
+ 'rrllsigma', 'rrllsigma_phase']:
+ ty = 'c16'
+ if self.polrep == 'stokes':
+ out = (data['vis'] + data['vvis']) / (data['vis'] - data['vvis'])
+ sig = (2.0**0.5 * (np.abs(data['vis'])**2 + np.abs(data['vvis'])**2)**0.5
+ / np.abs(data['vis'] - data['vvis'])**2
+ * (data['sigma']**2 + data['vsigma']**2)**0.5)
+ elif self.polrep == 'circ':
+ out = data['rrvis'] / data['llvis']
+ sig = np.sqrt(np.abs(data['rrsigma'] / data['llvis'])**2
+ + np.abs(data['llsigma'] * data['rrvis'] / data['llvis'])**2)
+
+ else:
+ raise Exception("%s is not a valid field \n" % field +
+ "valid field values are: " + ' '.join(ehc.FIELDS))
+
+ if field in ["time_utc"] and self.timetype == 'GMST':
+ out = obsh.gmst_to_utc(out, self.mjd)
+ if field in ["time_gmst"] and self.timetype == 'UTC':
+ out = obsh.utc_to_gmst(out, self.mjd)
+ if field in ["time"] and self.timetype == 'GMST' and timetype == 'UTC':
+ out = obsh.gmst_to_utc(out, self.mjd)
+ if field in ["time"] and self.timetype == 'UTC' and timetype == 'GMST':
+ out = obsh.utc_to_gmst(out, self.mjd)
+
+ # Compute elevation and parallactic angles
+ if field in ["el1", "el2", "hr_ang1", "hr_ang2", "par_ang1", "par_ang2"]:
+ if self.timetype == 'GMST':
+ times_sid = data['time']
+ else:
+ times_sid = obsh.utc_to_gmst(data['time'], self.mjd)
+
+ thetas = np.mod((times_sid - self.ra) * ehc.HOUR, 2 * np.pi)
+ coords = obsh.recarr_to_ndarr(tdata[['x', 'y', 'z']], 'f8')
+ el_angle = obsh.elev(obsh.earthrot(coords, thetas), self.sourcevec())
+ latlon = obsh.xyz_2_latlong(coords)
+ hr_angles = obsh.hr_angle(times_sid * ehc.HOUR, latlon[:, 1], self.ra * ehc.HOUR)
+
+ if field in ["el1", "el2"]:
+ out = el_angle / angle
+ ty = 'f8'
+ if field in ["hr_ang1", "hr_ang2"]:
+ out = hr_angles / angle
+ ty = 'f8'
+ if field in ["par_ang1", "par_ang2"]:
+ par_ang = obsh.par_angle(hr_angles, latlon[:, 0], self.dec * ehc.DEGREE)
+ out = par_ang / angle
+ ty = 'f8'
+
+ # Get arg/amps/snr
+ if field in ["amp", "qamp", "uamp", "vamp", "pamp", "mamp", "bamp", "eamp",
+ "rramp", "llamp", "rlamp", "lramp", "rrllamp"]:
+ out = np.abs(out)
+ if debias:
+ out = obsh.amp_debias(out, sig)
+ ty = 'f8'
+ elif field in ["sigma", "qsigma", "usigma", "vsigma",
+ "psigma", "msigma", "bsigma", "esigma",
+ "rrsigma", "llsigma", "rlsigma", "lrsigma", "rrllsigma"]:
+ out = np.abs(sig)
+ ty = 'f8'
+ elif field in ["phase", "qphase", "uphase", "vphase", "pphase", "bphase", "ephase",
+ "mphase", "rrphase", "llphase", "lrphase", "rlphase", "rrllphase"]:
+ out = np.angle(out) / angle
+ ty = 'f8'
+ elif field in ["sigma_phase", "qsigma_phase", "usigma_phase", "vsigma_phase",
+ "psigma_phase", "msigma_phase", "bsigma_phase", "esigma_phase",
+ "rrsigma_phase", "llsigma_phase", "rlsigma_phase", "lrsigma_phase",
+ "rrllsigma_phase"]:
+ out = np.abs(sig) / np.abs(out) / angle
+ ty = 'f8'
+ elif field in ["snr", "qsnr", "usnr", "vsnr", "psnr", "bsnr", "esnr",
+ "msnr", "rrsnr", "llsnr", "rlsnr", "lrsnr", "rrllsnr"]:
+ out = np.abs(out) / np.abs(sig)
+ ty = 'f8'
+
+ # Reshape and stack with other fields
+ out = np.array(out, dtype=[(field, ty)])
+
+ if len(allout) > 0:
+ allout = rec.merge_arrays((allout, out), asrecarray=True, flatten=True)
+ else:
+ allout = out
+
+ return allout
+
+ def sourcevec(self):
+ """Return the source position vector in geocentric coordinates at 0h GMST.
+
+ Args:
+
+ Returns:
+ (numpy.array): normal vector pointing to source in geocentric coordinates (m)
+ """
+
+ sourcevec = np.array([np.cos(self.dec * ehc.DEGREE), 0, np.sin(self.dec * ehc.DEGREE)])
+
+ return sourcevec
+
+ def res(self):
+ """Return the nominal resolution (1/longest baseline) of the observation in radians.
+
+ Args:
+
+ Returns:
+ (float): normal array resolution in radians
+ """
+
+ res = 1.0 / np.max(self.unpack('uvdist')['uvdist'])
+
+ return res
+
+
+ def chisq(self, im_or_mov, dtype='vis', pol='I', ttype='nfft', mask=[], **kwargs):
+ """Give the reduced chi^2 of the observation for the specified image and datatype.
+
+ Args:
+ im_or_mov (Image or Movie): image or movie object on which to test chi^2
+ dtype (str): data type of chi^2 (e.g., 'vis', 'amp', 'bs', 'cphase')
+ pol (str): polarization type ('I', 'Q', 'U', 'V', 'LL', 'RR', 'LR', or 'RL'
+ mask (arr): mask of same dimension as im.imvec
+ ttype (str): "fast" or "nfft" or "direct"
+ fft_pad_factor (float): zero pad the image to (fft_pad_factor * image size) in FFT
+ conv_func ('str'): The convolving function for gridding; 'gaussian', 'pill','cubic'
+ p_rad (int): The pixel radius for the convolving function
+ order ('str'): Interpolation order for sampling the FFT
+
+ systematic_noise (float): adds a fractional systematic noise tolerance to sigmas
+ snrcut (float): a snr cutoff for including data in the chi^2 sum
+ debias (bool): if True then apply debiasing to amplitudes/closure amplitudes
+ weighting (str): 'natural' or 'uniform'
+
+ systematic_cphase_noise (float): a value in degrees to add to closure phase sigmas
+ cp_uv_min (float): flag short baselines before forming closure quantities
+ maxset (bool): if True, use maximal set instead of minimal for closure quantities
+
+ Returns:
+ (float): image chi^2
+ """
+ if dtype not in ['vis', 'bs', 'amp', 'cphase',
+ 'cphase_diag', 'camp', 'logcamp', 'logcamp_diag', 'm']:
+ raise Exception("%s is not a supported dterms!" % dtype)
+
+ # TODO -- should import this at top, but the circular dependencies create a mess...
+ import ehtim.imaging.imager_utils as iu
+ import ehtim.modeling.modeling_utils as mu
+
+ # Movie -- weighted sum of all frame chi^2 values
+ if hasattr(im_or_mov, 'get_image'):
+ mov = im_or_mov
+ obs_list = self.split_obs()
+
+ chisq_list = []
+ num_list = []
+ for ii, obs in enumerate(obs_list):
+
+ im = mov.get_image(obs.data[0]['time']) # Get image at the observation time
+
+ if pol not in im._imdict.keys():
+ raise Exception(pol + ' is not in the current image.' +
+ ' Consider changing the polarization basis of the image.')
+
+ try:
+ (data, sigma, A) = iu.chisqdata(obs, im, mask, dtype,
+ pol=pol, ttype=ttype, **kwargs)
+
+ except IndexError: # Not enough data to form closures!
+ continue
+
+ imvec = im._imdict[pol]
+ if len(mask) > 0 and np.any(np.invert(mask)):
+ imvec = imvec[mask]
+
+ chisq_list.append(iu.chisq(imvec, A, data, sigma, dtype, ttype=ttype, mask=mask))
+ num_list.append(len(data))
+
+ chisq = np.sum(np.array(num_list) * np.array(chisq_list)) / np.sum(num_list)
+
+ # Model -- single chi^2
+ elif hasattr(im_or_mov,'N_models'):
+ (data, sigma, uv, jonesdict) = mu.chisqdata(self, dtype, pol, **kwargs)
+ chisq = mu.chisq(im_or_mov, uv, data, sigma, dtype, jonesdict)
+
+ # Image -- single chi^2
+ else:
+ im = im_or_mov
+ if pol not in im._imdict.keys():
+ raise Exception(pol + ' is not in the current image.' +
+ ' Consider changing the polarization basis of the image.')
+
+ (data, sigma, A) = iu.chisqdata(self, im, mask, dtype, pol=pol, ttype=ttype, **kwargs)
+
+ imvec = im._imdict[pol]
+ if len(mask) > 0 and np.any(np.invert(mask)):
+ imvec = imvec[mask]
+
+ chisq = iu.chisq(imvec, A, data, sigma, dtype, ttype=ttype, mask=mask)
+
+ return chisq
+
+ def polchisq(self, im, dtype='pvis', ttype='nfft', pol_trans=True, mask=[], **kwargs):
+ """Give the reduced chi^2 for the specified image and polarimetric datatype.
+
+ Args:
+ im (Image): image to test polarimetric chi^2
+ dtype (str): data type of polarimetric chi^2 ('pvis','m','pbs')
+ pol (str): polarization type ('I', 'Q', 'U', 'V', 'LL', 'RR', 'LR', or 'RL'
+ mask (arr): mask of same dimension as im.imvec
+ ttype (str): if "fast" or "nfft" or "direct"
+ pol_trans (bool): True for I,m,chi, False for IQU
+ fft_pad_factor (float): zero pad the image to (fft_pad_factor * image size) in FFT
+ conv_func ('str'): The convolving function for gridding; 'gaussian', 'pill','cubic'
+ p_rad (int): The pixel radius for the convolving function
+ order ('str'): Interpolation order for sampling the FFT
+
+ systematic_noise (float): adds a fractional systematic noise tolerance to sigmas
+ snrcut (float): a snr cutoff for including data in the chi^2 sum
+ debias (bool): if True then apply debiasing to amplitudes/closure amplitudes
+ weighting (str): 'natural' or 'uniform'
+
+ systematic_cphase_noise (float): value in degrees to add to closure phase sigmas
+ cp_uv_min (float): flag short baselines before forming closure quantities
+ maxset (bool): if True, use maximal set instead of minimal for closure quantities
+
+ Returns:
+ (float): image chi^2
+ """
+
+ if dtype not in ['pvis', 'm', 'pbs','vvis']:
+ raise Exception("Only supported polarimetric dterms are 'pvis','m, 'pbs','vvis'!")
+
+ # TODO -- should import this at top, but the circular dependencies create a mess...
+ import ehtim.imaging.pol_imager_utils as piu
+
+ # Unpack the necessary polarimetric data
+ (data, sigma, A) = piu.polchisqdata(self, im, mask, dtype, ttype=ttype, **kwargs)
+
+ # Pack the comparison image in the proper format
+ imstokes = im.switch_polrep(polrep_out='stokes', pol_prim_out='I')
+ if pol_trans:
+ ivec = imstokes.imvec
+ mvec = (np.abs(imstokes.qvec + 1j * imstokes.uvec) / ivec)
+ chivec = np.angle(imstokes.qvec + 1j * imstokes.uvec) / 2
+ vvec = imstokes.vvec/ivec
+ if len(mask) > 0 and np.any(np.invert(mask)):
+ ivec = ivec[mask]
+ mvec = mvec[mask]
+ chivec = chivec[mask]
+ vvec = vvec[mask]
+ imtuple = np.array((ivec, mvec, chivec,vvec))
+ else:
+ ivec = imstokes.imvec
+ qvec = imstokes.qvec
+ uvec = imstokes.uvec
+ vvec = imstokes.vvec
+ if len(mask) > 0 and np.any(np.invert(mask)):
+ ivec = ivec[mask]
+ qvec = qvec[mask]
+ uvec = uvec[mask]
+ vvec = vvec[mask]
+ imtuple = np.array((ivec, qvec, uvec,vvec))
+
+
+ # Calculate the chi^2
+ chisq = piu.polchisq(imtuple, A, data, sigma, dtype,
+ ttype=ttype, mask=mask, pol_trans=pol_trans)
+
+ return chisq
+
+ def recompute_uv(self):
+ """Recompute u,v points using observation times and metadata
+
+ Args:
+
+ Returns:
+ (Obsdata): New Obsdata object containing the same data with recomputed u,v points
+ """
+
+ times = self.data['time']
+ site1 = self.data['t1']
+ site2 = self.data['t2']
+ arr = ehtim.array.Array(self.tarr)
+ print("Recomputing U,V Points using MJD %d \n RA %e \n DEC %e \n RF %e GHz"
+ % (self.mjd, self.ra, self.dec, self.rf / 1.e9))
+
+ (timesout, uout, vout) = obsh.compute_uv_coordinates(arr, site1, site2, times,
+ self.mjd, self.ra, self.dec, self.rf,
+ timetype=self.timetype,
+ elevmin=0, elevmax=90, no_elevcut_space=False)
+
+ if len(timesout) != len(times):
+ raise Exception(
+ "len(timesout) != len(times) in recompute_uv: check elevation limits!!")
+
+ datatable = self.data.copy()
+ datatable['u'] = uout
+ datatable['v'] = vout
+
+ arglist, argdict = self.obsdata_args()
+ arglist[DATPOS] = np.array(datatable)
+ out = Obsdata(*arglist, **argdict)
+
+ return out
+
+ def avg_coherent(self, inttime, scan_avg=False, moving=False):
+ """Coherently average data along u,v tracks in chunks of length inttime (sec)
+
+ Args:
+ inttime (float): coherent integration time in seconds
+ scan_avg (bool): if True, average over scans in self.scans instead of intime
+ moving (bool): averaging with moving window (boxcar width in seconds)
+ Returns:
+ (Obsdata): Obsdata object containing averaged data
+ """
+
+ if (scan_avg) and (getattr(self.scans, "shape", None) is None or len(self.scans) == 0):
+ print('No scan data, ignoring scan_avg!')
+ scan_avg = False
+
+ if inttime <= 0.0 and scan_avg is False:
+ print('No averaging done!')
+ return self.copy()
+
+ if moving:
+ vis_avg = ehdf.coh_moving_avg_vis(self, dt=inttime, return_type='rec')
+ else:
+ vis_avg = ehdf.coh_avg_vis(self, dt=inttime, return_type='rec',
+ err_type='predicted', scan_avg=scan_avg)
+
+ arglist, argdict = self.obsdata_args()
+ arglist[DATPOS] = vis_avg
+ out = Obsdata(*arglist, **argdict)
+
+ return out
+
+ def avg_incoherent(self, inttime, scan_avg=False, debias=True, err_type='predicted'):
+ """Incoherently average data along u,v tracks in chunks of length inttime (sec)
+
+ Args:
+ inttime (float): incoherent integration time in seconds
+ scan_avg (bool): if True, average over scans in self.scans instead of intime
+ debias (bool): if True, debias the averaged amplitudes
+ err_type (str): 'predicted' or 'measured'
+
+ Returns:
+ (Obsdata): Obsdata object containing averaged data
+ """
+
+ print('Incoherently averaging data, putting phases to zero!')
+ amp_rec = ehdf.incoh_avg_vis(self, dt=inttime, debias=debias, scan_avg=scan_avg,
+ return_type='rec', rec_type='vis', err_type=err_type)
+ arglist, argdict = self.obsdata_args()
+ arglist[DATPOS] = amp_rec
+ out = Obsdata(*arglist, **argdict)
+
+ return out
+
+ def add_amp(self, avg_time=0, scan_avg=False, debias=True, err_type='predicted',
+ return_type='rec', round_s=0.1, snrcut=0.):
+ """Adds attribute self.amp: aan amplitude table with incoherently averaged amplitudes
+
+ Args:
+ avg_time (float): incoherent integration time in seconds
+ scan_avg (bool): if True, average over scans in self.scans instead of intime
+ debias (bool): if True then apply debiasing
+ err_type (str): 'predicted' or 'measured'
+ return_type: data frame ('df') or recarray ('rec')
+ round_s (float): accuracy of datetime object in seconds
+ snrcut (float): flag amplitudes with snr lower than this
+
+ """
+
+ # Get the spacing between datapoints in seconds
+ if len(set([x[0] for x in list(self.unpack('time'))])) > 1:
+ tint0 = np.min(np.diff(np.asarray(sorted(list(set(
+ [x[0] for x in list(self.unpack('time'))])))))) * 3600.
+ else:
+ tint0 = 0.0
+
+ if avg_time <= tint0:
+ adf = ehdf.make_amp(self, debias=debias, round_s=round_s)
+ if return_type == 'rec':
+ adf = ehdf.df_to_rec(adf, 'amp')
+ print("Updated self.amp: no averaging")
+ else:
+ adf = ehdf.incoh_avg_vis(self, dt=avg_time, debias=debias, scan_avg=scan_avg,
+ return_type=return_type, rec_type='amp', err_type=err_type)
+
+ # snr cut
+ adf = adf[adf['amp'] / adf['sigma'] > snrcut]
+ self.amp = adf
+ print("Updated self.amp: avg_time %f s\n" % avg_time)
+
+ return
+
+ def add_bispec(self, avg_time=0, return_type='rec', count='max', snrcut=0.,
+ err_type='predicted', num_samples=1000, round_s=0.1, uv_min=False):
+ """Adds attribute self.bispec: bispectra table with bispectra averaged for dt
+
+ Args:
+ avg_time (float): bispectrum averaging timescale
+ return_type: data frame ('df') or recarray ('rec')
+ count (str): If 'min', return minimal set of bispectra,
+ if 'max' return all bispectra up to reordering
+ err_type (str): 'predicted' or 'measured'
+ num_samples: number of bootstrap (re)samples if measuring error
+ round_s (float): accuracy of datetime object in seconds
+ snrcut (float): flag bispectra with snr lower than this
+
+ """
+
+ # Get spacing between datapoints in seconds
+ if len(set([x[0] for x in list(self.unpack('time'))])) > 1:
+ tint0 = np.min(np.diff(np.asarray(sorted(list(set(
+ [x[0] for x in list(self.unpack('time'))])))))) * 3600.
+ else:
+ tint0 = 0
+
+ if avg_time > tint0:
+ cdf = ehdf.make_bsp_df(self, mode='all', round_s=round_s, count=count,
+ snrcut=0., uv_min=uv_min)
+ cdf = ehdf.average_bispectra(cdf, avg_time, return_type=return_type,
+ num_samples=num_samples, snrcut=snrcut)
+ else:
+ cdf = ehdf.make_bsp_df(self, mode='all', round_s=round_s, count=count,
+ snrcut=snrcut, uv_min=uv_min)
+ print("Updated self.bispec: no averaging")
+ if return_type == 'rec':
+ cdf = ehdf.df_to_rec(cdf, 'bispec')
+
+ self.bispec = cdf
+ print("Updated self.bispec: avg_time %f s\n" % avg_time)
+
+ return
+
+ def add_cphase(self, avg_time=0, return_type='rec', count='max', snrcut=0.,
+ err_type='predicted', num_samples=1000, round_s=0.1, uv_min=False):
+ """Adds attribute self.cphase: cphase table averaged for dt
+
+ Args:
+ avg_time (float): closure phase averaging timescale
+ return_type: data frame ('df') or recarray ('rec')
+ count (str): If 'min', return minimal set of phases,
+ if 'max' return all closure phases up to reordering
+ err_type (str): 'predicted' or 'measured'
+ num_samples: number of bootstrap (re)samples if measuring error
+ round_s (float): accuracy of datetime object in seconds
+ snrcut (float): flag closure phases with snr lower than this
+
+ """
+
+ # Get spacing between datapoints in seconds
+ if len(set([x[0] for x in list(self.unpack('time'))])) > 1:
+ tint0 = np.min(np.diff(np.asarray(sorted(list(set(
+ [x[0] for x in list(self.unpack('time'))])))))) * 3600.
+ else:
+ tint0 = 0
+
+ if avg_time > tint0:
+ cdf = ehdf.make_cphase_df(self, mode='all', round_s=round_s, count=count,
+ snrcut=0., uv_min=uv_min)
+ cdf = ehdf.average_cphases(cdf, avg_time, return_type=return_type, err_type=err_type,
+ num_samples=num_samples, snrcut=snrcut)
+ else:
+ cdf = ehdf.make_cphase_df(self, mode='all', round_s=round_s, count=count,
+ snrcut=snrcut, uv_min=uv_min)
+ if return_type == 'rec':
+ cdf = ehdf.df_to_rec(cdf, 'cphase')
+ print("Updated self.cphase: no averaging")
+
+ self.cphase = cdf
+ print("updated self.cphase: avg_time %f s\n" % avg_time)
+
+ return
+
+ def add_cphase_diag(self, avg_time=0, return_type='rec', vtype='vis', count='min', snrcut=0.,
+ err_type='predicted', num_samples=1000, round_s=0.1, uv_min=False):
+ """Adds attribute self.cphase_diag: cphase_diag table averaged for dt
+
+ Args:
+ avg_time (float): closure phase averaging timescale
+ return_type: data frame ('df') or recarray ('rec')
+ vtype (str): Visibility type (e.g., 'vis', 'llvis', 'rrvis', etc.)
+ count (str): If 'min', return minimal set of phases,
+ If 'max' return all closure phases up to reordering
+ err_type (str): 'predicted' or 'measured'
+ num_samples: number of bootstrap (re)samples if measuring error
+ round_s (float): accuracy of datetime object in seconds
+ snrcut (float): flag closure phases with snr lower than this
+
+ """
+
+ # Get spacing between datapoints in seconds
+ if len(set([x[0] for x in list(self.unpack('time'))])) > 1:
+ tint0 = np.min(np.diff(np.asarray(sorted(list(set(
+ [x[0] for x in list(self.unpack('time'))]))))))
+ tint0 *= 3600
+ else:
+ tint0 = 0
+
+ # Dom TODO: implement averaging during diagonal closure phase creation
+ if avg_time > tint0:
+ print("Averaging while creating diagonal closure phases is not yet implemented!")
+ print("Proceeding for now without averaging.")
+ cdf = ehdf.make_cphase_diag_df(self, vtype=vtype, round_s=round_s,
+ count=count, snrcut=snrcut, uv_min=uv_min)
+ else:
+ cdf = ehdf.make_cphase_diag_df(self, vtype=vtype, round_s=round_s,
+ count=count, snrcut=snrcut, uv_min=uv_min)
+ if return_type == 'rec':
+ cdf = ehdf.df_to_rec(cdf, 'cphase_diag')
+ print("Updated self.cphase_diag: no averaging")
+
+ self.cphase_diag = cdf
+ print("updated self.cphase_diag: avg_time %f s\n" % avg_time)
+
+ return
+
+ def add_camp(self, avg_time=0, return_type='rec', ctype='camp',
+ count='max', debias=True, snrcut=0.,
+ err_type='predicted', num_samples=1000, round_s=0.1):
+ """Adds attribute self.camp or self.logcamp: closure amplitudes table
+
+ Args:
+ avg_time (float): closure amplitude averaging timescale
+ return_type: data frame ('df') or recarray ('rec')
+ ctype (str): The closure amplitude type ('camp' or 'logcamp')
+ debias (bool): If True, debias the closure amplitude
+ count (str): If 'min', return minimal set of amplitudes,
+ if 'max' return all closure amplitudes up to inverses
+ err_type (str): 'predicted' or 'measured'
+ num_samples: number of bootstrap (re)samples if measuring error
+ round_s (float): accuracy of datetime object in seconds
+ snrcut (float): flag closure amplitudes with snr lower than this
+ """
+
+ # Get spacing between datapoints in seconds
+ if len(set([x[0] for x in list(self.unpack('time'))])) > 1:
+ tint0 = np.min(np.diff(np.asarray(sorted(list(set(
+ [x[0] for x in list(self.unpack('time'))]))))))
+ tint0 *= 3600
+ else:
+ tint0 = 0
+
+ if avg_time > tint0:
+ foo = self.avg_incoherent(avg_time, debias=debias, err_type=err_type)
+ else:
+ foo = self
+ cdf = ehdf.make_camp_df(foo, ctype=ctype, debias=False,
+ count=count, round_s=round_s, snrcut=snrcut)
+
+ if ctype == 'logcamp':
+ print("updated self.lcamp: no averaging")
+ elif ctype == 'camp':
+ print("updated self.camp: no averaging")
+ if return_type == 'rec':
+ cdf = ehdf.df_to_rec(cdf, 'camp')
+
+ if ctype == 'logcamp':
+ self.logcamp = cdf
+ print("updated self.logcamp: avg_time %f s\n" % avg_time)
+ elif ctype == 'camp':
+ self.camp = cdf
+ print("updated self.camp: avg_time %f s\n" % avg_time)
+
+ return
+
+ def add_logcamp(self, avg_time=0, return_type='rec', ctype='camp',
+ count='max', debias=True, snrcut=0.,
+ err_type='predicted', num_samples=1000, round_s=0.1):
+ """Adds attribute self.logcamp: closure amplitudes table
+
+ Args:
+ avg_time (float): closure amplitude averaging timescale
+ return_type: data frame ('df') or recarray ('rec')
+ ctype (str): The closure amplitude type ('camp' or 'logcamp')
+ debias (bool): If True, debias the closure amplitude
+ count (str): If 'min', return minimal set of amplitudes,
+ if 'max' return all closure amplitudes up to inverses
+ err_type (str): 'predicted' or 'measured'
+ num_samples: number of bootstrap (re)samples if measuring error
+ round_s (float): accuracy of datetime object in seconds
+ snrcut (float): flag closure amplitudes with snr lower than this
+
+ """
+
+ self.add_camp(return_type=return_type, ctype='logcamp',
+ count=count, debias=debias, snrcut=snrcut,
+ avg_time=avg_time, err_type=err_type,
+ num_samples=num_samples, round_s=round_s)
+
+ return
+
+ def add_logcamp_diag(self, avg_time=0, return_type='rec', count='min', snrcut=0.,
+ debias=True, err_type='predicted', num_samples=1000, round_s=0.1):
+ """Adds attribute self.logcamp_diag: logcamp_diag table averaged for dt
+
+ Args:
+ avg_time (float): diagonal log closure amplitude averaging timescale
+ return_type: data frame ('df') or recarray ('rec')
+ debias (bool): If True, debias the diagonal log closure amplitude
+ count (str): If 'min', return minimal set of amplitudes,
+ If 'max' return all diagonal log closure amplitudes up to inverses
+ err_type (str): 'predicted' or 'measured'
+ num_samples: number of bootstrap (re)samples if measuring error
+ round_s (float): accuracy of datetime object in seconds
+ snrcut (float): flag diagonal log closure amplitudes with snr lower than this
+
+ """
+
+ # Get spacing between datapoints in seconds
+ if len(set([x[0] for x in list(self.unpack('time'))])) > 1:
+ tint0 = np.min(np.diff(np.asarray(sorted(list(set(
+ [x[0] for x in list(self.unpack('time'))]))))))
+ tint0 *= 3600
+ else:
+ tint0 = 0
+
+ if avg_time > tint0:
+ foo = self.avg_incoherent(avg_time, debias=debias, err_type=err_type)
+ cdf = ehdf.make_logcamp_diag_df(foo, debias='False', count=count,
+ round_s=round_s, snrcut=snrcut)
+ else:
+ foo = self
+ cdf = ehdf.make_logcamp_diag_df(foo, debias=debias, count=count,
+ round_s=round_s, snrcut=snrcut)
+
+ if return_type == 'rec':
+ cdf = ehdf.df_to_rec(cdf, 'logcamp_diag')
+
+ self.logcamp_diag = cdf
+ print("updated self.logcamp_diag: avg_time %f s\n" % avg_time)
+
+ return
+
+ def add_all(self, avg_time=0, return_type='rec',
+ count='max', debias=True, snrcut=0.,
+ err_type='predicted', num_samples=1000, round_s=0.1):
+ """Adds tables of all all averaged derived quantities
+ self.amp,self.bispec,self.cphase,self.camp,self.logcamp
+
+ Args:
+ avg_time (float): closure amplitude averaging timescale
+ return_type: data frame ('df') or recarray ('rec')
+ debias (bool): If True, debias the closure amplitude
+ count (str): If 'min', return minimal set of closure quantities,
+ if 'max' return all closure quantities
+ err_type (str): 'predicted' or 'measured'
+ num_samples: number of bootstrap (re)samples if measuring error
+ round_s (float): accuracy of datetime object in seconds
+ snrcut (float): flag closure amplitudes with snr lower than this
+
+ """
+
+ self.add_amp(return_type=return_type, avg_time=avg_time, debias=debias, err_type=err_type)
+ self.add_bispec(return_type=return_type, count=count,
+ avg_time=avg_time, snrcut=snrcut, err_type=err_type,
+ num_samples=num_samples, round_s=round_s)
+ self.add_cphase(return_type=return_type, count=count,
+ avg_time=avg_time, snrcut=snrcut, err_type=err_type,
+ num_samples=num_samples, round_s=round_s)
+ self.add_cphase_diag(return_type=return_type, count='min',
+ avg_time=avg_time, snrcut=snrcut, err_type=err_type,
+ num_samples=num_samples, round_s=round_s)
+ self.add_camp(return_type=return_type, ctype='camp',
+ count=count, debias=debias, snrcut=snrcut,
+ avg_time=avg_time, err_type=err_type,
+ num_samples=num_samples, round_s=round_s)
+ self.add_camp(return_type=return_type, ctype='logcamp',
+ count=count, debias=debias, snrcut=snrcut,
+ avg_time=avg_time, err_type=err_type,
+ num_samples=num_samples, round_s=round_s)
+ self.add_logcamp_diag(return_type=return_type, count='min',
+ debias=debias, avg_time=avg_time,
+ snrcut=snrcut, err_type=err_type,
+ num_samples=num_samples, round_s=round_s)
+
+ return
+
+ def add_scans(self, info='self', filepath='', dt=0.0165, margin=0.0001):
+ """Compute scans and add self.scans to Obsdata object.
+
+ Args:
+ info (str): 'self' to infer from data, 'txt' for text file,
+ 'vex' for vex schedule file
+ filepath (str): path to txt/vex file with scans info
+ dt (float): minimal time interval between scans in hours
+ margin (float): padding scans by that time margin in hours
+
+ """
+
+ # infer scans directly from data
+ if info == 'self':
+ times_uni = np.asarray(sorted(list(set(self.data['time']))))
+ scans = np.zeros_like(times_uni)
+ scan_id = 0
+ for cou in range(len(times_uni) - 1):
+ scans[cou] = scan_id
+ if (times_uni[cou + 1] - times_uni[cou] > dt):
+ scan_id += 1
+ scans[-1] = scan_id
+ scanlist = np.asarray([np.asarray([
+ np.min(times_uni[scans == cou]) - margin,
+ np.max(times_uni[scans == cou]) + margin])
+ for cou in range(int(scans[-1]) + 1)])
+
+ # read in scans from a text file
+ elif info == 'txt':
+ scanlist = np.loadtxt(filepath)
+
+ # read in scans from a vex file
+ elif info == 'vex':
+ vex0 = ehtim.vex.Vex(filepath)
+ t_min = [vex0.sched[x]['start_hr'] for x in range(len(vex0.sched))]
+ duration = []
+ for x in range(len(vex0.sched)):
+ duration_foo = max([vex0.sched[x]['scan'][y]['scan_sec']
+ for y in range(len(vex0.sched[x]['scan']))])
+ duration.append(duration_foo)
+ t_max = [tmin + dur / 3600. for (tmin, dur) in zip(t_min, duration)]
+ scanlist = np.array([[tmin, tmax] for (tmin, tmax) in zip(t_min, t_max)])
+
+ else:
+ print("Parameter 'info' can only assume values 'self', 'txt' or 'vex'! ")
+ scanlist = None
+
+ self.scans = scanlist
+
+ return
+
+ def cleanbeam(self, npix, fov, pulse=ehc.PULSE_DEFAULT):
+ """Make an image of the observation clean beam.
+
+ Args:
+ npix (int): The pixel size of the square output image.
+ fov (float): The field of view of the square output image in radians.
+ pulse (function): The function convolved with the pixel values for continuous image.
+
+ Returns:
+ (Image): an Image object with the clean beam.
+ """
+
+ im = ehtim.image.make_square(self, npix, fov, pulse=pulse)
+ beamparams = self.fit_beam()
+ im = im.add_gauss(1.0, beamparams)
+
+ return im
+
+ def fit_beam(self, weighting='uniform', units='rad'):
+ """Fit a Gaussian to the dirty beam and return the parameters (fwhm_maj, fwhm_min, theta).
+
+ Args:
+ weighting (str): 'uniform' or 'natural'.
+ units (string): 'rad' returns values in radians,
+ 'natural' returns FWHMs in uas and theta in degrees
+
+ Returns:
+ (tuple): a tuple (fwhm_maj, fwhm_min, theta) of the dirty beam parameters in radians.
+ """
+
+ # Define the fit function that compares the quadratic expansion of the dirty image
+ # with the quadratic expansion of an elliptical gaussian
+ def fit_chisq(beamparams, db_coeff):
+
+ (fwhm_maj2, fwhm_min2, theta) = beamparams
+ a = 4 * np.log(2) * (np.cos(theta)**2 / fwhm_min2 + np.sin(theta)**2 / fwhm_maj2)
+ b = 4 * np.log(2) * (np.cos(theta)**2 / fwhm_maj2 + np.sin(theta)**2 / fwhm_min2)
+ c = 8 * np.log(2) * np.cos(theta) * np.sin(theta) * (1.0 / fwhm_maj2 - 1.0 / fwhm_min2)
+ gauss_coeff = np.array((a, b, c))
+
+ chisq = np.sum((np.array(db_coeff) - gauss_coeff)**2)
+
+ return chisq
+
+ # These are the coefficients (a,b,c) of a quadratic expansion of the dirty beam
+ # For a point (x,y) in the image plane, the dirty beam expansion is 1-ax^2-by^2-cxy
+ u = self.unpack('u')['u']
+ v = self.unpack('v')['v']
+ sigma = self.unpack('sigma')['sigma']
+
+ weights = np.ones(u.shape)
+ if weighting == 'natural':
+ weights = 1. / sigma**2
+
+ abc = np.array([np.sum(weights * u**2),
+ np.sum(weights * v**2),
+ 2 * np.sum(weights * u * v)])
+ abc *= (2. * np.pi**2 / np.sum(weights))
+ abc *= 1e-20 # Decrease size of coefficients
+
+ # Fit the beam
+ guess = [(50)**2, (50)**2, 0.0]
+ params = opt.minimize(fit_chisq, guess, args=(abc,), method='Powell')
+
+ # Return parameters, adjusting fwhm_maj and fwhm_min if necessary
+ if params.x[0] > params.x[1]:
+ fwhm_maj = 1e-10 * np.sqrt(params.x[0])
+ fwhm_min = 1e-10 * np.sqrt(params.x[1])
+ theta = np.mod(params.x[2], np.pi)
+ else:
+ fwhm_maj = 1e-10 * np.sqrt(params.x[1])
+ fwhm_min = 1e-10 * np.sqrt(params.x[0])
+ theta = np.mod(params.x[2] + np.pi / 2.0, np.pi)
+
+ gparams = np.array((fwhm_maj, fwhm_min, theta))
+
+ if units == 'natural':
+ gparams[0] /= ehc.RADPERUAS
+ gparams[1] /= ehc.RADPERUAS
+ gparams[2] *= 180. / np.pi
+
+ return gparams
+
+ def dirtybeam(self, npix, fov, pulse=ehc.PULSE_DEFAULT, weighting='uniform'):
+ """Make an image of the observation dirty beam.
+
+ Args:
+ npix (int): The pixel size of the square output image.
+ fov (float): The field of view of the square output image in radians.
+ pulse (function): The function convolved with the pixel values for continuous image.
+ weighting (str): 'uniform' or 'natural'
+ Returns:
+ (Image): an Image object with the dirty beam.
+ """
+
+ pdim = fov / npix
+ sigma = self.unpack('sigma')['sigma']
+ u = self.unpack('u')['u']
+ v = self.unpack('v')['v']
+ if weighting == 'natural':
+ weights = 1. / sigma**2
+ else:
+ weights = np.ones(u.shape)
+
+ xlist = np.arange(0, -npix, -1) * pdim + (pdim * npix) / 2.0 - pdim / 2.0
+
+ # TODO -- use NFFT
+ # TODO -- different beam weightings
+ im = np.array([[np.mean(weights * np.cos(-2 * np.pi * (i * u + j * v)))
+ for i in xlist]
+ for j in xlist])
+
+ im = im[0:npix, 0:npix]
+ im = im / np.sum(im) # Normalize to a total beam power of 1
+
+ src = self.source + "_DB"
+ outim = ehtim.image.Image(im, pdim, self.ra, self.dec,
+ rf=self.rf, source=src, mjd=self.mjd, pulse=pulse)
+
+ return outim
+
+ def dirtyimage(self, npix, fov, pulse=ehc.PULSE_DEFAULT, weighting='uniform'):
+ """Make the observation dirty image (direct Fourier transform).
+
+ Args:
+ npix (int): The pixel size of the square output image.
+ fov (float): The field of view of the square output image in radians.
+ pulse (function): The function convolved with the pixel values for continuous image.
+ weighting (str): 'uniform' or 'natural'
+ Returns:
+ (Image): an Image object with dirty image.
+ """
+
+ pdim = fov / npix
+ u = self.unpack('u')['u']
+ v = self.unpack('v')['v']
+ sigma = self.unpack('sigma')['sigma']
+ xlist = np.arange(0, -npix, -1) * pdim + (pdim * npix) / 2.0 - pdim / 2.0
+ if weighting == 'natural':
+ weights = 1. / sigma**2
+ else:
+ weights = np.ones(u.shape)
+
+ dim = np.array([[np.mean(weights * np.cos(-2 * np.pi * (i * u + j * v)))
+ for i in xlist]
+ for j in xlist])
+ normfac = 1. / np.sum(dim)
+
+ for label in ['vis1', 'vis2', 'vis3', 'vis4']:
+ visname = self.poldict[label]
+
+ vis = self.unpack(visname)[visname]
+
+ # TODO -- use NFFT
+ # TODO -- different beam weightings
+ im = np.array([[np.mean(weights * (np.real(vis) * np.cos(-2 * np.pi * (i * u + j * v)) -
+ np.imag(vis) * np.sin(-2 * np.pi * (i * u + j * v))))
+ for i in xlist]
+ for j in xlist])
+
+ # Final normalization
+ im = im * normfac
+ im = im[0:npix, 0:npix]
+
+ if label == 'vis1':
+ out = ehtim.image.Image(im, pdim, self.ra, self.dec, polrep=self.polrep,
+ rf=self.rf, source=self.source, mjd=self.mjd, pulse=pulse)
+ else:
+ pol = {ehc.vis_poldict[key]: key for key in ehc.vis_poldict.keys()}[visname]
+ out.add_pol_image(im, pol)
+
+ return out
+
+ def rescale_zbl(self, totflux, uv_max, debias=True):
+ """Rescale the short baselines to a new level of total flux.
+
+ Args:
+ totflux (float): new total flux to rescale to
+ uv_max (float): maximum baseline length to rescale
+ debias (bool): Debias amplitudes before computing original total flux from short bls
+
+ Returns:
+ (Obsdata): An Obsdata object with the inflated noise values.
+ """
+
+ # estimate the original total flux
+ obs_zerobl = self.flag_uvdist(uv_max=uv_max)
+ obs_zerobl.add_amp(debias=True)
+ orig_totflux = np.sum(obs_zerobl.amp['amp'] * (1 / obs_zerobl.amp['sigma']**2))
+ orig_totflux /= np.sum(1 / obs_zerobl.amp['sigma']**2)
+
+ print('Rescaling zero baseline by ' + str(orig_totflux - totflux) + ' Jy' +
+ ' to ' + str(totflux) + ' Jy')
+
+ # Rescale short baselines to excise contributions from extended flux
+ # Note: this does not do the proper thing for fractional polarization)
+ obs = self.copy()
+ for j in range(len(obs.data)):
+ if (obs.data['u'][j]**2 + obs.data['v'][j]**2)**0.5 < uv_max:
+ obs.data['vis'][j] *= totflux / orig_totflux
+ obs.data['qvis'][j] *= totflux / orig_totflux
+ obs.data['uvis'][j] *= totflux / orig_totflux
+ obs.data['vvis'][j] *= totflux / orig_totflux
+ obs.data['sigma'][j] *= totflux / orig_totflux
+ obs.data['qsigma'][j] *= totflux / orig_totflux
+ obs.data['usigma'][j] *= totflux / orig_totflux
+ obs.data['vsigma'][j] *= totflux / orig_totflux
+
+ return obs
+
+ def add_leakage_noise(self, Dterm_amp=0.1, min_noise=0.01, debias=False):
+ """Add estimated systematic noise from leakage at quadrature to thermal noise.
+ Requires cross-hand visibilities.
+ !! this operation is not currently tracked and should be applied with extreme caution!!
+
+ Args:
+ Dterm_amp (float): Estimated magnitude of leakage terms
+ min_noise (float): Minimum fractional systematic noise to add
+ debias (bool): Debias amplitudes before computing fractional noise
+
+ Returns:
+ (Obsdata): An Obsdata object with the inflated noise values.
+ """
+
+ # Extract visibility amplitudes
+ # Switch to Stokes for graceful handling of circular basis products missing RR or LL
+ amp = self.switch_polrep('stokes').unpack('amp', debias=debias)['amp']
+ rlamp = np.nan_to_num(self.switch_polrep('circ').unpack('rlamp', debias=debias)['rlamp'])
+ lramp = np.nan_to_num(self.switch_polrep('circ').unpack('lramp', debias=debias)['lramp'])
+
+ frac_noise = (Dterm_amp * rlamp / amp)**2 + (Dterm_amp * lramp / amp)**2
+ frac_noise = frac_noise * (frac_noise > min_noise) + min_noise * (frac_noise < min_noise)
+
+ out = self.copy()
+ for sigma in ['sigma1', 'sigma2', 'sigma3', 'sigma4']:
+ try:
+ field = self.poldict[sigma]
+ out.data[field] = (self.data[field]**2 + np.abs(frac_noise * amp)**2)**0.5
+ except KeyError:
+ continue
+
+ return out
+
+ def add_fractional_noise(self, frac_noise, debias=False):
+ """Add a constant fraction of amplitude at quadrature to thermal noise.
+ Effectively imposes a maximal signal-to-noise ratio.
+ !! this operation is not currently tracked and should be applied with extreme caution!!
+
+ Args:
+ frac_noise (float): The fraction of noise to add.
+ debias (bool): Whether or not to add frac_noise of debiased amplitudes.
+
+ Returns:
+ (Obsdata): An Obsdata object with the inflated noise values.
+ """
+
+ # Extract visibility amplitudes
+ # Switch to Stokes for graceful handling of circular basis products missing RR or LL
+ amp = self.switch_polrep('stokes').unpack('amp', debias=debias)['amp']
+
+ out = self.copy()
+ for sigma in ['sigma1', 'sigma2', 'sigma3', 'sigma4']:
+ try:
+ field = self.poldict[sigma]
+ out.data[field] = (self.data[field]**2 + np.abs(frac_noise * amp)**2)**0.5
+ except KeyError:
+ continue
+
+ return out
+
+ def find_amt_fractional_noise(self, im, dtype='vis', target=1.0, debias=False,
+ maxiter=200, ftol=1e-20, gtol=1e-20):
+ """Returns the amount of fractional sys error necessary
+ to make the image have a chisq close to the targeted value (1.0)
+ """
+
+ obs = self.copy()
+
+ def objfunc(frac_noise):
+ obs_tmp = obs.add_fractional_noise(frac_noise, debias=debias)
+ chisq = obs_tmp.chisq(im, dtype=dtype)
+ return np.abs(target - chisq)
+
+ optdict = {'maxiter': maxiter, 'ftol': ftol, 'gtol': gtol}
+ res = opt.minimize(objfunc, 0.0, method='L-BFGS-B', options=optdict)
+
+ return res.x
+
+ def rescale_noise(self, noise_rescale_factor=1.0):
+ """Rescale the thermal noise on all Stokes parameters by a constant factor.
+ This is useful for AIPS data, which has a missing factor relating 'weights' to noise.
+
+ Args:
+ noise_rescale_factor (float): The number to multiple the existing sigmas by.
+
+ Returns:
+ (Obsdata): An Obsdata object with the rescaled noise values.
+ """
+
+ datatable = self.data.copy()
+
+ for d in datatable:
+ d[-4] = d[-4] * noise_rescale_factor
+ d[-3] = d[-3] * noise_rescale_factor
+ d[-2] = d[-2] * noise_rescale_factor
+ d[-1] = d[-1] * noise_rescale_factor
+
+ arglist, argdict = self.obsdata_args()
+ arglist[DATPOS] = np.array(datatable)
+ out = Obsdata(*arglist, **argdict)
+
+ return out
+
+ def estimate_noise_rescale_factor(self, max_diff_sec=0.0, min_num=10, median_snr_cut=0,
+ count='max', vtype='vis', print_std=False):
+ """Estimate a constant noise rescaling factor on all baselines, times, and polarizations.
+ Uses pairwise differences of closure phases relative to the expected scatter.
+ This is useful for AIPS data, which has a missing factor relating 'weights' to noise.
+
+ Args:
+ max_diff_sec (float): The maximum difference of adjacent closure phases (in seconds)
+ If 0, auto-estimates to twice the median scan length.
+ min_num (int): The minimum number of closure phase differences for a triangle
+ to be included in the set of estimators.
+ median_snr_cut (float): Do not include a triangle if its median SNR is below this
+ count (str): If 'min', use minimal set of phases,
+ if 'max' use all closure phases up to reordering
+ vtype (str): Visibility type (e.g., 'vis', 'llvis', 'rrvis', etc.)
+ print_std (bool): Whether or not to print the std dev. for each closure triangle.
+
+ Returns:
+ (float): The rescaling factor.
+ """
+
+ if max_diff_sec == 0.0:
+ max_diff_sec = 5 * np.median(self.unpack('tint')['tint'])
+ print("estimated max_diff_sec: ", max_diff_sec)
+
+ # Now check the noise statistics on all closure phase triangles
+ c_phases = self.c_phases(vtype=vtype, mode='time', count=count, ang_unit='')
+
+ # First, just determine the set of closure phase triangles
+ all_triangles = []
+ for scan in c_phases:
+ for cphase in scan:
+ all_triangles.append((cphase[1], cphase[2], cphase[3]))
+ std_list = []
+ print("Estimating noise rescaling factor from %d triangles...\n" % len(set(all_triangles)))
+
+ # Now determine the differences of adjacent samples on each triangle,
+ # relative to the expected thermal noise
+ i_count = 0
+ for tri in set(all_triangles):
+ i_count = i_count + 1
+ if print_std:
+ sys.stdout.write('\rGetting noise for triangles %i/%i ' %
+ (i_count, len(set(all_triangles))))
+ sys.stdout.flush()
+ all_tri = np.array([[]])
+ for scan in c_phases:
+ for cphase in scan:
+ if (cphase[1] == tri[0] and cphase[2] == tri[1] and cphase[3] == tri[2] and
+ not np.isnan(cphase[-2]) and not np.isnan(cphase[-2])):
+
+ all_tri = np.append(all_tri, ((cphase[0], cphase[-2], cphase[-1])))
+
+ all_tri = all_tri.reshape(int(len(all_tri) / 3), 3)
+
+ # See whether the triangle has sufficient SNR
+ if np.median(np.abs(all_tri[:, 1] / all_tri[:, 2])) < median_snr_cut:
+ if print_std:
+ print(tri, 'median snr too low (%6.4f)' %
+ np.median(np.abs(all_tri[:, 1] / all_tri[:, 2])))
+ continue
+
+ # Now go through and find studentized differences of adjacent points
+ s_list = np.array([])
+ for j in range(len(all_tri) - 1):
+ if (all_tri[j + 1, 0] - all_tri[j, 0]) * 3600.0 < max_diff_sec:
+ diff = (all_tri[j + 1, 1] - all_tri[j, 1]) % (2.0 * np.pi)
+ if diff > np.pi:
+ diff -= 2.0 * np.pi
+ s_list = np.append(
+ s_list, diff / (all_tri[j, 2]**2 + all_tri[j + 1, 2]**2)**0.5)
+
+ if len(s_list) > min_num:
+ std_list.append(np.std(s_list))
+ if print_std:
+ print(tri, '%6.4f [%d differences]' % (np.std(s_list), len(s_list)))
+ else:
+ if print_std and len(all_tri) > 0:
+ print(tri, '%d cphases found [%d differences < min_num = %d]' %
+ (len(all_tri), len(s_list), min_num))
+
+ if len(std_list) == 0:
+ print("No suitable closure phase differences! Try using a larger max_diff_sec.")
+ median = 1.0
+ else:
+ median = np.median(std_list)
+
+ return median
+
+ def flag_elev(self, elev_min=0.0, elev_max=90, output='kept'):
+ """Flag visibilities for which either station is outside a stated elevation range
+
+ Args:
+ elev_min (float): Minimum elevation (deg)
+ elev_max (float): Maximum elevation (deg)
+ output (str): returns 'kept', 'flagged', or 'both' (a dictionary)
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ el_pairs = self.unpack(['el1', 'el2'])
+ mask = (np.min((el_pairs['el1'], el_pairs['el2']), axis=0) > elev_min)
+ mask *= (np.max((el_pairs['el1'], el_pairs['el2']), axis=0) < elev_max)
+
+ datatable_kept = self.data.copy()
+ datatable_flagged = self.data.copy()
+
+ datatable_kept = datatable_kept[mask]
+ datatable_flagged = datatable_flagged[np.invert(mask)]
+ print('Flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data)))
+
+ obs_kept = self.copy()
+ obs_flagged = self.copy()
+ obs_kept.data = datatable_kept
+ obs_flagged.data = datatable_flagged
+
+ if output == 'flagged':
+ return obs_flagged
+ elif output == 'both':
+ return {'kept': obs_kept, 'flagged': obs_flagged}
+ else:
+ return obs_kept
+
+ def flag_large_fractional_pol(self, max_fractional_pol=1.0, output='kept'):
+ """Flag visibilities for which the fractional polarization is above a specified threshold
+
+ Args:
+ max_fractional_pol (float): Maximum fractional polarization
+ output (str): returns 'kept', 'flagged', or 'both' (a dictionary)
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ m = np.nan_to_num(self.unpack(['mamp'])['mamp'])
+ mask = m < max_fractional_pol
+
+ datatable_kept = self.data.copy()
+ datatable_flagged = self.data.copy()
+
+ datatable_kept = datatable_kept[mask]
+ datatable_flagged = datatable_flagged[np.invert(mask)]
+ print('Flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data)))
+
+ obs_kept = self.copy()
+ obs_flagged = self.copy()
+ obs_kept.data = datatable_kept
+ obs_flagged.data = datatable_flagged
+
+ if output == 'flagged':
+ return obs_flagged
+ elif output == 'both':
+ return {'kept': obs_kept, 'flagged': obs_flagged}
+ else:
+ return obs_kept
+
+ def flag_uvdist(self, uv_min=0.0, uv_max=1e12, output='kept'):
+ """Flag data points outside a given uv range
+
+ Args:
+ uv_min (float): remove points with uvdist less than this
+ uv_max (float): remove points with uvdist greater than this
+ output (str): returns 'kept', 'flagged', or 'both' (a dictionary)
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ uvdist_list = self.unpack('uvdist')['uvdist']
+ mask = np.array([uv_min <= uvdist_list[j] <= uv_max for j in range(len(uvdist_list))])
+ datatable_kept = self.data.copy()
+ datatable_flagged = self.data.copy()
+
+ datatable_kept = datatable_kept[mask]
+ datatable_flagged = datatable_flagged[np.invert(mask)]
+ print('U-V flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data)))
+
+ obs_kept = self.copy()
+ obs_flagged = self.copy()
+ obs_kept.data = datatable_kept
+ obs_flagged.data = datatable_flagged
+
+ if output == 'flagged':
+ return obs_flagged
+ elif output == 'both':
+ return {'kept': obs_kept, 'flagged': obs_flagged}
+ else:
+ return obs_kept
+
+ def flag_sites(self, sites, output='kept'):
+ """Flag data points that include the specified sites
+
+ Args:
+ sites (list): list of sites to remove from the data
+ output (str): returns 'kept', 'flagged', or 'both' (a dictionary)
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ # This will remove all visibilities that include any of the specified sites
+
+ t1_list = self.unpack('t1')['t1']
+ t2_list = self.unpack('t2')['t2']
+ mask = np.array([t1_list[j] not in sites and t2_list[j] not in sites
+ for j in range(len(t1_list))])
+
+ datatable_kept = self.data.copy()
+ datatable_flagged = self.data.copy()
+
+ datatable_kept = datatable_kept[mask]
+ datatable_flagged = datatable_flagged[np.invert(mask)]
+ print('Flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data)))
+
+ obs_kept = self.copy()
+ obs_flagged = self.copy()
+ obs_kept.data = datatable_kept
+ obs_flagged.data = datatable_flagged
+
+ if output == 'flagged': # return only the points flagged as anomalous
+ return obs_flagged
+ elif output == 'both':
+ return {'kept': obs_kept, 'flagged': obs_flagged}
+ else:
+ return obs_kept
+
+ def flag_bl(self, sites, output='kept'):
+ """Flag data points that include the specified baseline
+
+ Args:
+ sites (list): baseline to remove from the data
+ output (str): returns 'kept', 'flagged', or 'both' (a dictionary)
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ # This will remove all visibilities that include any of the specified baseline
+ obs_out = self.copy()
+ t1_list = obs_out.unpack('t1')['t1']
+ t2_list = obs_out.unpack('t2')['t2']
+ mask = np.array([not(t1_list[j] in sites and t2_list[j] in sites)
+ for j in range(len(t1_list))])
+
+ datatable_kept = self.data.copy()
+ datatable_flagged = self.data.copy()
+
+ datatable_kept = datatable_kept[mask]
+ datatable_flagged = datatable_flagged[np.invert(mask)]
+ print('Flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data)))
+
+ obs_kept = self.copy()
+ obs_flagged = self.copy()
+ obs_kept.data = datatable_kept
+ obs_flagged.data = datatable_flagged
+
+ if output == 'flagged': # return only the points flagged as anomalous
+ return obs_flagged
+ elif output == 'both':
+ return {'kept': obs_kept, 'flagged': obs_flagged}
+ else:
+ return obs_kept
+
+ def flag_low_snr(self, snr_cut=3, output='kept'):
+ """Flag low snr data points
+
+ Args:
+ snr_cut (float): remove points with snr lower than this
+ output (str): returns 'kept', 'flagged', or 'both' (a dictionary)
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ mask = self.unpack('snr')['snr'] > snr_cut
+
+ datatable_kept = self.data.copy()
+ datatable_flagged = self.data.copy()
+
+ datatable_kept = datatable_kept[mask]
+ datatable_flagged = datatable_flagged[np.invert(mask)]
+ print('snr flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data)))
+
+ obs_kept = self.copy()
+ obs_flagged = self.copy()
+ obs_kept.data = datatable_kept
+ obs_flagged.data = datatable_flagged
+
+ if output == 'flagged': # return only the points flagged as anomalous
+ return obs_flagged
+ elif output == 'both':
+ return {'kept': obs_kept, 'flagged': obs_flagged}
+ else:
+ return obs_kept
+
+ def flag_high_sigma(self, sigma_cut=.005, sigma_type='sigma', output='kept'):
+ """Flag high sigma (thermal noise on Stoke I) data points
+
+ Args:
+ sigma_cut (float): remove points with sigma higher than this
+ sigma_type (str): sigma type (sigma, rrsigma, llsigma, etc.)
+ output (str): returns 'kept', 'flagged', or 'both' (a dictionary)
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ mask = self.unpack(sigma_type)[sigma_type] < sigma_cut
+
+ datatable_kept = self.data.copy()
+ datatable_flagged = self.data.copy()
+
+ datatable_kept = datatable_kept[mask]
+ datatable_flagged = datatable_flagged[np.invert(mask)]
+ print('sigma flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data)))
+
+ obs_kept = self.copy()
+ obs_flagged = self.copy()
+ obs_kept.data = datatable_kept
+ obs_flagged.data = datatable_flagged
+
+ if output == 'flagged': # return only the points flagged as anomalous
+ return obs_flagged
+ elif output == 'both':
+ return {'kept': obs_kept, 'flagged': obs_flagged}
+ else:
+ return obs_kept
+
+ def flag_UT_range(self, UT_start_hour=0., UT_stop_hour=0.,
+ flag_type='all', flag_what='', output='kept'):
+ """Flag data points within a certain UT range
+
+ Args:
+ UT_start_hour (float): start of time window
+ UT_stop_hour (float): end of time window
+ flag_type (str): 'all', 'baseline', or 'station'
+ flag_what (str): baseline or station to flag
+ output (str): returns 'kept', 'flagged', or 'both' (a dictionary)
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ # This drops (or only keeps) points within a specified UT range
+
+ UT_mask = self.unpack('time')['time'] <= UT_start_hour
+ UT_mask = UT_mask + (self.unpack('time')['time'] >= UT_stop_hour)
+ if flag_type != 'all':
+ t1_list = self.unpack('t1')['t1']
+ t2_list = self.unpack('t2')['t2']
+ if flag_type == 'station':
+ station = flag_what
+ what_mask = np.array([not (t1_list[j] == station or t2_list[j] == station)
+ for j in range(len(t1_list))])
+ elif flag_type == 'baseline':
+ station1 = flag_what.split('-')[0]
+ station2 = flag_what.split('-')[1]
+ stations = [station1, station2]
+ what_mask = np.array([not ((t1_list[j] in stations) and (t2_list[j] in stations))
+ for j in range(len(t1_list))])
+ else:
+ what_mask = np.array([False for j in range(len(UT_mask))])
+ mask = UT_mask | what_mask
+
+ datatable_kept = self.data.copy()
+ datatable_flagged = self.data.copy()
+
+ datatable_kept = datatable_kept[mask]
+ datatable_flagged = datatable_flagged[np.invert(mask)]
+ print('time flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data)))
+
+ obs_kept = self.copy()
+ obs_flagged = self.copy()
+ obs_kept.data = datatable_kept
+ obs_flagged.data = datatable_flagged
+
+ if output == 'flagged': # return only the points flagged as anomalous
+ return obs_flagged
+ elif output == 'both':
+ return {'kept': obs_kept, 'flagged': obs_flagged}
+ else:
+ return obs_kept
+
+ def flags_from_file(self, flagfile, flag_type='station'):
+ """Flagging data based on csv file
+
+ Args:
+ flagfile (str): path to csv file with mjds of flagging start / stop time,
+ and optionally baseline / stations
+ flag_type (str): 'all', 'baseline', or 'station'
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ df = pd.read_csv(flagfile)
+ mjd_start = list(df['mjd_start'])
+ mjd_stop = list(df['mjd_stop'])
+ if flag_type == 'station':
+ whatL = list(df['station'])
+ elif flag_type == 'baseline':
+ whatL = list(df['baseline'])
+ elif flag_type == 'all':
+ whatL = ['' for cou in range(len(mjd_start))]
+ obs = self.copy()
+ for cou in range(len(mjd_start)):
+ what = whatL[cou]
+ starth = (mjd_start[cou] % 1) * 24.
+ stoph = (mjd_stop[cou] % 1) * 24.
+ obs = obs.flag_UT_range(UT_start_hour=starth, UT_stop_hour=stoph,
+ flag_type=flag_type, flag_what=what, output='kept')
+
+ return obs
+
+ def flag_anomalous(self, field='snr', max_diff_seconds=100, robust_nsigma_cut=5, output='kept'):
+ """Flag anomalous data points
+
+ Args:
+ field (str): The quantity to test for
+ max_diff_seconds (float): The moving window size for testing outliers
+ robust_nsigma_cut (float): Outliers further than this from the mean are removed
+ output (str): returns 'kept', 'flagged', or 'both' (a dictionary)
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ stats = dict()
+ for t1 in set(self.data['t1']):
+ for t2 in set(self.data['t2']):
+ vals = self.unpack_bl(t1, t2, field)
+ if len(vals) > 0:
+ # nans will all be dropped, which can be problematic for polarimetric values
+ vals[field] = np.nan_to_num(vals[field])
+ for j in range(len(vals)):
+ near_vals_mask = np.abs(vals['time'] - vals['time']
+ [j]) < max_diff_seconds / 3600.0
+ fields = vals[field][near_vals_mask]
+
+ # Here, we use median absolute deviation as a robust proxy for standard
+ # deviation
+ dfields = np.median(np.abs(fields - np.median(fields)))
+ # Avoid problems when the MAD is zero (e.g., a single sample)
+ if dfields == 0.0:
+ dfields = 1.0
+ stat = np.abs(vals[field][j] - np.median(fields)) / dfields
+ stats[(vals['time'][j][0], tuple(sorted((t1, t2))))] = stat
+
+ mask = np.array([stats[(rec[0], tuple(sorted((rec[2], rec[3]))))][0] < robust_nsigma_cut
+ for rec in self.data])
+
+ datatable_kept = self.data.copy()
+ datatable_flagged = self.data.copy()
+
+ datatable_kept = datatable_kept[mask]
+ datatable_flagged = datatable_flagged[np.invert(mask)]
+ print('anomalous %s flagged %d/%d visibilities' %
+ (field, len(datatable_flagged), len(self.data)))
+
+ # Make new observations with all data first to avoid problems with empty arrays
+ obs_kept = self.copy()
+ obs_flagged = self.copy()
+ obs_kept.data = datatable_kept
+ obs_flagged.data = datatable_flagged
+
+ if output == 'flagged': # return only the points flagged as anomalous
+ return obs_flagged
+ elif output == 'both':
+ return {'kept': obs_kept, 'flagged': obs_flagged}
+ else:
+ return obs_kept
+
+ def filter_subscan_dropouts(self, perc=0, return_type='rec'):
+ """Filtration to drop data and ensure that we only average parts with same timestamp.
+ Potentially this could reduce risk of non-closing errors.
+
+ Args:
+ perc (float): drop baseline from scan if it has less than this fraction
+ of median baseline observation time during the scan
+ return_type (str): data frame ('df') or recarray ('rec')
+
+ Returns:
+ (Obsdata): a observation object with flagged data points removed
+ """
+
+ if not isinstance(self.scans, np.ndarray):
+ print('List of scans in ndarray format required! Add it with add_scans')
+
+ else:
+ # make df and add scan_id to data
+ df = ehdf.make_df(self)
+ tot_points = np.shape(df)[0]
+ bins, labs = ehdf.get_bins_labels(self.scans)
+ df['scan_id'] = list(pd.cut(df.time, bins, labels=labs))
+
+ # first flag baselines that are working for short part of scan
+ df['count_samples'] = 1
+ hm1 = df.groupby(['scan_id', 'baseline', 'polarization'])
+ hm1 = hm1.agg({'count_samples': np.sum}).reset_index()
+ hm1['count_baselines_before'] = 1
+ hm2 = hm1.groupby(['scan_id', 'polarization'])
+ hm2 = hm2.agg({'count_samples': lambda x: perc * np.median(x),
+ 'count_baselines_before': np.sum}).reset_index()
+
+ # dictionary with minimum acceptable number of samples per scan
+ dict_elem_in_scan = dict(zip(hm2.scan_id, hm2.count_samples))
+
+ # list of acceptable scans and baselines
+ hm1 = hm1[list(map(lambda x: x[1] >= dict_elem_in_scan[x[0]],
+ list(zip(hm1.scan_id, hm1.count_samples))))]
+ list_good_scans_baselines = list(zip(hm1.scan_id, hm1.baseline))
+
+ # filter out data
+ df_filtered = df[list(map(lambda x: x in list_good_scans_baselines,
+ list(zip(df.scan_id, df.baseline))))]
+
+ # how many baselines present during scan?
+ df_filtered['count_samples'] = 1
+ hm3 = df_filtered.groupby(['scan_id', 'baseline', 'polarization'])
+ hm3 = hm3.agg({'count_samples': np.sum}).reset_index()
+ hm3['count_baselines_after'] = 1
+ hm4 = hm3.groupby(['scan_id', 'polarization'])
+ hm4 = hm4.agg({'count_baselines_after': np.sum}).reset_index()
+ dict_how_many_baselines = dict(zip(hm4.scan_id, hm4.count_baselines_after))
+
+ # how many baselines present during each time?
+ df_filtered['count_baselines_per_time'] = 1
+ hm5 = df_filtered.groupby(['datetime', 'scan_id', 'polarization'])
+ hm5 = hm5.agg({'count_baselines_per_time': np.sum}).reset_index()
+ dict_datetime_num_baselines = dict(zip(hm5.datetime, hm5.count_baselines_per_time))
+
+ # only keep times when all baselines available
+ df_filtered2 = df_filtered[list(map(lambda x: dict_datetime_num_baselines[x[1]] == dict_how_many_baselines[x[0]], list(
+ zip(df_filtered.scan_id, df_filtered.datetime))))]
+
+ remaining_points = np.shape(df_filtered2)[0]
+ print('Flagged out {} of {} datapoints'.format(
+ tot_points - remaining_points, tot_points))
+ if return_type == 'rec':
+ out_vis = ehdf.df_to_rec(df_filtered2, 'vis')
+
+ arglist, argdict = self.obsdata_args()
+ arglist[DATPOS] = out_vis
+ out = Obsdata(*arglist, **argdict)
+
+ return out
+
+ def reverse_taper(self, fwhm):
+ """Reverse taper the observation with a circular Gaussian kernel
+
+ Args:
+ fwhm (float): real space fwhm size of convolution kernel in radian
+
+ Returns:
+ (Obsdata): a new reverse-tapered observation object
+ """
+ datatable = self.data.copy()
+ vis1 = datatable[self.poldict['vis1']]
+ vis2 = datatable[self.poldict['vis2']]
+ vis3 = datatable[self.poldict['vis3']]
+ vis4 = datatable[self.poldict['vis4']]
+ sigma1 = datatable[self.poldict['sigma1']]
+ sigma2 = datatable[self.poldict['sigma2']]
+ sigma3 = datatable[self.poldict['sigma3']]
+ sigma4 = datatable[self.poldict['sigma4']]
+ u = datatable['u']
+ v = datatable['v']
+
+ fwhm_sigma = fwhm / (2 * np.sqrt(2 * np.log(2)))
+ ker = np.exp(-2 * np.pi**2 * fwhm_sigma**2 * (u**2 + v**2))
+
+ datatable[self.poldict['vis1']] = vis1 / ker
+ datatable[self.poldict['vis2']] = vis2 / ker
+ datatable[self.poldict['vis3']] = vis3 / ker
+ datatable[self.poldict['vis4']] = vis4 / ker
+ datatable[self.poldict['sigma1']] = sigma1 / ker
+ datatable[self.poldict['sigma2']] = sigma2 / ker
+ datatable[self.poldict['sigma3']] = sigma3 / ker
+ datatable[self.poldict['sigma4']] = sigma4 / ker
+
+ arglist, argdict = self.obsdata_args()
+ arglist[DATPOS] = datatable
+ obstaper = Obsdata(*arglist, **argdict)
+
+ return obstaper
+
+ def taper(self, fwhm):
+ """Taper the observation with a circular Gaussian kernel
+
+ Args:
+ fwhm (float): real space fwhm size of convolution kernel in radian
+
+ Returns:
+ (Obsdata): a new tapered observation object
+ """
+ datatable = self.data.copy()
+
+ vis1 = datatable[self.poldict['vis1']]
+ vis2 = datatable[self.poldict['vis2']]
+ vis3 = datatable[self.poldict['vis3']]
+ vis4 = datatable[self.poldict['vis4']]
+ sigma1 = datatable[self.poldict['sigma1']]
+ sigma2 = datatable[self.poldict['sigma2']]
+ sigma3 = datatable[self.poldict['sigma3']]
+ sigma4 = datatable[self.poldict['sigma4']]
+ u = datatable['u']
+ v = datatable['v']
+
+ fwhm_sigma = fwhm / (2 * np.sqrt(2 * np.log(2)))
+ ker = np.exp(-2 * np.pi**2 * fwhm_sigma**2 * (u**2 + v**2))
+
+ datatable[self.poldict['vis1']] = vis1 * ker
+ datatable[self.poldict['vis2']] = vis2 * ker
+ datatable[self.poldict['vis3']] = vis3 * ker
+ datatable[self.poldict['vis4']] = vis4 * ker
+ datatable[self.poldict['sigma1']] = sigma1 * ker
+ datatable[self.poldict['sigma2']] = sigma2 * ker
+ datatable[self.poldict['sigma3']] = sigma3 * ker
+ datatable[self.poldict['sigma4']] = sigma4 * ker
+
+ arglist, argdict = self.obsdata_args()
+ arglist[DATPOS] = datatable
+ obstaper = Obsdata(*arglist, **argdict)
+
+ return obstaper
+
+ def deblur(self):
+ """Deblur the observation obs by dividing by the Sgr A* scattering kernel.
+
+ Args:
+
+ Returns:
+ (Obsdata): a new deblurred observation object.
+ """
+
+ # make a copy of observation data
+ datatable = self.data.copy()
+
+ vis1 = datatable[self.poldict['vis1']]
+ vis2 = datatable[self.poldict['vis2']]
+ vis3 = datatable[self.poldict['vis3']]
+ vis4 = datatable[self.poldict['vis4']]
+ sigma1 = datatable[self.poldict['sigma1']]
+ sigma2 = datatable[self.poldict['sigma2']]
+ sigma3 = datatable[self.poldict['sigma3']]
+ sigma4 = datatable[self.poldict['sigma4']]
+ u = datatable['u']
+ v = datatable['v']
+
+ # divide visibilities by the scattering kernel
+ for i in range(len(vis1)):
+ ker = obsh.sgra_kernel_uv(self.rf, u[i], v[i])
+ vis1[i] = vis1[i] / ker
+ vis2[i] = vis2[i] / ker
+ vis2[i] = vis3[i] / ker
+ vis4[i] = vis4[i] / ker
+ sigma1[i] = sigma1[i] / ker
+ sigma2[i] = sigma2[i] / ker
+ sigma3[i] = sigma3[i] / ker
+ sigma4[i] = sigma4[i] / ker
+
+ datatable[self.poldict['vis1']] = vis1
+ datatable[self.poldict['vis2']] = vis2
+ datatable[self.poldict['vis3']] = vis3
+ datatable[self.poldict['vis4']] = vis4
+ datatable[self.poldict['sigma1']] = sigma1
+ datatable[self.poldict['sigma2']] = sigma2
+ datatable[self.poldict['sigma3']] = sigma3
+ datatable[self.poldict['sigma4']] = sigma4
+
+ arglist, argdict = self.obsdata_args()
+ arglist[DATPOS] = datatable
+ obsdeblur = Obsdata(*arglist, **argdict)
+
+ return obsdeblur
+
+ def reweight(self, uv_radius, weightdist=1.0):
+ """Reweight the sigmas based on the local density of uv points
+
+ Args:
+ uv_radius (float): radius in uv-plane to look for nearby points
+ weightdist (float): ??
+
+ Returns:
+ (Obsdata): a new reweighted observation object.
+ """
+
+ obs_new = self.copy()
+ npts = len(obs_new.data)
+
+ uvpoints = np.vstack((obs_new.data['u'], obs_new.data['v'])).transpose()
+ uvpoints_tree1 = spatial.cKDTree(uvpoints)
+ uvpoints_tree2 = spatial.cKDTree(-uvpoints)
+
+ for i in range(npts):
+ matches1 = uvpoints_tree1.query_ball_point(uvpoints[i, :], uv_radius)
+ matches2 = uvpoints_tree2.query_ball_point(uvpoints[i, :], uv_radius)
+ nmatches = len(matches1) + len(matches2)
+
+ for sigma in ['sigma', 'qsigma', 'usigma', 'vsigma']:
+ obs_new.data[sigma][i] = np.sqrt(nmatches)
+
+ scale = np.mean(self.data['sigma']) / np.mean(obs_new.data['sigma'])
+ for sigma in ['sigma', 'qsigma', 'usigma', 'vsigma']:
+ obs_new.data[sigma] *= scale * weightdist
+
+ if weightdist < 1.0:
+ for i in range(npts):
+ for sigma in ['sigma', 'qsigma', 'usigma', 'vsigma']:
+ obs_new.data[sigma][i] += (1 - weightdist) * self.data[sigma][i]
+
+ return obs_new
+
+ def fit_gauss(self, flux=1.0, fittype='amp', paramguess=(
+ 100 * ehc.RADPERUAS, 100 * ehc.RADPERUAS, 0.)):
+ """Fit a gaussian to either Stokes I complex visibilities or Stokes I visibility amplitudes.
+
+ Args:
+ flux (float): total flux in the fitted gaussian
+ fitttype (str): "amp" to fit to visibilty amplitudes
+ paramguess (tuble): initial guess of fit Gaussian (fwhm_maj, fwhm_min, theta)
+
+ Returns:
+ (tuple) : (fwhm_maj, fwhm_min, theta) of the fit Gaussian parameters in radians.
+ """
+
+ # TODO this fit doesn't work very well!!
+ vis = self.data['vis']
+ u = self.data['u']
+ v = self.data['v']
+ sig = self.data['sigma']
+
+ # error function
+ if fittype == 'amp':
+ def errfunc(p):
+ vismodel = obsh.gauss_uv(u, v, flux, p, x=0., y=0.)
+ err = np.sum((np.abs(vis) - np.abs(vismodel))**2 / sig**2)
+ return err
+ else:
+ def errfunc(p):
+ vismodel = obsh.gauss_uv(u, v, flux, p, x=0., y=0.)
+ err = np.sum(np.abs(vis - vismodel)**2 / sig**2)
+ return err
+
+ optdict = {'maxiter': 5000} # minimizer params
+ res = opt.minimize(errfunc, paramguess, method='Powell', options=optdict)
+ gparams = res.x
+
+ return gparams
+
+ def ClosureInvariants(self):
+ """
+ Calculates copolar closure invariants for visibilities assuming an n element
+ interferometer array using method 1.
+
+ Nithyanandan, T., Rajaram, N., Joseph, S. 2022 “Invariants in copolar
+ interferometry: An Abelian gauge theory”, PHYS. REV. D 105, 043019.
+ https://doi.org/10.1103/PhysRevD.105.043019
+
+ Args:
+ vis (np.ndarray): visibility data sampled by the interferometer array
+ n (int): number of antenna as part of the interferometer array
+
+ Returns:
+ ci (np.ndarray): closure invariants
+ """
+ tlist = self.tlist()
+ out_ci = np.array([])
+ for tdata in tlist:
+ num_antenna = len(np.unique(tdata['t1'])) + 1
+ if num_antenna < 3:
+ continue
+ vis = tdata['vis'].reshape(1,1,-1)
+ _, btriads = self.Triads(num_antenna)
+ C_oa = vis[:, :, btriads[:, 0]]
+ C_ab = vis[:, :, btriads[:, 1]]
+ C_bo = np.conjugate(vis[:, :, btriads[:, 2]])
+ A_oab = C_oa / np.conjugate(C_ab) * C_bo
+ A_oab = np.dstack((A_oab.real, A_oab.imag))
+ A_max = np.nanmax(np.abs(A_oab), axis=-1, keepdims=True)
+ ci = A_oab / A_max
+ ci = ci.reshape(-1)
+ out_ci = np.concatenate([out_ci, ci], axis=0)
+
+ return out_ci
+
+ def Triads(self, n:int):
+ """
+ Generates arrays of antenna and baseline indicies that form triangular
+ loops pivoted around the 0th antenna. Used to calculate closure invariants
+ whereby specific baseline correlations need to be indexed according
+ to those triangular loops.
+ Baseline array format [ant1, ant2]:
+ [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6] ...
+ [1, 2], [1, 3], [1, 4], [1, 5], [1, 6] ...
+ [2, 3], [2, 4], [2, 5], [2, 6] ...
+ [3, 4], [3, 5], [3, 6] ...
+ [4, 5], [4, 6] ...
+ [5, 6] ...
+
+ Args:
+ n (int): number of antenna in the array
+
+ Returns:
+ atriads (np.ndarray): antenna triangular loop indicies
+ btriads (np.ndarray): baseline triangular loop indicies
+ """
+ ntriads = (n-1)*(n-2)//2
+ ant1 = np.zeros(ntriads, dtype=np.uint8)
+ ant2 = np.arange(1, n, dtype=np.uint8).reshape(n-1, 1) + np.zeros(n-2, dtype=np.uint8).reshape(1, n-2)
+ ant3 = np.arange(2, n, dtype=np.uint8).reshape(1, n-2) + np.zeros(n-1, dtype=np.uint8).reshape(n-1, 1)
+ anti = np.where(ant3 > ant2)
+ ant2, ant3 = ant2[anti], ant3[anti]
+ atriads = np.concatenate([ant1.reshape(-1, 1), ant2.reshape(-1, 1), ant3.reshape(-1, 1)], axis=-1)
+
+ ant_pairs_01 = list(zip(ant1, ant2))
+ ant_pairs_12 = list(zip(ant2, ant3))
+ ant_pairs_20 = list(zip(ant3, ant1))
+
+ t1 = np.arange(n, dtype=int).reshape(n, 1) + np.zeros(n, dtype=int).reshape(1, n)
+ t2 = np.arange(n, dtype=int).reshape(1, n) + np.zeros(n, dtype=int).reshape(n, 1)
+ bli = np.where(t2 > t1)
+ t1, t2 = t1[bli], t2[bli]
+ bl_pairs = list(zip(t1, t2))
+
+ bl_01 = np.asarray([bl_pairs.index(apair) for apair in ant_pairs_01])
+ bl_12 = np.asarray([bl_pairs.index(apair) for apair in ant_pairs_12])
+ bl_20 = np.asarray([bl_pairs.index(tuple(reversed(apair))) for apair in ant_pairs_20])
+ btriads = np.concatenate([bl_01.reshape(-1, 1), bl_12.reshape(-1, 1), bl_20.reshape(-1, 1)], axis=-1)
+ return atriads, btriads
+
+ def bispectra(self, vtype='vis', mode='all', count='min',
+ timetype=False, uv_min=False, snrcut=0.):
+ """Return a recarray of the equal time bispectra.
+
+ Args:
+ vtype (str): The visibilty type from which to assemble bispectra
+ ('vis', 'qvis', 'uvis','vvis','rrvis','lrvis','rlvis','llvis')
+ mode (str): If 'time', return phases in a list of equal time arrays,
+ if 'all', return all phases in a single array
+ count (str): If 'min', return minimal set of bispectra,
+ if 'max' return all bispectra up to reordering
+ timetype (str): 'GMST' or 'UTC'
+ uv_min (float): flag baselines shorter than this before forming closure quantities
+ snrcut (float): flag bispectra with snr lower than this
+
+ Returns:
+ (numpy.recarry): A recarray of the bispectra values with datatype DTBIS
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+ if mode not in ('time', 'all'):
+ raise Exception("possible options for mode are 'time' and 'all'")
+ if count not in ('max', 'min', 'min-cut0bl'):
+ raise Exception("possible options for count are 'max', 'min', or 'min-cut0bl'")
+ if vtype not in ('vis', 'qvis', 'uvis', 'vvis', 'rrvis', 'lrvis', 'rlvis', 'llvis'):
+ raise Exception("possible options for vtype are" +
+ " 'vis', 'qvis', 'uvis','vvis','rrvis','lrvis','rlvis','llvis'")
+ if timetype not in ['GMST', 'UTC', 'gmst', 'utc']:
+ raise Exception("timetype should be 'GMST' or 'UTC'!")
+
+ # Flag zero baselines
+ obsdata = self.copy()
+ if uv_min:
+ obsdata = obsdata.flag_uvdist(uv_min=uv_min)
+ # get which sites were flagged
+ obsdata_flagged = self.copy()
+ obsdata_flagged = obsdata_flagged.flag_uvdist(uv_max=uv_min)
+
+ # Generate the time-sorted data with conjugate baselines
+ tlist = obsdata.tlist(conj=True)
+ out = []
+ bis = []
+ tt = 1
+ for tdata in tlist:
+
+ # sys.stdout.write('\rGetting bispectra:: type %s, count %s, scan %i/%i ' %
+ # (vtype, count, tt, len(tlist)))
+ # sys.stdout.flush()
+
+ tt += 1
+
+ time = tdata[0]['time']
+ if timetype in ['GMST', 'gmst'] and self.timetype == 'UTC':
+ time = obsh.utc_to_gmst(time, self.mjd)
+ if timetype in ['UTC', 'utc'] and self.timetype == 'GMST':
+ time = obsh.gmst_to_utc(time, self.mjd)
+ sites = list(set(np.hstack((tdata['t1'], tdata['t2']))))
+
+ # Create a dictionary of baselines at the current time incl. conjugates;
+ l_dict = {}
+ for dat in tdata:
+ l_dict[(dat['t1'], dat['t2'])] = dat
+
+ # Determine the triangles in the time step
+ # Minimal Set
+ if count == 'min':
+ tris = obsh.tri_minimal_set(sites, self.tarr, self.tkey)
+
+ # Maximal Set
+ elif count == 'max':
+ tris = np.sort(list(it.combinations(sites, 3)))
+
+ elif count == 'min-cut0bl':
+ tris = obsh.tri_minimal_set(sites, self.tarr, self.tkey)
+
+ # if you cut the 0 baselines, add in triangles that now are not in the minimal set
+ if uv_min:
+ # get the reference site
+ sites_ordered = [x for x in self.tarr['site'] if x in sites]
+ ref = sites_ordered[0]
+
+ # check if the reference site was in a zero baseline
+ zerobls = np.vstack([obsdata_flagged.data['t1'], obsdata_flagged.data['t2']])
+ if np.sum(zerobls == ref):
+
+ # determine which sites were cut out of the minimal set
+ cutsites = np.unique(np.hstack([zerobls[1][zerobls[0] == ref],
+ zerobls[0][zerobls[1] == ref]]))
+
+ # we can only handle if there was 1 connecting site that was cut
+ if len(cutsites) > 1:
+ raise Exception("Cannot have the root node be in a clique" +
+ "with more than 2 sites sharing 0 baselines'")
+
+ # get the remaining sites
+ cutsite = cutsites[0]
+ sites_remaining = np.array(sites_ordered)[np.array(sites_ordered) != ref]
+ sites_remaining = sites_remaining[np.array(sites_remaining) != cutsite]
+ # get the next site in the list, ideally sorted by snr
+ second_ref = sites_remaining[0]
+
+ # add in additional triangles
+ for s2 in range(1, len(sites_remaining)):
+ tris.append((cutsite, second_ref, sites_remaining[s2]))
+
+ # Generate bispectra for each triangle
+ for tri in tris:
+
+ # Select triangle entries in the data dictionary
+ try:
+ l1 = l_dict[(tri[0], tri[1])]
+ l2 = l_dict[(tri[1], tri[2])]
+ l3 = l_dict[(tri[2], tri[0])]
+ except KeyError:
+ continue
+
+ (bi, bisig) = obsh.make_bispectrum(l1, l2, l3, vtype, polrep=self.polrep)
+
+ # Cut out low snr points
+ if np.abs(bi) / bisig < snrcut:
+ continue
+
+ # Append to the equal-time list
+ bis.append(np.array((time,
+ tri[0], tri[1], tri[2],
+ l1['u'], l1['v'],
+ l2['u'], l2['v'],
+ l3['u'], l3['v'],
+ bi, bisig), dtype=ehc.DTBIS))
+
+ # Append to outlist
+ if mode == 'time' and len(bis) > 0:
+ out.append(np.array(bis))
+ bis = []
+
+ if mode == 'all':
+ out = np.array(bis)
+
+ return out
+
+ def c_phases(self, vtype='vis', mode='all', count='min', ang_unit='deg',
+ timetype=False, uv_min=False, snrcut=0.):
+ """Return a recarray of the equal time closure phases.
+
+ Args:
+ vtype (str): The visibilty type from which to assemble closure phases
+ ('vis','qvis','uvis','vvis','pvis')
+ mode (str): If 'time', return phases in a list of equal time arrays,
+ if 'all', return all phases in a single array
+ count (str): If 'min', return minimal set of phases,
+ if 'max' return all closure phases up to reordering
+ ang_unit (str): If 'deg', return closure phases in degrees, else return in radians
+ timetype (str): 'UTC' or 'GMST'
+ uv_min (float): flag baselines shorter than this before forming closure quantities
+ snrcut (float): flag bispectra with snr lower than this
+
+ Returns:
+ (numpy.recarry): A recarray of the closure phases with datatype DTCPHASE
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+ if mode not in ('time', 'all'):
+ raise Exception("possible options for mode are 'time' and 'all'")
+ if count not in ('max', 'min', 'min-cut0bl'):
+ raise Exception("possible options for count are 'max', 'min', or 'min-cut0bl'")
+ if vtype not in ('vis', 'qvis', 'uvis', 'vvis', 'rrvis', 'lrvis', 'rlvis', 'llvis'):
+ raise Exception("possible options for vtype are" +
+ " 'vis', 'qvis', 'uvis','vvis','rrvis','lrvis','rlvis','llvis'")
+ if timetype not in ['GMST', 'UTC', 'gmst', 'utc']:
+ raise Exception("timetype should be 'GMST' or 'UTC'!")
+
+ if ang_unit == 'deg':
+ angle = ehc.DEGREE
+ else:
+ angle = 1.0
+
+ # Get the bispectra data
+ bispecs = self.bispectra(vtype=vtype, mode='time', count=count,
+ timetype=timetype, uv_min=uv_min, snrcut=snrcut)
+
+ # Reformat into a closure phase list/array
+ out = []
+ cps = []
+
+ cpnames = ('time', 't1', 't2', 't3', 'u1', 'v1', 'u2',
+ 'v2', 'u3', 'v3', 'cphase', 'sigmacp')
+ for bis in bispecs:
+ for bi in bis:
+ if len(bi) == 0:
+ continue
+ bi.dtype.names = cpnames
+ bi['sigmacp'] = np.real(bi['sigmacp'] / np.abs(bi['cphase']) / angle)
+ bi['cphase'] = np.real((np.angle(bi['cphase']) / angle))
+ cps.append(bi.astype(np.dtype(ehc.DTCPHASE)))
+
+ if mode == 'time' and len(cps) > 0:
+ out.append(np.array(cps))
+ cps = []
+
+ if mode == 'all':
+ out = np.array(cps)
+
+ return out
+
+ def c_phases_diag(self, vtype='vis', count='min', ang_unit='deg',
+ timetype=False, uv_min=False, snrcut=0.):
+ """Return a recarray of the equal time diagonalized closure phases.
+
+ Args:
+ vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis')
+ from which to assemble closure phases
+ count (str): If 'min', return minimal set of phases,
+ If 'min-cut0bl' return minimal set after flagging zero-baselines
+ ang_unit (str): If 'deg', return closure phases in degrees, else return in radians
+ timetype (str): 'UTC' or 'GMST'
+ uv_min (float): flag baselines shorter than this before forming closure quantities
+ snrcut (float): flag bispectra with snr lower than this
+
+ Returns:
+ (numpy.recarry): A recarray of diagonalized closure phases (datatype DTCPHASEDIAG),
+ along with associated triangles and transformation matrices
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+ if count not in ('min', 'min-cut0bl'):
+ raise Exception(
+ "possible options for count are 'min' or 'min-cut0bl' for diagonal closure phases")
+ if vtype not in ('vis', 'qvis', 'uvis', 'vvis', 'rrvis', 'lrvis', 'rlvis', 'llvis'):
+ raise Exception(
+ "possible options for vtype are 'vis', 'qvis', " +
+ "'uvis','vvis','rrvis','lrvis','rlvis','llvis'")
+ if timetype not in ['GMST', 'UTC', 'gmst', 'utc']:
+ raise Exception("timetype should be 'GMST' or 'UTC'!")
+
+ if ang_unit == 'deg':
+ angle = ehc.DEGREE
+ else:
+ angle = 1.0
+
+ # determine the appropriate sigmatype
+ if vtype in ["vis", "qvis", "uvis", "vvis"]:
+ if vtype == 'vis':
+ sigmatype = 'sigma'
+ if vtype == 'qvis':
+ sigmatype = 'qsigma'
+ if vtype == 'uvis':
+ sigmatype = 'usigma'
+ if vtype == 'vvis':
+ sigmatype = 'vsigma'
+ if vtype in ["rrvis", "llvis", "rlvis", "lrvis"]:
+ if vtype == 'rrvis':
+ sigmatype = 'rrsigma'
+ if vtype == 'llvis':
+ sigmatype = 'llsigma'
+ if vtype == 'rlvis':
+ sigmatype = 'rlsigma'
+ if vtype == 'lrvis':
+ sigmatype = 'lrsigma'
+
+ # get the time-sorted visibility data including conjugate baselines
+ viss = np.concatenate(self.tlist(conj=True))
+
+ # get the closure phase data
+ cps = self.c_phases(vtype=vtype, mode='all', count=count, ang_unit=ang_unit,
+ timetype=timetype, uv_min=uv_min, snrcut=snrcut)
+
+ # get the unique timestamps for the closure phases
+ T_cps = np.unique(cps['time'])
+
+ # list of diagonalized closure phases and corresponding transformation matrices
+ dcps = []
+ dcp_errs = []
+ tfmats = []
+
+ tris = []
+ us = []
+ vs = []
+
+ # loop over the timestamps
+ for kk, t in enumerate(T_cps):
+
+ sys.stdout.write('\rDiagonalizing closure phases:: type %s, count %s, scan %i/%i ' %
+ (vtype, count, kk + 1, len(T_cps)))
+ sys.stdout.flush()
+
+ # index masks for this timestamp
+ mask_cp = (cps['time'] == t)
+ mask_vis = (viss['time'] == t)
+
+ # closure phases for this timestamp
+ cps_here = cps[mask_cp]
+
+ # visibilities for this timestamp
+ viss_here = viss[mask_vis]
+
+ # initialize the design matrix
+ design_mat = np.zeros((mask_cp.sum(), mask_vis.sum()))
+
+ # loop over the closure phases within this timestamp
+ trilist = []
+ ulist = []
+ vlist = []
+ for ic, cp in enumerate(cps_here):
+
+ trilist.append((cp['t1'], cp['t2'], cp['t3']))
+ ulist.append((cp['u1'], cp['u2'], cp['u3']))
+ vlist.append((cp['v1'], cp['v2'], cp['v3']))
+
+ # matrix entry for first leg of triangle
+ ind1 = ((viss_here['t1'] == cp['t1']) & (viss_here['t2'] == cp['t2']))
+ design_mat[ic, ind1] = 1.0
+
+ # matrix entry for second leg of triangle
+ ind2 = ((viss_here['t1'] == cp['t2']) & (viss_here['t2'] == cp['t3']))
+ design_mat[ic, ind2] = 1.0
+
+ # matrix entry for third leg of triangle
+ ind3 = ((viss_here['t1'] == cp['t3']) & (viss_here['t2'] == cp['t1']))
+ design_mat[ic, ind3] = 1.0
+
+ # construct the covariance matrix
+ visphase_err = viss_here[sigmatype] / np.abs(viss_here[vtype])
+ sigma_mat = np.diag(visphase_err**2.0)
+ covar_mat = np.matmul(np.matmul(design_mat, sigma_mat), np.transpose(design_mat))
+
+ # diagonalize via eigendecomposition
+ eigeninfo = np.linalg.eigh(covar_mat)
+ S_matrix = np.copy(eigeninfo[1]).transpose()
+ dcphase = np.matmul(S_matrix, cps_here['cphase'])
+ if ang_unit != 'deg':
+ dcphase *= angle
+ dcphase_err = np.sqrt(np.copy(eigeninfo[0])) / angle
+
+ dcps.append(dcphase)
+ dcp_errs.append(dcphase_err)
+ tfmats.append(S_matrix)
+ tris.append(trilist)
+ us.append(ulist)
+ vs.append(vlist)
+
+ # Reformat into a list
+ out = []
+ for kk, t in enumerate(T_cps):
+ dcparr = []
+ for idcp, dcp in enumerate(dcps[kk]):
+ dcparr.append((t, dcp, dcp_errs[kk][idcp]))
+ dcparr = np.array(dcparr, dtype=[('time', 'f8'), ('cphase', 'f8'), ('sigmacp', 'f8')])
+ out.append((dcparr,
+ np.array(tris[kk]).astype(np.dtype([('trianges', 'U2')])),
+ np.array(us[kk]).astype(np.dtype([('u', 'f8')])),
+ np.array(vs[kk]).astype(np.dtype([('v', 'f8')])),
+ tfmats[kk].astype(np.dtype([('tform_matrix', 'f8')]))))
+ print("\n")
+ return out
+
+ def bispectra_tri(self, site1, site2, site3,
+ vtype='vis', timetype=False, snrcut=0., method='from_maxset',
+ bs=[], force_recompute=False):
+
+ """Return complex bispectrum over time on a triangle (1-2-3).
+
+ Args:
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+
+ vtype (str): The visibilty type from which to assemble bispectra
+ ('vis','qvis','uvis','vvis','pvis')
+ timetype (str): 'UTC' or 'GMST'
+ snrcut (float): flag bispectra with snr lower than this
+
+ method (str): 'from_maxset' (old, default), 'from_vis' (new, more robust)
+ bs (list): optionally pass in the precomputed, time-sorted bispectra
+ force_recompute (bool): if True, recompute bispectra instead of using saved data
+
+ Returns:
+ (numpy.recarry): A recarray of the bispectra on this triangle with datatype DTBIS
+ """
+ if timetype is False:
+ timetype = self.timetype
+
+ if method=='from_maxset' and (vtype in ['lrvis','pvis','rlvis']):
+ print ("Warning! method='from_maxset' default in bispectra_tri() inconsistent with vtype=%s" % vtype)
+ print ("Switching to method='from_vis'")
+ method = 'from_vis'
+
+ tri = (site1, site2, site3)
+ outdata = []
+
+ # get selected bispectra from the maximal set
+ # TODO: verify consistency/performance of from_vis, and delete this method
+ if method=='from_maxset':
+
+ if ((len(bs) == 0) and not (self.bispec is None) and not (len(self.bispec) == 0) and
+ not force_recompute):
+ bs = self.bispec
+ elif (len(bs) == 0) or force_recompute:
+ bs = self.bispectra(mode='all', count='max', vtype=vtype,
+ timetype=timetype, snrcut=snrcut)
+
+ # Get requested bispectra over time
+ for obs in bs:
+ obstri = (obs['t1'], obs['t2'], obs['t3'])
+ if set(obstri) == set(tri):
+ t1 = copy.deepcopy(obs['t1'])
+ t2 = copy.deepcopy(obs['t2'])
+ t3 = copy.deepcopy(obs['t3'])
+ u1 = copy.deepcopy(obs['u1'])
+ u2 = copy.deepcopy(obs['u2'])
+ u3 = copy.deepcopy(obs['u3'])
+ v1 = copy.deepcopy(obs['v1'])
+ v2 = copy.deepcopy(obs['v2'])
+ v3 = copy.deepcopy(obs['v3'])
+
+ # Reorder baselines and flip the sign of the closure phase if necessary
+ if t1 == site1:
+ if t2 == site2:
+ pass
+ else:
+ obs['t2'] = t3
+ obs['t3'] = t2
+
+ obs['u1'] = -u3
+ obs['v1'] = -v3
+ obs['u2'] = -u2
+ obs['v2'] = -v2
+ obs['u3'] = -u1
+ obs['v3'] = -v1
+ obs['bispec'] = np.conjugate(obs['bispec'])
+
+ elif t1 == site2:
+ if t2 == site3:
+ obs['t1'] = t3
+ obs['t2'] = t1
+ obs['t3'] = t2
+
+ obs['u1'] = u3
+ obs['v1'] = v3
+ obs['u2'] = u1
+ obs['v2'] = v1
+ obs['u3'] = u2
+ obs['v3'] = v2
+
+ else:
+ obs['t1'] = t2
+ obs['t2'] = t1
+
+ obs['u1'] = -u1
+ obs['v1'] = -v1
+ obs['u2'] = -u3
+ obs['v2'] = -v3
+ obs['u3'] = -u2
+ obs['v3'] = -v2
+ obs['bispec'] = np.conjugate(obs['bispec'])
+
+ elif t1 == site3:
+ if t2 == site1:
+ obs['t1'] = t2
+ obs['t2'] = t3
+ obs['t3'] = t1
+
+ obs['u1'] = u2
+ obs['v1'] = v2
+ obs['u2'] = u3
+ obs['v2'] = v3
+ obs['u3'] = u1
+ obs['v3'] = v1
+
+ else:
+ obs['t1'] = t3
+ obs['t3'] = t1
+
+ obs['u1'] = -u2
+ obs['v1'] = -v2
+ obs['u2'] = -u1
+ obs['v2'] = -v1
+ obs['u3'] = -u3
+ obs['v3'] = -v3
+ obs['bispec'] = np.conjugate(obs['bispec'])
+
+ outdata.append(np.array(obs, dtype=ehc.DTBIS))
+ continue
+
+ # get selected bispectra from the visibilities directly
+ # taken from bispectra() method
+ elif method=='from_vis':
+
+ # get all equal-time data, and loop over to construct bispectra
+ tlist = self.tlist(conj=True)
+ for tdata in tlist:
+
+ time = tdata[0]['time']
+ if timetype in ['GMST', 'gmst'] and self.timetype == 'UTC':
+ time = obsh.utc_to_gmst(time, self.mjd)
+ if timetype in ['UTC', 'utc'] and self.timetype == 'GMST':
+ time = obsh.gmst_to_utc(time, self.mjd)
+
+ # Create a dictionary of baselines at the current time incl. conjugates;
+ l_dict = {}
+ for dat in tdata:
+ l_dict[(dat['t1'], dat['t2'])] = dat
+
+ # Select triangle entries in the data dictionary
+ try:
+ l1 = l_dict[(tri[0], tri[1])]
+ l2 = l_dict[(tri[1], tri[2])]
+ l3 = l_dict[(tri[2], tri[0])]
+ except KeyError:
+ continue
+
+ (bi, bisig) = obsh.make_bispectrum(l1, l2, l3, vtype, polrep=self.polrep)
+
+ # Cut out low snr points
+ if np.abs(bi) / bisig < snrcut:
+ continue
+
+ # Append to the equal-time list
+ outdata.append(np.array((time,
+ tri[0], tri[1], tri[2],
+ l1['u'], l1['v'],
+ l2['u'], l2['v'],
+ l3['u'], l3['v'],
+ bi,
+ bisig),
+ dtype=ehc.DTBIS))
+ else:
+ raise Exception("keyword 'method' in bispectra_tri() must be either 'from_cphase' or 'from_vis'")
+
+ outdata = np.array(outdata)
+ return outdata
+
+ def cphase_tri(self, site1, site2, site3, vtype='vis', ang_unit='deg',
+ timetype=False, snrcut=0., method='from_maxset',
+ cphases=[], force_recompute=False):
+ """Return closure phase over time on a triangle (1-2-3).
+
+ Args:
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+
+ vtype (str): The visibilty type from which to assemble closure phases
+ (e.g., 'vis','qvis','uvis','vvis','pvis')
+ ang_unit (str): If 'deg', return closure phases in degrees, else return in radians
+ timetype (str): 'GMST' or 'UTC'
+ snrcut (float): flag bispectra with snr lower than this
+
+ method (str): 'from_maxset' (old, default), 'from_vis' (new, more robust)
+ cphases (list): optionally pass in the precomputed time-sorted cphases
+ force_recompute (bool): if True, do not use save closure phase tables
+
+ Returns:
+ (numpy.recarry): A recarray of the closure phases with datatype DTCPHASE
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+
+ if method=='from_maxset' and (vtype in ['lrvis','pvis','rlvis']):
+ print ("Warning! method='from_maxset' default in cphase_tri() is inconsistent with vtype=%s" % vtype)
+ print ("Switching to method='from_vis'")
+ method = 'from_vis'
+
+ tri = (site1, site2, site3)
+ outdata = []
+
+ # get selected closure phases from the maximal set
+ # TODO: verify consistency/performance of from_vis, and delete this method
+ if method=='from_maxset':
+
+ # Get closure phases (maximal set)
+ if ((len(cphases) == 0) and not (self.cphase is None) and not (len(self.cphase) == 0) and
+ not force_recompute):
+ cphases = self.cphase
+
+ elif (len(cphases) == 0) or force_recompute:
+ cphases = self.c_phases(mode='all', count='max', vtype=vtype, ang_unit=ang_unit,
+ timetype=timetype, snrcut=snrcut)
+
+ # Get requested closure phases over time
+ for obs in cphases:
+ obstri = (obs['t1'], obs['t2'], obs['t3'])
+ if set(obstri) == set(tri):
+ t1 = copy.deepcopy(obs['t1'])
+ t2 = copy.deepcopy(obs['t2'])
+ t3 = copy.deepcopy(obs['t3'])
+ u1 = copy.deepcopy(obs['u1'])
+ u2 = copy.deepcopy(obs['u2'])
+ u3 = copy.deepcopy(obs['u3'])
+ v1 = copy.deepcopy(obs['v1'])
+ v2 = copy.deepcopy(obs['v2'])
+ v3 = copy.deepcopy(obs['v3'])
+
+ # Reorder baselines and flip the sign of the closure phase if necessary
+ if t1 == site1:
+ if t2 == site2:
+ pass
+ else:
+ obs['t2'] = t3
+ obs['t3'] = t2
+
+ obs['u1'] = -u3
+ obs['v1'] = -v3
+ obs['u2'] = -u2
+ obs['v2'] = -v2
+ obs['u3'] = -u1
+ obs['v3'] = -v1
+ obs['cphase'] *= -1
+
+ elif t1 == site2:
+ if t2 == site3:
+ obs['t1'] = t3
+ obs['t2'] = t1
+ obs['t3'] = t2
+
+ obs['u1'] = u3
+ obs['v1'] = v3
+ obs['u2'] = u1
+ obs['v2'] = v1
+ obs['u3'] = u2
+ obs['v3'] = v2
+
+ else:
+ obs['t1'] = t2
+ obs['t2'] = t1
+
+ obs['u1'] = -u1
+ obs['v1'] = -v1
+ obs['u2'] = -u3
+ obs['v2'] = -v3
+ obs['u3'] = -u2
+ obs['v3'] = -v2
+ obs['cphase'] *= -1
+
+ elif t1 == site3:
+ if t2 == site1:
+ obs['t1'] = t2
+ obs['t2'] = t3
+ obs['t3'] = t1
+
+ obs['u1'] = u2
+ obs['v1'] = v2
+ obs['u2'] = u3
+ obs['v2'] = v3
+ obs['u3'] = u1
+ obs['v3'] = v1
+
+ else:
+ obs['t1'] = t3
+ obs['t3'] = t1
+
+ obs['u1'] = -u2
+ obs['v1'] = -v2
+ obs['u2'] = -u1
+ obs['v2'] = -v1
+ obs['u3'] = -u3
+ obs['v3'] = -v3
+ obs['cphase'] *= -1
+
+ outdata.append(np.array(obs, dtype=ehc.DTCPHASE))
+ continue
+
+ # get selected closure phases from the visibilities directly
+ # taken from bispectra() method
+ elif method=='from_vis':
+ if ang_unit == 'deg': angle = ehc.DEGREE
+ else: angle = 1.0
+
+ # get all equal-time data, and loop over to construct closure phase
+ tlist = self.tlist(conj=True)
+ for tdata in tlist:
+
+ time = tdata[0]['time']
+ if timetype in ['GMST', 'gmst'] and self.timetype == 'UTC':
+ time = obsh.utc_to_gmst(time, self.mjd)
+ if timetype in ['UTC', 'utc'] and self.timetype == 'GMST':
+ time = obsh.gmst_to_utc(time, self.mjd)
+
+ # Create a dictionary of baselines at the current time incl. conjugates;
+ l_dict = {}
+ for dat in tdata:
+ l_dict[(dat['t1'], dat['t2'])] = dat
+
+ # Select triangle entries in the data dictionary
+ try:
+ l1 = l_dict[(tri[0], tri[1])]
+ l2 = l_dict[(tri[1], tri[2])]
+ l3 = l_dict[(tri[2], tri[0])]
+ except KeyError:
+ continue
+
+ (bi, bisig) = obsh.make_bispectrum(l1, l2, l3, vtype, polrep=self.polrep)
+
+ # Cut out low snr points
+ if np.abs(bi) / bisig < snrcut:
+ continue
+
+ # Append to the equal-time list
+ outdata.append(np.array((time,
+ tri[0], tri[1], tri[2],
+ l1['u'], l1['v'],
+ l2['u'], l2['v'],
+ l3['u'], l3['v'],
+ np.real(np.angle(bi) / angle),
+ np.real(bisig / np.abs(bi) / angle)),
+ dtype=ehc.DTCPHASE))
+ else:
+ raise Exception("keyword 'method' in cphase_tri() must be either 'from_cphase' or 'from_vis'")
+
+ outdata = np.array(outdata)
+ return outdata
+
+ def c_amplitudes(self, vtype='vis', mode='all', count='min', ctype='camp', debias=True,
+ timetype=False, snrcut=0.):
+ """Return a recarray of the equal time closure amplitudes.
+
+ Args:
+ vtype (str): The visibilty type from which to assemble closure amplitudes
+ ('vis','qvis','uvis','vvis','pvis')
+ ctype (str): The closure amplitude type ('camp' or 'logcamp')
+ mode (str): If 'time', return amplitudes in a list of equal time arrays,
+ if 'all', return all amplitudes in a single array
+ count (str): If 'min', return minimal set of amplitudes,
+ if 'max' return all closure amplitudes up to inverses
+ debias (bool): If True, debias the closure amplitude
+ timetype (str): 'GMST' or 'UTC'
+ snrcut (float): flag closure amplitudes with snr lower than this
+
+ Returns:
+ (numpy.recarry): A recarray of the closure amplitudes with datatype DTCAMP
+
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+ if mode not in ('time', 'all'):
+ raise Exception("possible options for mode are 'time' and 'all'")
+ if count not in ('max', 'min'):
+ raise Exception("possible options for count are 'max' and 'min'")
+ if vtype not in ('vis', 'qvis', 'uvis', 'vvis', 'rrvis', 'lrvis', 'rlvis', 'llvis'):
+ raise Exception("possible options for vtype are " +
+ "'vis', 'qvis', 'uvis','vvis','rrvis','lrvis','rlvis','llvis'")
+ if not (ctype in ['camp', 'logcamp']):
+ raise Exception("closure amplitude type must be 'camp' or 'logcamp'!")
+ if timetype not in ['GMST', 'UTC', 'gmst', 'utc']:
+ raise Exception("timetype should be 'GMST' or 'UTC'!")
+
+ # Get data sorted by time
+ tlist = self.tlist(conj=True)
+ out = []
+ cas = []
+ tt = 1
+ for tdata in tlist:
+
+ # sys.stdout.write('\rGetting closure amps:: type %s %s , count %s, scan %i/%i' %
+ # (vtype, ctype, count, tt, len(tlist)))
+ # sys.stdout.flush()
+ tt += 1
+
+ time = tdata[0]['time']
+ if timetype in ['GMST', 'gmst'] and self.timetype == 'UTC':
+ time = obsh.utc_to_gmst(time, self.mjd)
+ if timetype in ['UTC', 'utc'] and self.timetype == 'GMST':
+ time = obsh.gmst_to_utc(time, self.mjd)
+
+ sites = np.array(list(set(np.hstack((tdata['t1'], tdata['t2'])))))
+ if len(sites) < 4:
+ continue
+
+ # Create a dictionary of baseline data at the current time including conjugates;
+ l_dict = {}
+ for dat in tdata:
+ l_dict[(dat['t1'], dat['t2'])] = dat
+
+ # Minimal set
+ if count == 'min':
+ quadsets = obsh.quad_minimal_set(sites, self.tarr, self.tkey)
+
+ # Maximal Set
+ elif count == 'max':
+ # Find all quadrangles
+ quadsets = np.sort(list(it.combinations(sites, 4)))
+ # Include 3 closure amplitudes on each quadrangle
+ quadsets = np.array([(q, [q[0], q[2], q[1], q[3]], [q[0], q[1], q[3], q[2]])
+ for q in quadsets]).reshape((-1, 4))
+
+ # Loop over all closure amplitudes
+ for quad in quadsets:
+ # Blue is numerator, red is denominator
+ if (quad[0], quad[1]) not in l_dict.keys():
+ continue
+ if (quad[2], quad[3]) not in l_dict.keys():
+ continue
+ if (quad[1], quad[2]) not in l_dict.keys():
+ continue
+ if (quad[0], quad[3]) not in l_dict.keys():
+ continue
+
+ try:
+ blue1 = l_dict[quad[0], quad[1]]
+ blue2 = l_dict[quad[2], quad[3]]
+ red1 = l_dict[quad[0], quad[3]]
+ red2 = l_dict[quad[1], quad[2]]
+ except KeyError:
+ continue
+
+ # Compute the closure amplitude and the error
+ (camp, camperr) = obsh.make_closure_amplitude(blue1, blue2, red1, red2, vtype,
+ polrep=self.polrep,
+ ctype=ctype, debias=debias)
+
+ if ctype == 'camp' and camp / camperr < snrcut:
+ continue
+ elif ctype == 'logcamp' and 1. / camperr < snrcut:
+ continue
+
+ # Add the closure amplitudes to the equal-time list
+ # Our site convention is (12)(34)/(14)(23)
+ cas.append(np.array((time,
+ quad[0], quad[1], quad[2], quad[3],
+ blue1['u'], blue1['v'], blue2['u'], blue2['v'],
+ red1['u'], red1['v'], red2['u'], red2['v'],
+ camp, camperr),
+ dtype=ehc.DTCAMP))
+
+ # Append all equal time closure amps to outlist
+ if mode == 'time':
+ out.append(np.array(cas))
+ cas = []
+
+ if mode == 'all':
+ out = np.array(cas)
+
+ return out
+
+ def c_log_amplitudes_diag(self, vtype='vis', mode='all', count='min',
+ debias=True, timetype=False, snrcut=0.):
+ """Return a recarray of the equal time diagonalized log closure amplitudes.
+
+ Args:
+ vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis')
+ From which to assemble closure amplitudes
+ ctype (str): The closure amplitude type ('camp' or 'logcamp')
+ mode (str): If 'time', return amplitudes in a list of equal time arrays,
+ If 'all', return all amplitudes in a single array
+ count (str): If 'min', return minimal set of amplitudes,
+ If 'max' return all closure amplitudes up to inverses
+ debias (bool): If True, debias the closure amplitude -
+ The individual visibility amplitudes are always debiased.
+ timetype (str): 'GMST' or 'UTC'
+ snrcut (float): flag closure amplitudes with snr lower than this
+
+ Returns:
+ (numpy.recarry): A recarray of diagonalized closure amps with datatype DTLOGCAMPDIAG
+
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+ if mode not in ('time', 'all'):
+ raise Exception("possible options for mode are 'time' and 'all'")
+ if count not in ('min'):
+ raise Exception("count can only be 'min' for diagonal log closure amplitudes")
+ if vtype not in ('vis', 'qvis', 'uvis', 'vvis', 'rrvis', 'lrvis', 'rlvis', 'llvis'):
+ raise Exception(
+ "possible options for vtype are 'vis', 'qvis', 'uvis', " +
+ "'vvis','rrvis','lrvis','rlvis','llvis'")
+ if timetype not in ['GMST', 'UTC', 'gmst', 'utc']:
+ raise Exception("timetype should be 'GMST' or 'UTC'!")
+
+ # determine the appropriate sigmatype
+ if vtype in ["vis", "qvis", "uvis", "vvis"]:
+ if vtype == 'vis':
+ sigmatype = 'sigma'
+ if vtype == 'qvis':
+ sigmatype = 'qsigma'
+ if vtype == 'uvis':
+ sigmatype = 'usigma'
+ if vtype == 'vvis':
+ sigmatype = 'vsigma'
+ if vtype in ["rrvis", "llvis", "rlvis", "lrvis"]:
+ if vtype == 'rrvis':
+ sigmatype = 'rrsigma'
+ if vtype == 'llvis':
+ sigmatype = 'llsigma'
+ if vtype == 'rlvis':
+ sigmatype = 'rlsigma'
+ if vtype == 'lrvis':
+ sigmatype = 'lrsigma'
+
+ # get the time-sorted visibility data including conjugate baselines
+ viss = np.concatenate(self.tlist(conj=True))
+
+ # get the log closure amplitude data
+ lcas = self.c_amplitudes(vtype=vtype, mode=mode, count=count,
+ ctype='logcamp', debias=debias, timetype=timetype, snrcut=snrcut)
+
+ # get the unique timestamps for the log closure amplitudes
+ T_lcas = np.unique(lcas['time'])
+
+ # list of diagonalized log closure camplitudes and corresponding transformation matrices
+ dlcas = []
+ dlca_errs = []
+ tfmats = []
+
+ quads = []
+ us = []
+ vs = []
+
+ # loop over the timestamps
+ for kk, t in enumerate(T_lcas):
+
+ printstr = ('\rDiagonalizing log closure amplitudes:: type %s, count %s, scan %i/%i ' %
+ (vtype, count, kk + 1, len(T_lcas)))
+ sys.stdout.write(printstr)
+ sys.stdout.flush()
+
+ # index masks for this timestamp
+ mask_lca = (lcas['time'] == t)
+ mask_vis = (viss['time'] == t)
+
+ # log closure amplitudes for this timestamp
+ lcas_here = lcas[mask_lca]
+
+ # visibilities for this timestamp
+ viss_here = viss[mask_vis]
+
+ # initialize the design matrix
+ design_mat = np.zeros((mask_lca.sum(), mask_vis.sum()))
+
+ # loop over the log closure amplitudes within this timestamp
+ quadlist = []
+ ulist = []
+ vlist = []
+ for il, lca in enumerate(lcas_here):
+
+ quadlist.append((lca['t1'], lca['t2'], lca['t3'], lca['t4']))
+ ulist.append((lca['u1'], lca['u2'], lca['u3'], lca['u4']))
+ vlist.append((lca['v1'], lca['v2'], lca['v3'], lca['v4']))
+
+ # matrix entry for first leg of quadrangle
+ ind1 = ((viss_here['t1'] == lca['t1']) & (viss_here['t2'] == lca['t2']))
+ design_mat[il, ind1] = 1.0
+
+ # matrix entry for second leg of quadrangle
+ ind2 = ((viss_here['t1'] == lca['t3']) & (viss_here['t2'] == lca['t4']))
+ design_mat[il, ind2] = 1.0
+
+ # matrix entry for third leg of quadrangle
+ ind3 = ((viss_here['t1'] == lca['t1']) & (viss_here['t2'] == lca['t4']))
+ design_mat[il, ind3] = -1.0
+
+ # matrix entry for fourth leg of quadrangle
+ ind4 = ((viss_here['t1'] == lca['t2']) & (viss_here['t2'] == lca['t3']))
+ design_mat[il, ind4] = -1.0
+
+ # construct the covariance matrix
+ logvisamp_err = viss_here[sigmatype] / np.abs(viss_here[vtype])
+ sigma_mat = np.diag(logvisamp_err**2.0)
+ covar_mat = np.matmul(np.matmul(design_mat, sigma_mat), np.transpose(design_mat))
+
+ # diagonalize via eigendecomposition
+ eigeninfo = np.linalg.eigh(covar_mat)
+ T_matrix = np.copy(eigeninfo[1]).transpose()
+ dlogcamp = np.matmul(T_matrix, lcas_here['camp'])
+ dlogcamp_err = np.sqrt(np.copy(eigeninfo[0]))
+
+ dlcas.append(dlogcamp)
+ dlca_errs.append(dlogcamp_err)
+ tfmats.append(T_matrix)
+ quads.append(quadlist)
+ us.append(ulist)
+ vs.append(vlist)
+
+ # Reformat into a list
+ out = []
+ for kk, t in enumerate(T_lcas):
+ dlcaarr = []
+ for idlca, dlca in enumerate(dlcas[kk]):
+ dlcaarr.append((t, dlca, dlca_errs[kk][idlca]))
+ dlcaarr = np.array(dlcaarr, dtype=[('time', 'f8'), ('camp', 'f8'), ('sigmaca', 'f8')])
+ out.append((dlcaarr,
+ np.array(quads[kk]).astype(np.dtype([('quadrangles', 'U2')])),
+ np.array(us[kk]).astype(np.dtype([('u', 'f8')])),
+ np.array(vs[kk]).astype(np.dtype([('v', 'f8')])),
+ tfmats[kk].astype(np.dtype([('tform_matrix', 'f8')]))))
+ print("\n")
+ return out
+
+ def camp_quad(self, site1, site2, site3, site4,
+ vtype='vis', ctype='camp', debias=True, timetype=False, snrcut=0.,
+ method='from_maxset',
+ camps=[], force_recompute=False):
+ """Return closure phase over time on a quadrange (1-2)(3-4)/(1-4)(2-3).
+
+ Args:
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+ site4 (str): station 4 name
+
+ vtype (str): The visibilty type from which to assemble closure amplitudes
+ ('vis','qvis','uvis','vvis','pvis')
+ ctype (str): The closure amplitude type ('camp' or 'logcamp')
+ debias (bool): If True, debias the closure amplitude
+ timetype (str): 'UTC' or 'GMST'
+ snrcut (float): flag closure amplitudes with snr lower than this
+
+ method (str): 'from_maxset' (old, default), 'from_vis' (new, more robust)
+ camps (list): optionally pass in the time-sorted, precomputed camps
+ force_recompute (bool): if True, do not use save closure amplitude data
+
+ Returns:
+ (numpy.recarry): A recarray of the closure amplitudes with datatype DTCAMP
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+
+
+ if method=='from_maxset' and (vtype in ['lrvis','pvis','rlvis']):
+ print ("Warning! method='from_maxset' default in camp_quad() is inconsistent with vtype=%s" % vtype)
+ print ("Switching to method='from_vis'")
+ method = 'from_vis'
+
+ quad = (site1, site2, site3, site4)
+ outdata = []
+
+ # get selected closure amplitudes from the maximal set
+ # TODO: verify consistency/performance of from_vis, and delete this method
+ if method=='from_maxset':
+ if (((ctype == 'camp') and (len(camps) == 0)) and not (self.camp is None) and
+ not (len(self.camp) == 0) and not force_recompute):
+ camps = self.camp
+ elif (((ctype == 'logcamp') and (len(camps) == 0)) and not (self.logcamp is None) and
+ not (len(self.logcamp) == 0) and not force_recompute):
+ camps = self.logcamp
+ elif (len(camps) == 0) or force_recompute:
+ camps = self.c_amplitudes(mode='all', count='max', vtype=vtype, ctype=ctype,
+ debias=debias, timetype=timetype, snrcut=snrcut)
+
+ # blue bls in numerator, red in denominator
+ b1 = set((site1, site2))
+ b2 = set((site3, site4))
+ r1 = set((site1, site4))
+ r2 = set((site2, site3))
+
+ for obs in camps: # camps does not contain inverses!
+
+ num = [set((obs['t1'], obs['t2'])), set((obs['t3'], obs['t4']))]
+ denom = [set((obs['t1'], obs['t4'])), set((obs['t2'], obs['t3']))]
+
+ obsquad = (obs['t1'], obs['t2'], obs['t3'], obs['t4'])
+ if set(quad) == set(obsquad):
+
+ # is this either the closure amplitude or inverse?
+ rightup = (b1 in num) and (b2 in num) and (r1 in denom) and (r2 in denom)
+ wrongup = (b1 in denom) and (b2 in denom) and (r1 in num) and (r2 in num)
+ if not (rightup or wrongup):
+ continue
+
+ # flip the inverse closure amplitudes
+ if wrongup:
+ t1old = copy.deepcopy(obs['t1'])
+ u1old = copy.deepcopy(obs['u1'])
+ v1old = copy.deepcopy(obs['v1'])
+ t2old = copy.deepcopy(obs['t2'])
+ u2old = copy.deepcopy(obs['u2'])
+ v2old = copy.deepcopy(obs['v2'])
+ t3old = copy.deepcopy(obs['t3'])
+ u3old = copy.deepcopy(obs['u3'])
+ v3old = copy.deepcopy(obs['v3'])
+ t4old = copy.deepcopy(obs['t4'])
+ u4old = copy.deepcopy(obs['u4'])
+ v4old = copy.deepcopy(obs['v4'])
+ campold = copy.deepcopy(obs['camp'])
+ csigmaold = copy.deepcopy(obs['sigmaca'])
+
+ obs['t1'] = t1old
+ obs['t2'] = t4old
+ obs['t3'] = t3old
+ obs['t4'] = t2old
+
+ obs['u1'] = u3old
+ obs['v1'] = v3old
+
+ obs['u2'] = -u4old
+ obs['v2'] = -v4old
+
+ obs['u3'] = u1old
+ obs['v3'] = v1old
+
+ obs['u4'] = -u2old
+ obs['v4'] = -v2old
+
+ if ctype == 'logcamp':
+ obs['camp'] = -campold
+ obs['sigmaca'] = csigmaold
+ else:
+ obs['camp'] = 1. / campold
+ obs['sigmaca'] = csigmaold / (campold**2)
+
+ t1old = copy.deepcopy(obs['t1'])
+ u1old = copy.deepcopy(obs['u1'])
+ v1old = copy.deepcopy(obs['v1'])
+ t2old = copy.deepcopy(obs['t2'])
+ u2old = copy.deepcopy(obs['u2'])
+ v2old = copy.deepcopy(obs['v2'])
+ t3old = copy.deepcopy(obs['t3'])
+ u3old = copy.deepcopy(obs['u3'])
+ v3old = copy.deepcopy(obs['v3'])
+ t4old = copy.deepcopy(obs['t4'])
+ u4old = copy.deepcopy(obs['u4'])
+ v4old = copy.deepcopy(obs['v4'])
+
+ # this is all same closure amplitude, but the ordering of labels is different
+ # return the label ordering that the user requested!
+ if (obs['t2'], obs['t1'], obs['t4'], obs['t3']) == quad:
+ obs['t1'] = t2old
+ obs['t2'] = t1old
+ obs['t3'] = t4old
+ obs['t4'] = t3old
+
+ obs['u1'] = -u1old
+ obs['v1'] = -v1old
+
+ obs['u2'] = -u2old
+ obs['v2'] = -v2old
+
+ obs['u3'] = u4old
+ obs['v3'] = v4old
+
+ obs['u4'] = u3old
+ obs['v4'] = v3old
+
+ elif (obs['t3'], obs['t4'], obs['t1'], obs['t2']) == quad:
+ obs['t1'] = t3old
+ obs['t2'] = t4old
+ obs['t3'] = t1old
+ obs['t4'] = t2old
+
+ obs['u1'] = u2old
+ obs['v1'] = v2old
+
+ obs['u2'] = u1old
+ obs['v2'] = v1old
+
+ obs['u3'] = -u4old
+ obs['v3'] = -v4old
+
+ obs['u4'] = -u3old
+ obs['v4'] = -v3old
+
+ elif (obs['t4'], obs['t3'], obs['t2'], obs['t1']) == quad:
+ obs['t1'] = t4old
+ obs['t2'] = t3old
+ obs['t3'] = t2old
+ obs['t4'] = t1old
+
+ obs['u1'] = -u2old
+ obs['v1'] = -v2old
+
+ obs['u2'] = -u1old
+ obs['v2'] = -v1old
+
+ obs['u3'] = -u3old
+ obs['v3'] = -v3old
+
+ obs['u4'] = -u4old
+ obs['v4'] = -v4old
+
+ # append to output array
+ outdata.append(np.array(obs, dtype=ehc.DTCAMP))
+
+ # get selected bispectra from the visibilities directly
+ # taken from c_ampitudes() method
+ elif method=='from_vis':
+
+ # get all equal-time data, and loop over to construct closure amplitudes
+ tlist = self.tlist(conj=True)
+ for tdata in tlist:
+
+ time = tdata[0]['time']
+ if timetype in ['GMST', 'gmst'] and self.timetype == 'UTC':
+ time = obsh.utc_to_gmst(time, self.mjd)
+ if timetype in ['UTC', 'utc'] and self.timetype == 'GMST':
+ time = obsh.gmst_to_utc(time, self.mjd)
+ sites = np.array(list(set(np.hstack((tdata['t1'], tdata['t2'])))))
+ if len(sites) < 4:
+ continue
+
+ # Create a dictionary of baselines at the current time incl. conjugates;
+ l_dict = {}
+ for dat in tdata:
+ l_dict[(dat['t1'], dat['t2'])] = dat
+
+ # Select quadrangle entries in the data dictionary
+ # Blue is numerator, red is denominator
+ if (quad[0], quad[1]) not in l_dict.keys():
+ continue
+ if (quad[2], quad[3]) not in l_dict.keys():
+ continue
+ if (quad[1], quad[2]) not in l_dict.keys():
+ continue
+ if (quad[0], quad[3]) not in l_dict.keys():
+ continue
+
+ try:
+ blue1 = l_dict[quad[0], quad[1]]
+ blue2 = l_dict[quad[2], quad[3]]
+ red1 = l_dict[quad[0], quad[3]]
+ red2 = l_dict[quad[1], quad[2]]
+ except KeyError:
+ continue
+
+ # Compute the closure amplitude and the error
+ (camp, camperr) = obsh.make_closure_amplitude(blue1, blue2, red1, red2, vtype,
+ polrep=self.polrep,
+ ctype=ctype, debias=debias)
+
+ if ctype == 'camp' and camp / camperr < snrcut:
+ continue
+ elif ctype == 'logcamp' and 1. / camperr < snrcut:
+ continue
+
+ # Add the closure amplitudes to the equal-time list
+ # Our site convention is (12)(34)/(14)(23)
+ outdata.append(np.array((time,
+ quad[0], quad[1], quad[2], quad[3],
+ blue1['u'], blue1['v'], blue2['u'], blue2['v'],
+ red1['u'], red1['v'], red2['u'], red2['v'],
+ camp, camperr),
+ dtype=ehc.DTCAMP))
+
+ else:
+ raise Exception("keyword 'method' in camp_quad() must be either 'from_cphase' or 'from_vis'")
+
+ outdata = np.array(outdata)
+ return outdata
+
+ def plotall(self, field1, field2,
+ conj=False, debias=False, tag_bl=False, ang_unit='deg', timetype=False,
+ axis=False, rangex=False, rangey=False, snrcut=0.,
+ color=ehc.SCOLORS[0], marker='o', markersize=ehc.MARKERSIZE, label=None,
+ grid=True, ebar=True, axislabels=True, legend=False,
+ show=True, export_pdf=""):
+ """Plot two fields against each other.
+
+ Args:
+ field1 (str): x-axis field (from FIELDS)
+ field2 (str): y-axis field (from FIELDS)
+
+ conj (bool): Plot conjuage baseline data points if True
+ debias (bool): If True, debias amplitudes.
+ tag_bl (bool): if True, label each baseline
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+
+ color (str): color for scatterplot points
+ marker (str): matplotlib plot marker
+ markersize (int): size of plot markers
+ label (str): plot legend label
+ snrcut (float): flag closure amplitudes with snr lower than this
+
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+
+ show (bool): Display the plot if true
+ export_pdf (str): path to pdf file to save figure
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+
+ # Determine if fields are valid
+ field1 = field1.lower()
+ field2 = field2.lower()
+ if (field1 not in ehc.FIELDS) and (field2 not in ehc.FIELDS):
+ raise Exception("valid fields are " + ' '.join(ehc.FIELDS))
+
+ if 'amp' in [field1, field2] and not (self.amp is None):
+ print("Warning: plotall is not using amplitudes in Obsdata.amp array!")
+
+ # Label individual baselines
+ # ANDREW TODO this is way too slow, make it faster??
+ if tag_bl:
+ clist = ehc.SCOLORS
+
+ # make a color coding dictionary
+ cdict = {}
+ ii = 0
+ baselines = np.sort(list(it.combinations(self.tarr['site'], 2)))
+ for baseline in baselines:
+ cdict[(baseline[0], baseline[1])] = clist[ii % len(clist)]
+ cdict[(baseline[1], baseline[0])] = clist[ii % len(clist)]
+ ii += 1
+
+ # get unique baselines -- TODO easier way? separate function?
+ alldata = []
+ allsigx = []
+ allsigy = []
+ bllist = []
+ colors = []
+ bldata = self.bllist(conj=conj)
+ for bl in bldata:
+ t1 = bl['t1'][0]
+ t2 = bl['t2'][0]
+
+ bllist.append((t1, t2))
+ colors.append(cdict[(t1, t2)])
+
+ # Unpack data
+ dat = self.unpack_dat(bl, [field1, field2],
+ ang_unit=ang_unit, debias=debias, timetype=timetype)
+ alldata.append(dat)
+
+ # X error bars
+ if obsh.sigtype(field1):
+ allsigx.append(self.unpack_dat(bl, [obsh.sigtype(field1)],
+ ang_unit=ang_unit)[obsh.sigtype(field1)])
+ else:
+ allsigx.append(None)
+
+ # Y error bars
+ if obsh.sigtype(field2):
+ allsigy.append(self.unpack_dat(bl, [obsh.sigtype(field2)],
+ ang_unit=ang_unit)[obsh.sigtype(field2)])
+ else:
+ allsigy.append(None)
+
+ # Don't Label individual baselines
+ else:
+ bllist = [['All', 'All']]
+ colors = [color]
+
+ # unpack data
+ alldata = [self.unpack([field1, field2],
+ conj=conj, ang_unit=ang_unit, debias=debias, timetype=timetype)]
+
+ # X error bars
+ if obsh.sigtype(field1):
+ allsigx = self.unpack(obsh.sigtype(field2), conj=conj, ang_unit=ang_unit)
+ allsigx = [allsigx[obsh.sigtype(field1)]]
+ else:
+ allsigx = [None]
+
+ # Y error bars
+ if obsh.sigtype(field2):
+ allsigy = self.unpack(obsh.sigtype(field2), conj=conj, ang_unit=ang_unit)
+ allsigy = [allsigy[obsh.sigtype(field2)]]
+ else:
+ allsigy = [None]
+
+ # make plot(s)
+ if axis:
+ x = axis
+ else:
+ fig = plt.figure()
+ x = fig.add_subplot(1, 1, 1)
+
+ xmins = []
+ xmaxes = []
+ ymins = []
+ ymaxes = []
+ for i in range(len(alldata)):
+ data = alldata[i]
+ sigy = allsigy[i]
+ sigx = allsigx[i]
+ color = colors[i]
+ bl = bllist[i]
+
+ # Flag out nans (to avoid problems determining plotting limits)
+ mask = ~(np.isnan(data[field1]) + np.isnan(data[field2]))
+
+ # Flag out due to snrcut
+ if snrcut > 0.:
+ sigs = [sigx, sigy]
+ for jj, field in enumerate([field1, field2]):
+ if field in ehc.FIELDS_AMPS:
+ fmask = data[field] / sigs[jj] > snrcut
+ elif field in ehc.FIELDS_PHASE:
+ fmask = sigs[jj] < (180. / np.pi / snrcut)
+ elif field in ehc.FIELDS_SNRS:
+ fmask = data[field] > snrcut
+ else:
+ fmask = np.ones(mask.shape).astype(bool)
+ mask *= fmask
+
+ data = data[mask]
+ if sigy is not None:
+ sigy = sigy[mask]
+ if sigx is not None:
+ sigx = sigx[mask]
+ if len(data) == 0:
+ continue
+
+ xmins.append(np.min(data[field1]))
+ xmaxes.append(np.max(data[field1]))
+ ymins.append(np.min(data[field2]))
+ ymaxes.append(np.max(data[field2]))
+
+ # Plot the data
+ tolerance = len(data[field2])
+
+ if label is None:
+ labelstr = "%s-%s" % ((str(bl[0]), str(bl[1])))
+
+ else:
+ labelstr = str(label)
+
+ if ebar and (np.any(sigy) or np.any(sigx)):
+ x.errorbar(data[field1], data[field2], xerr=sigx, yerr=sigy, label=labelstr,
+ fmt=marker, markersize=markersize, color=color, picker=tolerance)
+ else:
+ x.plot(data[field1], data[field2], marker, markersize=markersize, color=color,
+ label=labelstr, picker=tolerance)
+
+ # Data ranges
+ if not rangex:
+ rangex = [np.min(xmins) - 0.2 * np.abs(np.min(xmins)),
+ np.max(xmaxes) + 0.2 * np.abs(np.max(xmaxes))]
+ if np.any(np.isnan(np.array(rangex))):
+ print("Warning: NaN in data x range: specifying rangex to default")
+ rangex = [-100, 100]
+
+ if not rangey:
+ rangey = [np.min(ymins) - 0.2 * np.abs(np.min(ymins)),
+ np.max(ymaxes) + 0.2 * np.abs(np.max(ymaxes))]
+ if np.any(np.isnan(np.array(rangey))):
+ print("Warning: NaN in data y range: specifying rangey to default")
+ rangey = [-100, 100]
+
+ x.set_xlim(rangex)
+ x.set_ylim(rangey)
+
+ # label and save
+ if axislabels:
+ try:
+ x.set_xlabel(ehc.FIELD_LABELS[field1])
+ x.set_ylabel(ehc.FIELD_LABELS[field2])
+ except KeyError:
+ x.set_xlabel(field1.capitalize())
+ x.set_ylabel(field2.capitalize())
+ if legend and tag_bl:
+ plt.legend(ncol=2)
+ elif legend:
+ plt.legend()
+ if grid:
+ x.grid()
+ if export_pdf != "" and not axis:
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if export_pdf != "" and axis:
+ fig = plt.gcf()
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ return x
+
+ def plot_bl(self, site1, site2, field,
+ debias=False, ang_unit='deg', timetype=False,
+ axis=False, rangex=False, rangey=False, snrcut=0.,
+ color=ehc.SCOLORS[0], marker='o', markersize=ehc.MARKERSIZE, label=None,
+ grid=True, ebar=True, axislabels=True, legend=False,
+ show=True, export_pdf=""):
+ """Plot a field over time on a baseline site1-site2.
+
+ Args:
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ field (str): y-axis field (from FIELDS)
+
+ debias (bool): If True, debias amplitudes.
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+
+ color (str): color for scatterplot points
+ marker (str): matplotlib plot marker
+ markersize (int): size of plot markers
+ label (str): plot legend label
+
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+ export_pdf (str): path to pdf file to save figure
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+
+ field = field.lower()
+ if field == 'amp' and not (self.amp is None):
+ print("Warning: plot_bl is not using amplitudes in Obsdata.amp array!")
+
+ if label is None:
+ label = str(self.source)
+ else:
+ label = str(label)
+
+ # Determine if fields are valid
+ if field not in ehc.FIELDS:
+ raise Exception("valid fields are " + string.join(ehc.FIELDS))
+
+ plotdata = self.unpack_bl(site1, site2, field, ang_unit=ang_unit,
+ debias=debias, timetype=timetype)
+ sigmatype = obsh.sigtype(field)
+ if obsh.sigtype(field):
+ errdata = self.unpack_bl(site1, site2, obsh.sigtype(field),
+ ang_unit=ang_unit, debias=debias)
+ else:
+ errdata = None
+
+ # Flag out nans (to avoid problems determining plotting limits)
+ mask = ~np.isnan(plotdata[field][:, 0])
+
+ # Flag out due to snrcut
+ if snrcut > 0.:
+ if field in ehc.FIELDS_AMPS:
+ fmask = plotdata[field] / errdata[sigmatype] > snrcut
+ elif field in ehc.FIELDS_PHASE:
+ fmask = errdata[sigmatype] < (180. / np.pi / snrcut)
+ elif field in ehc.FIELDS_SNRS:
+ fmask = plotdata[field] > snrcut
+ else:
+ fmask = np.ones(mask.shape).astype(bool)
+ fmask = fmask[:, 0]
+ mask *= fmask
+
+ plotdata = plotdata[mask]
+ if errdata is not None:
+ errdata = errdata[mask]
+
+ if not rangex:
+ rangex = [self.tstart, self.tstop]
+ if np.any(np.isnan(np.array(rangex))):
+ print("Warning: NaN in data x range: specifying rangex to default")
+ rangex = [0, 24]
+ if not rangey:
+ rangey = [np.min(plotdata[field]) - 0.2 * np.abs(np.min(plotdata[field])),
+ np.max(plotdata[field]) + 0.2 * np.abs(np.max(plotdata[field]))]
+ if np.any(np.isnan(np.array(rangey))):
+ print("Warning: NaN in data y range: specifying rangex to default")
+ rangey = [-100, 100]
+
+ # Plot the data
+ if axis:
+ x = axis
+ else:
+ fig = plt.figure()
+ x = fig.add_subplot(1, 1, 1)
+
+ if ebar and obsh.sigtype(field) is not False:
+ x.errorbar(plotdata['time'][:, 0], plotdata[field][:, 0],
+ yerr=errdata[obsh.sigtype(field)][:, 0],
+ fmt=marker, markersize=markersize, color=color,
+ linestyle='none', label=label)
+ else:
+ x.plot(plotdata['time'][:, 0], plotdata[field][:, 0], marker, markersize=markersize,
+ color=color, label=label, linestyle='none')
+
+ x.set_xlim(rangex)
+ x.set_ylim(rangey)
+
+ if axislabels:
+ x.set_xlabel(timetype + ' (hr)')
+ try:
+ x.set_ylabel(ehc.FIELD_LABELS[field])
+ except KeyError:
+ x.set_ylabel(field.capitalize())
+ x.set_title('%s - %s' % (site1, site2))
+
+ if grid:
+ x.grid()
+ if legend:
+ plt.legend()
+ if export_pdf != "" and not axis:
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if export_pdf != "" and axis:
+ fig = plt.gcf()
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+ return x
+
+ def plot_cphase(self, site1, site2, site3,
+ vtype='vis', cphases=[], force_recompute=False,
+ ang_unit='deg', timetype=False, snrcut=0.,
+ axis=False, rangex=False, rangey=False,
+ color=ehc.SCOLORS[0], marker='o', markersize=ehc.MARKERSIZE, label=None,
+ grid=True, ebar=True, axislabels=True, legend=False,
+ show=True, export_pdf=""):
+ """Plot a closure phase over time on a triangle site1-site2-site3.
+
+ Args:
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+
+ vtype (str): The visibilty type from which to assemble closure phases
+ ('vis','qvis','uvis','vvis','pvis')
+ cphases (list): optionally pass in the prcomputed, time-sorted closure phases
+ force_recompute (bool): if True, do not use stored closure phase able
+ snrcut (float): flag closure amplitudes with snr lower than this
+
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+
+ color (str): color for scatterplot points
+ marker (str): matplotlib plot marker
+ markersize (int): size of plot markers
+ label (str): plot legend label
+
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+
+ show (bool): Display the plot if True
+ export_pdf (str): path to pdf file to save figure
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+ if ang_unit == 'deg':
+ angle = 1.0
+ else:
+ angle = ehc.DEGREE
+
+ if label is None:
+ label = str(self.source)
+ else:
+ label = str(label)
+
+ # Get closure phases (maximal set)
+ if (len(cphases) == 0) and (self.cphase is not None) and not force_recompute:
+ cphases = self.cphase
+
+ cpdata = self.cphase_tri(site1, site2, site3, vtype=vtype, timetype=timetype,
+ cphases=cphases, force_recompute=force_recompute, snrcut=snrcut)
+ plotdata = np.array([[obs['time'], obs['cphase'] * angle, obs['sigmacp']]
+ for obs in cpdata])
+
+ nan_mask = np.isnan(plotdata[:, 1])
+ plotdata = plotdata[~nan_mask]
+
+ if len(plotdata) == 0:
+ print("%s %s %s : No closure phases on this triangle!" % (site1, site2, site3))
+ return
+
+ # Plot the data
+ if axis:
+ x = axis
+ else:
+ fig = plt.figure()
+ x = fig.add_subplot(1, 1, 1)
+
+ # Data ranges
+ if not rangex:
+ rangex = [self.tstart, self.tstop]
+ if np.any(np.isnan(np.array(rangex))):
+ print("Warning: NaN in data x range: specifying rangex to default")
+ rangex = [0, 24]
+
+ if not rangey:
+ if ang_unit == 'deg':
+ rangey = [-190, 190]
+ else:
+ rangey = [-1.1 * np.pi, 1.1 * np.pi]
+
+ x.set_xlim(rangex)
+ x.set_ylim(rangey)
+
+ if ebar and np.any(plotdata[:, 2]):
+ x.errorbar(plotdata[:, 0], plotdata[:, 1], yerr=plotdata[:, 2],
+ fmt=marker, markersize=markersize,
+ color=color, linestyle='none', label=label)
+ else:
+ x.plot(plotdata[:, 0], plotdata[:, 1], marker, markersize=markersize,
+ color=color, linestyle='none', label=label)
+
+ if axislabels:
+ x.set_xlabel(self.timetype + ' (h)')
+ if ang_unit == 'deg':
+ x.set_ylabel(r'Closure Phase $(^\circ)$')
+ else:
+ x.set_ylabel(r'Closure Phase (radian)')
+
+ x.set_title('%s - %s - %s' % (site1, site2, site3))
+
+ if grid:
+ x.grid()
+ if legend:
+ plt.legend()
+ if export_pdf != "" and not axis:
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if export_pdf != "" and axis:
+ fig = plt.gcf()
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ return x
+
+ def plot_camp(self, site1, site2, site3, site4,
+ vtype='vis', ctype='camp', camps=[], force_recompute=False,
+ debias=False, timetype=False, snrcut=0.,
+ axis=False, rangex=False, rangey=False,
+ color=ehc.SCOLORS[0], marker='o', markersize=ehc.MARKERSIZE, label=None,
+ grid=True, ebar=True, axislabels=True, legend=False,
+ show=True, export_pdf=""):
+ """Plot closure amplitude over time on a quadrangle (1-2)(3-4)/(1-4)(2-3).
+
+ Args:
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+ site4 (str): station 4 name
+
+ vtype (str): The visibilty type from which to assemble closure amplitudes
+ ('vis','qvis','uvis','vvis','pvis')
+ ctype (str): The closure amplitude type ('camp' or 'logcamp')
+ camps (list): optionally pass in camps so they don't have to be recomputed
+ force_recompute (bool): if True, recompute camps instead of using stored data
+ snrcut (float): flag closure amplitudes with snr lower than this
+
+ debias (bool): If True, debias the closure amplitude
+ timetype (str): 'GMST' or 'UTC'
+
+ axis (matplotlib.axes.Axes): amake_cdd plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ color (str): color for scatterplot points
+ marker (str): matplotlib plot marker
+ markersize (int): size of plot markers
+ label (str): plot legend label
+
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+
+ show (bool): Display the plot if True
+ export_pdf (str): path to pdf file to save figure
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+ """
+
+ if timetype is False:
+ timetype = self.timetype
+ if label is None:
+ label = str(self.source)
+
+ else:
+ label = str(label)
+
+ # Get closure amplitudes (maximal set)
+ if ((ctype == 'camp') and (len(camps) == 0) and (self.camp is not None) and
+ not (len(self.camp) == 0) and not force_recompute):
+ camps = self.camp
+ elif ((ctype == 'logcamp') and (len(camps) == 0) and (self.logcamp is not None) and
+ not (len(self.logcamp) == 0) and not force_recompute):
+ camps = self.logcamp
+
+ # Get closure amplitudes (maximal set)
+ cpdata = self.camp_quad(site1, site2, site3, site4,
+ vtype=vtype, ctype=ctype, snrcut=snrcut,
+ debias=debias, timetype=timetype,
+ camps=camps, force_recompute=force_recompute)
+
+ if len(cpdata) == 0:
+ print('No closure amplitudes on this triangle!')
+ return
+
+ plotdata = np.array([[obs['time'], obs['camp'], obs['sigmaca']] for obs in cpdata])
+ plotdata = np.array(plotdata)
+
+ nan_mask = np.isnan(plotdata[:, 1])
+ plotdata = plotdata[~nan_mask]
+
+ if len(plotdata) == 0:
+ print("No closure amplitudes on this quadrangle!")
+ return
+
+ # Data ranges
+ if not rangex:
+ rangex = [self.tstart, self.tstop]
+ if np.any(np.isnan(np.array(rangex))):
+ print("Warning: NaN in data x range: specifying rangex to default")
+ rangex = [0, 24]
+
+ if not rangey:
+ rangey = [np.min(plotdata[:, 1]) - 0.2 * np.abs(np.min(plotdata[:, 1])),
+ np.max(plotdata[:, 1]) + 0.2 * np.abs(np.max(plotdata[:, 1]))]
+ if np.any(np.isnan(np.array(rangey))):
+ print("Warning: NaN in data y range: specifying rangey to default")
+ if ctype == 'camp':
+ rangey = [0, 100]
+ if ctype == 'logcamp':
+ rangey = [-10, 10]
+
+ # Plot the data
+ if axis:
+ x = axis
+ else:
+ fig = plt.figure()
+ x = fig.add_subplot(1, 1, 1)
+
+ if ebar and np.any(plotdata[:, 2]):
+ x.errorbar(plotdata[:, 0], plotdata[:, 1], yerr=plotdata[:, 2],
+ fmt=marker, markersize=markersize,
+ color=color, linestyle='none', label=label)
+ else:
+ x.plot(plotdata[:, 0], plotdata[:, 1], marker, markersize=markersize,
+ color=color, linestyle='none', label=label)
+
+ x.set_xlim(rangex)
+ x.set_ylim(rangey)
+
+ if axislabels:
+ x.set_xlabel(self.timetype + ' (h)')
+ if ctype == 'camp':
+ x.set_ylabel('Closure Amplitude')
+ elif ctype == 'logcamp':
+ x.set_ylabel('Log Closure Amplitude')
+ x.set_title('(%s - %s)(%s - %s)/(%s - %s)(%s - %s)' % (site1, site2, site3, site4,
+ site1, site4, site2, site3))
+ if grid:
+ x.grid()
+ if legend:
+ plt.legend()
+ if export_pdf != "" and not axis:
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if export_pdf != "" and axis:
+ fig = plt.gcf()
+ fig.savefig(export_pdf, bbox_inches='tight')
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+ return
+ else:
+ return x
+
+ def save_txt(self, fname):
+ """Save visibility data to a text file.
+
+ Args:
+ fname (str): path to output text file
+ """
+
+ ehtim.io.save.save_obs_txt(self, fname)
+
+ return
+
+ def save_uvfits(self, fname, force_singlepol=False, polrep_out='circ'):
+ """Save visibility data to uvfits file.
+ Args:
+ fname (str): path to output uvfits file.
+ force_singlepol (str): if 'R' or 'L', will interpret stokes I field as 'RR' or 'LL'
+ polrep_out (str): 'circ' or 'stokes': how data should be stored in the uvfits file
+ """
+
+ if (force_singlepol is not False) and (self.polrep != 'stokes'):
+ raise Exception(
+ "force_singlepol is incompatible with polrep!='stokes'")
+
+ output = ehtim.io.save.save_obs_uvfits(self, fname,
+ force_singlepol=force_singlepol, polrep_out=polrep_out)
+
+ return
+
+ def make_hdulist(self, force_singlepol=False, polrep_out='circ'):
+ """Returns an hdulist in the same format as in a saved .uvfits file.
+ Args:
+ force_singlepol (str): if 'R' or 'L', will interpret stokes I field as 'RR' or 'LL'
+ polrep_out (str): 'circ' or 'stokes': how data should be stored in the uvfits file
+ Returns:
+ hdulist (astropy.io.fits.HDUList)
+ """
+
+ if (force_singlepol is not False) and (self.polrep != 'stokes'):
+ raise Exception(
+ "force_singlepol is incompatible with polrep!='stokes'")
+
+ hdulist = ehtim.io.save.save_obs_uvfits(self, None,
+ force_singlepol=force_singlepol, polrep_out=polrep_out)
+ return hdulist
+
+
+ def save_oifits(self, fname, flux=1.0):
+ """ Save visibility data to oifits. Polarization data is NOT saved.
+
+ Args:
+ fname (str): path to output text file
+ flux (float): normalization total flux
+ """
+
+ if self.polrep != 'stokes':
+ raise Exception("save_oifits not yet implemented for polreps other than 'stokes'")
+
+ # Antenna diameters are currently incorrect
+ # the exact times are also not correct in the datetime object
+ ehtim.io.save.save_obs_oifits(self, fname, flux=flux)
+
+ return
+
+##################################################################################################
+# Observation creation functions
+##################################################################################################
+
+
+def merge_obs(obs_List, force_merge=False):
+ """Merge a list of observations into a single observation file.
+
+ Args:
+ obs_List (list): list of split observation Obsdata objects.
+ force_merge (bool): forces the observations to merge even if parameters are different
+
+ Returns:
+ mergeobs (Obsdata): merged Obsdata object containing all scans in input list
+ """
+
+ if (len(set([obs.polrep for obs in obs_List])) > 1):
+ raise Exception("All observations must have the same polarization representation !")
+ return
+
+ if np.any([obs.timetype == 'GMST' for obs in obs_List]):
+ raise Exception("merge_obs only works for observations with obs.timetype='UTC'!")
+ return
+
+ if not force_merge:
+ if (len(set([obs.ra for obs in obs_List])) > 1 or
+ len(set([obs.dec for obs in obs_List])) > 1 or
+ len(set([obs.rf for obs in obs_List])) > 1 or
+ len(set([obs.bw for obs in obs_List])) > 1 or
+ len(set([obs.source for obs in obs_List])) > 1):
+ # or len(set([np.floor(obs.mjd) for obs in obs_List])) > 1):
+
+ raise Exception("All observations must have the same parameters!")
+ return
+
+ # the reference observation is the one with the minimum mjd
+ obs_idx = np.argmin([obs.mjd for obs in obs_List])
+ obs_ref = obs_List[obs_idx]
+
+ # re-reference times to new mjd
+ # must be in UTC!
+ mjd_ref = obs_ref.mjd
+ for obs in obs_List:
+ mjd_offset = obs.mjd - mjd_ref
+ obs.data['time'] += mjd_offset * 24
+ if not(obs.scans is None or obs.scans == []):
+ obs.scans += mjd_offset * 24
+
+ # merge the data
+ data_merge = np.hstack([obs.data for obs in obs_List])
+
+ # merge the scan list
+ scan_merge = None
+ for obs in obs_List:
+ if (obs.scans is None or obs.scans == []):
+ continue
+ if (scan_merge is None or scan_merge == []):
+ scan_merge = [obs.scans]
+ else:
+ scan_merge.append(obs.scans)
+
+ if not (scan_merge is None or scan_merge == []):
+ scan_merge = np.vstack(scan_merge)
+ _idxsort = np.argsort(scan_merge[:, 0])
+ scan_merge = scan_merge[_idxsort]
+
+ # merge the list of telescopes
+ tarr_merge = np.unique(np.concatenate([obs.tarr for obs in obs_List]))
+
+ arglist, argdict = obs_ref.obsdata_args()
+ arglist[DATPOS] = data_merge
+ arglist[TARRPOS] = tarr_merge
+ argdict['scantable'] = scan_merge
+ mergeobs = Obsdata(*arglist, **argdict)
+
+ return mergeobs
+
+
+def load_txt(fname, polrep='stokes'):
+ """Read an observation from a text file.
+
+ Args:
+ fname (str): path to input text file
+ polrep (str): load data as either 'stokes' or 'circ'
+
+ Returns:
+ obs (Obsdata): Obsdata object loaded from file
+ """
+
+ return ehtim.io.load.load_obs_txt(fname, polrep=polrep)
+
+
+def load_uvfits(fname, flipbl=False, remove_nan=False, force_singlepol=None,
+ channel=all, IF=all, polrep='stokes', allow_singlepol=True,
+ ignore_pzero_date=True,
+ trial_speedups=False):
+ """Load observation data from a uvfits file.
+
+ Args:
+ fname (str or HDUList): path to input text file or HDUList object
+ flipbl (bool): flip baseline phases if True.
+ remove_nan (bool): True to remove nans from missing polarizations
+ polrep (str): load data as either 'stokes' or 'circ'
+ force_singlepol (str): 'R' or 'L' to load only 1 polarization
+ channel (list): list of channels to average in the import. channel=all averages all
+ IF (list): list of IFs to average in the import. IF=all averages all IFS
+ remove_nan (bool): whether or not to remove entries with nan data
+
+ ignore_pzero_date (bool): if True, ignore the offset parameters in DATE field
+ TODO: what is the correct behavior per AIPS memo 117?
+ Returns:
+ obs (Obsdata): Obsdata object loaded from file
+ """
+
+ return ehtim.io.load.load_obs_uvfits(fname, flipbl=flipbl, force_singlepol=force_singlepol,
+ channel=channel, IF=IF, polrep=polrep,
+ remove_nan=remove_nan, allow_singlepol=allow_singlepol,
+ ignore_pzero_date=ignore_pzero_date,
+ trial_speedups=trial_speedups)
+
+
+def load_oifits(fname, flux=1.0):
+ """Load data from an oifits file. Does NOT currently support polarization.
+
+ Args:
+ fname (str): path to input text file
+ flux (float): normalization total flux
+
+ Returns:
+ obs (Obsdata): Obsdata object loaded from file
+ """
+
+ return ehtim.io.load.load_obs_oifits(fname, flux=flux)
+
+
+def load_maps(arrfile, obsspec, ifile, qfile=0, ufile=0, vfile=0,
+ src=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, ampcal=False, phasecal=False):
+ """Read an observation from a maps text file and return an Obsdata object.
+
+ Args:
+ arrfile (str): path to input array file
+ obsspec (str): path to input obs spec file
+ ifile (str): path to input Stokes I data file
+ qfile (str): path to input Stokes Q data file
+ ufile (str): path to input Stokes U data file
+ vfile (str): path to input Stokes V data file
+ src (str): source name
+ mjd (int): integer observation MJD
+ ampcal (bool): True if amplitude calibrated
+ phasecal (bool): True if phase calibrated
+
+ Returns:
+ obs (Obsdata): Obsdata object loaded from file
+ """
+
+ return ehtim.io.load.load_obs_maps(arrfile, obsspec, ifile,
+ qfile=qfile, ufile=ufile, vfile=vfile,
+ src=src, mjd=mjd, ampcal=ampcal, phasecal=phasecal)
+
+def load_obs(
+ fname,
+ polrep='stokes',
+ flipbl=False,
+ remove_nan=False,
+ force_singlepol=None,
+ channel=all,
+ IF=all,
+ allow_singlepol=True,
+ flux=1.0,
+ obsspec=None,
+ ifile=None,
+ qfile=None,
+ ufile=None,
+ vfile=None,
+ src=ehc.SOURCE_DEFAULT,
+ mjd=ehc.MJD_DEFAULT,
+ ampcal=False,
+ phasecal=False
+ ):
+ """Smart obs read-in, detects file type and loads appropriately.
+
+ Args:
+ fname (str): path to input text file
+ polrep (str): load data as either 'stokes' or 'circ'
+ flipbl (bool): flip baseline phases if True.
+ remove_nan (bool): True to remove nans from missing polarizations
+ polrep (str): load data as either 'stokes' or 'circ'
+ force_singlepol (str): 'R' or 'L' to load only 1 polarization
+ channel (list): list of channels to average in the import. channel=all averages all
+ IF (list): list of IFs to average in the import. IF=all averages all IFS
+ flux (float): normalization total flux
+ obsspec (str): path to input obs spec file
+ ifile (str): path to input Stokes I data file
+ qfile (str): path to input Stokes Q data file
+ ufile (str): path to input Stokes U data file
+ vfile (str): path to input Stokes V data file
+ src (str): source name
+ mjd (int): integer observation MJD
+ ampcal (bool): True if amplitude calibrated
+ phasecal (bool): True if phase calibrated
+
+ Returns:
+ obs (Obsdata): Obsdata object loaded from file
+ """
+
+
+ ## grab file ending ##
+ fname_extension = fname.split('.')[-1]
+ print(f"Extension is {fname_extension}.")
+
+ ## check extension ##
+ if fname_extension.lower() == 'uvfits':
+ return load_uvfits(fname, flipbl=flipbl, remove_nan=remove_nan, force_singlepol=force_singlepol, channel=channel, IF=IF, polrep=polrep, allow_singlepol=allow_singlepol)
+
+ elif fname_extension.lower() in ['txt', 'text']:
+ return load_txt(fname, polrep=polrep)
+
+ elif fname_extension.lower() == 'oifits':
+ return load_oifits(fname, flux=flux)
+
+
+ else:
+ if obsspec is not None and ifile is None:
+ print("You have provided a value for but no value for ")
+ return
+ elif obsspec is None and ifile is not None:
+ print("You have provided a value for but no value for ")
+ return
+
+ elif obsspec is not None and ifile is not None:
+ return load_maps(fname, obsspec, ifile, qfile=qfile, ufile=ufile, vfile=vfile,
+ src=src, mjd=mjd, ampcal=ampcal, phasecal=phasecal)
+
diff --git a/observing/__init__.py b/observing/__init__.py
new file mode 100644
index 00000000..213aefd6
--- /dev/null
+++ b/observing/__init__.py
@@ -0,0 +1,11 @@
+"""
+.. module:: ehtim.observing
+ :platform: Unix
+ :synopsis: EHT Imaging Utilities: simulated observation functions
+
+.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu)
+
+"""
+from . import pulses
+from . import obs_helpers
+from . import obs_simulate
diff --git a/observing/obs_helpers.py b/observing/obs_helpers.py
new file mode 100644
index 00000000..6dc1dbaa
--- /dev/null
+++ b/observing/obs_helpers.py
@@ -0,0 +1,1776 @@
+# obs_helpers.py
+# helper functions for simulating and manipulating observations
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+try:
+ from sgp4.api import Satrec, WGS72
+ import skyfield.api
+except ImportError:
+ print("Warning: skyfield not installed: cannot simulate space VLBI")
+
+try:
+ from pynfft.nfft import NFFT
+except ImportError:
+ print("Warning: No NFFT installed!")
+
+import astropy.time as at
+import astropy.coordinates as coords
+import numpy as np
+import itertools as it
+import scipy.ndimage as nd
+import scipy.spatial.distance
+import copy
+import sys
+
+import ehtim.const_def as ehc
+
+import warnings
+warnings.filterwarnings("ignore", message="divide by zero encountered in double_scalars")
+
+##################################################################################################
+# Other Functions
+##################################################################################################
+
+
+def compute_uv_coordinates(array, site1, site2, time, mjd, ra, dec, rf, timetype='UTC',
+ elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH, no_elevcut_space=False,
+ fix_theta_GMST=False):
+
+ """Compute u,v coordinates for an array at a given time for a source at a given ra,dec,rf
+ """
+
+ if not isinstance(time, np.ndarray):
+ time = np.array([time]).flatten()
+ if not isinstance(site1, np.ndarray):
+ site1 = np.array([site1]).flatten()
+ if not isinstance(site2, np.ndarray):
+ site2 = np.array([site2]).flatten()
+
+ if len(site1) == len(site2) == 1:
+ site1 = np.array([site1[0] for i in range(len(time))])
+ site2 = np.array([site2[0] for i in range(len(time))])
+ elif not (len(site1) == len(site2) == len(time)):
+ raise Exception("site1, site2, and time not the same dimension in compute_uv_coordinates!")
+
+ # Source vector
+ sourcevec = np.array([np.cos(dec*ehc.DEGREE), 0, np.sin(dec*ehc.DEGREE)])
+ projU = np.cross(np.array([0, 0, 1]), sourcevec)
+ projU = projU/np.linalg.norm(projU)
+ projV = -np.cross(projU, sourcevec)
+
+ # Wavelength
+ wvl = ehc.C/rf
+
+ if timetype == 'GMST':
+ time_sidereal = time
+ # time_utc = gmst_to_utc(time, mjd)
+ elif timetype == 'UTC':
+ time_sidereal = utc_to_gmst(time, mjd)
+ # time_utc = time
+ else:
+ raise Exception("timetype must be UTC or GMST!")
+
+ fracmjd = np.floor(mjd) + time/24.
+ dto = (at.Time(fracmjd, format='mjd')).datetime
+ theta = np.mod((time_sidereal - ra)*ehc.HOUR, 2*np.pi)
+ if type(fix_theta_GMST) != bool:
+ theta = np.mod((fix_theta_GMST - ra)*ehc.HOUR, 2*np.pi)
+
+ i1 = np.array([array.tkey[site] for site in site1])
+ i2 = np.array([array.tkey[site] for site in site2])
+
+ coord1 = np.vstack((array.tarr[i1]['x'], array.tarr[i1]['y'], array.tarr[i1]['z'])).T
+ coord2 = np.vstack((array.tarr[i2]['x'], array.tarr[i2]['y'], array.tarr[i2]['z'])).T
+
+ # Satellites: new method
+ spacemask1 = [np.all(coord == (0., 0., 0.)) for coord in coord1]
+ spacemask2 = [np.all(coord == (0., 0., 0.)) for coord in coord2]
+
+ satnames = array.ephem.keys()
+ satdict = {satname: sat_skyfield_from_ephementry(satname, array.ephem, mjd) for satname in satnames}
+ for satname in satnames:
+ sat = satdict[satname]
+
+ mask1 = (site1==satname)
+ c1 = orbit_skyfield(sat, fracmjd[mask1], whichout='itrs')
+ coord1[mask1] = c1.T
+
+ mask2 = (site2==satname)
+ c2 = orbit_skyfield(sat, fracmjd[mask2], whichout='itrs')
+ coord2[mask2] = c2.T
+
+ # Satellites: old method
+ """
+ if np.any(spacemask1):
+ if timetype == 'GMST':
+ raise Exception("Spacecraft ephemeris only work with UTC!")
+
+ site1space_list = site1[spacemask1]
+ site1space_fracmjdlist = fracmjd[spacemask1]
+ site1space_dtolist = dto[spacemask1]
+ coord1space = []
+
+ for k in range(len(site1space_list)):
+
+
+ # old method with pyephem
+
+ site1space = site1space_list[k]
+ dto_now = site1space_dtolist[k]
+ sat = ephem.readtle(array.ephem[site1space][0],
+ array.ephem[site1space][1], array.ephem[site1space][2])
+ sat.compute(dto_now) # often complains if ephemeris out of date!
+ elev = sat.elevation
+ lat = sat.sublat / ehc.DEGREE
+ lon = sat.sublong / ehc.DEGREE
+ # pyephem doesn't use an ellipsoid earth model!
+ c1 = coords.EarthLocation.from_geodetic(lon, lat, elev, ellipsoid=None)
+ c1 = np.array((c1.x.value, c1.y.value, c1.z.value))
+ coord1space.append(c1)
+
+ coord1space = np.array(coord1space)
+ coord1[spacemask1] = coord1space
+
+ # use spacecraft ephemeris to get position of site 2
+ spacemask2 = [np.all(coord == (0., 0., 0.)) for coord in coord2]
+ if np.any(spacemask2):
+ if timetype == 'GMST':
+ raise Exception("Spacecraft ephemeris only work with UTC!")
+
+ site2space_list = site2[spacemask2]
+ site2space_fracmjdlist = fracmjd[spacemask2]
+ site2space_dtolist = dto[spacemask2]
+ coord2space = []
+ for k in range(len(site2space_list)):
+
+ site2space = site2space_list[k]
+ dto_now = site2space_dtolist[k]
+ sat = ephem.readtle(array.ephem[site2space][0],
+ array.ephem[site2space][1],
+ array.ephem[site2space][2])
+ sat.compute(dto_now) # often complains if ephemeris out of date!
+ elev = sat.elevation
+ lat = sat.sublat / ehc.DEGREE
+ lon = sat.sublong / ehc.DEGREE
+ # pyephem doesn't use an ellipsoid earth model!
+ c2 = coords.EarthLocation.from_geodetic(lon, lat, elev, ellipsoid=None)
+ c2 = np.array((c2.x.value, c2.y.value, c2.z.value))
+ coord2space.append(c2)
+
+ coord2space = np.array(coord2space)
+ coord2[spacemask2] = coord2space
+ """
+
+ # rotate the station coordinates with the earth
+ coord1 = earthrot(coord1, theta)
+ coord2 = earthrot(coord2, theta)
+
+ # u,v coordinates
+ u = np.dot((coord1 - coord2)/wvl, projU) # u (lambda)
+ v = np.dot((coord1 - coord2)/wvl, projV) # v (lambda)
+
+ # mask out below elevation cut
+ mask_elev_1 = elevcut(coord1, sourcevec, elevmin=elevmin, elevmax=elevmax)
+ mask_elev_2 = elevcut(coord2, sourcevec, elevmin=elevmin, elevmax=elevmax)
+
+ # do NOT apply elevation cut for space orbiters
+ if no_elevcut_space:
+ mask_elev_1[spacemask1] = 1
+ mask_elev_2[spacemask2] = 1
+
+ # apply elevation mask
+ mask = mask_elev_1 * mask_elev_2
+
+ time = time[mask]
+ u = u[mask]
+ v = v[mask]
+
+ # return times and uv points where we have data
+ return (time, u, v)
+
+
+def make_bispectrum(l1, l2, l3, vistype, polrep='stokes'):
+ """Make a list of bispectra and errors
+ l1,l2,l3 are full datatables of visibility entries
+ vtype is visibility types
+ """
+
+ if vistype == 'pvis':
+ vtype = 'rlvis'
+ else:
+ vtype = vistype
+
+ # Choose the appropriate polarization and compute the bs and err
+ if polrep == 'stokes':
+ if vtype in ["vis", "qvis", "uvis", "vvis"]:
+ if vtype == 'vis':
+ sigmatype = 'sigma'
+ elif vtype == 'qvis':
+ sigmatype = 'qsigma'
+ elif vtype == 'uvis':
+ sigmatype = 'usigma'
+ elif vtype == 'vvis':
+ sigmatype = 'vsigma'
+
+ p1 = l1[vtype]
+ p2 = l2[vtype]
+ p3 = l3[vtype]
+
+ var1 = l1[sigmatype]**2
+ var2 = l2[sigmatype]**2
+ var3 = l3[sigmatype]**2
+
+ elif vtype == "rrvis":
+ p1 = l1['vis'] + l1['vvis']
+ p2 = l2['vis'] + l2['vvis']
+ p3 = l3['vis'] + l3['vvis']
+
+ var1 = l1['sigma']**2 + l1['vsigma']**2
+ var2 = l2['sigma']**2 + l2['vsigma']**2
+ var3 = l3['sigma']**2 + l3['vsigma']**2
+
+ elif vtype == "llvis":
+ p1 = l1['vis'] - l1['vvis']
+ p2 = l2['vis'] - l2['vvis']
+ p3 = l3['vis'] - l3['vvis']
+
+ var1 = l1['sigma']**2 + l1['vsigma']**2
+ var2 = l2['sigma']**2 + l2['vsigma']**2
+ var3 = l3['sigma']**2 + l3['vsigma']**2
+
+ elif vtype == "lrvis":
+ p1 = l1['qvis'] - 1j*l1['uvis']
+ p2 = l2['qvis'] - 1j*l2['uvis']
+ p3 = l3['qvis'] - 1j*l3['uvis']
+
+ var1 = l1['qsigma']**2 + l1['usigma']**2
+ var2 = l2['qsigma']**2 + l2['usigma']**2
+ var3 = l3['qsigma']**2 + l3['usigma']**2
+
+ elif vtype == "rlvis":
+ p1 = l1['qvis'] + 1j*l1['uvis']
+ p2 = l2['qvis'] + 1j*l2['uvis']
+ p3 = l3['qvis'] + 1j*l3['uvis']
+
+ var1 = l1['qsigma']**2 + l1['usigma']**2
+ var2 = l2['qsigma']**2 + l2['usigma']**2
+ var3 = l3['qsigma']**2 + l3['usigma']**2
+
+ elif polrep == 'circ':
+ if vtype in ["rrvis", "llvis", "rlvis", "lrvis"]:
+
+
+
+ if vtype == 'rrvis':
+ sigmatype = 'rrsigma'
+ elif vtype == 'llvis':
+ sigmatype = 'llsigma'
+ elif vtype == 'rlvis':
+ sigmatype = 'rlsigma'
+ elif vtype == 'lrvis':
+ sigmatype = 'lrsigma'
+
+ p1 = l1[vtype]
+ p2 = l2[vtype]
+ p3 = l3[vtype]
+
+ var1 = l1[sigmatype]**2
+ var2 = l2[sigmatype]**2
+ var3 = l3[sigmatype]**2
+
+ elif vtype == "vis":
+ p1 = 0.5*(l1['rrvis'] + l1['llvis'])
+ p2 = 0.5*(l2['rrvis'] + l2['llvis'])
+ p3 = 0.5*(l3['rrvis'] + l3['llvis'])
+
+ var1 = 0.25*(l1['rrsigma']**2 + l1['llsigma']**2)
+ var2 = 0.25*(l2['rrsigma']**2 + l2['llsigma']**2)
+ var3 = 0.25*(l3['rrsigma']**2 + l3['llsigma']**2)
+
+ elif vtype == "vvis":
+ p1 = 0.5*(l1['rrvis'] - l1['llvis'])
+ p2 = 0.5*(l2['rrvis'] - l2['llvis'])
+ p3 = 0.5*(l3['rrvis'] - l3['llvis'])
+
+ var1 = 0.25*(l1['rrsigma']**2 + l1['llsigma']**2)
+ var2 = 0.25*(l2['rrsigma']**2 + l2['llsigma']**2)
+ var3 = 0.25*(l3['rrsigma']**2 + l3['llsigma']**2)
+
+ elif vtype == "qvis":
+ p1 = 0.5*(l1['lrvis'] + l1['rlvis'])
+ p2 = 0.5*(l2['lrvis'] + l2['rlvis'])
+ p3 = 0.5*(l3['lrvis'] + l3['rlvis'])
+
+ var1 = 0.25*(l1['lrsigma']**2 + l1['rlsigma']**2)
+ var2 = 0.25*(l2['lrsigma']**2 + l2['rlsigma']**2)
+ var3 = 0.25*(l3['lrsigma']**2 + l3['rlsigma']**2)
+
+ elif vtype == "uvis":
+ p1 = 0.5j*(l1['lrvis'] - l1['rlvis'])
+ p2 = 0.5j*(l2['lrvis'] - l2['rlvis'])
+ p3 = 0.5j*(l3['lrvis'] - l3['rlvis'])
+
+ var1 = 0.25*(l1['lrsigma']**2 + l1['rlsigma']**2)
+ var2 = 0.25*(l2['lrsigma']**2 + l2['rlsigma']**2)
+ var3 = 0.25*(l3['lrsigma']**2 + l3['rlsigma']**2)
+ else:
+ raise Exception("only 'stokes' and 'circ' are supported polreps!")
+
+ # Make the bispectrum and its uncertainty
+ bi = p1*p2*p3
+ bisig = np.abs(bi) * np.sqrt(var1/np.abs(p1)**2 +
+ var2/np.abs(p2)**2 +
+ var3/np.abs(p3)**2)
+
+ return (bi, bisig)
+
+
+def make_closure_amplitude(blue1, blue2, red1, red2, vtype,
+ ctype='camp', debias=True, polrep='stokes'):
+ """Make a list of closure amplitudes and errors
+ blue1 and blue2 are full datatables numerator entries
+ red1 and red2 are full datatables of denominator entries
+ vtype is the visibility type
+ """
+
+ if not (ctype in ['camp', 'logcamp']):
+ raise Exception("closure amplitude type must be 'camp' or 'logcamp'!")
+
+ if polrep == 'stokes':
+
+ if vtype in ["vis", "qvis", "uvis", "vvis"]:
+ if vtype == 'vis':
+ sigmatype = 'sigma'
+ if vtype == 'qvis':
+ sigmatype = 'qsigma'
+ if vtype == 'uvis':
+ sigmatype = 'usigma'
+ if vtype == 'vvis':
+ sigmatype = 'vsigma'
+
+ sig1 = blue1[sigmatype]
+ sig2 = blue2[sigmatype]
+ sig3 = red1[sigmatype]
+ sig4 = red2[sigmatype]
+
+ p1 = np.abs(blue1[vtype])
+ p2 = np.abs(blue2[vtype])
+ p3 = np.abs(red1[vtype])
+ p4 = np.abs(red2[vtype])
+
+ elif vtype == "rrvis":
+ sig1 = np.sqrt(blue1['sigma']**2 + blue1['vsigma']**2)
+ sig2 = np.sqrt(blue2['sigma']**2 + blue2['vsigma']**2)
+ sig3 = np.sqrt(red1['sigma']**2 + red1['vsigma']**2)
+ sig4 = np.sqrt(red2['sigma']**2 + red2['vsigma']**2)
+
+ p1 = np.abs(blue1['vis'] + blue1['vvis'])
+ p2 = np.abs(blue2['vis'] + blue2['vvis'])
+ p3 = np.abs(red1['vis'] + red1['vvis'])
+ p4 = np.abs(red2['vis'] + red2['vvis'])
+
+ elif vtype == "llvis":
+ sig1 = np.sqrt(blue1['sigma']**2 + blue1['vsigma']**2)
+ sig2 = np.sqrt(blue2['sigma']**2 + blue2['vsigma']**2)
+ sig3 = np.sqrt(red1['sigma']**2 + red1['vsigma']**2)
+ sig4 = np.sqrt(red2['sigma']**2 + red2['vsigma']**2)
+
+ p1 = np.abs(blue1['vis'] - blue1['vvis'])
+ p2 = np.abs(blue2['vis'] - blue2['vvis'])
+ p3 = np.abs(red1['vis'] - red1['vvis'])
+ p4 = np.abs(red2['vis'] - red2['vvis'])
+
+ elif vtype == "lrvis":
+ sig1 = np.sqrt(blue1['qsigma']**2 + blue1['usigma']**2)
+ sig2 = np.sqrt(blue2['qsigma']**2 + blue2['usigma']**2)
+ sig3 = np.sqrt(red1['qsigma']**2 + red1['usigma']**2)
+ sig4 = np.sqrt(red2['qsigma']**2 + red2['usigma']**2)
+
+ p1 = np.abs(blue1['qvis'] - 1j*blue1['uvis'])
+ p2 = np.abs(blue2['qvis'] - 1j*blue2['uvis'])
+ p3 = np.abs(red1['qvis'] - 1j*red1['uvis'])
+ p4 = np.abs(red2['qvis'] - 1j*red2['uvis'])
+
+ elif vtype in ["pvis", "rlvis"]:
+ sig1 = np.sqrt(blue1['qsigma']**2 + blue1['usigma']**2)
+ sig2 = np.sqrt(blue2['qsigma']**2 + blue2['usigma']**2)
+ sig3 = np.sqrt(red1['qsigma']**2 + red1['usigma']**2)
+ sig4 = np.sqrt(red2['qsigma']**2 + red2['usigma']**2)
+
+ p1 = np.abs(blue1['qvis'] + 1j*blue1['uvis'])
+ p2 = np.abs(blue2['qvis'] + 1j*blue2['uvis'])
+ p3 = np.abs(red1['qvis'] + 1j*red1['uvis'])
+ p4 = np.abs(red2['qvis'] + 1j*red2['uvis'])
+
+ elif polrep == 'circ':
+ if vtype in ["rrvis", "llvis", "rlvis", "lrvis", 'pvis']:
+ if vtype == 'pvis':
+ vtype = 'rlvis' # p = rl
+
+ if vtype == 'rrvis':
+ sigmatype = 'rrsigma'
+ if vtype == 'llvis':
+ sigmatype = 'llsigma'
+ if vtype == 'rlvis':
+ sigmatype = 'rlsigma'
+ if vtype == 'lrvis':
+ sigmatype = 'lrsigma'
+
+ sig1 = blue1[sigmatype]
+ sig2 = blue2[sigmatype]
+ sig3 = red1[sigmatype]
+ sig4 = red2[sigmatype]
+
+ p1 = np.abs(blue1[vtype])
+ p2 = np.abs(blue2[vtype])
+ p3 = np.abs(red1[vtype])
+ p4 = np.abs(red2[vtype])
+
+ elif vtype == "vis":
+ sig1 = 0.5*np.sqrt(blue1['rrsigma']**2 + blue1['llsigma']**2)
+ sig2 = 0.5*np.sqrt(blue2['rrsigma']**2 + blue2['llsigma']**2)
+ sig3 = 0.5*np.sqrt(red1['rrsigma']**2 + red1['llsigma']**2)
+ sig4 = 0.5*np.sqrt(red2['rrsigma']**2 + red2['llsigma']**2)
+
+ p1 = 0.5*np.abs(blue1['rrvis'] + blue1['llvis'])
+ p2 = 0.5*np.abs(blue2['rrvis'] + blue2['llvis'])
+ p3 = 0.5*np.abs(red1['rrvis'] + red1['llvis'])
+ p4 = 0.5*np.abs(red2['rrvis'] + red2['llvis'])
+
+ elif vtype == "vvis":
+ sig1 = 0.5*np.sqrt(blue1['rrsigma']**2 + blue1['llsigma']**2)
+ sig2 = 0.5*np.sqrt(blue2['rrsigma']**2 + blue2['llsigma']**2)
+ sig3 = 0.5*np.sqrt(red1['rrsigma']**2 + red1['llsigma']**2)
+ sig4 = 0.5*np.sqrt(red2['rrsigma']**2 + red2['llsigma']**2)
+
+ p1 = 0.5*np.abs(blue1['rrvis'] - blue1['llvis'])
+ p2 = 0.5*np.abs(blue2['rrvis'] - blue2['llvis'])
+ p3 = 0.5*np.abs(red1['rrvis'] - red1['llvis'])
+ p4 = 0.5*np.abs(red2['rrvis'] - red2['llvis'])
+
+ elif vtype == "qvis":
+ sig1 = 0.5*np.sqrt(blue1['lrsigma']**2 + blue1['rlsigma']**2)
+ sig2 = 0.5*np.sqrt(blue2['lrsigma']**2 + blue2['rlsigma']**2)
+ sig3 = 0.5*np.sqrt(red1['lrsigma']**2 + red1['rlsigma']**2)
+ sig4 = 0.5*np.sqrt(red2['lrsigma']**2 + red2['rlsigma']**2)
+
+ p1 = 0.5*np.abs(blue1['lrvis'] + blue1['rlvis'])
+ p2 = 0.5*np.abs(blue2['lrvis'] + blue2['rlvis'])
+ p3 = 0.5*np.abs(red1['lrvis'] + red1['rlvis'])
+ p4 = 0.5*np.abs(red2['lrvis'] + red2['rlvis'])
+
+ elif vtype == "uvis":
+ sig1 = 0.5*np.sqrt(blue1['lrsigma']**2 + blue1['rlsigma']**2)
+ sig2 = 0.5*np.sqrt(blue2['lrsigma']**2 + blue2['rlsigma']**2)
+ sig3 = 0.5*np.sqrt(red1['lrsigma']**2 + red1['rlsigma']**2)
+ sig4 = 0.5*np.sqrt(red2['lrsigma']**2 + red2['rlsigma']**2)
+
+ p1 = 0.5*np.abs(blue1['lrvis'] - blue1['rlvis'])
+ p2 = 0.5*np.abs(blue2['lrvis'] - blue2['rlvis'])
+ p3 = 0.5*np.abs(red1['lrvis'] - red1['rlvis'])
+ p4 = 0.5*np.abs(red2['lrvis'] - red2['rlvis'])
+ else:
+ raise Exception("only 'stokes' and 'circ' are supported polreps!")
+
+ # Debias the amplitude
+ if debias:
+ p1 = amp_debias(p1, sig1, force_nonzero=True)
+ p2 = amp_debias(p2, sig2, force_nonzero=True)
+ p3 = amp_debias(p3, sig3, force_nonzero=True)
+ p4 = amp_debias(p4, sig4, force_nonzero=True)
+ else:
+ p1 = np.abs(p1)
+ p2 = np.abs(p2)
+ p3 = np.abs(p3)
+ p4 = np.abs(p4)
+
+ # Get snrs
+ snr1 = p1/sig1
+ snr2 = p2/sig2
+ snr3 = p3/sig3
+ snr4 = p4/sig4
+
+ # Compute the closure amplitude and its uncertainty
+ if ctype == 'camp':
+ camp = np.abs((p1*p2)/(p3*p4))
+ camperr = camp * np.sqrt(1./(snr1**2) + 1./(snr2**2) + 1./(snr3**2) + 1./(snr4**2))
+
+ # Debias
+ if debias:
+ camp = camp_debias(camp, snr3, snr4)
+
+ elif ctype == 'logcamp':
+ camp = np.log(np.abs(p1)) + np.log(np.abs(p2)) - np.log(np.abs(p3)) - np.log(np.abs(p4))
+ camperr = np.sqrt(1./(snr1**2) + 1./(snr2**2) + 1./(snr3**2) + 1./(snr4**2))
+
+ # Debias
+ if debias:
+ camp = logcamp_debias(camp, snr1, snr2, snr3, snr4)
+
+ return (camp, camperr)
+
+
+def amp_debias(amp, sigma, force_nonzero=False):
+ """Return debiased visibility amplitudes
+ """
+
+ deb2 = np.abs(amp)**2 - np.abs(sigma)**2
+
+ # puts amplitude at 0 if snr < 1
+ deb2 *= (np.nan_to_num(np.abs(amp)) > np.nan_to_num(np.abs(sigma)))
+
+ # raises amplitude to sigma to force nonzero
+ if force_nonzero:
+ deb2 += (np.nan_to_num(np.abs(amp)) < np.nan_to_num(np.abs(sigma))) * np.abs(sigma)**2
+ out = np.sqrt(deb2)
+
+ return out
+
+
+def camp_debias(camp, snr3, snr4):
+ """Debias closure amplitudes
+ snr3 and snr4 are snr of visibility amplitudes #3 and 4.
+ """
+
+ camp_debias = camp / (1 + 1./(snr3**2) + 1./(snr4**2))
+ return camp_debias
+
+
+def logcamp_debias(log_camp, snr1, snr2, snr3, snr4):
+ """Debias log closure amplitudes
+ The snrs are the snr of the component visibility amplitudes
+ """
+
+ log_camp_debias = log_camp + 0.5*(1./(snr1**2) + 1./(snr2**2) - 1./(snr3**2) - 1./(snr4**2))
+ return log_camp_debias
+
+
+def gauss_uv(u, v, flux, beamparams, x=0., y=0.):
+ """Return the value of the Gaussian FT with
+ beamparams is [FWHMmaj, FWHMmin, theta, x, y], all in radian
+ x,y are the center coordinates
+ """
+
+ sigma_maj = beamparams[0]/(2*np.sqrt(2*np.log(2)))
+ sigma_min = beamparams[1]/(2*np.sqrt(2*np.log(2)))
+ theta = -beamparams[2] # theta needs to be negative in this convention!
+
+ # Covariance matrix
+ a = (sigma_min * np.cos(theta))**2 + (sigma_maj*np.sin(theta))**2
+ b = (sigma_maj * np.cos(theta))**2 + (sigma_min*np.sin(theta))**2
+ c = (sigma_min**2 - sigma_maj**2) * np.cos(theta) * np.sin(theta)
+ m = np.array([[a, c], [c, b]])
+
+ uv = np.array([[u[i], v[i]] for i in range(len(u))])
+ x2 = np.array([np.dot(uvi, np.dot(m, uvi)) for uvi in uv])
+
+ g = np.exp(-2 * np.pi**2 * x2)
+ p = np.exp(-2j * np.pi * (u*x + v*y))
+
+ return flux * g * p
+
+
+def rbf_kernel_covariance(x, sigma):
+ """Compute a covariance matrix from an RBF kernel
+
+ Args:
+ x (ndarray): 1D data points for which to compute the covariance
+ sigma (float): std for the covariance. Controls correlation length / time.
+
+ Returns:
+ cov (ndarray): Covariance matrix
+ """
+ x = np.expand_dims(x, 1) if x.ndim == 1 else x
+ norm = -0.5 * scipy.spatial.distance.cdist(x, x, 'sqeuclidean') / sigma**2
+ cov = np.exp(norm)
+ cov *= 1.0 / cov.sum(axis=0)
+ return cov
+
+
+def sgra_kernel_uv(rf, u, v):
+ """Return the value of the Sgr A* scattering kernel at a given u,v (in lambda)
+
+ Args:
+ rf (float): The observation frequency in Hz
+ u (float or ndarray): an array of u coordinates
+ v (float or ndarray): an array of v coordinates
+
+ Returns:
+ g (float ndarray): Sgr A* scattering kernel
+ """
+ u = np.array(u)
+ v = np.array(v)
+ assert u.size == v.size, 'u and v should have the same size'
+
+ lcm = (ehc.C / rf) * 100 # in cm
+ sigma_maj = ehc.FWHM_MAJ * (lcm ** 2) / (2 * np.sqrt(2 * np.log(2))) * ehc.RADPERUAS
+ sigma_min = ehc.FWHM_MIN * (lcm ** 2) / (2 * np.sqrt(2 * np.log(2))) * ehc.RADPERUAS
+ theta = -ehc.POS_ANG * ehc.DEGREE # theta needs to be negative in this convention!
+
+ # Covariance matrix
+ a = (sigma_min * np.cos(theta)) ** 2 + (sigma_maj * np.sin(theta)) ** 2
+ b = (sigma_maj * np.cos(theta)) ** 2 + (sigma_min * np.sin(theta)) ** 2
+ c = (sigma_min ** 2 - sigma_maj ** 2) * np.cos(theta) * np.sin(theta)
+ m = np.array([[a, c], [c, b]])
+ uv = np.array([u, v])
+
+ x2 = (uv * np.dot(m, uv)).sum(axis=0)
+ g = np.exp(-2 * np.pi ** 2 * x2)
+
+ return g
+
+
+def sgra_kernel_params(rf):
+ """Return elliptical gaussian parameters in radian for the Sgr A* scattering ellipse
+ at a given frequency rf
+ """
+
+ lcm = (ehc.C/rf) * 100 # in cm
+ fwhm_maj_rf = ehc.FWHM_MAJ * (lcm**2) * ehc.RADPERUAS
+ fwhm_min_rf = ehc.FWHM_MIN * (lcm**2) * ehc.RADPERUAS
+ theta = ehc.POS_ANG * ehc.DEGREE
+
+ return np.array([fwhm_maj_rf, fwhm_min_rf, theta])
+
+
+def blnoise(sefd1, sefd2, tint, bw):
+ """Determine the standard deviation of Gaussian thermal noise on a baseline
+ This is the noise on the rr/ll/rl/lr product, not the Stokes parameter
+ 2-bit quantization is responsible for the 0.88 factor
+ """
+
+ noise = np.sqrt(sefd1*sefd2/(2*bw*tint))/0.88
+ # noise = np.sqrt(sefd1*sefd2/(bw*tint))/0.88
+
+ return noise
+
+
+def merr(sigma, qsigma, usigma, I, m):
+ """Return the error in mbreve real and imaginary parts given stokes input
+ """
+
+ err = np.sqrt((qsigma**2 + usigma**2 + (sigma*np.abs(m))**2)/(np.abs(I) ** 2))
+
+ return err
+
+
+def merr2(rlsigma, rrsigma, llsigma, I, m):
+ """Return the error in mbreve real and imaginary parts given polprod input
+ """
+
+ err = np.sqrt((rlsigma**2 + (rrsigma**2 + llsigma**2)*np.abs(m)**2)/(np.abs(I) ** 2))
+
+ return err
+
+
+def cerror(sigma):
+ """Return a complex number drawn from a circular complex Gaussian of zero mean
+ """
+
+ noise = np.random.normal(loc=0, scale=sigma) + 1j*np.random.normal(loc=0, scale=sigma)
+ return noise
+
+
+def cerror_hash(sigma, *args):
+ """Return a complex number drawn from a circular complex Gaussian of zero mean
+ """
+
+ reargs = list(args)
+ reargs.append('re')
+ np.random.seed(hash(",".join(map(repr, reargs))) % 4294967295)
+ re = np.random.randn()
+
+ imargs = list(args)
+ imargs.append('im')
+ np.random.seed(hash(",".join(map(repr, imargs))) % 4294967295)
+ im = np.random.randn()
+
+ err = sigma * (re + 1j*im)
+
+ return err
+
+
+def hashmultivariaterandn(size, cov, *args):
+ """set the seed according to a collection of arguments and return random multivariate gaussian var
+ """
+ np.random.seed(hash(",".join(map(repr, args))) % 4294967295)
+ mean = np.zeros(size)
+ noise = np.random.multivariate_normal(mean, cov, check_valid='ignore')
+ return noise
+
+
+def hashrandn(*args):
+ """set the seed according to a collection of arguments and return random gaussian var
+ """
+
+ np.random.seed(hash(",".join(map(repr, args))) % 4294967295)
+ noise = np.random.randn()
+ return noise
+
+
+def hashrand(*args):
+ """set the seed according to a collection of arguments and return random number in 0,1
+ """
+
+ np.random.seed(hash(",".join(map(repr, args))) % 4294967295)
+ noise = np.random.rand()
+ return noise
+
+
+def image_centroid(im):
+ """Return the image centroid (in radians)
+ """
+
+ xlist = np.arange(0, -im.xdim, -1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0
+ ylist = np.arange(0, -im.ydim, -1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0
+
+ x0 = np.sum(np.outer(0.0*ylist+1.0, xlist).ravel()*im.imvec)/np.sum(im.imvec)
+ y0 = np.sum(np.outer(ylist, 0.0*xlist+1.0).ravel()*im.imvec)/np.sum(im.imvec)
+
+ return np.array([x0, y0])
+
+
+def ftmatrix(pdim, xdim, ydim, uvlist, pulse=ehc.PULSE_DEFAULT, mask=[]):
+ """Return a DFT matrix for the xdim*ydim image with pixel width pdim
+ that extracts spatial frequencies of the uv points in uvlist.
+ """
+
+ xlist = np.arange(0, -xdim, -1)*pdim + (pdim*xdim)/2.0 - pdim/2.0
+ ylist = np.arange(0, -ydim, -1)*pdim + (pdim*ydim)/2.0 - pdim/2.0
+
+ # original sign convention
+ # ftmatrices = [pulse(2*np.pi*uv[0], 2*np.pi*uv[1], pdim, dom="F") *
+ # np.outer(np.exp(-2j*np.pi*ylist*uv[1]), np.exp(-2j*np.pi*xlist*uv[0]))
+ # for uv in uvlist]
+
+ # changed the sign convention to agree with BU data (Jan 2017)
+ # this is correct for a u,v definition from site 1-2 as (x1-x2)/lambda
+ ftmatrices = [pulse(2*np.pi*uv[0], 2*np.pi*uv[1], pdim, dom="F") *
+ np.outer(np.exp(2j*np.pi*ylist*uv[1]), np.exp(2j*np.pi*xlist*uv[0]))
+ for uv in uvlist]
+ ftmatrices = np.reshape(np.array(ftmatrices), (len(uvlist), xdim*ydim))
+
+ if len(mask):
+ ftmatrices = ftmatrices[:, mask]
+
+ return ftmatrices
+
+
+def ftmatrix_centered(im, pdim, xdim, ydim, uvlist, pulse=ehc.PULSE_DEFAULT):
+ """Return a DFT matrix for the xdim*ydim image with pixel width pdim
+ that extracts spatial frequencies of the uv points in uvlist.
+ in this version, it puts the image centroid at the origin
+ """
+
+ # TODO : there is a residual value for the center being around 0,
+ # maybe we should chop this off to be exactly 0?
+ xlist = np.arange(0, -xdim, -1)*pdim + (pdim*xdim)/2.0 - pdim/2.0
+ ylist = np.arange(0, -ydim, -1)*pdim + (pdim*ydim)/2.0 - pdim/2.0
+ x0 = np.sum(np.outer(0.0*ylist+1.0, xlist).ravel()*im)/np.sum(im)
+ y0 = np.sum(np.outer(ylist, 0.0*xlist+1.0).ravel()*im)/np.sum(im)
+
+ # Now shift the lists
+ xlist = xlist - x0
+ ylist = ylist - y0
+
+ # list of matrices at each spatial freq
+ ftmatrices = [pulse(2*np.pi*uv[0], 2*np.pi*uv[1], pdim, dom="F") *
+ np.outer(np.exp(-2j*np.pi*ylist*uv[1]), np.exp(-2j*np.pi*xlist*uv[0]))
+ for uv in uvlist]
+ ftmatrices = np.reshape(np.array(ftmatrices), (len(uvlist), xdim*ydim))
+ return ftmatrices
+
+
+def ticks(axisdim, psize, nticks=8):
+ """Return a list of ticklocs and ticklabels
+ psize should be in desired units
+ """
+
+ axisdim = int(axisdim)
+ nticks = int(nticks)
+ if not axisdim % 2:
+ axisdim += 1
+ if nticks % 2:
+ nticks -= 1
+ tickspacing = float((axisdim-1))/nticks
+ ticklocs = np.arange(0, axisdim+1, tickspacing) - 0.5
+ ticklabels = np.around(psize * np.arange((axisdim-1)/2.0, -
+ (axisdim)/2.0, -tickspacing), decimals=1)
+
+ return (ticklocs, ticklabels)
+
+
+def power_of_two(target):
+ """Finds the next greatest power of two
+ """
+ cur = 1
+ if target > 1:
+ for i in range(0, int(target)):
+ if (cur >= target):
+ return cur
+ else:
+ cur *= 2
+ else:
+ return 1
+
+
+def paritycompare(perm1, perm2):
+ """Compare the parity of two permutations.
+ Assume both lists are equal length and with same elements
+ Copied from: http://stackoverflow.com/questions/1503072/how-to-check-if-permutations-have-equal-parity
+ """
+
+ perm2 = list(perm2)
+ perm2_map = dict((v, i) for i, v in enumerate(perm2))
+ transCount = 0
+ for loc, p1 in enumerate(perm1):
+ p2 = perm2[loc]
+ if p1 != p2:
+ sloc = perm2_map[p1]
+ perm2[loc], perm2[sloc] = p1, p2
+ perm2_map[p1], perm2_map[p2] = sloc, loc
+ transCount += 1
+
+ if not (transCount % 2):
+ return 1
+ else:
+ return -1
+
+
+def sigtype(datatype):
+ """Return the type of noise corresponding to the data type
+ """
+
+ datatype = str(datatype)
+ if datatype in ['vis', 'amp']:
+ sigmatype = 'sigma'
+ elif datatype in ['qvis', 'qamp']:
+ sigmatype = 'qsigma'
+ elif datatype in ['uvis', 'uamp']:
+ sigmatype = 'usigma'
+ elif datatype in ['vvis', 'vamp']:
+ sigmatype = 'vsigma'
+ elif datatype in ['pvis', 'pamp']:
+ sigmatype = 'psigma'
+ elif datatype in ['evis', 'eamp']:
+ sigmatype = 'esigma'
+ elif datatype in ['bvis', 'bamp']:
+ sigmatype = 'esigma'
+ elif datatype in ['rrvis', 'rramp']:
+ sigmatype = 'rrsigma'
+ elif datatype in ['llvis', 'llamp']:
+ sigmatype = 'llsigma'
+ elif datatype in ['rlvis', 'rlamp']:
+ sigmatype = 'rlsigma'
+ elif datatype in ['lrvis', 'lramp']:
+ sigmatype = 'lrsigma'
+ elif datatype in ['rrllvis', 'rrllamp']:
+ sigmatype = 'rrllsigma'
+ elif datatype in ['m', 'mamp']:
+ sigmatype = 'msigma'
+ elif datatype in ['phase']:
+ sigmatype = 'sigma_phase'
+ elif datatype in ['qphase']:
+ sigmatype = 'qsigma_phase'
+ elif datatype in ['uphase']:
+ sigmatype = 'usigma_phase'
+ elif datatype in ['vphase']:
+ sigmatype = 'vsigma_phase'
+ elif datatype in ['pphase']:
+ sigmatype = 'psigma_phase'
+ elif datatype in ['ephase']:
+ sigmatype = 'esigma_phase'
+ elif datatype in ['bphase']:
+ sigmatype = 'bsigma_phase'
+ elif datatype in ['mphase']:
+ sigmatype = 'msigma_phase'
+ elif datatype in ['rrphase']:
+ sigmatype = 'rrsigma_phase'
+ elif datatype in ['llphase']:
+ sigmatype = 'llsigma_phase'
+ elif datatype in ['rlphase']:
+ sigmatype = 'rlsigma_phase'
+ elif datatype in ['lrphase']:
+ sigmatype = 'lrsigma_phase'
+ elif datatype in ['rrllphase']:
+ sigmatype = 'rrllsigma_phase'
+
+ else:
+ sigmatype = False
+
+ return sigmatype
+
+
+def rastring(ra):
+ """Convert a ra in fractional hours to formatted string
+ """
+ h = int(ra)
+ m = int((ra-h)*60.)
+ s = (ra-h-m/60.)*3600.
+ out = "%2i h %2i m %2.4f s" % (h, m, s)
+
+ return out
+
+
+def decstring(dec):
+ """Convert a dec in fractional degrees to formatted string
+ """
+
+ deg = int(dec)
+ m = int((abs(dec)-abs(deg))*60.)
+ s = (abs(dec)-abs(deg)-m/60.)*3600.
+ out = "%2i deg %2i m %2.4f s" % (deg, m, s)
+
+ return out
+
+
+def gmtstring(gmt):
+ """Convert a gmt in fractional hours to formatted string
+ """
+
+ if gmt > 24.0:
+ gmt = gmt-24.0
+ h = int(gmt)
+ m = int((gmt-h)*60.)
+ s = (gmt-h-m/60.)*3600.
+ out = "%02i:%02i:%2.4f" % (h, m, s)
+
+ return out
+
+# TODO fix this hacky way to do it!!
+
+
+def gmst_to_utc(gmst, mjd):
+ """Convert gmst times in hours to utc hours using astropy
+ """
+
+ mjd = int(mjd)
+ time_obj_ref = at.Time(mjd, format='mjd', scale='utc')
+ time_sidereal_ref = time_obj_ref.sidereal_time('mean', 'greenwich').hour
+ time_utc = (gmst - time_sidereal_ref) * 0.9972695601848
+
+ return time_utc
+
+
+def utc_to_gmst(utc, mjd):
+ """Convert utc times in hours to gmst using astropy
+ """
+
+ mjd = int(mjd) # MJD should always be an integer, but was float in older versions of the code
+ time_obj = at.Time(utc/24.0 + np.floor(mjd), format='mjd', scale='utc')
+ time_sidereal = time_obj.sidereal_time('mean', 'greenwich').hour
+
+ return time_sidereal
+
+
+def earthrot(vecs, thetas):
+ """Rotate a vector / array of vectors about the z-direction by theta / array of thetas (radian)
+ """
+
+ if len(vecs.shape) == 1:
+ vecs = np.array([vecs])
+ if np.isscalar(thetas):
+ thetas = np.array([thetas for i in range(len(vecs))])
+
+ # equal numbers of sites and angles
+ if len(thetas) == len(vecs):
+ rotvec = np.array([np.dot(np.array(((np.cos(thetas[i]), -np.sin(thetas[i]), 0),
+ (np.sin(thetas[i]), np.cos(thetas[i]), 0),
+ (0, 0, 1))),
+ vecs[i])
+ for i in range(len(vecs))])
+
+ # only one rotation angle, many sites
+ elif len(thetas) == 1:
+ rotvec = np.array([np.dot(np.array(((np.cos(thetas[0]), -np.sin(thetas[0]), 0),
+ (np.sin(thetas[0]), np.cos(thetas[0]), 0),
+ (0, 0, 1))),
+ vecs[i])
+ for i in range(len(vecs))])
+
+ # only one site, many angles
+ elif len(vecs) == 1:
+ rotvec = np.array([np.dot(np.array(((np.cos(thetas[i]), -np.sin(thetas[i]), 0),
+ (np.sin(thetas[i]), np.cos(thetas[i]), 0),
+ (0, 0, 1))),
+ vecs[0])
+ for i in range(len(thetas))])
+ else:
+ raise Exception("Unequal numbers of vectors and angles in earthrot(vecs, thetas)!")
+
+ return rotvec
+
+
+def elev(obsvecs, sourcevec):
+ """Return the elevation of a source with respect to an observer/observers in radians
+ obsvec can be an array of vectors but sourcevec can ONLY be a single vector
+ """
+
+ if len(obsvecs.shape) == 1:
+ obsvecs = np.array([obsvecs])
+
+ anglebtw = np.array([np.dot(obsvec, sourcevec)/np.linalg.norm(obsvec) /
+ np.linalg.norm(sourcevec) for obsvec in obsvecs])
+ el = 0.5*np.pi - np.arccos(anglebtw)
+
+ return el
+
+
+def elevcut(obsvecs, sourcevec, elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH):
+ """Return True if a source is observable by a telescope vector
+ """
+
+ angles = elev(obsvecs, sourcevec)/ehc.DEGREE
+
+ return (angles > elevmin) * (angles < elevmax)
+
+
+def hr_angle(gst, lon, ra):
+ """Computes the hour angle for a source at RA, observer at longitude long, and GMST time gst
+ gst in hours, ra & lon ALL in radian, longitude positive east
+ """
+
+ hr_angle = np.mod(gst + lon - ra, 2*np.pi)
+
+ return hr_angle
+
+
+def par_angle(hr_angle, lat, dec):
+ """Compute the parallactic angle for a source at hr_angle and dec
+ for an observer with latitude lat.
+ All angles in radian
+ """
+
+ num = np.sin(hr_angle)*np.cos(lat)
+ denom = np.sin(lat)*np.cos(dec) - np.cos(lat)*np.sin(dec)*np.cos(hr_angle)
+
+ return np.arctan2(num, denom)
+
+
+def xyz_2_latlong(obsvecs):
+ """Compute the (geocentric) latitude and longitude of a site at geocentric position x,y,z
+ The output is in radians
+ """
+
+ if len(obsvecs.shape) == 1:
+ obsvecs = np.array([obsvecs])
+ out = []
+ for obsvec in obsvecs:
+ x = obsvec[0]
+ y = obsvec[1]
+ z = obsvec[2]
+ lon = np.array(np.arctan2(y, x))
+ lat = np.array(np.arctan2(z, np.sqrt(x**2+y**2)))
+ out.append([lat, lon])
+
+ out = np.array(out)
+
+ return out
+
+
+def tri_minimal_set(sites, tarr, tkey):
+ """returns a minimal set of triangles for bispectra and closure phase"""
+
+ # determine ordering and reference site based on order of self.tarr
+ sites_ordered = [x for x in tarr['site'] if x in sites]
+ ref = sites_ordered[0]
+ sites_ordered.remove(ref)
+
+ # Find all triangles that contain the ref
+ tris = list(it.combinations(sites_ordered, 2))
+ tris = [(ref, t[0], t[1]) for t in tris]
+
+ return tris
+
+
+def quad_minimal_set(sites, tarr, tkey):
+ """returns a minimal set of quadrangels for closure amplitude"""
+
+ # determine ordering and reference site based on order of self.tarr
+ sites_ordered = np.array([x for x in tarr['site'] if x in sites])
+ ref = sites_ordered[0]
+
+ # Loop over other sites >=3 and form minimal closure amplitude set
+ quads = []
+ for i in range(3, len(sites_ordered)):
+ for j in range(1, i):
+ if j == i-1:
+ k = 1
+ else:
+ k = j+1
+
+ # convetion is (12)(34)/(14)(23)
+ quad = (ref, sites_ordered[i], sites_ordered[j], sites_ordered[k])
+ quads.append(quad)
+
+ return quads
+
+
+# TODO This returns A minimal set if input is maximal, but it is not necessarily the same
+# minimal set as we would from calling c_phases(count='min'). This is because of sign flips.
+def reduce_tri_minimal(obs, datarr):
+ """reduce a bispectrum or closure phase data array to a minimal set
+ datarr can be either a bispectrum array of type DTBIS
+ or a closure phase array of type DTCPHASE, or a time sorted list of either
+ """
+
+ # time sort or not
+ if not (type(datarr) is list):
+ datalist = []
+ dtype = datarr.dtype
+ for key, group in it.groupby(datarr, lambda x: x['time']):
+ datalist.append(np.array([gp for gp in group], dtype=dtype))
+ returnType = 'all'
+ else:
+ dtype = datarr[0].dtype
+ datalist = datarr
+ returnType = 'time'
+
+ out = []
+
+ for timegroup in datalist:
+ if returnType == 'all':
+ outgroup = out
+ else:
+ outgroup = []
+
+ # determine a minimal set of trinagles
+ sites = list(set(np.hstack((timegroup['t1'], timegroup['t2'], timegroup['t3']))))
+ tris = tri_minimal_set(sites, obs.tarr, obs.tkey)
+ tris = [set(tri) for tri in tris]
+
+ # add data points from original array to new array if in minimal set
+ for dp in timegroup:
+ # TODO: sign flips?
+ if set((dp['t1'], dp['t2'], dp['t3'])) in tris:
+ outgroup.append(dp)
+
+ if returnType == 'time':
+ out.append(np.array(outgroup, dtype=dtype))
+ else:
+ out = outgroup
+
+ if returnType == 'all':
+ out = np.array(out, dtype=dtype)
+ return out
+
+# TODO This returns A minimal set if input is maximal, but it is not necessarily the same
+# minimal set as we would from calling c_amplitudes(count='min'). This is because of inverses.
+
+
+def reduce_quad_minimal(obs, datarr, ctype='camp'):
+ """Reduce a closure amplitude or log closure amplitude array
+ FROM a maximal set TO a minimal set
+ """
+
+ if ctype not in ['camp', 'logcamp']:
+ raise Exception("ctype must be 'camp' or 'logcamp'")
+
+ # time sort or not
+ if not (type(datarr) is list):
+ datalist = []
+ dtype = datarr.dtype
+ for key, group in it.groupby(datarr, lambda x: x['time']):
+ datalist.append(np.array([x for x in group]))
+ returnType = 'all'
+ else:
+ dtype = datarr[0].dtype
+ datalist = datarr
+ returnType = 'time'
+
+ out = []
+ for timegroup in datalist:
+ if returnType == 'all':
+ outgroup = out
+ else:
+ outgroup = []
+
+ # determine a minimal set of quadrangles
+ sites = np.array(list(set(np.hstack((timegroup['t1'],
+ timegroup['t2'],
+ timegroup['t3'],
+ timegroup['t4'])))))
+ if len(sites) < 4:
+ continue
+ quads = quad_minimal_set(sites, obs.tarr, obs.tkey)
+
+ # add data points from original camp array to new array if in minimal set
+ # ANDREW TODO: do we need to change the ordering ??
+ for dp in timegroup:
+
+ # this is all same closure amplitude, but the ordering of labels is different
+ if ((dp['t1'], dp['t2'], dp['t3'], dp['t4']) in quads or
+ (dp['t2'], dp['t1'], dp['t4'], dp['t3']) in quads or
+ (dp['t3'], dp['t4'], dp['t1'], dp['t2']) in quads or
+ (dp['t4'], dp['t3'], dp['t2'], dp['t1']) in quads):
+
+ outgroup.append(np.array(dp, dtype=ehc.DTCAMP))
+
+ # flip the inverse closure amplitude
+ elif ((dp['t1'], dp['t4'], dp['t3'], dp['t2']) in quads or
+ (dp['t2'], dp['t3'], dp['t4'], dp['t1']) in quads or
+ (dp['t3'], dp['t2'], dp['t1'], dp['t4']) in quads or
+ (dp['t4'], dp['t1'], dp['t2'], dp['t3']) in quads):
+
+ dp2 = copy.deepcopy(dp)
+ campold = dp['camp']
+ sigmaold = dp['sigmaca']
+ t1old = dp['t1']
+ t2old = dp['t2']
+ t3old = dp['t3']
+ t4old = dp['t4']
+ u1old = dp['u1']
+ u2old = dp['u2']
+ u3old = dp['u3']
+ u4old = dp['u4']
+ v1old = dp['v1']
+ v2old = dp['v2']
+ v3old = dp['v3']
+ v4old = dp['v4']
+
+ dp2['t1'] = t1old
+ dp2['t2'] = t4old
+ dp2['t3'] = t3old
+ dp2['t4'] = t2old
+
+ dp2['u1'] = u3old
+ dp2['v1'] = v3old
+
+ dp2['u2'] = -u4old
+ dp2['v2'] = -v4old
+
+ dp2['u3'] = u1old
+ dp2['v3'] = v1old
+
+ dp2['u4'] = -u2old
+ dp2['v4'] = -v2old
+
+ if ctype == 'camp':
+ dp2['camp'] = 1./campold
+ dp2['sigmaca'] = sigmaold/(campold**2)
+
+ elif ctype == 'logcamp':
+ dp2['camp'] = -campold
+ dp2['sigmaca'] = sigmaold
+
+ outgroup.append(dp2)
+
+ if returnType == 'time':
+ out.append(np.array(outgroup, dtype=dtype))
+ else:
+ out = outgroup
+
+ if returnType == 'all':
+ out = np.array(out, dtype=dtype)
+ return out
+
+
+def qimage(iimage, mimage, chiimage):
+ """Return the Q image from m and chi"""
+ return iimage * mimage * np.cos(2*chiimage)
+
+
+def uimage(iimage, mimage, chiimage):
+ """Return the U image from m and chi"""
+ return iimage * mimage * np.sin(2*chiimage)
+
+
+##################################################################################################
+# FFT & NFFT helper functions
+##################################################################################################
+class NFFTInfo(object):
+ def __init__(self, xdim, ydim, psize, pulse, npad, p_rad, uv):
+ self.xdim = int(xdim)
+ self.ydim = int(ydim)
+ self.psize = psize
+ self.pulse = pulse
+
+ self.npad = int(npad)
+ self.p_rad = int(p_rad)
+ self.uv = uv
+ self.uvdim = len(uv)
+
+ # set nfft plan
+ uv_scaled = uv*psize
+ nfft_plan = NFFT([xdim, ydim], self.uvdim, m=p_rad, n=[npad, npad])
+ nfft_plan.x = uv_scaled
+ nfft_plan.precompute()
+ self.plan = nfft_plan
+
+ # compute phase and pulsefac
+ phases = np.exp(-1j*np.pi*(uv_scaled[:, 0]+uv_scaled[:, 1]))
+ pulses = np.fromiter((pulse(2*np.pi*uv_scaled[i, 0], 2*np.pi*uv_scaled[i, 1], 1., dom="F")
+ for i in range(self.uvdim)), 'c16')
+ self.pulsefac = (pulses*phases)
+
+class SamplerInfo(object):
+ def __init__(self, order, uv, pulsefac):
+ self.order = int(order)
+ self.uv = uv
+ self.pulsefac = pulsefac
+
+
+class GridderInfo(object):
+ def __init__(self, npad, func, p_rad, coords, weights):
+ self.npad = int(npad)
+ self.conv_func = func
+ self.p_rad = int(p_rad)
+ self.coords = coords
+ self.weights = weights
+
+
+class ImInfo(object):
+ def __init__(self, xdim, ydim, npad, psize, pulse):
+ self.xdim = int(xdim)
+ self.ydim = int(ydim)
+ self.npad = int(npad)
+ self.psize = psize
+ self.pulse = pulse
+
+ padvalx1 = padvalx2 = int(np.floor((npad - xdim)/2.0))
+ if xdim % 2:
+ padvalx2 += 1
+ padvaly1 = padvaly2 = int(np.floor((npad - ydim)/2.0))
+ if ydim % 2:
+ padvaly2 += 1
+
+ self.padvalx1 = padvalx1
+ self.padvalx2 = padvalx2
+ self.padvaly1 = padvaly1
+ self.padvaly2 = padvaly2
+
+
+def conv_func_pill(x, y):
+ if abs(x) < 0.5 and abs(y) < 0.5:
+ out = 1.
+ else:
+ out = 0.
+ return out
+
+
+def conv_func_gauss(x, y):
+ return np.exp(-(x**2 + y**2))
+
+
+def conv_func_cubicspline(x, y):
+ if abs(x) <= 1:
+ fx = 1.5*abs(x)**3 - 2.5*abs(x)**2 + 1
+ elif abs(x) < 2:
+ fx = -0.5*abs(x)**3 + 2.5*abs(x)**2 - 4*abs(x) + 2
+ else:
+ fx = 0
+
+ if abs(y) <= 1:
+ fy = 1.5*abs(y)**3 - 2.5*abs(y)**2 + 1
+ elif abs(y) < 2:
+ fy = -0.5*abs(y)**3 + 2.5*abs(y)**2 - 4*abs(y) + 2
+ else:
+ fy = 0
+
+ return fx*fy
+
+# There's a bug in scipy spheroidal function of order 0! - gives nans for eta<1
+# def conv_func_spheroidal(x,y,p,m):
+# etax = 2.*x/float(p)
+# etay = 2.*x/float(p)
+# psix = abs(1-etax**2)**m * scipy.special.pro_rad1(m,0,0.5*np.pi*p,etax)[0]
+# psiy = abs(1-etay**2)**m * scipy.special.pro_rad1(m,0,0.5*np.pi*p,etay)[0]
+# return psix*psiy
+
+
+def fft_imvec(imvec, im_info):
+ """
+ Returns fft of imvec on grid
+ im_info = (xdim, ydim, npad, psize, pulse)
+ order is the order of the spline interpolation
+ """
+
+ xdim = im_info.xdim
+ ydim = im_info.ydim
+ padvalx1 = im_info.padvalx1
+ padvalx2 = im_info.padvalx2
+ padvaly1 = im_info.padvaly1
+ padvaly2 = im_info.padvaly2
+
+ imarr = imvec.reshape(ydim, xdim)
+ imarr = np.pad(imarr, ((padvalx1, padvalx2), (padvaly1, padvaly2)),
+ 'constant', constant_values=0.0)
+
+ if imarr.shape[0] != imarr.shape[1]:
+ raise Exception("FFT padding did not return a square image!")
+
+ # FFT for visibilities
+ vis_im = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(imarr)))
+
+ return vis_im
+
+
+def sampler(griddata, sampler_info_list, sample_type="vis"):
+ """
+ Samples griddata (e.g. the FFT of an image) at uv points
+ the griddata should already be rotated so u,v = 0,0 in the center
+ sampler_info_list is an appropriately ordered list of 4 sampler_info objects
+ order is the order of the spline interpolation
+ """
+ if sample_type not in ["vis", "bs", "camp"]:
+ raise Exception("sampler sample_type should be either 'vis','bs',or 'camp'!")
+ if griddata.shape[0] != griddata.shape[1]:
+ raise Exception("griddata should be a square array!")
+
+ dataset = []
+ for sampler_info in sampler_info_list:
+
+ vu2 = sampler_info.uv
+ pulsefac = sampler_info.pulsefac
+
+ datare = nd.map_coordinates(np.real(griddata), vu2, order=sampler_info.order)
+ dataim = nd.map_coordinates(np.imag(griddata), vu2, order=sampler_info.order)
+
+ data = datare + 1j*dataim
+ data = data * pulsefac
+
+ dataset.append(data)
+
+ if sample_type == "vis":
+ out = dataset[0]
+ if sample_type == "bs":
+ out = dataset[0]*dataset[1]*dataset[2]
+ if sample_type == "camp":
+ out = np.abs((dataset[0]*dataset[1])/(dataset[2]*dataset[3]))
+ return out
+
+
+def gridder(data_list, gridder_info_list):
+ """
+ Grid the data sampled at uv points on a square array
+ gridder_info_list is an list of gridder_info objects
+ """
+
+ if len(data_list) != len(gridder_info_list):
+ raise Exception("length of data_list in gridder() " +
+ "is not equal to length of gridder_info_list!")
+
+ npad = gridder_info_list[0].npad
+ datagrid = np.zeros((npad, npad)).astype('c16')
+
+ for k in range(len(gridder_info_list)):
+ gridder_info = gridder_info_list[k]
+ data = data_list[k]
+
+ if gridder_info.npad != npad:
+ raise Exception("npad values not consistent in gridder_info_list!")
+
+ p_rad = gridder_info.p_rad
+ coords = gridder_info.coords
+ weights = gridder_info.weights
+
+ p_rad = int(p_rad)
+ for i in range(2*p_rad+1):
+ dy = i - p_rad
+ for j in range(2*p_rad+1):
+ dx = j - p_rad
+ weight = weights[i][j]
+ np.add.at(datagrid, tuple(map(tuple, (coords + [dy, dx]).transpose())), data*weight)
+
+ return datagrid
+
+
+def make_gridder_and_sampler_info(im_info, uv, conv_func=ehc.GRIDDER_CONV_FUNC_DEFAULT,
+ p_rad=ehc.GRIDDER_P_RAD_DEFAULT, order=ehc.FFT_INTERP_DEFAULT):
+ """
+ Prep norms and weights for gridding data sampled at uv points on a square array
+ im_info tuple contains (xdim, ydim, npad, psize, pulse) of the grid
+ conv_func is the convolution function: current options are "pillbox", "gaussian"
+ p_rad is the pixel radius inside wich the conv_func is nonzero
+ """
+
+ if not (conv_func in ['pillbox', 'gaussian', 'cubic']):
+ raise Exception("conv_func must be either 'pillbox', 'gaussian', or, 'cubic'")
+
+ npad = im_info.npad
+ psize = im_info.psize
+ pulse = im_info.pulse
+
+ # compute grid u,v coordinates
+ vu2 = np.hstack((uv[:, 1].reshape(-1, 1), uv[:, 0].reshape(-1, 1)))
+ du = 1.0/(npad*psize)
+ vu2 = (vu2/du + 0.5*npad)
+
+ coords = np.round(vu2).astype(int)
+ dcoords = vu2 - np.round(vu2).astype(int)
+ vu2 = vu2.T
+
+ # TODO: phase rotations should be done separately for x and y if the image isn't square
+ # e.g.,
+ phase = np.exp(-1j*np.pi*psize*((1+im_info.xdim % 2)*uv[:, 0] + (1+im_info.ydim % 2)*uv[:, 1]))
+
+ pulsefac = np.fromiter(
+ (pulse(2*np.pi*uvpt[0], 2*np.pi*uvpt[1], psize, dom="F") for uvpt in uv), 'c16')
+ pulsefac = pulsefac * phase
+
+ # compute gridder norm
+ weights = []
+ norm = np.zeros_like(len(coords))
+ for i in range(2*p_rad+1):
+ weights.append([])
+ dy = i - p_rad
+ for j in range(2*p_rad+1):
+ dx = j - p_rad
+ if conv_func == 'gaussian':
+ norm = norm + conv_func_gauss(dy - dcoords[:, 0], dx - dcoords[:, 1])
+ elif conv_func == 'pillbox':
+ norm = norm + conv_func_pill(dy - dcoords[:, 0], dx - dcoords[:, 1])
+ elif conv_func == 'cubic':
+ norm = norm + conv_func_cubicspline(dy - dcoords[:, 0], dx - dcoords[:, 1])
+
+ weights[i].append(None)
+
+ # compute weights for gridding
+ for i in range(2*p_rad+1):
+ dy = i - p_rad
+ for j in range(2*p_rad+1):
+ dx = j - p_rad
+ if conv_func == 'gaussian':
+ weight = conv_func_gauss(dy - dcoords[:, 0], dx - dcoords[:, 1])/norm
+ elif conv_func == 'pillbox':
+ weight = conv_func_pill(dy - dcoords[:, 0], dx - dcoords[:, 1])/norm
+ elif conv_func == 'cubic':
+ weight = conv_func_cubicspline(dy - dcoords[:, 0], dx - dcoords[:, 1])/norm
+
+ weights[i][j] = weight
+
+ # output the coordinates, norms, and weights
+ sampler_info = SamplerInfo(order, vu2, pulsefac)
+ gridder_info = GridderInfo(npad, conv_func, p_rad, coords, weights)
+ return (sampler_info, gridder_info)
+
+
+##################################################################################################
+# miscellaneous functions
+##################################################################################################
+
+# TODO this makes a copy -- is there a faster robust way?
+def recarr_to_ndarr(x, typ):
+ """converts a record array x to a normal ndarray with all fields converted to datatype typ
+ """
+
+ fields = x.dtype.names
+ shape = x.shape + (len(fields),)
+ dt = [(name, typ) for name in fields]
+ y = x.astype(dt).view(typ).reshape(shape)
+ return y
+
+
+def prog_msg(nscan, totscans, msgtype='bar', nscan_last=0):
+ """print a progress method for calibration
+ """
+ complete_percent_last = int(100*float(nscan_last)/float(totscans))
+ complete_percent = int(100*float(nscan)/float(totscans))
+ ndigit = str(len(str(totscans)))
+
+ if msgtype == 'bar':
+ bar_width = 30
+ progress = int(bar_width * complete_percent/float(100))
+ barparams = (nscan, totscans, ("-"*progress) +
+ (" " * (bar_width-progress)), complete_percent)
+
+ printstr = "\rScan %0"+ndigit+"i/%i : [%s]%i%%"
+ sys.stdout.write(printstr % barparams)
+ sys.stdout.flush()
+
+ elif msgtype == 'bar2':
+ bar_width = 30
+ progress = int(bar_width * complete_percent/float(100))
+ barparams = (nscan, totscans, ("/"*progress) +
+ (" " * (bar_width-progress)), complete_percent)
+
+ printstr = "\rScan %0"+ndigit+"i/%i : [%s]%i%%"
+ sys.stdout.write(printstr % barparams)
+ sys.stdout.flush()
+
+ elif msgtype == 'casa':
+ message_list = [".", ".", ".", "10", ".", ".", ".", "20",
+ ".", ".", ".", "30", ".", ".", ".", "40",
+ ".", ".", ".", "50", ".", ".", ".", "60",
+ ".", ".", ".", "70", ".", ".", ".", "80",
+ ".", ".", ".", "90", ".", ".", ".", "DONE"]
+ bar_width = len(message_list)
+ progress = int(bar_width * complete_percent/float(100))
+ message = ''.join(message_list[:progress])
+
+ barparams = (nscan, totscans, message)
+ printstr = "\rScan %0"+ndigit+"i/%i : %s"
+ sys.stdout.write(printstr % barparams)
+ sys.stdout.flush()
+
+ elif msgtype == 'itcrowd':
+ message_list = ["0", "1", "1", "8", " ", "9", "9", "9", " ", "8", "8", "1", "9", "9", " ",
+ "9", "1", "1", "9", " ", "7", "2", "5", " ", " ", " ", "3"]
+ bar_width = len(message_list)
+ progress = int(bar_width * complete_percent/float(100))
+ message = ''.join(message_list[:progress])
+ if complete_percent < 100:
+ message += "."
+ message += " "*(bar_width-progress-1)
+
+ barparams = (nscan, totscans, message)
+
+ printstr = "\rScan %0"+ndigit+"i/%i : [%s]"
+ sys.stdout.write(printstr % barparams)
+ sys.stdout.flush()
+
+ elif msgtype == 'bh':
+ message_all = ehc.BHIMAGE
+ bar_width = len(message_all)
+ progress = int(np.floor(bar_width * complete_percent/float(100)))-1
+ progress_last = int(np.floor(bar_width * complete_percent_last/float(100)))-1
+ if progress > progress_last:
+ for i in range(progress_last+1, progress+1):
+ message_line = ''.join(message_all[i])
+ message_line = '%03i' % int(complete_percent) + message_line
+ print(message_line)
+
+ elif msgtype == 'eht':
+ message_all = ehc.EHTIMAGE
+ bar_width = len(message_all)
+ progress = int(np.floor(bar_width * complete_percent/float(100)))-1
+ progress_last = int(np.floor(bar_width * complete_percent_last/float(100)))-1
+ if progress > progress_last:
+ for i in range(progress_last+1, progress+1):
+ message_line = ''.join(message_all[i])
+ message_line = '%03i' % int(complete_percent) + message_line
+ print(message_line)
+
+ elif msgtype == 'dots':
+ sys.stdout.write('.')
+ sys.stdout.flush()
+
+ else: # msgtype=='default':
+ barparams = (nscan, totscans, complete_percent)
+ printstr = "\rScan %0"+ndigit+"i/%i : %i%% done . . ."
+ sys.stdout.write(printstr % barparams)
+ sys.stdout.flush()
+
+def sat_skyfield_from_elements(satname, epoch_mjd, perigee_mjd,
+ period_days, eccentricity,
+ inclination, arg_perigee, long_ascending):
+
+ """skyfield EarthSatellite object from keplerian orbital elements
+ perfect keplerian orbit is assumed, no derivatives
+ epoch, pericenter given in mjd
+ period given in days
+ inclination, arg_perigee, long_ascending given in degrees
+ """
+
+ if not(0<=eccentricity<1):
+ raise Exception("eccentricity must be between 0 and 1")
+ if not(0<=inclination<=180):
+ raise Exception("inclination must be between 0 and 180")
+ if not(0<=arg_perigee<=180):
+ raise Exception("arg_perigee must be between 0 and 180")
+ if not(0<=long_ascending<=360):
+ raise Exception("arg_perigee must be between 0 and 360")
+
+ satrec = Satrec()
+ ts = skyfield.api.load.timescale()
+ ref_mjd = 33281. # mjd 1949 December 31 00:00 UT
+ epoch_wrt_ref = epoch_mjd - ref_mjd
+
+ inclination_rad = inclination*ehc.DEGREE
+ arg_perigee_rad = arg_perigee*ehc.DEGREE
+ long_ascending_rad = long_ascending*ehc.DEGREE
+
+ mean_motion = 2*np.pi/(period_days*24.*60.) # radians/minute
+
+ mean_anomaly = mean_motion*(epoch_mjd - perigee_mjd)
+ mean_anomaly = np.mod(mean_anomaly, 2*np.pi)
+
+ satrec.sgp4init(
+ WGS72, # gravity model
+ 'i', # 'a' = old AFSPC mode, 'i' = improved mode
+ 1, # satnum: Satellite number
+ epoch_wrt_ref, # epoch: days since 1949 December 31 00:00 UT
+ 0.0, # bstar: drag coefficient (/earth radii)
+ 0.0, # ndot: ballistic coefficient (revs/day)
+ 0.0, # nddot: second derivative of mean motion (revs/day^3)
+ eccentricity, # ecco: eccentricity
+ arg_perigee_rad, # argpo: argument of perigee (radians)
+ inclination_rad, # inclo: inclination (radians)
+ mean_anomaly, # mo: mean anomaly (radians)
+ mean_motion, # no_kozai: mean motion (radians/minute)
+ long_ascending_rad, # nodeo: right ascension of ascending node (radians)
+ )
+ sat_skyfield = skyfield.api.EarthSatellite.from_satrec(satrec, ts)
+
+ return sat_skyfield
+
+def sat_skyfield_from_tle(satname, line1, line2):
+ ts = skyfield.api.load.timescale()
+ sat_skyfield = skyfield.api.EarthSatellite(line1, line2, satname, ts)
+ return sat_skyfield
+
+def sat_skyfield_from_ephementry(satname, ephem, epoch_mjd):
+ if len(ephem[satname])==3: # TLE
+ line1 = ephem[satname][1]
+ line2 = ephem[satname][2]
+ sat = sat_skyfield_from_tle(satname, line1, line2)
+ elif len(ephem[satname])==6: #keplerian elements
+ elements = ephem[satname]
+ sat = sat_skyfield_from_elements(satname, epoch_mjd,
+ elements[0],elements[1],elements[2],elements[3],elements[4],elements[5])
+ else:
+ raise Exception("ephemeris format not recognized for %s"%satellite)
+
+ return sat
+
+def orbit_skyfield(sat, fracmjds, whichout='itrs'):
+
+ """uses skyfield to propagate a earth satellite orbit and return x,y,z coordinates in co-rotating earth frame
+ sat is a skyfield.sgp4lib.EarthSatellite object
+ times is a list of fractional mjds
+ whichout is 'itrs' (co-rotating) or 'gcrs' (fixed x-axis to equinox)
+ """
+
+ #fractional days of orbit in jd
+ mjd_to_jd = 2400000.5
+ ts = skyfield.api.load.timescale()
+ t = ts.ut1_jd(mjd_to_jd+fracmjds)
+
+ # propagate orbit
+ time_data = sat.at(t)
+
+ if whichout=='gcrs':
+ # GCRS coordinates in km
+ positions = time_data.xyz.m
+
+ elif whichout=='itrs':
+ # get coordinates in earth frame (WGS84 ellipsiod)
+ geographic_position = skyfield.api.wgs84.geographic_position_of(time_data)
+ positions = geographic_position.itrs_xyz.m
+
+ else:
+ raise Excption("orbit_skyfield whichout must be 'itrs' or 'gcrs'")
+
+ return positions
+
diff --git a/observing/obs_simulate.py b/observing/obs_simulate.py
new file mode 100644
index 00000000..1afaf71c
--- /dev/null
+++ b/observing/obs_simulate.py
@@ -0,0 +1,1500 @@
+# obs_simulate.py
+# functions to simulate interferometric observations
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import time as ttime
+import scipy.ndimage as nd
+from scipy.interpolate import interp1d
+import numpy as np
+import copy
+try:
+ from pynfft.nfft import NFFT
+except ImportError:
+ pass
+ #print("Warning: No NFFT installed!")
+
+from . import obs_helpers as obsh
+import ehtim.const_def as ehc
+
+##################################################################################################
+# Generate U-V Points
+##################################################################################################
+
+
+def make_uvpoints(array, ra, dec, rf, bw, tint, tadv, tstart, tstop,
+ polrep='stokes',
+ mjd=ehc.MJD_DEFAULT, tau=ehc.TAUDEF,
+ elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH,
+ no_elevcut_space=False,
+ timetype='UTC', fix_theta_GMST=False):
+ """Generate u,v points and baseline sigmas for a given array.
+
+ Args:
+ array (Array): the array object
+ ra (float): The source Right Ascension in fractional hours
+ dec (float): The source declination in fractional degrees
+ rf (float): The observation frequency in Hz
+ bw (float): The observation bandwidth in Hz
+ tint (float): the scan integration time in seconds
+ tadv (float): the uniform cadence between scans in seconds
+ tstart (float): the start time of the observation in hours
+ tstop (float): the end time of the observation in hours
+ polrep (str): 'stokes' or 'circ' sets the data polarimetric representtion
+ mjd (int): the mjd of the observation, if different from the image mjd
+ tau (float): the base opacity at all sites, or a dict giving one opacity per site
+ elevmin (float): station minimum elevation in degrees
+ elevmax (float): station maximum elevation in degrees
+ no_elevcut_space (bool): if True, do not apply elevation cut to orbiters
+ timetype (str): how to interpret tstart and tstop; either 'GMST' or 'UTC'
+ fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v points
+ Returns:
+ (Obsdata): an observation object with all visibilities zeroed
+ """
+
+ if polrep == 'stokes':
+ poltype = ehc.DTPOL_STOKES
+ elif polrep == 'circ':
+ poltype = ehc.DTPOL_CIRC
+ else:
+ raise Exception("only 'stokes' and 'circ' are supported polreps!")
+
+ # Set up time start and steps
+ tstep = tadv/3600.0
+ if tstop < tstart:
+ tstop = tstop + 24.0
+
+ # Observing times
+ times = np.arange(tstart, tstop, tstep)
+ if timetype not in ['UTC', 'GMST']:
+ print("Time Type Not Recognized! Assuming UTC!")
+ timetype = 'UTC'
+
+ # Generate uv points at all times
+ outlist = []
+ blpairs = []
+
+ for i1 in range(len(array.tarr)):
+ for i2 in range(len(array.tarr)):
+ if (i1 != i2 and
+ i1 < i2 and # This is the right condition for uvfits save order
+ not ((i2, i1) in blpairs)): # This cuts out the conjugate baselines
+
+ blpairs.append((i1, i2))
+
+ # sites
+ site1 = array.tarr[i1]['site']
+ site2 = array.tarr[i2]['site']
+ coord1 = ((array.tarr[i1]['x'], array.tarr[i1]['y'], array.tarr[i1]['z']))
+ coord2 = ((array.tarr[i2]['x'], array.tarr[i2]['y'], array.tarr[i2]['z']))
+
+ # Optical Depth
+ if type(tau) == dict:
+ try:
+ tau1 = tau[site1]
+ tau2 = tau[site2]
+ except KeyError:
+ tau1 = tau2 = ehc.TAUDEF
+ else:
+ tau1 = tau2 = tau
+ # no optical depth for space sites
+ if coord1 == (0., 0., 0.):
+ tau1 = 0.
+ if coord2 == (0., 0., 0.):
+ tau2 = 0.
+
+ # Noise on the correlations
+ if np.any(array.tarr['sefdr'] <= 0) or np.any(array.tarr['sefdl'] <= 0):
+ print("Warning!: in make_uvpoints, some SEFDs are <= 0!")
+
+ sig_rr = obsh.blnoise(array.tarr[i1]['sefdr'], array.tarr[i2]['sefdr'], tint, bw)
+ sig_ll = obsh.blnoise(array.tarr[i1]['sefdl'], array.tarr[i2]['sefdl'], tint, bw)
+ sig_rl = obsh.blnoise(array.tarr[i1]['sefdr'], array.tarr[i2]['sefdl'], tint, bw)
+ sig_lr = obsh.blnoise(array.tarr[i1]['sefdl'], array.tarr[i2]['sefdr'], tint, bw)
+ if polrep == 'stokes':
+ sig_iv = 0.5*np.sqrt(sig_rr**2 + sig_ll**2)
+ sig_qu = 0.5*np.sqrt(sig_rl**2 + sig_lr**2)
+ sig1 = sig_iv
+ sig2 = sig_qu
+ sig3 = sig_qu
+ sig4 = sig_iv
+ elif polrep == 'circ':
+ sig1 = sig_rr
+ sig2 = sig_ll
+ sig3 = sig_rl
+ sig4 = sig_lr
+
+ uvdat = obsh.compute_uv_coordinates(array, site1, site2, times, mjd,
+ ra, dec, rf, timetype=timetype,
+ elevmin=elevmin, elevmax=elevmax,
+ no_elevcut_space=no_elevcut_space,
+ fix_theta_GMST=fix_theta_GMST)
+
+ (timesout, uout, vout) = uvdat
+ for k in range(len(timesout)):
+ outlist.append(np.array((
+ timesout[k],
+ tint, # Integration
+ site1, # Station 1
+ site2, # Station 2
+ tau1, # Station 1 zenith optical depth
+ tau2, # Station 1 zenith optical depth
+ uout[k], # u (lambda)
+ vout[k], # v (lambda)
+ 0.0, # 1st Visibility (Jy)
+ 0.0, # 2nd Visibility
+ 0.0, # 3rd Visibility
+ 0.0, # 4th Visibility
+ sig1, # 1st Sigma (Jy)
+ sig2, # 2nd Sigma
+ sig3, # 3rd Sigma
+ sig4 # 4th Sigma
+ ), dtype=poltype
+ ))
+
+ obsarr = np.array(outlist)
+
+ if not len(obsarr):
+ raise Exception("No mutual visibilities in the specified time range!")
+
+ return obsarr
+
+##################################################################################################
+# Observe w/o noise
+##################################################################################################
+
+
+def sample_vis(im_org, uv, sgrscat=False, polrep_obs='stokes',
+ ttype="nfft", cache=False, fft_pad_factor=2, zero_empty_pol=True, verbose=True):
+ """Observe a image on given baselines with no noise.
+
+ Args:
+ im (Image): the image to be observed
+ uv (ndarray): an array of u,v coordinates
+ sgrscat (bool): if True, the visibilites are blurred by the Sgr A* scattering kernel
+ polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representtion
+ ttype (str): 'direct' or 'fast' or 'nfft'
+ fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT
+ zero_empty_pol (bool): if True, returns zero vec if the polarization doesn't exist.
+ Otherwise return None
+ verbose (bool): Boolean value controls output prints.
+
+ Returns:
+ (Obsdata): an observation object
+ """
+
+ if polrep_obs == 'stokes':
+ im = im_org.switch_polrep('stokes', 'I')
+ pollist = ['I', 'Q', 'U', 'V'] # TODO what if we have to I image?
+ elif polrep_obs == 'circ':
+ im = im_org.switch_polrep('circ', 'RR')
+ pollist = ['RR', 'LL', 'RL', 'LR'] # TODO what if we have to RR image?
+ else:
+ raise Exception("only 'stokes' and 'circ' are supported polreps!")
+
+ uv = np.array(uv)
+ if uv.shape[1] != 2:
+ raise Exception("When given as a list of uv points, " +
+ "the obs should be a list of pairs of u-v coordinates!")
+ if im.pa != 0.0:
+ c = np.cos(im.pa)
+ s = np.sin(im.pa)
+ u = uv[:, 0]
+ v = uv[:, 1]
+ uv = np.column_stack([c * u - s * v,
+ s * u + c * v])
+
+# umin = np.min(np.sqrt(uv[:,0]**2 + uv[:,1]**2))
+# umax = np.max(np.sqrt(uv[:,0]**2 + uv[:,1]**2))
+# if not im.psize < 1.0/(2.0*umax):
+# print(" Warning!: longest baseline > 1/2 x maximum image spatial wavelength!")
+# if not im.psize*np.sqrt(im.xdim*im.ydim) > 1.0/(0.5*umin):
+# print(" Warning!: shortest baseline < 2 x minimum image spatial wavelength!")
+
+ obsdata = []
+
+ # Get visibilities from straightforward FFT
+ if ttype == "fast":
+
+ # Padded image size
+ npad = fft_pad_factor * np.max((im.xdim, im.ydim))
+ npad = obsh.power_of_two(npad)
+
+ padvalx1 = padvalx2 = int(np.floor((npad - im.xdim)/2.0))
+ if im.xdim % 2:
+ padvalx2 += 1
+ padvaly1 = padvaly2 = int(np.floor((npad - im.ydim)/2.0))
+ if im.ydim % 2:
+ padvaly2 += 1
+
+ imarr = im.imvec.reshape(im.ydim, im.xdim)
+ imarr = np.pad(imarr, ((padvalx1, padvalx2), (padvaly1, padvaly2)),
+ 'constant', constant_values=0.0)
+ npad = imarr.shape[0]
+ if imarr.shape[0] != imarr.shape[1]:
+ raise Exception("FFT padding did not return a square image!")
+
+ # Scaled uv points
+ du = 1.0/(npad*im.psize)
+ uv2 = np.hstack((uv[:, 1].reshape(-1, 1), uv[:, 0].reshape(-1, 1)))
+ uv2 = (uv2/du + 0.5*npad).T
+
+ # Extra phase to match centroid convention
+ phase = np.exp(-1j*np.pi*im.psize*((1+im.xdim % 2)*uv[:, 0] + (1+im.ydim % 2)*uv[:, 1]))
+
+ # Pulse function
+ pulsefac = np.fromiter(
+ (im.pulse(2*np.pi*uvpt[0], 2*np.pi*uvpt[1], im.psize, dom="F") for uvpt in uv), 'c16')
+
+ for i in range(4):
+ pol = pollist[i]
+ imvec = im._imdict[pol]
+ if imvec is None or len(imvec) == 0:
+ if zero_empty_pol:
+ obsdata.append(np.zeros(len(uv)))
+ else:
+ obsdata.append(None)
+ else:
+ # FFT for visibilities
+ if pol in im_org.cached_fft:
+ vis_im = im_org.cached_fft[pol]
+ else:
+ imarr = imvec.reshape(im.ydim, im.xdim)
+ imarr = np.pad(imarr, ((padvalx1, padvalx2), (padvaly1, padvaly2)),
+ 'constant', constant_values=0.0)
+ vis_im = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(imarr)))
+ if cache == 'auto':
+ im_org.cached_fft[pol] = vis_im
+
+ # Sample the visibilities
+ # default is cubic spline interpolation
+ visre = nd.map_coordinates(np.real(vis_im), uv2)
+ visim = nd.map_coordinates(np.imag(vis_im), uv2)
+ vis = visre + 1j*visim
+
+ # Extra phase and pulse factor
+ vis = vis * phase * pulsefac
+
+ # Return visibilities
+ obsdata.append(vis)
+
+ # Get visibilities from the NFFT
+ elif ttype == "nfft":
+
+ uvdim = len(uv)
+ if (im.xdim % 2 or im.ydim % 2):
+ raise Exception("NFFT doesn't work with odd image dimensions!")
+
+ npad = fft_pad_factor * np.max((im.xdim, im.ydim))
+
+ # TODO what is a good kernel size??
+ nker = np.floor(np.min((im.xdim, im.ydim))/5)
+ if (nker > 50):
+ nker = 50
+ elif (im.xdim < 50 or im.ydim < 50):
+ nker = np.min((im.xdim, im.ydim))/2
+
+ # TODO are y & x reversed?
+ plan = NFFT([im.xdim, im.ydim], uvdim, m=nker, n=[npad, npad])
+
+ # Sampled uv points
+ uvlist = uv*im.psize
+
+ # Precompute
+ plan.x = uvlist
+ plan.precompute()
+
+ # Extra phase and pulsefac
+ phase = np.exp(-1j*np.pi*(uvlist[:, 0] + uvlist[:, 1]))
+ pulsefac = np.fromiter((im.pulse(2*np.pi*uvlist[i, 0], 2*np.pi*uvlist[i, 1], 1., dom="F")
+ for i in range(uvdim)), 'c16')
+
+ # Compute the uniform --> nonuniform transform for different polarizations
+ for i in range(4):
+ pol = pollist[i]
+ imvec = im._imdict[pol]
+ if imvec is None or len(imvec) == 0:
+ if zero_empty_pol:
+ obsdata.append(np.zeros(len(uv)))
+ else:
+ obsdata.append(None)
+ else:
+ plan.f_hat = imvec.copy().reshape((im.ydim, im.xdim)).T
+ plan.trafo()
+ vis = plan.f.copy()*phase*pulsefac
+
+ obsdata.append(vis)
+
+ elif ttype == "DFT":
+ xfov, yfov = im.xdim*im.psize/4.84813681109536e-12, im.ydim*im.psize/4.84813681109536e-12
+ for i in range(4):
+ pol = pollist[i]
+ imvec = im._imdict[pol]
+ if imvec is None or len(imvec) == 0:
+ if zero_empty_pol:
+ obsdata.append(np.zeros(len(uv)))
+ else:
+ obsdata.append(None)
+ else:
+ imarr = imvec.reshape(im.ydim, im.xdim)
+ vis = DFT(imarr, uv, xfov=xfov, yfov=yfov)
+ obsdata.append(vis)
+
+ elif ttype == "DFT_i":
+ xfov, yfov = im.xdim*im.psize/4.84813681109536e-12, im.ydim*im.psize/4.84813681109536e-12
+ for i in range(4):
+ pol = pollist[i]
+ imvec = im._imdict[pol]
+ if imvec is None or len(imvec) == 0:
+ if zero_empty_pol:
+ obsdata.append(np.zeros(len(uv)))
+ else:
+ obsdata.append(None)
+ else:
+ uv = np.array([uv[:,1], uv[:,0]]).T # uv swap hack
+ imarr = imvec.reshape(im.ydim, im.xdim)
+ vis = DFT(imarr, uv, xfov=xfov, yfov=yfov)
+ obsdata.append(vis)
+
+
+ # Get visibilities from DTFT
+ else:
+ # Construct Fourier matrix
+ mat = obsh.ftmatrix(im.psize, im.xdim, im.ydim, uv, pulse=im.pulse)
+
+ # Compute DTFT for different polarizations
+ for i in range(4):
+ pol = pollist[i]
+ imvec = im._imdict[pol]
+ if imvec is None or len(imvec) == 0:
+ if zero_empty_pol:
+ obsdata.append(np.zeros(len(uv)))
+ else:
+ obsdata.append(None)
+ else:
+ vis = np.dot(mat, imvec)
+ obsdata.append(vis)
+
+ # Scatter the visibilities with the SgrA* kernel
+ if sgrscat:
+ if verbose:
+ print('Scattering Visibilities with Sgr A* kernel!')
+ ker = obsh.sgra_kernel_uv(im.rf, uv[:, 0], uv[:, 1])
+ for data in obsdata:
+ if data is None:
+ continue
+ data *= ker
+
+ return obsdata
+
+
+def DFT(data, uv, xfov=225, yfov=225):
+ if data.ndim == 2:
+ data = data[np.newaxis,...]
+ out_shape = (uv.shape[0],)
+ elif data.ndim > 2:
+ data = data.reshape((-1,) + data.shape[-2:])
+ out_shape = data.shape[:-2] + (uv.shape[0],)
+ ny, nx = data.shape[-2:]
+ dx = xfov*4.84813681109536e-12 / nx
+ dy = yfov*4.84813681109536e-12 / ny
+ angx = (np.arange(nx) - nx//2) * dx
+ angy = (np.arange(ny) - ny//2) * dy
+ lvect = np.sin(angx)
+ mvect = np.sin(angy)
+ l, m = np.meshgrid(lvect, mvect)
+ lm = np.concatenate([l.reshape(1,-1), m.reshape(1,-1)], axis=0)
+ imgvect = data.reshape((data.shape[0],-1))
+ x = -2*np.pi*np.dot(uv,lm)[np.newaxis, ...]
+ visr = np.sum(imgvect[:, np.newaxis, :] * np.cos(x, dtype=np.float32), axis=-1)
+ visi = np.sum(imgvect[:, np.newaxis, :] * np.sin(x, dtype=np.float32), axis=-1)
+ if data.ndim == 2:
+ vis = visr.ravel() + 1j*visi.ravel()
+ else:
+ vis = visr.ravel() + 1j*visi.ravel()
+ vis = vis.reshape(out_shape)
+ return vis
+
+
+##################################################################################################
+# Noise + miscalibration funcitons
+##################################################################################################
+
+
+def make_jones(obs, opacitycal=True, ampcal=True, phasecal=True, dcal=True,
+ frcal=True, rlgaincal=True,
+ stabilize_scan_phase=False, stabilize_scan_amp=False, neggains=False,
+ taup=ehc.GAINPDEF,
+ gainp=ehc.GAINPDEF, gain_offset=ehc.GAINPDEF,
+ phase_std=-1,
+ dterm_offset=ehc.DTERMPDEF,
+ rlratio_std=0., rlphase_std=0.,
+ sigmat=None,phasesigmat=None,rlgsigmat=None,rlpsigmat=None,
+ caltable_path=None, seed=False):
+ """Computes Jones Matrices for a list of times (non repeating), with gain and dterm errors.
+
+ Args:
+ obs (Obsdata): the observation with scans for the Jones matrices to be computed
+ opacitycal (bool): if False, time-dependent gaussian errors are added to station opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to complex station gains
+ phasecal (bool): if False, time-dependent random phases are added to complex station gains
+ dcal (bool): if False, time-dependent gaussian errors are added to D-terms.
+ frcal (bool): if False, feed rotation angle terms are added to Jones matrices.
+ rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ neggains (bool): if True, force the applied gains to be <1
+
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gainp (float): the fractional std. dev. of the random error on the gains
+ or a dict giving one std. dev. per site
+
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ phase_std (float): std. dev. of LCP phase,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+ dterm_offset (float): the base std. dev. of random additive error at all sites,
+ or a dict giving one std. dev. per site
+
+ rlratio_std (float): the fractional std. dev. of the R/L gain offset
+ or a dict giving one std. dev. per site
+ rlphase_std (float): std. dev. of R/L phase offset,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+
+ sigmat (float): temporal std for a Gaussian Process used to generate gains.
+ If sigmat=None then an iid gain noise is applied.
+ phasesigmat (float): temporal std for a Gaussian Process used to generate phases.
+ If phasesigmat=None then an iid gain noise is applied.
+ rlgsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L gain ratios.
+ If rlgsigmat=None then an iid gain noise is applied.
+ rlpsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L phase diff.
+ If rlpsigmat=None then an iid gain noise is applied.
+
+ caltable_path (string): If not None, path and prefix for saving the applied caltable
+ seed : a seed for the random number generators, uses system time if false
+ Returns:
+ (dict): a nested dictionary of matrices indexed by the site, then by the time
+ """
+
+ obs_tmp = obs.copy()
+ tlist = obs_tmp.tlist()
+ tarr = obs_tmp.tarr
+ ra = obs_tmp.ra
+ dec = obs_tmp.dec
+ sourcevec = np.array([np.cos(dec*ehc.DEGREE), 0, np.sin(dec*ehc.DEGREE)])
+
+ # Create a dictionary of taus and a list of unique times
+ nsites = len(obs_tmp.tarr['site'])
+ taudict = {site: np.array([]) for site in obs_tmp.tarr['site']}
+ times = np.array([])
+ for scan in tlist:
+ time = scan['time'][0]
+ times = np.append(times, time)
+ sites_in = np.array([])
+ for bl in scan:
+
+ # Should we screen for conflicting same-time measurements of tau?
+ if len(sites_in) >= nsites:
+ break
+
+ if (not len(sites_in)) or (not bl['t1'] in sites_in):
+ taudict[bl['t1']] = np.append(taudict[bl['t1']], bl['tau1'])
+ sites_in = np.append(sites_in, bl['t1'])
+
+ if (not len(sites_in)) or (not bl['t2'] in sites_in):
+ taudict[bl['t2']] = np.append(taudict[bl['t2']], bl['tau2'])
+ sites_in = np.append(sites_in, bl['t2'])
+
+ if len(sites_in) < nsites:
+ for site in obs_tmp.tarr['site']:
+ if site not in sites_in:
+ taudict[site] = np.append(taudict[site], 0.0)
+
+ # Now define a list that accounts for periods where the phase or amplitude errors
+ # are stable (e.g., over scans if stabilize_scan_phase==True)
+ times_stable_phase = times.copy()
+ times_stable_amp = times.copy()
+ times_stable = times.copy()
+ if stabilize_scan_phase is True or stabilize_scan_amp is True:
+ scans = obs_tmp.scans
+ if np.all(scans) is None or len(scans) == 0:
+ obs_scans = obs.copy()
+ obs_scans.add_scans()
+ scans = obs_scans.scans
+ for j in range(len(times_stable)):
+ for scan in scans:
+ if scan[0] <= times_stable[j] and scan[1] >= times_stable[j]:
+ times_stable[j] = scan[0]
+ break
+
+ if stabilize_scan_phase is True:
+ times_stable_phase = times_stable.copy()
+ if stabilize_scan_amp is True:
+ times_stable_amp = times_stable.copy()
+
+ # Compute Sidereal Times
+ if obs.timetype == 'GMST':
+ times_sid = times
+ else:
+ times_sid = obsh.utc_to_gmst(times, obs.mjd)
+
+ # Seed for random number generators
+ if seed is False:
+ seed = str(ttime.time())
+
+ # Generate Jones Matrices at each time for each telescope
+ out = {}
+ datatables = {}
+ for i in range(len(tarr)):
+ site = tarr[i]['site']
+ coords = np.array([tarr[i]['x'], tarr[i]['y'], tarr[i]['z']])
+ latlon = obsh.xyz_2_latlong(coords)
+
+ # Elevation Angles
+ thetas = np.mod((times_sid - ra)*ehc.HOUR, 2*np.pi)
+ el_angles = obsh.elev(obsh.earthrot(coords, thetas), sourcevec)
+
+ # Parallactic Angles
+ hr_angles = obsh.hr_angle(times_sid*ehc.HOUR, latlon[:, 1], ra*ehc.HOUR)
+ par_angles = obsh.par_angle(hr_angles, latlon[:, 0], dec*ehc.DEGREE)
+
+ # gain offset: time independent part
+ if type(gain_offset) == dict:
+ goff = gain_offset[site]
+ else:
+ goff = gain_offset
+
+ # gain_mult: time dependent part
+ if type(gainp) == dict:
+ gain_mult = gainp[site]
+ else:
+ gain_mult = gainp
+
+ # phase mult - phase std deviation (-1 = uniform)
+ if type(phase_std) == dict:
+ phase_mult = phase_std[site]
+ else:
+ phase_mult = phase_std
+
+ # gainratio_mult: time dependent R/L gain offset
+ if type(rlratio_std) == dict:
+ gainratio_mult = rlratio_std[site]
+ else:
+ gainratio_mult = rlratio_std
+
+ # phasediff_mult: time dependent R-L phase offset
+ if type(rlphase_std) == dict:
+ phasediff_mult = rlphase_std[site]
+ else:
+ phasediff_mult = rlphase_std
+
+ # correlation timescales
+ if type(sigmat) == dict:
+ sigt_g = sigmat[site]
+ else:
+ sigt_g = sigmat
+
+ if type(phasesigmat) == dict:
+ sigt_p = phasesigmat[site]
+ else:
+ sigt_p = phasesigmat
+
+ if type(rlgsigmat) == dict:
+ sigt_rlg = rlgsigmat[site]
+ else:
+ sigt_rlg = rlgsigmat
+
+ if type(rlpsigmat) == dict:
+ sigt_rlp = rlpsigmat[site]
+ else:
+ sigt_rlp = rlpsigmat
+
+ # Amplitude gains
+ gainR = gainL = np.ones(len(times))
+ if not ampcal:
+
+ # mean LCP gain
+ gainL_constant = goff * obsh.hashrandn(site, 'gain', str(goff), seed)
+
+ # Enforce mean log gain < 1
+ if neggains:
+ gainL_constant = -np.abs(gainL_constant)
+
+ # LCP gain
+ if sigt_g is None: # iid sampling in time
+
+ gainL = np.sqrt(np.abs(np.fromiter((
+ (1.0 + gainL_constant) *
+ (1.0 + gain_mult * obsh.hashrandn(site, 'gain', str(time), str(gain_mult), seed))
+ for time in times_stable_amp
+ ), float)))
+
+ elif sigt_g <=0: # single sample in time
+
+ gainL = np.sqrt(np.abs(np.fromiter(((1.0 + gainL_constant)
+ for time in times_stable_amp), float)))
+
+ else: # correlated sampling in time
+ scan_start_times = scans[:, 0]
+ cov = obsh.rbf_kernel_covariance(scan_start_times, sigt_g)
+ randLx = obsh.hashmultivariaterandn(len(scan_start_times), cov, site,
+ 'gain', str(time), str(gain_mult), seed)
+ gainL = np.sqrt(np.abs((1.0 + gainL_constant) * (1.0 + gain_mult * randLx)))
+
+
+ gainL_interpolateor = interp1d(scan_start_times, gainL, kind='zero')
+ gainL = gainL_interpolateor(times_stable_amp)
+
+ # R/L gain offset (if present)
+ if rlgaincal:
+ gain_RLratio = 1.
+ else:
+ if sigt_rlg is None: #iid sampling in time
+
+ gain_RLratio = np.abs(np.fromiter((
+ (1.0 + gainratio_mult * obsh.hashrandn(site, 'gainratio', str(time),
+ str(gainratio_mult), seed))
+ for time in times_stable_amp), float))
+
+ elif sigt_rlg <=0: # single sample in time
+ gain_RLratio = np.abs(np.fromiter((
+ (1.0 + gainratio_mult * obsh.hashrandn(site, 'gainratio',
+ str(gainratio_mult), seed))
+ for time in times_stable_amp), float))
+
+ else: #correlated sampling in time
+ scan_start_times = scans[:, 0]
+ cov = obsh.rbf_kernel_covariance(scan_start_times, sigt_rlg)
+ randRLx = obsh.hashmultivariaterandn(len(scan_start_times), cov, site,
+ 'gainratio', str(time), str(gainratio_mult),
+ seed)
+ gain_RLratio = np.abs(1.0 + gainratio_mult * randRLx)
+ gainRLratio_interpolateor = interp1d(scan_start_times, gain_RLratio, kind='zero')
+ gain_RLratio = gainRLratio_interpolateor(times_stable_amp)
+
+ # RCP gain
+ gainR = gain_RLratio * gainL
+
+ # enforce gains < 1
+ # TODO -- will this mess up gain offset priors?
+ if neggains:
+ gainR = np.exp(-np.abs(np.log(gainR)))
+ gainL = np.exp(-np.abs(np.log(gainL)))
+
+ # Opacity attenuation of amplitude gain
+ if not opacitycal:
+ taus = np.abs(np.fromiter((
+ (taudict[site][j]) *
+ (1.0 + taup * obsh.hashrandn(site, 'tau', times_stable_amp[j], seed))
+ for j in range(len(times))), float))
+ atten = np.exp(-taus/(ehc.EP + 2.0*np.sin(el_angles)))
+
+ gainR = gainR * atten
+ gainL = gainL * atten
+
+ # Atmospheric Phase
+ if not phasecal:
+
+ # Gaussian distribution of LCP phase
+ if phase_mult >=0:
+
+ if sigt_p is None: #iid sampling in time
+ phaseL = np.fromiter((phase_mult * obsh.hashrandn(site, 'phase', str(time),
+ str(phase_mult), seed)
+ for time in times_stable_phase), float)
+
+ elif sigt_p <=0: # single sample in time
+ phaseL = np.fromiter((phase_mult * obsh.hashrandn(site, 'phase',
+ str(phase_mult), seed)
+ for time in times_stable_phase), float)
+ else: #correlated sampling in time
+ scan_start_times = scans[:, 0]
+ cov = obsh.rbf_kernel_covariance(scan_start_times, sigt_p)
+ phaseL = phase_mult* obsh.hashmultivariaterandn(len(scan_start_times), cov, site,
+ 'phase', str(time), str(phase_mult),
+ seed)
+ phaseL_interpolateor = interp1d(scan_start_times, phaseL, kind='zero')
+ phaseL = phaseL_interpolateor(times_stable_phase)
+
+ # flat distribution of LCP phase, iid in time
+ # TODO correlated sampling with flat phases?
+ else:
+
+ phaseL = np.fromiter((2 * np.pi * obsh.hashrand(site, 'phase', time, seed)
+ for time in times_stable_phase), float)
+
+ # R-L phase offset
+ if rlgaincal:
+ phaseRLdiff = 0.
+ else:
+ # Gaussian distributed phase difference
+ if phasediff_mult >=0:
+ if sigt_rlp is None: #iid sampling in time
+ phaseRLdiff = np.fromiter((phasediff_mult *
+ obsh.hashrandn(site, 'phasediff', str(time),
+ str(phasediff_mult), seed)
+ for time in times_stable_phase), float)
+ elif sigt_rlp <=0: # single sample in time
+ phaseRLdiff = np.fromiter((phasediff_mult *
+ obsh.hashrandn(site, 'phasediff',
+ str(phasediff_mult), seed)
+ for time in times_stable_phase), float)
+ else: #correlated sampling in time
+ scan_start_times = scans[:, 0]
+ cov = obsh.rbf_kernel_covariance(scan_start_times, sigt_rlp)
+ phaseRLdiff = phasediff_mult * obsh.hashmultivariaterandn(len(scan_start_times),
+ cov, site, 'phasediff',
+ str(time), str(phasediff_mult),
+ seed)
+ phaseRL_interpolateor = interp1d(scan_start_times, phaseRLdiff, kind='zero')
+ phaseRLdiff = phaseRL_interpolateor(times_stable_phase)
+
+ # flat distribution phase difference, iid in time
+ # TODO correlated sampling with flat phases?
+ else:
+ phaseRLdiff = np.fromiter((2 * np.pi * obsh.hashrand(site, 'phase', time, seed)
+ for time in times_stable_phase), float)
+ phaseRLdiff -= np.pi
+
+ # Complex gains
+ gainL = gainL * np.exp(1j*phaseL)
+ gainR = gainR * np.exp(1j*(phaseL + phaseRLdiff))
+
+
+ # D Term errors
+ dR = dL = 0.0
+ if not dcal:
+
+ # D terms are always time-independent
+ if type(dterm_offset) == dict:
+ doff = dterm_offset[site]
+ else:
+ doff = dterm_offset
+
+ dR = tarr[i]['dr']
+ dL = tarr[i]['dl']
+
+ dR += doff * (obsh.hashrandn(site, 'dRre', seed) +
+ 1j * obsh.hashrandn(site, 'dRim', seed))
+ dL += doff * (obsh.hashrandn(site, 'dLre', seed) +
+ 1j * obsh.hashrandn(site, 'dLim', seed))
+
+ # Feed Rotation Angles
+ fr_angle = np.zeros(len(times))
+ fr_angle_D = np.zeros(len(times))
+
+ # Field rotation has not been corrected
+ if not frcal:
+ fr_angle = tarr[i]['fr_elev']*el_angles + \
+ tarr[i]['fr_par']*par_angles + tarr[i]['fr_off']*ehc.DEGREE
+
+ # If field rotation has been corrected, but leakage has NOT been corrected,
+ # the leakage needs to rotate doubly
+ elif frcal and not dcal:
+ fr_angle_D = 2.0*(tarr[i]['fr_elev']*el_angles + tarr[i]
+ ['fr_par']*par_angles + tarr[i]['fr_off']*ehc.DEGREE)
+
+ # Assemble the Jones Matrices and save to dictionary
+ j_matrices = {times[j]: np.array([
+ [np.exp(-1j*fr_angle[j])*gainR[j],
+ np.exp(1j*(fr_angle[j]+fr_angle_D[j]))*dR*gainR[j]],
+ [np.exp(-1j*(fr_angle[j]+fr_angle_D[j]))*dL*gainL[j],
+ np.exp(1j*fr_angle[j])*gainL[j]]
+ ])
+ for j in range(len(times))
+ }
+
+ out[site] = j_matrices
+
+ if caltable_path:
+ obs_tmp.tarr[i]['dr'] = dR
+ obs_tmp.tarr[i]['dl'] = dL
+ datatable = []
+ for j in range(len(times)):
+ datatable.append(np.array((times[j], gainR[j], gainL[j]), dtype=ehc.DTCAL))
+ datatables[site] = np.array(datatable)
+
+ # Save a calibration table with the synthetic gains and dterms added
+ if caltable_path and len(datatables) > 0:
+ from ehtim.caltable import Caltable # TODO blah circular imports
+ caltable = Caltable(obs_tmp.ra, obs_tmp.dec, obs_tmp.rf, obs_tmp.bw,
+ datatables, obs_tmp.tarr, source=obs_tmp.source,
+ mjd=obs_tmp.mjd, timetype=obs_tmp.timetype)
+
+ caltable.save_txt(obs_tmp, datadir=caltable_path+'_simdata_caltable')
+
+ return out
+
+
+def make_jones_inverse(obs, opacitycal=True, dcal=True, frcal=True):
+ """Computes inverse Jones Matrices for a list of times (non repeating),
+ with NO gain and dterm errors.
+
+ Args:
+ obs (Obsdata): the observation with scans for the inverse Jones matrices to be computed
+ opacitycal (bool): if False, estimated opacity terms are applied in the inverse gains
+ dcal (bool): if False, estimated d-terms are applied to the inverse Jones matrices
+ frcal (bool): if False, inverse feed rotation angle terms are applied to Jones matrices.
+
+ Returns:
+ (dict): a nested dictionary of matrices indexed by the site, then by the time
+ """
+
+ # Get data
+ tlist = obs.tlist()
+ tarr = obs.tarr
+ ra = obs.ra
+ dec = obs.dec
+ sourcevec = np.array([np.cos(dec*ehc.DEGREE), 0, np.sin(dec*ehc.DEGREE)])
+
+ # Create a dictionary of taus and a list of unique times
+ nsites = len(obs.tarr['site'])
+ taudict = {site: np.array([]) for site in obs.tarr['site']}
+ times = np.array([])
+ for scan in tlist:
+ time = scan['time'][0]
+ times = np.append(times, time)
+ sites_in = np.array([])
+ for bl in scan:
+
+ # Should we screen for conflicting same-time measurements of tau?
+ if len(sites_in) >= nsites:
+ break
+
+ if (not len(sites_in)) or (not bl['t1'] in sites_in):
+ taudict[bl['t1']] = np.append(taudict[bl['t1']], bl['tau1'])
+ sites_in = np.append(sites_in, bl['t1'])
+
+ if (not len(sites_in)) or (not bl['t2'] in sites_in):
+ taudict[bl['t2']] = np.append(taudict[bl['t2']], bl['tau2'])
+ sites_in = np.append(sites_in, bl['t2'])
+ if len(sites_in) < nsites:
+ for site in obs.tarr['site']:
+ if site not in sites_in:
+ taudict[site] = np.append(taudict[site], 0.0)
+
+ # Compute Sidereal Times
+ if obs.timetype == 'GMST':
+ times_sid = times
+ else:
+ times_sid = obsh.utc_to_gmst(times, obs.mjd)
+
+ # Make inverse Jones Matrices
+ out = {}
+ for i in range(len(tarr)):
+ site = tarr[i]['site']
+ coords = np.array([tarr[i]['x'], tarr[i]['y'], tarr[i]['z']])
+ latlon = obsh.xyz_2_latlong(coords)
+
+ # Elevation Angles
+ thetas = np.mod((times_sid - ra)*ehc.HOUR, 2*np.pi)
+ el_angles = obsh.elev(obsh.earthrot(coords, thetas), sourcevec)
+
+ # Parallactic Angles (positive longitude EAST)
+ hr_angles = obsh.hr_angle(times_sid*ehc.HOUR, latlon[:, 1], ra*ehc.HOUR)
+ par_angles = obsh.par_angle(hr_angles, latlon[:, 0], dec*ehc.DEGREE)
+
+ # Amplitude gain assumed 1
+ gainR = gainL = np.ones(len(times))
+
+ # Opacity attenuation of amplitude gain
+ if not opacitycal:
+ taus = np.abs(np.array(taudict[site]))
+ atten = np.exp(-taus/(ehc.EP + 2.0*np.sin(el_angles)))
+
+ gainR = gainR * atten
+ gainL = gainL * atten
+
+ # D Terms
+ dR = dL = 0.0
+ if not dcal:
+ dR = tarr[i]['dr']
+ dL = tarr[i]['dl']
+
+ # Feed Rotation Angles
+ fr_angle = np.zeros(len(times))
+ # This is for when field rotation is corrected but not leakage
+ fr_angle_D = np.zeros(len(times))
+ if not frcal:
+ # Total Angle (Radian)
+ fr_angle = (tarr[i]['fr_elev']*el_angles + tarr[i]['fr_par']
+ * par_angles + tarr[i]['fr_off']*ehc.DEGREE)
+
+ elif frcal and not dcal:
+ # If the field rotation angle has been removed but leakage hasn't,
+ # we still need to rotate the leakage terms appropriately
+ # by *twice* the field rotation angle
+ fr_angle_D = 2.0*(tarr[i]['fr_elev']*el_angles + tarr[i]
+ ['fr_par']*par_angles + tarr[i]['fr_off']*ehc.DEGREE)
+
+ # Assemble the inverse Jones Matrices and save to dictionary
+ pref = 1.0/(gainL*gainR*(1.0 - dL*dR))
+ j_matrices_inv = {times[j]: pref[j]*np.array([
+ [np.exp(1j*fr_angle[j])*gainL[j],
+ -np.exp(1j*(fr_angle[j] + fr_angle_D[j]))*dR*gainR[j]],
+ [-np.exp(-1j*(fr_angle[j] + fr_angle_D[j]))*dL*gainL[j],
+ np.exp(-1j*fr_angle[j])*gainR[j]]
+ ]) for j in range(len(times))
+ }
+
+ out[site] = j_matrices_inv
+
+ return out
+
+
+def add_jones_and_noise(obs, add_th_noise=True,
+ opacitycal=True, ampcal=True, phasecal=True, dcal=True,
+ frcal=True, rlgaincal=True,
+ stabilize_scan_phase=False, stabilize_scan_amp=False,
+ neggains=False,
+ taup=ehc.GAINPDEF,
+ gainp=ehc.GAINPDEF,gain_offset=ehc.GAINPDEF,
+ phase_std=-1,
+ dterm_offset=ehc.DTERMPDEF,
+ rlratio_std=0., rlphase_std=0.,
+ sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None,
+ caltable_path=None, seed=False, verbose=True):
+ """Corrupt visibilities in obs with jones matrices and add thermal noise
+
+ Args:
+ obs (Obsdata): the original observation
+ add_th_noise (bool): if True, baseline-dependent thermal noise is added to each data point
+ opacitycal (bool): if False, time-dependent gaussian errors are added to station opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to complex station gains
+ phasecal (bool): if False, time-dependent random phases are added to complex station gains
+ dcal (bool): if False, time-dependent gaussian errors are added to D-terms.
+ frcal (bool): if False, feed rotation angle terms are added to Jones matrices.
+ rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol
+
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ neggains (bool): if True, force the applied gains to be <1
+
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gainp (float): the fractional std. dev. of the random error on the gains
+ or a dict giving one std. dev. per site
+
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ phase_std (float): std. dev. of LCP phase,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+ dterm_offset (float): the base std. dev. of random additive error at all sites,
+ or a dict giving one std. dev. per site
+
+ rlratio_std (float): the fractional std. dev. of the R/L gain offset
+ or a dict giving one std. dev. per site
+ rlphase_std (float): std. dev. of R/L phase offset,
+ or a dict giving one std. dev. per site
+ a negative value samples from uniform
+
+ sigmat (float): temporal std for a Gaussian Process used to generate gains.
+ If sigmat=None then an iid gain noise is applied.
+ phasesigmat (float): temporal std for a Gaussian Process used to generate phases.
+ If phasesigmat=None then an iid gain noise is applied.
+ rlgsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L gain ratios.
+ If rlgsigmat=None then an iid gain noise is applied.
+ rlpsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L phase diff.
+ If rlpsigmat=None then an iid gain noise is applied.
+
+ caltable_path (string): If not None, path and prefix for saving the applied caltable
+ seed (int): seeds the random component of the noise terms. DO NOT set to 0!
+ verbose (bool): print updates and warnings
+ Returns:
+ (np.array): an observation data array
+ """
+
+ if verbose:
+ print("Applying Jones Matrices to data . . . ")
+
+ # Build Jones Matrices
+ jm_dict = make_jones(obs,
+ ampcal=ampcal, opacitycal=opacitycal, phasecal=phasecal,
+ dcal=dcal, frcal=frcal, rlgaincal=rlgaincal,
+ stabilize_scan_phase=stabilize_scan_phase,
+ stabilize_scan_amp=stabilize_scan_amp, neggains=neggains,
+ taup=taup,
+ gainp=gainp, gain_offset=gain_offset,
+ phase_std=phase_std,
+ dterm_offset=dterm_offset,
+ rlratio_std=rlratio_std, rlphase_std=rlphase_std,
+ sigmat=sigmat,phasesigmat=phasesigmat,
+ rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat,
+ caltable_path=caltable_path, seed=seed)
+
+ # Change pol rep:
+ obs_circ = obs.switch_polrep('circ')
+ obsdata = copy.copy(obs_circ.data)
+
+ times = obsdata['time']
+ t1 = obsdata['t1']
+ t2 = obsdata['t2']
+ tints = obsdata['tint']
+ rr = obsdata['rrvis']
+ ll = obsdata['llvis']
+ rl = obsdata['rlvis']
+ lr = obsdata['lrvis']
+
+ # Recompute the noise std. deviations from the SEFDs
+ if np.any(obs.tarr['sefdr'] <= 0) or np.any(obs.tarr['sefdl'] <= 0):
+ if verbose:
+ print("Warning!: in add_jones_and_noise, some SEFDs are <= 0!")
+ print("Resorting to data point sigmas, which may add too much systematic noise!")
+ sig_rr = obsdata['rrsigma']
+ sig_ll = obsdata['llsigma']
+ sig_rl = obsdata['rlsigma']
+ sig_lr = obsdata['lrsigma']
+ else:
+ sig_rr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdr'],
+ obs.tarr[obs.tkey[t2[i]]]['sefdr'],
+ tints[i], obs.bw)
+ for i in range(len(rr))), float)
+ sig_ll = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdl'],
+ obs.tarr[obs.tkey[t2[i]]]['sefdl'],
+ tints[i], obs.bw)
+ for i in range(len(ll))), float)
+ sig_rl = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdr'],
+ obs.tarr[obs.tkey[t2[i]]]['sefdl'],
+ tints[i], obs.bw)
+ for i in range(len(rl))), float)
+ sig_lr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdl'],
+ obs.tarr[obs.tkey[t2[i]]]['sefdr'],
+ tints[i], obs.bw)
+ for i in range(len(lr))), float)
+
+ if verbose and not opacitycal:
+ print(" Applying opacity attenuation: opacitycal-->False")
+ if verbose and not ampcal:
+ print(" Applying gain corruption: ampcal-->False")
+ if verbose and not phasecal:
+ print(" Applying atmospheric phase corruption: phasecal-->False")
+ if verbose and not dcal:
+ print(" Applying D Term mixing: dcal-->False")
+ if verbose and not frcal:
+ print(" Applying Field Rotation: frcal-->False")
+ if verbose and add_th_noise:
+ print("Adding thermal noise to data . . . ")
+
+ # Corrupt each IQUV visibility set with the jones matrices and add noise
+ for i in range(len(times)):
+ # Form the visibility correlation matrix
+ corr_matrix = np.array([[rr[i], rl[i]], [lr[i], ll[i]]])
+
+ # Get the jones matrices and corrupt the corr_matrix
+ j1 = jm_dict[t1[i]][times[i]]
+ j2 = jm_dict[t2[i]][times[i]]
+
+ corr_matrix_corrupt = np.dot(j1, np.dot(corr_matrix, np.conjugate(j2.T)))
+
+ # Add noise
+ if add_th_noise:
+ noise_matrix = np.array([[obsh.cerror(sig_rr[i]), obsh.cerror(sig_rl[i])],
+ [obsh.cerror(sig_lr[i]), obsh.cerror(sig_ll[i])]])
+ corr_matrix_corrupt += noise_matrix
+
+ # Put the corrupted data back into the data table
+ obsdata['rrvis'][i] = corr_matrix_corrupt[0][0]
+ obsdata['llvis'][i] = corr_matrix_corrupt[1][1]
+ obsdata['rlvis'][i] = corr_matrix_corrupt[0][1]
+ obsdata['lrvis'][i] = corr_matrix_corrupt[1][0]
+
+ # Put the recomputed sigmas back into the data table
+ obsdata['rrsigma'][i] = sig_rr[i]
+ obsdata['llsigma'][i] = sig_ll[i]
+ obsdata['rlsigma'][i] = sig_rl[i]
+ obsdata['lrsigma'][i] = sig_lr[i]
+
+ # put back into input polvec
+ obs_circ.data = obsdata
+ obs_back = obs_circ.switch_polrep(obs.polrep)
+ obsdata_back = obs_back.data
+
+ # Return observation data
+ return obsdata_back
+
+
+def apply_jones_inverse(obs, opacitycal=True, dcal=True, frcal=True, verbose=True):
+ """Apply inverse jones matrices to an observation
+
+ Args:
+ opacitycal (bool): if False, estimated opacity terms are applied in the inverse gains
+ dcal (bool): if False, estimated d-terms applied to the inverse Jones matrices
+ frcal (bool): if False, feed rotation angle terms are applied to Jones matrices.
+ verbose (bool): print updates and warnings
+
+ Returns:
+ (np.array): an observation data array
+ """
+
+ if verbose:
+ print("Applying a priori calibration with estimated Jones matrices . . . ")
+
+ # Build Inverse Jones Matrices
+ jm_dict = make_jones_inverse(obs, opacitycal=opacitycal, dcal=dcal, frcal=frcal)
+
+ # Change pol rep:
+ obs_circ = obs.switch_polrep('circ')
+
+ # Get data
+ obsdata = copy.deepcopy(obs_circ.data)
+ times = obsdata['time']
+ t1 = obsdata['t1']
+ t2 = obsdata['t2']
+ tints = obsdata['tint']
+ rr = obsdata['rrvis']
+ ll = obsdata['llvis']
+ rl = obsdata['rlvis']
+ lr = obsdata['lrvis']
+
+ # Recompute the noise std. deviations from the SEFDs
+ if np.any(obs.tarr['sefdr'] <= 0) or np.any(obs.tarr['sefdl'] <= 0):
+ if verbose:
+ print("Warning!: in add_jones_and_noise, some SEFDs are <= 0!")
+ print("resorting to data point sigmas, which may add too much systematic noise!")
+ sig_rr = obsdata['rrsigma']
+ sig_ll = obsdata['llsigma']
+ sig_rl = obsdata['rlsigma']
+ sig_lr = obsdata['lrsigma']
+ else:
+ # TODO why are there sqrt(2)s here and not below?
+ sig_rr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdr'],
+ obs.tarr[obs.tkey[t2[i]]]['sefdr'],
+ tints[i], obs.bw)
+ for i in range(len(rr))), float)
+ sig_ll = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdl'],
+ obs.tarr[obs.tkey[t2[i]]]['sefdl'],
+ tints[i], obs.bw)
+ for i in range(len(ll))), float)
+ sig_rl = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdr'],
+ obs.tarr[obs.tkey[t2[i]]]['sefdl'],
+ tints[i], obs.bw)
+ for i in range(len(rl))), float)
+ sig_lr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdl'],
+ obs.tarr[obs.tkey[t2[i]]]['sefdr'],
+ tints[i], obs.bw)
+ for i in range(len(lr))), float)
+
+ if not opacitycal:
+ if verbose:
+ print(" Applying opacity corrections: opacitycal-->True")
+ opacitycal = True
+ if not dcal:
+ if verbose:
+ print(" Applying D Term corrections: dcal-->True")
+ dcal = True
+ if not frcal:
+ if verbose:
+ print(" Applying Field Rotation corrections: frcal-->True")
+ frcal = True
+
+ # Apply the inverse Jones matrices to each visibility
+ for i in range(len(times)):
+
+ # Get the inverse jones matrices
+ inv_j1 = jm_dict[t1[i]][times[i]]
+ inv_j2 = jm_dict[t2[i]][times[i]]
+
+ # Form the visibility correlation matrix
+ corr_matrix = np.array([[rr[i], rl[i]], [lr[i], ll[i]]])
+
+ # Form the sigma matrices
+ sig_rr_matrix = np.array([[sig_rr[i], 0.0], [0.0, 0.0]])
+ sig_ll_matrix = np.array([[0.0, 0.0], [0.0, sig_ll[i]]])
+ sig_rl_matrix = np.array([[0.0, sig_rl[i]], [0.0, 0.0]])
+ sig_lr_matrix = np.array([[0.0, 0.0], [sig_lr[i], 0.0]])
+
+ # Apply the inverse Jones Matrices to the visibility correlation matrix and sigma matrices
+ corr_matrix_new = np.dot(inv_j1, np.dot(corr_matrix, np.conjugate(inv_j2.T)))
+
+ sig_rr_matrix_new = np.dot(inv_j1, np.dot(sig_rr_matrix, np.conjugate(inv_j2.T)))
+ sig_ll_matrix_new = np.dot(inv_j1, np.dot(sig_ll_matrix, np.conjugate(inv_j2.T)))
+ sig_rl_matrix_new = np.dot(inv_j1, np.dot(sig_rl_matrix, np.conjugate(inv_j2.T)))
+ sig_lr_matrix_new = np.dot(inv_j1, np.dot(sig_lr_matrix, np.conjugate(inv_j2.T)))
+
+ # Get the final sigma matrix as a quadrature sum
+ sig_matrix_new = np.sqrt(np.abs(sig_rr_matrix_new)**2 + np.abs(sig_ll_matrix_new)**2 +
+ np.abs(sig_rl_matrix_new)**2 + np.abs(sig_lr_matrix_new)**2)
+
+ # Put the corrupted data back into the data table
+ obsdata['rrvis'][i] = corr_matrix_new[0][0]
+ obsdata['llvis'][i] = corr_matrix_new[1][1]
+ obsdata['rlvis'][i] = corr_matrix_new[0][1]
+ obsdata['lrvis'][i] = corr_matrix_new[1][0]
+
+ # Put the recomputed sigmas back into the data table
+ obsdata['rrsigma'][i] = sig_matrix_new[0][0]
+ obsdata['llsigma'][i] = sig_matrix_new[1][1]
+ obsdata['rlsigma'][i] = sig_matrix_new[0][1]
+ obsdata['lrsigma'][i] = sig_matrix_new[1][0]
+
+ # put back into input polvec
+ obs_circ.data = obsdata
+ obs_back = obs_circ.switch_polrep(obs.polrep)
+ obsdata_back = obs_back.data
+
+ # Return observation data
+ return obsdata_back
+
+# The old noise generating function.
+
+
+def add_noise(obs, add_th_noise=True, th_noise_factor=1, opacitycal=True, ampcal=True, phasecal=True,
+ stabilize_scan_amp=False, stabilize_scan_phase=False,
+ neggains=False,
+ taup=ehc.GAINPDEF, gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF,
+ caltable_path=None, seed=False, sigmat=None,
+ verbose=True):
+ """Add thermal noise and gain & phase calibration errors to a dataset.
+ Old routine replaced by add_jones_and_noise.
+
+ Args:
+ obs (Obsdata): the original observation
+ add_th_noise (bool): if True, baseline-dependent thermal noise is added to each data point
+ opacitycal (bool): if False, time-dependent gaussian errors are added to station opacities
+ ampcal (bool): if False, time-dependent gaussian errors are added to complex station gains
+ phasecal (bool): if False, time-dependent random phases are added to complex station gains
+ stabilize_scan_phase (bool): if True, random phase errors are constant over scans
+ stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans
+ neggains (bool): if True, force the applied gains to be <1
+ taup (float): the fractional std. dev. of the random error on the opacities
+ gain_offset (float): the base gain offset at all sites,
+ or a dict giving one gain offset per site
+ gainp (float): the fractional std. dev. of the random error on the gains
+
+ caltable_path (string): If not None, path and prefix for saving the applied caltable
+ NOT SUPPORTED for add_noise.
+ seed (int): seeds the random component of the noise terms. DO NOT set to 0!
+ sigmat (float): temporal std for a Gaussian Process used to generate gains.
+ NOT SUPPORTED for add_noise
+
+ verbose (bool): print updates and warnings
+ Returns:
+ (np.array): an observation data array
+ """
+
+ if caltable_path:
+ print("caltable saving not implemented for old add_noise function!")
+ if verbose:
+ print("Adding gain + phase errors to data and applying a priori calibration . . . ")
+
+ if verbose and not opacitycal:
+ print(" Applying opacity attenuation AND estimated opacity corrections: opacitycal-->True")
+ if verbose and not ampcal:
+ print(" Applying gain corruption: ampcal-->False")
+ if verbose and not phasecal:
+ print(" Applying atmospheric phase corruption: phasecal-->False")
+ if verbose and add_th_noise:
+ print("Adding thermal noise to data . . . ")
+
+ # Get data
+ obsdata = copy.deepcopy(obs.data)
+
+ sites = obsh.recarr_to_ndarr(obsdata[['t1', 't2']], 'U32')
+ taus = np.abs(obsh.recarr_to_ndarr(obsdata[['tau1', 'tau2']], 'f8'))
+ elevs = obsh.recarr_to_ndarr(obs.unpack(['el1', 'el2'], ang_unit='deg'), 'f8')
+ times = obsdata['time']
+ tint = obsdata['tint']
+ vis1 = obsdata[obs.poldict['vis1']]
+ vis2 = obsdata[obs.poldict['vis2']]
+ vis3 = obsdata[obs.poldict['vis3']]
+ vis4 = obsdata[obs.poldict['vis4']]
+
+ times_stable_phase = times.copy()
+ times_stable_amp = times.copy()
+ times_stable = times.copy()
+
+ if stabilize_scan_phase is True or stabilize_scan_amp is True:
+ scans = obs.scans
+ if np.all(scans) is None or len(scans) == 0:
+ if verbose:
+ print("Adding scan table")
+ obs_scans = obs.copy()
+ obs_scans.add_scans()
+ scans = obs_scans.scans
+ for j in range(len(times_stable)):
+ for scan in scans:
+ if scan[0] <= times_stable[j] and scan[1] >= times_stable[j]:
+ times_stable[j] = scan[0]
+ break
+
+ if stabilize_scan_phase is True:
+ times_stable_phase = times_stable.copy()
+ if stabilize_scan_amp is True:
+ times_stable_amp = times_stable.copy()
+
+ # Recompute perfect sigmas from SEFDs
+ bw = obs.bw
+ if np.any(obs.tarr['sefdr'] <= 0):
+ if verbose:
+ print("Warning!: in add_noise, some SEFDs are <= 0!")
+ print("NOT recomputing sigmas, which may result in double systematic noise")
+ sigma_perf1 = obsdata[obs.poldict['sigma1']]
+ sigma_perf2 = obsdata[obs.poldict['sigma2']]
+ sigma_perf3 = obsdata[obs.poldict['sigma3']]
+ sigma_perf4 = obsdata[obs.poldict['sigma4']]
+ else:
+ sig_rr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdr'],
+ obs.tarr[obs.tkey[sites[i][1]]]['sefdr'], tint[i], bw)
+ for i in range(len(tint))), float)
+ sig_ll = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdl'],
+ obs.tarr[obs.tkey[sites[i][1]]]['sefdl'], tint[i], bw)
+ for i in range(len(tint))), float)
+ sig_rl = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdr'],
+ obs.tarr[obs.tkey[sites[i][1]]]['sefdl'], tint[i], bw)
+ for i in range(len(tint))), float)
+ sig_lr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdl'],
+ obs.tarr[obs.tkey[sites[i][1]]]['sefdr'], tint[i], bw)
+ for i in range(len(tint))), float)
+ if obs.polrep == 'stokes':
+ sig_iv = 0.5*np.sqrt(sig_rr**2 + sig_ll**2)
+ sig_qu = 0.5*np.sqrt(sig_rl**2 + sig_lr**2)
+ sigma_perf1 = sig_iv
+ sigma_perf2 = sig_qu
+ sigma_perf3 = sig_qu
+ sigma_perf4 = sig_iv
+ elif obs.polrep == 'circ':
+ sigma_perf1 = sig_rr
+ sigma_perf2 = sig_ll
+ sigma_perf3 = sig_rl
+ sigma_perf4 = sig_lr
+
+ # Seed for random number generators
+ if seed is False:
+ seed = str(ttime.time())
+
+ # Add gain and opacity fluctuations to the TRUE noise
+ if not ampcal:
+ # Amplitude gain
+ if type(gain_offset) == dict:
+ goff1 = np.fromiter((gain_offset[sites[i, 0]] for i in range(len(times))), float)
+ goff2 = np.fromiter((gain_offset[sites[i, 1]] for i in range(len(times))), float)
+ else:
+ goff1 = np.fromiter((gain_offset for i in range(len(times))), float)
+ goff2 = np.fromiter((gain_offset for i in range(len(times))), float)
+
+ if type(gainp) is dict:
+ gain_mult_1 = np.fromiter((gainp[sites[i, 0]] for i in range(len(times))), float)
+ gain_mult_2 = np.fromiter((gainp[sites[i, 1]] for i in range(len(times))), float)
+ else:
+ gain_mult_1 = np.fromiter((gainp for i in range(len(times))), float)
+ gain_mult_2 = np.fromiter((gainp for i in range(len(times))), float)
+
+ gain1_constant = np.fromiter((goff1[i] * obsh.hashrandn(sites[i, 0], 'gain', str(goff1[i]), seed)
+ for i in range(len(times))), float)
+ gain2_constant = np.fromiter((goff2[i] * obsh.hashrandn(sites[i, 1], 'gain', str(goff2[i]), seed)
+ for i in range(len(times))), float)
+
+ if neggains:
+ gain1_constant = -np.abs(gain1_constant)
+ gain2_constant = -np.abs(gain2_constant)
+
+ if sigmat is None:
+ gain1_var = np.fromiter((gain_mult_1[i] * obsh.hashrandn(sites[i, 0], 'gain',
+ times_stable_amp[i],
+ str(gain_mult_1[i]), seed)
+ for i in range(len(times))), float)
+ gain2_var = np.fromiter((gain_mult_2[i] * obsh.hashrandn(sites[i, 1], 'gain',
+ times_stable_amp[i],
+ str(gain_mult_2[i]), seed)
+ for i in range(len(times))), float)
+ else:
+ raise Exception("correlated gains not supported in old add_noise! Use jones=True")
+
+ gain1 = np.abs((1.0 + gain1_constant)*(1.0 + gain1_var))
+ gain2 = np.abs((1.0 + gain2_constant)*(1.0 + gain2_var))
+ if neggains:
+ gain1 = np.exp(-np.abs(np.log(gain1)))
+ gain2 = np.exp(-np.abs(np.log(gain2)))
+
+ gain_true = np.sqrt(gain1 * gain2)
+ else:
+ gain_true = 1
+
+ if not opacitycal:
+
+ # Use estimated opacity to compute the ESTIMATED noise
+ tau_est = np.sqrt(np.exp(taus[:, 0]/(ehc.EP+np.sin(elevs[:, 0]*ehc.DEGREE)) +
+ taus[:, 1]/(ehc.EP+np.sin(elevs[:, 1]*ehc.DEGREE))))
+
+ # Opacity Errors
+ tau1 = np.abs(np.fromiter((taus[i, 0] * (1.0 + taup * obsh.hashrandn(sites[i, 0], 'tau', times_stable_amp[i], seed))
+ for i in range(len(times))), float))
+ tau2 = np.abs(np.fromiter((taus[i, 1] * (1.0 + taup * obsh.hashrandn(sites[i, 1], 'tau', times_stable_amp[i], seed))
+ for i in range(len(times))), float))
+
+ # Correct noise RMS for opacity
+ tau_true = np.sqrt(np.exp(tau1/(ehc.EP+np.sin(elevs[:, 0]*ehc.DEGREE)) +
+ tau2/(ehc.EP+np.sin(elevs[:, 1]*ehc.DEGREE))))
+ else:
+ tau_true = tau_est = 1
+
+ # Add the noise
+ sigma_true1 = sigma_perf1
+ sigma_true2 = sigma_perf2
+ sigma_true3 = sigma_perf3
+ sigma_true4 = sigma_perf4
+
+ sigma_est1 = sigma_perf1 * gain_true * tau_est
+ sigma_est2 = sigma_perf2 * gain_true * tau_est
+ sigma_est3 = sigma_perf3 * gain_true * tau_est
+ sigma_est4 = sigma_perf4 * gain_true * tau_est
+
+ if add_th_noise:
+ vis1 = (vis1 + th_noise_factor*obsh.cerror(sigma_true1))
+ vis2 = (vis2 + th_noise_factor*obsh.cerror(sigma_true2))
+ vis3 = (vis3 + th_noise_factor*obsh.cerror(sigma_true3))
+ vis4 = (vis4 + th_noise_factor*obsh.cerror(sigma_true4))
+
+ # Add the gain error to the true visibilities
+ vis1 = vis1 * gain_true * tau_est / tau_true
+ vis2 = vis2 * gain_true * tau_est / tau_true
+ vis3 = vis3 * gain_true * tau_est / tau_true
+ vis4 = vis4 * gain_true * tau_est / tau_true
+
+ # Add random atmospheric phases
+ if not phasecal:
+ phase1 = np.fromiter((2 * np.pi * obsh.hashrand(sites[i, 0], 'phase', times_stable_phase[i], seed)
+ for i in range(len(times))), float)
+ phase2 = np.fromiter((2 * np.pi * obsh.hashrand(sites[i, 1], 'phase', times_stable_phase[i], seed)
+ for i in range(len(times))), float)
+
+ vis1 *= np.exp(1j * (phase2-phase1))
+ vis2 *= np.exp(1j * (phase2-phase1))
+ vis3 *= np.exp(1j * (phase2-phase1))
+ vis4 *= np.exp(1j * (phase2-phase1))
+
+ # Put the visibilities estimated errors back in the obsdata array
+ obsdata[obs.poldict['vis1']] = vis1
+ obsdata[obs.poldict['vis2']] = vis2
+ obsdata[obs.poldict['vis3']] = vis3
+ obsdata[obs.poldict['vis4']] = vis4
+
+ obsdata[obs.poldict['sigma1']] = sigma_est1
+ obsdata[obs.poldict['sigma2']] = sigma_est2
+ obsdata[obs.poldict['sigma3']] = sigma_est3
+ obsdata[obs.poldict['sigma4']] = sigma_est4
+
+ # Return observation data
+ return obsdata
diff --git a/observing/pulses.py b/observing/pulses.py
new file mode 100644
index 00000000..35935b16
--- /dev/null
+++ b/observing/pulses.py
@@ -0,0 +1,197 @@
+# pulses.py
+# image restoring pulse functions
+#
+# Copyright (C) 2018 Katie Bouman
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+# If dom="I", we are in real space, if dom="F" we are in Fourier (uv) space
+# Coordinates in real space are in radian, coordinates in Fourier space are in lambda
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import math
+import numpy as np
+# import scipy.special as spec
+
+###################################################################################################
+# Delta Function Pulse
+###################################################################################################
+
+
+def deltaPulse2D(x, y, pdim, dom='F'):
+ if dom == 'I':
+ if x == y == 0.0:
+ return 1.0
+ else:
+ return 0.0
+ elif dom == 'F':
+ return 1.0
+
+###################################################################################################
+# Square Wave Pulse
+###################################################################################################
+
+
+def rectPulse2D(x, y, pdim, dom='F'):
+ if dom == 'I':
+ return rectPulse_I(x, pdim) * rectPulse_I(y, pdim)
+ elif dom == 'F':
+ return rectPulse_F(x, pdim) * rectPulse_F(y, pdim)
+
+
+def rectPulse_I(x, pdim):
+ if abs(x) >= pdim/2.0:
+ return 0.0
+ else:
+ return 1.0/pdim
+
+
+def rectPulse_F(omega, pdim):
+ if (omega == 0):
+ return 1.0
+ else:
+ return (2.0/(pdim*omega)) * math.sin((pdim*omega)/2.0)
+
+###################################################################################################
+# Triangle Wave Pulse
+###################################################################################################
+
+
+def trianglePulse2D(x, y, pdim, dom='F'):
+ if dom == 'I':
+ return trianglePulse_I(x, pdim) * trianglePulse_I(y, pdim)
+ elif dom == 'F':
+ return trianglePulse_F(x, pdim)*trianglePulse_F(y, pdim)
+
+
+def trianglePulse_I(x, pdim):
+ if abs(x) > pdim:
+ return 0.0
+ else:
+ return -(1.0/(pdim**2))*abs(x) + 1.0/pdim
+
+
+def trianglePulse_F(omega, pdim):
+ if (omega == 0):
+ return 1.0
+ else:
+ return (4.0/(pdim**2 * omega**2)) * (math.sin((pdim * omega)/2.0))**2
+
+###################################################################################################
+# Gaussian Pulse
+###################################################################################################
+
+
+def GaussPulse2D(x, y, pdim, dom='F'):
+ sigma = pdim/3. # Gaussian SD (sigma) vs pixelwidth (pdim)
+ a = 1./2./sigma/sigma
+ if dom == 'I':
+ return (a/np.pi)*np.exp(-a*(x**2 + y**2))
+
+ elif dom == 'F':
+ return np.exp(-(x**2 + y**2)/4./a)
+
+###################################################################################################
+# Cubic Pulse
+###################################################################################################
+
+
+def cubicPulse2D(x, y, pdim, dom='F'):
+ if dom == 'I':
+ return cubicPulse_I(x, pdim) * cubicPulse_I(y, pdim)
+
+ elif dom == 'F':
+ return cubicPulse_F(x, pdim)*cubicPulse_F(y, pdim)
+
+
+def cubicPulse_I(x, pdim):
+ if abs(x) < pdim:
+ return (1.5*(abs(x)/pdim)**3 - 2.5*(x/pdim)**2 + 1.)/pdim
+ elif abs(abs(x)-1.5*pdim) <= 0.5*pdim:
+ return (-0.5*(abs(x)/pdim)**3 + 2.5*(x/pdim)**2 - 4.*(abs(x)/pdim) + 2.)/pdim
+ else:
+ return 0.
+
+
+def cubicPulse_F(omega, pdim):
+ if (omega == 0):
+ return 1.0
+ else:
+ p1 = ((3./omega/pdim)*math.sin(omega*pdim/2.)-math.cos(omega*pdim/2.))
+ p2 = ((2./omega/pdim)*math.sin(omega*pdim/2.))**3
+ return 2.*p1*p2
+
+
+###################################################################################################
+# Sinc Pulse
+###################################################################################################
+
+def sincPulse2D(x, y, pdim, dom='F'):
+ if dom == 'I':
+ return sincPulse_I(x, pdim) * sincPulse_I(y, pdim)
+
+ elif dom == 'F':
+ return sincPulse_F(x, pdim) * sincPulse_F(y, pdim)
+
+
+def sincPulse_I(x, pdim):
+ if (x == 0):
+ return 1./pdim
+ else:
+ return (1./pdim)*math.sin(np.pi*x/pdim)/(np.pi*x/pdim)
+
+
+def sincPulse_F(omega, pdim):
+ if (abs(omega) < np.pi/pdim):
+ return 1.0
+ else:
+ return 0.
+
+###################################################################################################
+# Circular Disk Pulse
+###################################################################################################
+
+# def circPulse2D(x, y, pdim, dom='F'):
+# rm = 0.5*pdim #max radius of the disk
+# if dom=='I':
+# if x**2 + y**2 <= rm**2:
+# return 1./np.pi/rm**2
+# else: return 0.
+# elif dom=='F':
+# return 2.*spec.j1(rm*np.sqrt(x**2 + y**2))/np.sqrt(x**2 + y**2)/rm**2
+
+###################################################################################################
+# Cubic Spline Pulse
+###################################################################################################
+
+# def cubicsplinePulse2D_F(omegaX, omegaY, pdim):
+# return cubicsplinePulse(omegaX, pdim)*cubicsplinePulse(omegaY,pdim)
+#
+# def cubicsplinePulse_F(omega, delta):
+# if (omega == 0):
+# coeff = delta
+# else:
+# omega_delta = omega*delta
+#
+# coeff = delta * (
+# (4.0/omega_delta**3)*math.sin(omega_delta)*(2.0*math.cos(omega_delta) + 1.0) +
+# (24.0/omega_delta**4)*math.cos(omega_delta)*(math.cos(omega_delta) - 1.0) )
+#
+# return coeff / pdim # TODO : CHECK IF YOU DIVIDE BY PDIM FOR CLUBIC SPLINE PULSE
diff --git a/parloop.py b/parloop.py
new file mode 100644
index 00000000..af1a638f
--- /dev/null
+++ b/parloop.py
@@ -0,0 +1,156 @@
+# parloop.py
+# Wraps up some helper functions for parallel loops
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import sys
+import os
+
+from multiprocessing import cpu_count
+from multiprocessing import Pool, Value, Lock
+
+TIMEOUT = 31536000
+
+
+class Parloop(object):
+
+ """A simple parallel loop with a counter
+ """
+
+ def __init__(self, func):
+ """Initialize the loop
+ """
+
+ self.func = func
+
+ def run_loop(self, arglist, processes=-1):
+ """Run the loop on the list of arguments with multiple processes
+ """
+
+ n = len(arglist)
+
+ if not type(arglist[0]) is list:
+ arglist = [[arg] for arg in arglist]
+
+ if processes > 0:
+ print("Set up loop with %d Processes" % processes)
+ elif processes == 0: # maximum number of processes -- different argument?
+ processes = int(cpu_count())
+ print("Set up loop with all available (%d) Processes" % processes)
+ else:
+ print("Set up loop with no multiprocessing")
+
+ out = -1
+ if processes > 0: # run on multiple cores with multiprocessing
+ counter = Counter(initval=0, maxval=n)
+ pool = Pool(processes=processes, initializer=self._initcount, initargs=(counter,))
+ try:
+ print('Running the loop')
+ self.prog_msg(0, n, 0)
+ out = pool.map_async(self, arglist).get(TIMEOUT)
+ pool.close()
+ except KeyboardInterrupt:
+ print('\ngot ^C while pool mapping, terminating')
+ pool.terminate()
+ print('pool terminated')
+ except Exception as e:
+ print('\ngot exception: %r, terminating' % (e,))
+ pool.terminate()
+ print('pool terminated')
+ finally:
+ pool.join()
+
+ else: # run on a single core
+ out = []
+ for i in range(n):
+ self.prog_msg(i, n, i-1)
+ args = arglist[i]
+ out.append(self.func(*args))
+
+ return out
+
+ def _initcount(self, x):
+ """Initialize the counter
+ """
+ global counter
+ counter = x
+
+ def __call__(self, args):
+ """Call the loop function
+ """
+
+ try:
+ outval = self.func(*args)
+ counter.increment()
+ self.prog_msg(counter.value(), counter.maxval, counter.value()-1)
+ return outval
+ except KeyboardInterrupt:
+ raise KeyboardInterruptError()
+
+ def prog_msg(self, i, n, i_last=0):
+ """Print a progress bar
+ """
+
+ # complete_percent_last = int(100*float(i_last)/float(n))
+ complete_percent = int(100*float(i)/float(n))
+ ndigit = str(len(str(n)))
+
+ bar_width = 30
+ progress = int(bar_width * complete_percent/float(100))
+ barparams = (i, n, ("-"*progress) + (" " * (bar_width-progress)), complete_percent)
+
+ printstr = "\rProcessed %0"+ndigit+"i/%i : [%s]%i%%"
+ sys.stdout.write(printstr % barparams)
+ sys.stdout.flush()
+
+
+class Counter(object):
+ """Counter object for sharing among multiprocessing jobs
+ """
+
+ def __init__(self, initval=0, maxval=0):
+ self.val = Value('i', initval)
+ self.maxval = maxval
+ self.lock = Lock()
+
+ def increment(self):
+ with self.lock:
+ self.val.value += 1
+
+ def value(self):
+ with self.lock:
+ return self.val.value
+
+
+class HiddenPrints:
+ """Suppresses printing from the loop function
+ """
+
+ def __enter__(self):
+ self._original_stdout = sys.stdout
+ sys.stdout = open(os.devnull, 'w')
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ sys.stdout.close()
+ sys.stdout = self._original_stdout
diff --git a/plotting/__init__.py b/plotting/__init__.py
new file mode 100644
index 00000000..c63279a7
--- /dev/null
+++ b/plotting/__init__.py
@@ -0,0 +1,13 @@
+"""
+.. module:: ehtim.plotting
+ :platform: Unix
+ :synopsis: EHT Imaging Utilities: plotting functions
+
+.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu)
+
+"""
+from . import comp_plots
+from . import comparisons
+from . import summary_plots
+
+from ..const_def import *
diff --git a/plotting/comp_plots.py b/plotting/comp_plots.py
new file mode 100644
index 00000000..9862a44b
--- /dev/null
+++ b/plotting/comp_plots.py
@@ -0,0 +1,914 @@
+# comp_plots.py
+# Make data plots with multiple observations,images etc.
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import numpy.matlib as matlib
+import matplotlib.pyplot as plt
+import itertools as it
+import copy
+
+from ehtim.obsdata import merge_obs
+import ehtim.const_def as ehc
+
+COLORLIST = ehc.SCOLORS
+
+##################################################################################################
+# Plotters
+##################################################################################################
+
+
+def plotall_compare(obslist, imlist, field1, field2,
+ conj=False, debias=False, sgrscat=False,
+ ang_unit='deg', timetype='UTC', ttype='nfft',
+ axis=False, rangex=False, rangey=False, snrcut=0.,
+ clist=COLORLIST, legendlabels=None, markersize=ehc.MARKERSIZE,
+ export_pdf="", grid=False, ebar=True,
+ axislabels=True, legend=True, show=True):
+ """Plot data from observations compared to ground truth from an image on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ imlist (list): list of ground truth images to compare to
+ field1 (str): x-axis field (from FIELDS)
+ field2 (str): y-axis field (from FIELDS)
+
+ conj (bool): Plot conjuage baseline data points if True
+ debias (bool): If True, debias amplitudes.
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel
+
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ clist (list): list of colors scatterplot points
+ legendlabels (list): list of labels of the same length of obslist or imlist
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+ """
+ prepdata = prep_plot_lists(obslist, imlist, clist=clist, legendlabels=legendlabels,
+ sgrscat=sgrscat, ttype=ttype)
+ (obslist_plot, clist_plot, legendlabels_plot, markers) = prepdata
+ for i in range(len(obslist_plot)):
+ obs = obslist_plot[i]
+ axis = obs.plotall(field1, field2,
+ conj=conj, debias=debias,
+ ang_unit=ang_unit, timetype=timetype,
+ axis=axis, rangex=rangex, rangey=rangey,
+ grid=grid, ebar=ebar, axislabels=axislabels,
+ show=False, tag_bl=False, legend=False, snrcut=snrcut,
+ label=legendlabels_plot[i], color=clist_plot[i % len(clist_plot)],
+ marker=markers[i], markersize=markersize)
+
+ if legend:
+ plt.legend()
+ if grid:
+ axis.grid()
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ if export_pdf != "":
+ plt.savefig(export_pdf, bbox_inches='tight', pad_inches=0)
+
+ return axis
+
+
+def plot_bl_compare(obslist, imlist, site1, site2, field,
+ debias=False, sgrscat=False,
+ ang_unit='deg', timetype='UTC', ttype='nfft',
+ axis=False, rangex=False, rangey=False, snrcut=0.,
+ clist=COLORLIST, legendlabels=None, markersize=ehc.MARKERSIZE,
+ export_pdf="", grid=False, ebar=True,
+ axislabels=True, legend=True, show=True):
+ """Plot data from multiple observations vs time on a single baseline on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ imlist (list): list of ground truth images to compare to
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ field (str): y-axis field (from FIELDS)
+
+ debias (bool): If True, debias amplitudes.
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel
+
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ snrcut (float): a snr cutoff
+
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ clist (list): list of colors scatterplot points
+ legendlabels (list): list of labels of the same length of obslist or imlist
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+
+ """
+ prepdata = prep_plot_lists(obslist, imlist, clist=clist, legendlabels=legendlabels,
+ sgrscat=sgrscat, ttype=ttype)
+ (obslist_plot, clist_plot, legendlabels_plot, markers) = prepdata
+ for i in range(len(obslist_plot)):
+ obs = obslist_plot[i]
+ axis = obs.plot_bl(site1, site2, field,
+ debias=debias, ang_unit=ang_unit, timetype=timetype,
+ axis=axis, rangex=rangex, rangey=rangey,
+ grid=grid, ebar=ebar, axislabels=axislabels,
+ show=False, legend=False, snrcut=snrcut,
+ label=legendlabels_plot[i], color=clist_plot[i % len(clist_plot)],
+ marker=markers[i], markersize=markersize)
+ if legend:
+ plt.legend()
+ if grid:
+ axis.grid()
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ if export_pdf != "":
+ plt.savefig(export_pdf, bbox_inches='tight', pad_inches=0)
+
+ return axis
+
+
+def plot_cphase_compare(obslist, imlist, site1, site2, site3,
+ vtype='vis', cphases=[], force_recompute=False,
+ ang_unit='deg', timetype='UTC', ttype='nfft',
+ axis=False, rangex=False, rangey=False, snrcut=0.,
+ clist=COLORLIST, legendlabels=None, markersize=ehc.MARKERSIZE,
+ export_pdf="", grid=False, ebar=True,
+ axislabels=True, legend=True, show=True):
+ """Plot closure phase on a triangle compared to ground truth from an image on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ imlist (list): list of ground truth images to compare to
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+
+ vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis')
+ cphases (list): optionally pass in a list of cphases so they don't have to be recomputed
+ force_recompute (bool): if True, recompute closure phases instead of using stored data
+
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ snrcut (float): a snr cutoff
+
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ clist (list): list of colors scatterplot points
+ legendlabels (list): list of labels of the same length of obslist or imlist
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+
+ """
+ try:
+ len(obslist)
+ except TypeError:
+ obslist = [obslist]
+
+ if len(cphases) == 0:
+ cphases = matlib.repmat([], len(obslist), 1)
+
+ if len(cphases) != len(obslist):
+ raise Exception("cphases list must be same length as obslist!")
+
+ cphases_back = []
+ for i in range(len(obslist)):
+ cphases_back.append(obslist[i].cphase)
+ obslist[i].cphase = cphases[i]
+
+ prepdata = prep_plot_lists(obslist, imlist, clist=clist, legendlabels=legendlabels,
+ sgrscat=False, ttype=ttype)
+ (obslist_plot, clist_plot, legendlabels_plot, markers) = prepdata
+ for i in range(len(obslist_plot)):
+ obs = obslist_plot[i]
+ axis = obs.plot_cphase(site1, site2, site3,
+ vtype=vtype, force_recompute=force_recompute,
+ ang_unit=ang_unit, timetype=timetype,
+ axis=axis, rangex=rangex, rangey=rangey,
+ grid=grid, ebar=ebar, axislabels=axislabels,
+ show=False, legend=False, snrcut=snrcut,
+ label=legendlabels_plot[i], color=clist_plot[i % len(clist_plot)],
+ marker=markers[i], markersize=markersize)
+
+ # return to original cphase attribute
+ for i in range(len(obslist)):
+ obslist[i].cphase = cphases_back[i]
+
+ if legend:
+ plt.legend()
+ if grid:
+ axis.grid()
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+ if export_pdf != "":
+ plt.savefig(export_pdf, bbox_inches='tight', pad_inches=0)
+
+ return axis
+
+
+def plot_camp_compare(obslist, imlist, site1, site2, site3, site4,
+ vtype='vis', ctype='camp', camps=[], force_recompute=False,
+ debias=False, sgrscat=False, timetype='UTC', ttype='nfft',
+ axis=False, rangex=False, rangey=False, snrcut=0.,
+ clist=COLORLIST, legendlabels=None, markersize=ehc.MARKERSIZE,
+ export_pdf="", grid=False, ebar=True,
+ axislabels=True, legend=True, show=True):
+ """Plot closure amplitude on a triangle vs time from multiple observations on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ imlist (list): list of ground truth images to compare to
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+ site4 (str): station 4 name
+
+ vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis')
+ ctype (str): The closure amplitude type ('camp' or 'logcamp')
+ camps (list): optionally pass in a list of camp so they don't have to be recomputed
+ force_recompute (bool): if True, recompute closure phases instead of using stored data
+
+ debias (bool): If True, debias amplitudes.
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel
+
+ timetype (str): 'GMST' or 'UTC'
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ snrcut (float): a snr cutoff
+
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ clist (list): list of colors scatterplot points
+ legendlabels (list): list of labels of the same length of obslist or imlist
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+
+ """
+
+ try:
+ len(obslist)
+ except TypeError:
+ obslist = [obslist]
+
+ if len(camps) == 0:
+ camps = matlib.repmat([], len(obslist), 1)
+
+ if len(camps) != len(obslist):
+ raise Exception("camps list must be same length as obslist!")
+
+ camps_back = []
+ for i in range(len(obslist)):
+ if ctype == 'camp':
+ camps_back.append(obslist[i].camp)
+ obslist[i].camp = camps[i]
+ elif ctype == 'logcamp':
+ camps_back.append(obslist[i].logcamp)
+ obslist[i].logcamp = camps[i]
+
+ prepdata = prep_plot_lists(obslist, imlist, clist=clist, legendlabels=legendlabels,
+ sgrscat=sgrscat, ttype=ttype)
+ (obslist_plot, clist_plot, legendlabels_plot, markers) = prepdata
+
+ for i in range(len(obslist_plot)):
+ obs = obslist_plot[i]
+ axis = obs.plot_camp(site1, site2, site3, site4,
+ vtype=vtype, ctype=ctype, force_recompute=force_recompute,
+ debias=debias, timetype=timetype,
+ axis=axis, rangex=rangex, rangey=rangey,
+ grid=grid, ebar=ebar, axislabels=axislabels,
+ show=False, legend=False, snrcut=0.,
+ label=legendlabels_plot[i], color=clist_plot[i % len(clist_plot)],
+ marker=markers[i], markersize=markersize)
+
+ for i in range(len(obslist)):
+ if ctype == 'camp':
+ obslist[i].camp = camps_back[i]
+ elif ctype == 'logcamp':
+ obslist[i].logcamp = camps_back[i]
+
+ if legend:
+ plt.legend()
+ if grid:
+ axis.grid()
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+ if export_pdf != "":
+ plt.savefig(export_pdf, bbox_inches='tight', pad_inches=0)
+
+ return axis
+
+
+##################################################################################################
+# Aliases
+##################################################################################################
+def plotall_obs_compare(obslist, field1, field2, **kwargs):
+ """Plot data from observations compared to ground truth from an image on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ field1 (str): x-axis field (from FIELDS)
+ field2 (str): y-axis field (from FIELDS)
+
+ conj (bool): Plot conjuage baseline data points if True
+ debias (bool): If True, debias amplitudes.
+ ang_unit (str): phase unit 'deg' or 'rad'
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ legendlabels (str): should be a list of labels of the same length of obslist or imlist
+ snrcut (float): a snr cutoff
+
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+ axis (matplotlib.axes.Axes): add plot to this axis
+
+ clist (list): list of colors scatterplot points
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+ """
+ axis = plotall_compare(obslist, [], field1, field2, **kwargs)
+ return axis
+
+
+def plotall_obs_im_compare(obslist, imlist, field1, field2, **kwargs):
+ """Plot data from observations compared to ground truth from an image on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ imlist (list): list of images to plot
+ field1 (str): x-axis field (from FIELDS)
+ field2 (str): y-axis field (from FIELDS)
+
+ conj (bool): Plot conjuage baseline data points if True
+ debias (bool): If True, debias amplitudes.
+ ang_unit (str): phase unit 'deg' or 'rad'
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ legendlabels (str): should be a list of labels of the same length of obslist
+ snrcut (float): a snr cutoff
+
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+ axis (matplotlib.axes.Axes): add plot to this axis
+
+ clist (list): list of colors scatterplot points
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+ """
+ axis = plotall_compare(obslist, imlist, field1, field2, **kwargs)
+ return axis
+
+
+def plot_bl_obs_compare(obslist, site1, site2, field, **kwargs):
+ """Plot data from multiple observations vs time on a single baseline on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ field (str): y-axis field (from FIELDS)
+
+ debias (bool): If True and plotting vis amplitudes, debias them
+ axislabels (bool): Show axis labels if True
+ legendlabels (str): should be a list of labels of the same length of obslist or imlist
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel
+ snrcut (float): a snr cutoff
+
+ rangex (list): [xmin, xmax] x-axis (time) limits
+ rangey (list): [ymin, ymax] y-axis limits
+
+ legend (bool): Show legend if True
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ show (bool): Display the plot if true
+ axis (matplotlib.axes.Axes): add plot to this axis
+ clist (list): list of color strings of scatterplot points
+ export_pdf (str): path to pdf file to save figure
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+
+ """
+ axis = plot_bl_compare(obslist, [], site1, site2, field, **kwargs)
+ return axis
+
+
+def plot_bl_obs_im_compare(obslist, imlist, site1, site2, field, **kwargs):
+ """Plot data from multiple observations vs time on a single baseline on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ imlist (list): list of ground truth images to compare to
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ field (str): y-axis field (from FIELDS)
+
+ debias (bool): If True and plotting vis amplitudes, debias them
+ axislabels (bool): Show axis labels if True
+ legendlabels (str): should be a list of labels of the same length of obslist or imlist
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel
+ snrcut (float): a snr cutoff
+
+ rangex (list): [xmin, xmax] x-axis (time) limits
+ rangey (list): [ymin, ymax] y-axis limits
+
+ legend (bool): Show legend if True
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ show (bool): Display the plot if true
+ axis (matplotlib.axes.Axes): add plot to this axis
+ clist (list): list of color strings of scatterplot points
+ export_pdf (str): path to pdf file to save figure
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+
+ """
+ axis = plot_bl_compare(obslist, imlist, site1, site2, field, **kwargs)
+ return axis
+
+
+def plot_cphase_obs_compare(obslist, site1, site2, site3, **kwargs):
+ """Plot closure phase on a triangle vs time from multiple observations on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+
+ vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis')
+ cphases (list): optionally pass in a list of cphases so they don't have to be recomputed
+ force_recompute (bool): if True, recompute closure phases instead of using stored data
+
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ snrcut (float): a snr cutoff
+
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ clist (list): list of colors scatterplot points
+ legendlabels (list): list of labels of the same length of obslist or imlist
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+
+ """
+ axis = plot_cphase_compare(obslist, [], site1, site2, site3, **kwargs)
+ return axis
+
+
+def plot_cphase_obs_im_compare(obslist, imlist, site1, site2, site3, **kwargs):
+ """Plot closure phase on a triangle vs time from multiple observations on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ imlist (list): list of ground truth images to compare to
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+
+ vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis')
+ cphases (list): optionally pass in a list of cphases so they don't have to be recomputed
+ force_recompute (bool): if True, recompute closure phases instead of using stored data
+
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ snrcut (float): a snr cutoff
+
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ clist (list): list of colors scatterplot points
+ legendlabels (list): list of labels of the same length of obslist or imlist
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+
+ """
+ axis = plot_cphase_compare(obslist, imlist, site1, site2, site3, **kwargs)
+ return axis
+
+
+def plot_camp_obs_compare(obslist, site1, site2, site3, site4, **kwargs):
+ """Plot closure amplitude on a triangle vs time from multiple observations on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+ site4 (str): station 4 name
+
+ vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis')
+ ctype (str): The closure amplitude type ('camp' or 'logcamp')
+ camps (list): optionally pass in a list of camps so they don't have to be recomputed
+ force_recompute (bool): if True, recompute closure phases instead of using stored data
+
+ debias (bool): If True, debias amplitudes.
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel
+
+ timetype (str): 'GMST' or 'UTC'
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ snrcut (float): a snr cutoff
+
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ clist (list): list of colors scatterplot points
+ legendlabels (list): list of labels of the same length of obslist or imlist
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+
+ """
+ axis = plot_camp_compare(obslist, [], site1, site2, site3, site4, **kwargs)
+ return axis
+
+
+def plot_camp_obs_im_compare(obslist, imlist, site1, site2, site3, site4, **kwargs):
+ """Plot closure amplitude on a triangle vs time from multiple observations on the same axes.
+
+ Args:
+ obslist (list): list of observations to plot
+ image (Image): ground truth image to compare to
+ site1 (str): station 1 name
+ site2 (str): station 2 name
+ site3 (str): station 3 name
+ site4 (str): station 4 name
+
+ vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis')
+ ctype (str): The closure amplitude type ('camp' or 'logcamp')
+ camps (list): optionally pass in a list of camps so they don't have to be recomputed
+ force_recompute (bool): if True, recompute closure phases instead of using stored data
+
+ debias (bool): If True, debias amplitudes.
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel
+
+ timetype (str): 'GMST' or 'UTC'
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ snrcut (float): a snr cutoff
+
+ axis (matplotlib.axes.Axes): add plot to this axis
+ rangex (list): [xmin, xmax] x-axis limits
+ rangey (list): [ymin, ymax] y-axis limits
+ clist (list): list of colors scatterplot points
+ legendlabels (list): list of labels of the same length of obslist or imlist
+ markersize (int): size of plot markers
+ export_pdf (str): path to pdf file to save figure
+ grid (bool): Plot gridlines if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ legend (bool): Show legend if True
+ show (bool): Display the plot if true
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+
+ """
+ axis = plot_camp_compare(obslist, imlist, site1, site2, site3, site4, **kwargs)
+ return axis
+
+##################################################################################################
+# Plotters: Compare Observations to Image
+##################################################################################################
+
+
+def plotall_obs_im_cphases(obs, imlist,
+ vtype='vis', ang_unit='deg', timetype='UTC',
+ ttype='nfft', sgrscat=False,
+ rangex=False, rangey=[-180, 180], legend=False, legendlabels=None,
+ show=True, ebar=True, axislabels=False, print_chisqs=True,
+ display_mode='all'):
+ """Plot all observation closure phases on top of image ground truth values.
+ Works with ONE obs and MULTIPLE images.
+
+ Args:
+ obs (Obsdata): observation to plot
+ imlist (list): list of ground truth images to compare to
+
+ vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis')
+ ang_unit (str): phase unit 'deg' or 'rad'
+ timetype (str): 'GMST' or 'UTC'
+ ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT
+ sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel
+
+ rangex (list): [xmin, xmax] x-axis (time) limits
+ rangey (list): [ymin, ymax] y-axis (phase) limits
+
+ show (bool): Display the plot if True
+ ebar (bool): Plot error bars if True
+ axislabels (bool): Show axis labels if True
+ print_chisqs (bool): print individual chisqs if True
+
+ display_mode (str): 'all' or 'individual'
+
+ Returns:
+ (matplotlib.axes.Axes): Axes object with data plot
+
+ """
+
+ try:
+ len(imlist)
+ except TypeError:
+ imlist = [imlist]
+
+ # get closure triangle combinations
+ sites = []
+ for i in range(0, len(obs.tarr)):
+ sites.append(obs.tarr[i][0])
+ uniqueclosure_tri = list(it.combinations(sites, 3))
+
+ # generate data
+ cphases_obs = obs.c_phases(mode='all', count='max', vtype=vtype)
+ obs_all = [obs]
+ cphases_all = [cphases_obs]
+ for image in imlist:
+ obs_model = image.observe_same(obs, sgrscat=sgrscat, add_th_noise=False, ttype=ttype)
+ cphases_model = obs_model.c_phases(mode='all', count='max', vtype=vtype)
+ obs_all.append(obs_model)
+ cphases_all.append(cphases_model)
+
+ # display as individual plots or as a huge sheet
+ if display_mode == 'individual':
+ show = True
+ else:
+ nplots = len(uniqueclosure_tri)
+ ncols = 4
+ nrows = np.ceil(nplots / ncols).astype(int)
+ show = False
+ fig = plt.figure(figsize=(nrows*20, 40))
+
+ # plot closure phases
+ print("\n")
+
+ nplot = 0
+ for c in range(0, len(uniqueclosure_tri)):
+ cphases_obs_tri = obs.cphase_tri(uniqueclosure_tri[c][0],
+ uniqueclosure_tri[cs][1],
+ uniqueclosure_tri[c][2],
+ vtype=vtype, ang_unit='deg', cphases=cphases_obs)
+
+ if len(cphases_obs_tri) > 0:
+ if print_chisqs:
+ printstr = "%s %s %s :" % (
+ uniqueclosure_tri[c][0], uniqueclosure_tri[c][1], uniqueclosure_tri[c][2])
+ for i in range(1, len(obs_all)):
+ cphases_model_tri = obs_all[i].cphase_tri(uniqueclosure_tri[c][0],
+ uniqueclosure_tri[c][1],
+ uniqueclosure_tri[c][2],
+ vtype=vtype, ang_unit='deg',
+ cphases=cphases_all[i])
+ resids = (cphases_obs_tri['cphase'] - cphases_model_tri['cphase'])*ehc.DEGREE
+ chisq_tri = np.sum((1.0 - np.cos(resids)) /
+ ((cphases_obs_tri['sigmacp']*ehc.DEGREE)**2))
+ chisq_tri *= (2.0/len(cphases_obs_tri))
+ printstr += " chisq%i: %0.2f" % (i, chisq_tri)
+ print(printstr)
+
+ if display_mode == 'individual':
+ ax = False
+
+ else:
+ ax = plt.subplot2grid((nrows, ncols), (nplot//ncols, nplot % ncols), fig=fig)
+ axislabels = False
+
+ f = plot_cphase_obs_compare(obs_all,
+ uniqueclosure_tri[c][0],
+ uniqueclosure_tri[c][1],
+ uniqueclosure_tri[c][2],
+ vtype=vtype, rangex=rangex, rangey=rangey, ebar=ebar,
+ show=show, legend=legend, legendlabels=legendlabels,
+ cphases=cphases_all, axis=ax, axislabels=axislabels)
+ nplot += 1
+
+ if display_mode != 'individual':
+ plt.ion()
+ f = fig
+ f.subplots_adjust(wspace=0.1, hspace=0.5)
+ f.show()
+
+ return f
+
+##################################################################################################
+# Misc
+##################################################################################################
+
+
+def prep_plot_lists(obslist, imlist, clist=ehc.SCOLORS, legendlabels=None,
+ sgrscat=False, ttype='nfft'):
+ """Return observation, color, marker, legend lists for comp plots"""
+
+ if imlist is None or imlist is False:
+ imlist = []
+
+ try:
+ len(obslist)
+ except TypeError:
+ obslist = [obslist]
+
+ try:
+ len(imlist)
+ except TypeError:
+ imlist = [imlist]
+
+ if not((len(imlist) == len(obslist)) or len(imlist) <= 1 or len(obslist) <= 1):
+ raise Exception("imlist and obslist must be the same length, or either must have length 1")
+
+ if not (legendlabels is None) and (len(legendlabels) != max(len(imlist), len(obslist))):
+ raise Exception("legendlabels should be the same length of the longer of imlist, obslist!")
+
+ if legendlabels is None:
+ legendlabels = [str(i+1) for i in range(max(len(imlist), len(obslist)))]
+
+ obslist_plot = []
+ clist_plot = copy.copy(clist)
+ legendlabels_plot = copy.copy(legendlabels)
+
+ # one image, multiple observations
+ if len(imlist) == 0:
+ markers = []
+ for i in range(len(obslist)):
+ obslist_plot.append(obslist[i])
+ markers.append('o')
+
+ elif len(imlist) == 1 and len(obslist) > 1:
+ obslist_true = []
+ markers = ['s']
+ clist_plot = ['k']
+ for i in range(len(obslist)):
+ obslist_plot.append(obslist[i])
+ obstrue = imlist[0].observe_same(
+ obslist[i], sgrscat=sgrscat, add_th_noise=False, ttype=ttype)
+ for sigma_type in obstrue.data.dtype.names[-4:]:
+ obstrue.data[sigma_type] *= 0
+ obslist_true.append(obstrue)
+ markers.append('o')
+ clist_plot.append(clist[i])
+
+ obstrue = merge_obs(obslist_true)
+ obslist_plot.insert(0, obstrue)
+ legendlabels_plot.insert(0, 'Image')
+
+ # one observation, multiple images
+ elif len(obslist) == 1 and len(imlist) > 1:
+ obslist_plot.append(obslist[0])
+ markers = ['o']
+ for i in range(len(imlist)):
+ obstrue = imlist[i].observe_same(
+ obslist[0], sgrscat=sgrscat, add_th_noise=False, ttype=ttype)
+ for sigma_type in obstrue.data.dtype.names[-4:]:
+ obstrue.data[sigma_type] *= 0
+ obslist_plot.append(obstrue)
+ markers.append('s')
+
+ clist_plot.insert(0, 'k')
+ legendlabels_plot.insert(0, 'Observation')
+
+ # same number of images and observations
+ elif len(obslist) == 1 and len(imlist) == 1:
+ obslist_plot.append(obslist[0])
+
+ obstrue = imlist[0].observe_same(
+ obslist[0], sgrscat=sgrscat, add_th_noise=False, ttype=ttype)
+ for sigma_type in obstrue.data.dtype.names[-4:]:
+ obstrue.data[sigma_type] *= 0
+ obslist_plot.append(obstrue)
+
+ markers = ['o', 's']
+ clist_plot = ['k', clist[0]]
+ legendlabels_plot = [legendlabels[0]+'_obs', legendlabels[0]+'_im']
+
+ else:
+ markers = []
+ legendlabels_plot = []
+ clist_plot = []
+ for i in range(len(obslist)):
+ obstrue = imlist[i].observe_same(
+ obslist[i], sgrscat=sgrscat, add_th_noise=False, ttype=ttype)
+ for sigma_type in obstrue.data.dtype.names[-4:]:
+ obstrue.data[sigma_type] *= 0
+ obslist_plot.append(obstrue)
+ clist_plot.append(clist[i])
+ legendlabels_plot.append(legendlabels[i]+'_im')
+ markers.append('s')
+
+ obslist_plot.append(obslist[i])
+ clist_plot.append(clist[i])
+ legendlabels_plot.append(legendlabels[i]+'_obs')
+ markers.append('o')
+
+ if len(obslist_plot) > len(clist):
+ Exception("More observations than colors -- Add more colors to clist!")
+
+ return (obslist_plot, clist_plot, legendlabels_plot, markers)
diff --git a/plotting/comparisons.py b/plotting/comparisons.py
new file mode 100644
index 00000000..38c055ff
--- /dev/null
+++ b/plotting/comparisons.py
@@ -0,0 +1,238 @@
+# comparisons.py
+# Image Consistency Comparisons
+#
+# Copyright (C) 2018 Katie Bouman
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import sys
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.offsetbox import (OffsetImage, AnnotationBbox)
+from itertools import cycle
+
+try:
+ import networkx as nx
+except ImportError:
+ print("Warning: networkx not installed! Cannot use image_agreements()")
+
+
+import ehtim.const_def as ehc
+
+
+def image_consistency(imarr, beamparams, metric='nxcorr',
+ blursmall=True, beam_max=1.0, beam_steps=5, savepath=[]):
+
+ # get the pixel sizes and fov to compare images at
+ (min_psize, max_fov) = get_psize_fov(imarr)
+
+ # initialize matrix matrix
+ metric_mtx = np.zeros([len(imarr), len(imarr), beam_steps])
+
+ # get the different fracsteps
+ fracsteps = np.linspace(0, beam_max, beam_steps)
+
+ # loop over the different beam sizes
+ for fracidx in range(beam_steps):
+ # print(fracidx)
+ # look at every pair of images and compute their beam convolved metrics
+ for i in range(len(imarr)):
+ img1 = imarr[i]
+ if fracsteps[fracidx] > 0:
+ img1 = img1.blur_gauss(beamparams, fracsteps[fracidx])
+
+ for j in range(i+1, len(imarr)):
+ img2 = imarr[j]
+ if fracsteps[fracidx] > 0:
+ img2 = img2.blur_gauss(beamparams, fracsteps[fracidx])
+
+ # compute image comparision under a specified blur_frac
+ (error, im1_pad, im2_shift) = img1.compare_images(img2, metric=[metric],
+ psize=min_psize,
+ target_fov=max_fov,
+ blur_frac=0.0,
+ beamparams=beamparams)
+
+ # if specified save the shifted images used for comparision
+ if savepath:
+ im1_pad.save_fits(savepath + '/' + str(i) + '_' + str(fracidx) + '.fits')
+ im2_shift.save_fits(savepath + '/' + str(j) + '_' + str(fracidx) + '.fits')
+
+ # save the metric value in a matrix
+ metric_mtx[i, j, fracidx] = error[0]
+
+ return (metric_mtx, fracsteps)
+
+
+def get_psize_fov(imarr):
+ """Look over an array of images and determine the min pixel size and max fov
+ that can be used consistently across them
+ """
+
+ min_psize = 100
+ for i in range(0, len(imarr)):
+ if i == 0:
+ max_fov = np.max([imarr[i].psize*imarr[i].xdim, imarr[i].psize*imarr[i].ydim])
+ min_psize = imarr[i].psize
+ else:
+ max_fov = np.max([max_fov, imarr[i].psize*imarr[i].xdim, imarr[i].psize*imarr[i].ydim])
+ min_psize = np.min([min_psize, imarr[i].psize])
+ return (min_psize, max_fov)
+
+
+def image_agreements(imarr, beamparams, metric_mtx, fracsteps, cutoff=0.95):
+
+ if 'networkx' not in sys.modules:
+ raise Exception("networkx not installed: cannot use image_agreements()!")
+
+ (min_psize, max_fov) = get_psize_fov(imarr)
+
+ im_cliques_fraclevels = []
+ cliques_fraclevels = []
+ for fracidx in range(len(fracsteps)):
+ # print(fracidx)
+
+ slice_metric_mtx = metric_mtx[:, :, fracidx]
+ cuttoffidx = np.where(slice_metric_mtx >= cutoff)
+ consistant = zip(*cuttoffidx)
+
+ # make graph
+ G = nx.Graph()
+ for i in range(len(consistant)):
+ G.add_edge(consistant[i][0], consistant[i][1])
+
+ # find all cliques
+ cliques = list(nx.find_cliques(G))
+ # print(cliques)
+
+ cliques_fraclevels.append(cliques)
+
+ im_clique = []
+ for c in range(len(cliques)):
+ clique = cliques[c]
+ im_avg = imarr[clique[0]].blur_gauss(beamparams, fracsteps[fracidx])
+
+ for n in range(1, len(clique)):
+ imcomp = imarr[clique[n]].blur_gauss(beamparams, fracsteps[fracidx])
+ (error, im_avg, im2_shift) = im_avg.compare_images(imcomp, metric=['xcorr'],
+ psize=min_psize,
+ target_fov=max_fov,
+ blur_frac=0.0,
+ beamparams=beamparams)
+ im_avg.imvec = (im_avg.imvec + im2_shift.imvec) / 2.0
+
+ im_clique.append(im_avg.copy())
+
+ im_cliques_fraclevels.append(im_clique)
+
+ return(cliques_fraclevels, im_cliques_fraclevels)
+
+
+def change_cut_off(metric_mtx, fracsteps, imarr, beamparams, cutoff=0.95, zoom=0.1, fov=1):
+ (cliques_fraclevels, im_cliques_fraclevels) = image_agreements(
+ imarr, beamparams, metric_mtx, fracsteps, cutoff=cutoff)
+ generate_consistency_plot(cliques_fraclevels, im_cliques_fraclevels, metric_mtx=metric_mtx,
+ fracsteps=fracsteps, beamparams=beamparams, zoom=zoom, fov=fov)
+
+
+def generate_consistency_plot(clique_fraclevels, im_clique_fraclevels, zoom=0.1, fov=1, show=True,
+ framesize=(20, 10), fracsteps=None, cutoff=None, r_offset=1):
+
+ fig, ax = plt.subplots(figsize=framesize)
+ cycol = cycle('bgrcmk')
+
+ x_loc = []
+
+ for c, column in enumerate(clique_fraclevels):
+ colorc = cycol.next()
+ x_loc.append(((20./len(clique_fraclevels))*c))
+ if len(column) == 0:
+ continue
+
+ for r, row in enumerate(column):
+
+ # adding the images
+ lenx = len(clique_fraclevels)
+ leny = 0
+ for li in clique_fraclevels:
+ if len(li) > leny:
+ leny = len(li)
+ sample_image = im_clique_fraclevels[c][r].regrid_image(
+ fov*im_clique_fraclevels[c][r].fovx(), 512)
+ arr_img = sample_image.imvec.reshape(sample_image.xdim, sample_image.ydim)
+ imagebox = OffsetImage(arr_img, zoom=zoom, cmap='afmhot')
+
+ imagebox.image.axes = ax
+
+ ab = AnnotationBbox(imagebox, ((20./lenx)*c+r_offset, (20./leny)*r),
+ xycoords='data',
+ pad=0.0,
+ arrowprops=None)
+
+ ax.add_artist(ab)
+
+ # adding the arrows
+ if c+1 != len(clique_fraclevels):
+ for a, ro in enumerate(clique_fraclevels[c+1]):
+ if set(row).issubset(ro):
+ px = c+1
+ px = ((20./lenx)*px) + r_offset
+ py = a
+ py = (20./leny)*py
+ break
+
+ xx = (20./lenx)*c + (8./lenx) + r_offset
+ yy = (20./leny)*r
+ ax.arrow(xx, yy,
+ px - xx - (9./lenx), py - yy,
+ head_width=0.05,
+ head_length=0.1,
+ color=colorc
+ )
+ row.sort()
+ # adding the text
+ txtstring = str(row)
+ ax.text((20./lenx)*c, (20./leny)*(r-0.5), txtstring, fontsize=10,
+ horizontalalignment='center', color='black', zorder=1000)
+
+ ax.set_xlim(0, 22)
+ ax.set_ylim(-10, 22)
+
+ ax.set_xticks(x_loc)
+ ax.set_xticklabels(fracsteps)
+
+ ax.set_yticks([])
+ ax.set_yticklabels([])
+
+ ax.spines['right'].set_visible(False)
+ ax.spines['top'].set_visible(False)
+ ax.spines['left'].set_visible(False)
+ ax.spines['bottom'].set_visible(False)
+
+ ax.set_title('Blurred comparison of all images; cutoff={0}, fov (uas)={1}'.format(
+ str(cutoff), str(im_clique_fraclevels[-1][-1].fovx()/ehc.RADPERUAS)))
+
+# for item in [fig, ax]:
+# item.patch.set_visible(False)
+# fig.patch.set_visible(False)
+# ax.axis('off')
+ if show is True:
+ plt.show()
diff --git a/plotting/summary_plots.py b/plotting/summary_plots.py
new file mode 100644
index 00000000..c8de71c5
--- /dev/null
+++ b/plotting/summary_plots.py
@@ -0,0 +1,1733 @@
+# summary_plots.py
+# Make data plots with multiple observations,images etc.
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+# TODO add systematic noise to individual closure quantities?
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import matplotlib.pyplot as plt
+import matplotlib.gridspec as gridspec
+from matplotlib.backends.backend_pdf import PdfPages
+import datetime
+
+from ehtim.plotting.comp_plots import plotall_obs_compare
+from ehtim.plotting.comp_plots import plot_bl_obs_compare
+from ehtim.plotting.comp_plots import plot_cphase_obs_compare
+from ehtim.plotting.comp_plots import plot_camp_obs_compare
+from ehtim.calibrating.self_cal import self_cal as selfcal
+from ehtim.calibrating.pol_cal import leakage_cal, plot_leakage
+import ehtim.const_def as ehc
+
+FONTSIZE = 22
+WSPACE = 0.8
+HSPACE = 0.3
+MARGINS = 0.5
+PROCESSES = 4
+MARKERSIZE = 5
+
+
+def imgsum(im_or_mov, obs, obs_uncal, outname, outdir='.', title='imgsum', commentstr="",
+ fontsize=FONTSIZE, cfun='afmhot', snrcut=0., maxset=False, ttype='nfft',
+ gainplots=True, ampplots=True, cphaseplots=True, campplots=True, ebar=True,
+ debias=True, cp_uv_min=False, force_extrapolate=True, processes=PROCESSES,
+ sysnoise=0, syscnoise=0):
+ """Produce an image summary plot for an image and uvfits file.
+
+ Args:
+ im_or_mov (Image or Movie): an Image object or Movie
+ obs (Obsdata): the self-calibrated Obsdata object
+ obs_uncal (Obsdata): the original Obsdata object
+ outname (str): output pdf file name
+
+ outdir (str): directory for output file
+ title (str): the pdf file title
+ commentstr (str): a comment for the top line of the pdf
+ fontsize (float): the font size for text in the sheet
+ cfun (float): matplotlib color function
+
+ maxset (bool): True to use a maximal set of closure quantities
+
+ gainplots (bool): include gain plots or not
+ ampplots (bool): include amplitude consistency plots or not
+ cphaseplots (bool): include closure phase consistency plots or not
+ campplots (bool): include closure amplitude consistency plots or not
+ ebar (bool): include error bars or not
+ debias (bool): debias visibility amplitudes before computing chisq or not
+ cp_uv_min (bool): minimum uv-distance cutoff for including a baseline in closure phase
+
+ sysnoise (float): percent systematic noise added in quadrature
+ syscnoise (float): closure phase systematic noise in degrees added in quadrature
+
+ snrcut (dict): a dictionary of snrcut values for each quantity
+
+ ttype (str): "fast" or "nfft" or "direct"
+ force_extrapolate (bool): if True, always extrapolate movie start/stop frames
+ processes (int): number of cores to use in multiprocessing
+ Returns:
+
+ """
+
+ plt.close('all') # close conflicting plots
+ plt.rc('font', family='serif')
+ plt.rc('text', usetex=True)
+ plt.rc('font', size=FONTSIZE)
+ plt.rc('axes', titlesize=FONTSIZE)
+ plt.rc('axes', labelsize=FONTSIZE)
+ plt.rc('xtick', labelsize=FONTSIZE)
+ plt.rc('ytick', labelsize=FONTSIZE)
+ plt.rc('legend', fontsize=FONTSIZE)
+ plt.rc('figure', titlesize=FONTSIZE)
+
+ if fontsize == 0:
+ fontsize = FONTSIZE
+
+ if maxset:
+ count = 'max'
+ else:
+ count = 'min'
+
+ snrcut_dict = {key: 0. for key in ['vis', 'amp', 'cphase', 'logcamp', 'camp']}
+
+ if type(snrcut) is dict:
+ for key in snrcut.keys():
+ snrcut_dict[key] = snrcut[key]
+ else:
+ for key in snrcut_dict.keys():
+ snrcut_dict[key] = snrcut
+
+ with PdfPages(outname) as pdf:
+ titlestr = 'Summary Sheet for %s on MJD %s' % (im_or_mov.source, im_or_mov.mjd)
+
+ # pdf metadata
+ d = pdf.infodict()
+ d['Title'] = title
+ d['Author'] = u'EHT Team 1'
+ d['Subject'] = titlestr
+ d['CreationDate'] = datetime.datetime.today()
+ d['ModDate'] = datetime.datetime.today()
+
+ # define the figure
+ fig = plt.figure(1, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+
+ # user comments
+ if len(commentstr) > 1:
+ titlestr = titlestr+'\n'+str(commentstr)
+
+ plt.suptitle(titlestr, y=.9, va='center', fontsize=int(1.2*fontsize))
+
+ ################################################################################
+ print("===========================================")
+ print("displaying the image")
+ ax = plt.subplot(gs[0:2, 0:2])
+ ax.set_title('Submitted Image')
+
+ movie = hasattr(im_or_mov, 'get_image')
+ if movie:
+ im_display = im_or_mov.avg_frame()
+
+ # TODO --- ok to always extrapolate?
+ if force_extrapolate:
+ im_or_mov.reset_interp(bounds_error=False)
+ elif hasattr(im_or_mov, 'make_image'):
+ im_display = im_or_mov.make_image(obs.res() * 10., 512)
+ else:
+ im_display = im_or_mov.copy()
+
+ ax = _display_img(im_display, axis=ax, show=False,
+ has_title=False, cfun=cfun, fontsize=fontsize)
+
+ print("===========================================")
+ print("displaying the blurred image")
+ ax = plt.subplot(gs[0:2, 2:5])
+ ax.set_title('Image blurred to nominal resolution')
+ fwhm = obs.res()
+ print("blur_FWHM: ", fwhm/ehc.RADPERUAS)
+ beamparams = [fwhm, fwhm, 0]
+
+ imblur = im_display.blur_gauss(beamparams, frac=1.0)
+ ax = _display_img(imblur, beamparams=beamparams, axis=ax, show=False,
+ has_title=False, cfun=cfun, fontsize=fontsize)
+
+ ################################################################################
+ print("===========================================")
+ print("calculating statistics")
+ # display the overall chi2
+ ax = plt.subplot(gs[2, 0:2])
+ ax.set_title('Image statistics')
+ # ax.axis('off')
+ ax.set_yticks([])
+ ax.set_xticks([])
+
+ flux = im_display.total_flux()
+
+ # SNR ordering
+ # obs.reorder_tarr_snr()
+ # obs_uncal.reorder_tarr_snr()
+
+ maxset = False
+ # compute chi^2
+ chi2vis = obs.chisq(im_or_mov, dtype='vis', ttype=ttype,
+ systematic_noise=sysnoise, maxset=maxset, snrcut=snrcut_dict['vis'])
+ chi2amp = obs.chisq(im_or_mov, dtype='amp', ttype=ttype,
+ systematic_noise=sysnoise, maxset=maxset, snrcut=snrcut_dict['amp'])
+ chi2cphase = obs.chisq(im_or_mov, dtype='cphase', ttype=ttype, systematic_noise=sysnoise,
+ systematic_cphase_noise=syscnoise,
+ maxset=maxset, cp_uv_min=cp_uv_min, snrcut=snrcut_dict['cphase'])
+ chi2logcamp = obs.chisq(im_or_mov, dtype='logcamp', ttype=ttype, systematic_noise=sysnoise,
+ maxset=maxset, snrcut=snrcut_dict['logcamp'])
+ chi2camp = obs.chisq(im_or_mov, dtype='camp', ttype=ttype,
+ systematic_noise=sysnoise, maxset=maxset, snrcut=snrcut_dict['camp'])
+
+ chi2vis_uncal = obs_uncal.chisq(im_or_mov, dtype='vis', ttype=ttype, systematic_noise=0,
+ maxset=maxset, snrcut=snrcut_dict['vis'])
+ chi2amp_uncal = obs_uncal.chisq(im_or_mov, dtype='amp', ttype=ttype, systematic_noise=0,
+ maxset=maxset, snrcut=snrcut_dict['amp'])
+ chi2cphase_uncal = obs_uncal.chisq(im_or_mov, dtype='cphase', ttype=ttype,
+ systematic_noise=0,
+ systematic_cphase_noise=0, maxset=maxset,
+ cp_uv_min=cp_uv_min, snrcut=snrcut_dict['cphase'])
+ chi2logcamp_uncal = obs_uncal.chisq(im_or_mov, dtype='logcamp', ttype=ttype,
+ systematic_noise=0, maxset=maxset,
+ snrcut=snrcut_dict['logcamp'])
+ chi2camp_uncal = obs_uncal.chisq(im_or_mov, dtype='camp', ttype=ttype, systematic_noise=0,
+ maxset=maxset, snrcut=snrcut_dict['camp'])
+
+ print("chi^2 vis: %0.2f %0.2f" % (chi2vis, chi2vis_uncal))
+ print("chi^2 amp: %0.2f %0.2f" % (chi2amp, chi2amp_uncal))
+ print("chi^2 cphase: %0.2f %0.2f" % (chi2cphase, chi2cphase_uncal))
+ print("chi^2 logcamp: %0.2f %0.2f" % (chi2logcamp, chi2logcamp_uncal))
+ print("chi^2 camp: %0.2f %0.2f" % (chi2logcamp, chi2logcamp_uncal))
+
+ fs = int(1*fontsize)
+ fs2 = int(.8*fontsize)
+ ax.text(.05, .9, "Source:", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.05, .7, "MJD:", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.05, .5, "FREQ:", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.05, .3, "FOV:", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.05, .1, "FLUX:", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+
+ ax.text(.23, .9, "%s" % im_or_mov.source, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.23, .7, "%i" % im_or_mov.mjd, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.23, .5, "%0.0f GHz" % (im_or_mov.rf/1.e9), fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.23, .3, "%0.1f $\mu$as" % (im_display.fovx()/ehc.RADPERUAS), fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.23, .1, "%0.2f Jy" % flux, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+
+ ax.text(.5, .9, "$\chi^2_{vis}$", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.5, .7, "$\chi^2_{amp}$", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.5, .5, "$\chi^2_{cphase}$", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.5, .3, "$\chi^2_{log camp}$", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.5, .1, "$\chi^2_{camp}$", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+
+ ax.text(.72, .9, "%0.2f" % chi2vis, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.72, .7, "%0.2f" % chi2amp, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.72, .5, "%0.2f" % chi2cphase, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.72, .3, "%0.2f" % chi2logcamp, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.72, .1, "%0.2f" % chi2camp, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+
+ ax.text(.85, .9, "(%0.2f)" % chi2vis_uncal, fontsize=fs2,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.85, .7, "(%0.2f)" % chi2amp_uncal, fontsize=fs2,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.85, .5, "(%0.2f)" % chi2cphase_uncal, fontsize=fs2,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.85, .3, "(%0.2f)" % chi2logcamp, fontsize=fs2,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.85, .1, "(%0.2f)" % chi2camp_uncal, fontsize=fs2,
+ ha='left', va='center', transform=ax.transAxes)
+
+ ################################################################################
+ print("===========================================")
+ print("calculating cphase statistics")
+ # display the closure phase chi2
+ ax = plt.subplot(gs[3:6, 0:2])
+ ax.set_title('Closure phase statistics')
+ ax.set_yticks([])
+ ax.set_xticks([])
+
+ # get closure triangle combinations
+ # ANDREW -- hacky, fix this!
+ cp = obs.c_phases(mode="all", count=count, uv_min=cp_uv_min, snrcut=snrcut_dict['cphase'])
+ n_cphase = len(cp)
+ alltris = [(str(cpp['t1']), str(cpp['t2']), str(cpp['t3'])) for cpp in cp]
+ uniqueclosure_tri = []
+ for tri in alltris:
+ if tri not in uniqueclosure_tri:
+ uniqueclosure_tri.append(tri)
+
+ # generate data
+ obs_model = im_or_mov.observe_same(obs, add_th_noise=False, ttype=ttype)
+
+ # TODO: check SNR cut
+ cphases_obs = obs.c_phases(mode='all', count='max', vtype='vis',
+ uv_min=cp_uv_min, snrcut=snrcut_dict['cphase'])
+ if snrcut_dict['cphase'] > 0:
+ cphases_obs_all = obs.c_phases(mode='all', count='max',
+ vtype='vis', uv_min=cp_uv_min, snrcut=0.)
+ cphases_model_all = obs_model.c_phases(
+ mode='all', count='max', vtype='vis', uv_min=cp_uv_min, snrcut=0.)
+ mask = [cphase in cphases_obs for cphase in cphases_obs_all]
+ cphases_model = cphases_model_all[mask]
+ print('cphase snr cut', snrcut_dict['cphase'], ' : kept', len(
+ cphases_obs), '/', len(cphases_obs_all))
+ else:
+ cphases_model = obs_model.c_phases(
+ mode='all', count='max', vtype='vis', uv_min=cp_uv_min, snrcut=0.)
+
+ # generate chi^2 -- NO SYSTEMATIC NOISES
+ cphase_chisq_data = []
+ for c in range(0, len(uniqueclosure_tri)):
+ cphases_obs_tri = obs.cphase_tri(uniqueclosure_tri[c][0],
+ uniqueclosure_tri[c][1],
+ uniqueclosure_tri[c][2],
+ vtype='vis', ang_unit='deg', cphases=cphases_obs)
+
+ if len(cphases_obs_tri) > 0:
+ cphases_model_tri = obs_model.cphase_tri(uniqueclosure_tri[c][0],
+ uniqueclosure_tri[c][1],
+ uniqueclosure_tri[c][2],
+ vtype='vis', ang_unit='deg',
+ cphases=cphases_model)
+
+ resids = (cphases_obs_tri['cphase'] - cphases_model_tri['cphase'])*ehc.DEGREE
+ chisq_tri = 2*np.sum((1.0 - np.cos(resids)) /
+ ((cphases_obs_tri['sigmacp']*ehc.DEGREE)**2))
+
+ npts = len(cphases_obs_tri)
+ data = [uniqueclosure_tri[c][0], uniqueclosure_tri[c]
+ [1], uniqueclosure_tri[c][2], npts, chisq_tri]
+ cphase_chisq_data.append(data)
+
+ # sort by decreasing chi^2
+ idx = np.argsort([data[-1] for data in cphase_chisq_data])
+ idx = list(reversed(idx))
+
+ chisqtab = (r"\begin{tabular}{ l|l|l|l } \hline Triangle " +
+ r"& $N_{tri}$ & $\chi^2_{tri}/N_{tri}$ & $\chi^2_{tri}/N_{tot}$" +
+ r"\\ \hline \hline")
+ first = True
+ for i in range(len(cphase_chisq_data)):
+ if i > 30:
+ break
+ data = cphase_chisq_data[idx[i]]
+ tristr = r"%s-%s-%s" % (data[0], data[1], data[2])
+ nstr = r"%i" % data[3]
+ rchisqstr = r"%0.1f" % (float(data[4])/float(data[3]))
+ rrchisqstr = r"%0.3f" % (float(data[4])/float(n_cphase))
+ if first:
+ chisqtab += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ first = False
+ else:
+ chisqtab += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ chisqtab += r" \end{tabular}"
+
+ ax.text(0.5, .975, chisqtab, ha="center", va="top", transform=ax.transAxes, size=fontsize)
+
+ ################################################################################
+ print("===========================================")
+ print("calculating camp statistics")
+ # display the log closure amplitude chi2
+ ax = plt.subplot(gs[2:6, 2::])
+ ax.set_title('Log Closure amplitude statistics')
+ # ax.axis('off')
+ ax.set_yticks([])
+ ax.set_xticks([])
+
+ # get closure amplitude combinations
+ # TODO -- hacky, fix this!
+ cp = obs.c_amplitudes(mode="all", count=count, ctype='logcamp', debias=debias)
+ n_camps = len(cp)
+ allquads = [(str(cpp['t1']), str(cpp['t2']), str(cpp['t3']), str(cpp['t4'])) for cpp in cp]
+ uniqueclosure_quad = []
+ for quad in allquads:
+ if quad not in uniqueclosure_quad:
+ uniqueclosure_quad.append(quad)
+
+ # generate data
+ # TODO: check SNR cut
+ camps_obs = obs.c_amplitudes(mode='all', count='max', ctype='logcamp',
+ debias=debias, snrcut=snrcut_dict['logcamp'])
+ if snrcut_dict['logcamp'] > 0:
+ camps_obs_all = obs.c_amplitudes(
+ mode='all', count='max', ctype='logcamp', debias=debias, snrcut=0.)
+ camps_model_all = obs_model.c_amplitudes(
+ mode='all', count='max', ctype='logcamp', debias=False, snrcut=0.)
+ mask = [camp['camp'] in camps_obs['camp'] for camp in camps_obs_all]
+ camps_model = camps_model_all[mask]
+ print('closure amp snrcut', snrcut_dict['logcamp'],
+ ': kept', len(camps_obs), '/', len(camps_obs_all))
+ else:
+ camps_model = obs_model.c_amplitudes(
+ mode='all', count='max', ctype='logcamp', debias=False, snrcut=0.)
+
+ # generate chi2 -- NO SYSTEMATIC NOISES
+ camp_chisq_data = []
+ for c in range(0, len(uniqueclosure_quad)):
+ camps_obs_quad = obs.camp_quad(uniqueclosure_quad[c][0], uniqueclosure_quad[c][1],
+ uniqueclosure_quad[c][2], uniqueclosure_quad[c][3],
+ vtype='vis', camps=camps_obs, ctype='logcamp')
+
+ if len(camps_obs_quad) > 0:
+ camps_model_quad = obs.camp_quad(uniqueclosure_quad[c][0], uniqueclosure_quad[c][1],
+ uniqueclosure_quad[c][2], uniqueclosure_quad[c][3],
+ vtype='vis', camps=camps_model, ctype='logcamp')
+
+ resids = camps_obs_quad['camp'] - camps_model_quad['camp']
+ chisq_quad = np.sum(np.abs(resids/camps_obs_quad['sigmaca'])**2)
+ npts = len(camps_obs_quad)
+
+ data = (uniqueclosure_quad[c][0], uniqueclosure_quad[c][1],
+ uniqueclosure_quad[c][2], uniqueclosure_quad[c][3],
+ npts,
+ chisq_quad)
+ camp_chisq_data.append(data)
+
+ # sort by decreasing chi^2
+ idx = np.argsort([data[-1] for data in camp_chisq_data])
+ idx = list(reversed(idx))
+
+ chisqtab = (r"\begin{tabular}{ l|l|l|l } \hline Quadrangle " +
+ r"& $N_{quad}$ & $\chi^2_{quad}/N_{quad}$ & $\chi^2_{quad}/N_{tot}$ " +
+ r"\\ \hline \hline")
+ for i in range(len(camp_chisq_data)):
+ if i > 45:
+ break
+ data = camp_chisq_data[idx[i]]
+ tristr = r"%s-%s-%s-%s" % (data[0], data[1], data[2], data[3])
+ nstr = r"%i" % data[4]
+ rchisqstr = r"%0.1f" % (data[5]/float(data[4]))
+ rrchisqstr = r"%0.3f" % (data[5]/float(n_camps))
+ if i == 0:
+ chisqtab += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ else:
+ chisqtab += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+
+ chisqtab += r" \end{tabular}"
+
+ ax.text(0.5, .975, chisqtab, ha="center", va="top", transform=ax.transAxes, size=fontsize)
+
+ # save the first page of the plot
+ print('saving pdf page 1')
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+ ################################################################################
+ # plot the vis amps
+ fig = plt.figure(2, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+
+ print("===========================================")
+ print("plotting vis amps")
+ ax = plt.subplot(gs[0:2, 0:2])
+ obs_tmp = obs_model.copy()
+ obs_tmp.data['sigma'] *= 0.
+ ax = plotall_obs_compare([obs, obs_tmp],
+ 'uvdist', 'amp', axis=ax, legend=False,
+ clist=['k', ehc.SCOLORS[1]],
+ ttype=ttype, show=False, debias=debias,
+ snrcut=snrcut_dict['amp'],
+ ebar=ebar, markersize=MARKERSIZE)
+ # modify the labels
+ ax.set_title('Calibrated Visiblity Amplitudes')
+ ax.set_xlabel('u-v distance (G$\lambda$)')
+ ax.set_xlim([0, 1.e10])
+ ax.set_xticks([0, 2.e9, 4.e9, 6.e9, 8.e9, 10.e9])
+ ax.set_xticklabels(["0", "2", "4", "6", "8", "10"])
+ ax.set_xticks([1.e9, 3.e9, 5.e9, 7.e9, 9.e9], minor=True)
+ ax.set_xticklabels([], minor=True)
+
+ ax.set_ylabel('Amplitude (Jy)')
+ ax.set_ylim([0, 1.2*flux])
+ yticks_maj = np.array([0, .2, .4, .6, .8, 1])*flux
+ ax.set_yticks(yticks_maj)
+ ax.set_yticklabels(["%0.2f" % fl for fl in yticks_maj])
+ yticks_min = np.array([.1, .3, .5, .7, .9])*flux
+ ax.set_yticks(yticks_min, minor=True)
+ ax.set_yticklabels([], minor=True)
+
+ # plot the caltable gains
+ if gainplots:
+ print("===========================================")
+ print("plotting gains")
+ ax2 = plt.subplot(gs[0:2, 2:6])
+ obs_tmp = obs_uncal.copy()
+ for i in range(1):
+ ct = selfcal(obs_tmp, im_or_mov,
+ method='amp', ttype=ttype,
+ caltable=True, gain_tol=.2,
+ processes=processes)
+ ct = ct.pad_scans()
+ obs_tmp = ct.applycal(obs_tmp, interp='nearest', extrapolate=True)
+ if np.any(np.isnan(obs_tmp.data['vis'])):
+ print("Warning: NaN in applycal vis table!")
+ break
+ if i > 0:
+ ct_out = ct_out.merge([ct])
+ else:
+ ct_out = ct
+
+ ax2 = ct_out.plot_gains('all', rangey=[.1, 10],
+ yscale='log', axis=ax2, legend=True, show=False)
+
+ # median gains
+ ax = plt.subplot(gs[3:6, 2:5])
+ ax.set_title('Station gain statistics')
+ ax.set_yticks([])
+ ax.set_xticks([])
+
+ gain_data = []
+ for station in ct_out.tarr['site']:
+ try:
+ gain = np.median(np.abs(ct_out.data[station]['lscale']))
+ except:
+ continue
+ pdiff = np.abs(gain-1)*100
+ data = (station, gain, pdiff)
+ gain_data.append(data)
+
+ # sort by decreasing chi^2
+ idx = np.argsort([data[-1] for data in gain_data])
+ idx = list(reversed(idx))
+
+ chisqtab = (r"\begin{tabular}{ l|l|l } \hline Site & " +
+ r"Median Gain & Percent diff. \\ \hline \hline")
+ for i in range(len(gain_data)):
+ if i > 45:
+ break
+ data = gain_data[idx[i]]
+ sitestr = r"%s" % (data[0])
+ gstr = r"%0.2f" % data[1]
+ pstr = r"%0.0f" % data[2]
+ if i == 0:
+ chisqtab += r" " + sitestr + " & " + gstr + " & " + pstr
+ else:
+ chisqtab += r" \\" + sitestr + " & " + gstr + " & " + pstr
+
+ chisqtab += r" \end{tabular}"
+ ax.text(0.5, .975, chisqtab, ha="center", va="top",
+ transform=ax.transAxes, size=fontsize)
+
+ # baseline amplitude chi2
+ print("===========================================")
+ print("baseline vis amps chisq")
+ ax = plt.subplot(gs[3:6, 0:2])
+ ax.set_title('Visibility amplitude statistics')
+ ax.set_yticks([])
+ ax.set_xticks([])
+
+ bl_unpk = obs.unpack(['t1', 't2'], debias=debias)
+ n_bl = len(bl_unpk)
+ allbl = [(str(bl['t1']), str(bl['t2'])) for bl in bl_unpk]
+ uniquebl = []
+ for bl in allbl:
+ if bl not in uniquebl:
+ uniquebl.append(bl)
+
+ # generate chi2 -- NO SYSTEMATIC NOISES
+ bl_chisq_data = []
+ for ii in range(0, len(uniquebl)):
+ bl = uniquebl[ii]
+
+ amps_bl = obs.unpack_bl(bl[0], bl[1], ['amp', 'sigma'], debias=debias)
+ if len(amps_bl) > 0:
+
+ amps_bl_model = obs_model.unpack_bl(bl[0], bl[1], ['amp', 'sigma'], debias=False)
+
+ if snrcut_dict['amp'] > 0:
+ amask = amps_bl['amp']/amps_bl['sigma'] > snrcut_dict['amp']
+ amps_bl = amps_bl[amask]
+ amps_bl_model = amps_bl_model[amask]
+
+ chisq_bl = np.sum(
+ np.abs((amps_bl['amp'] - amps_bl_model['amp'])/amps_bl['sigma'])**2)
+ npts = len(amps_bl_model)
+
+ data = (bl[0], bl[1],
+ npts,
+ chisq_bl)
+ bl_chisq_data.append(data)
+
+ # sort by decreasing chi^2
+ idx = np.argsort([data[-1] for data in bl_chisq_data])
+ idx = list(reversed(idx))
+
+ chisqtab = (r"\begin{tabular}{ l|l|l|l } \hline Baseline & " +
+ r"$N_{amp}$ & $\chi^2_{amp}/N_{amp}$ & $\chi^2_{amp}/N_{total}$ " +
+ r"\\ \hline \hline")
+ for i in range(len(bl_chisq_data)):
+ if i > 45:
+ break
+ data = bl_chisq_data[idx[i]]
+ tristr = r"%s-%s" % (data[0], data[1])
+ nstr = r"%i" % data[2]
+ rchisqstr = r"%0.1f" % (data[3]/float(data[2]))
+ rrchisqstr = r"%0.3f" % (data[3]/float(n_bl))
+ if i == 0:
+ chisqtab += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ else:
+ chisqtab += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+
+ chisqtab += r" \end{tabular}"
+
+ ax.text(0.5, .975, chisqtab, ha="center", va="top", transform=ax.transAxes, size=fontsize)
+
+ # save the first page of the plot
+ print('saving pdf page 2')
+ # plt.tight_layout()
+ # plt.subplots_adjust(wspace=1,hspace=1)
+ # plt.savefig(outname, pad_inches=MARGINS,bbox_inches='tight')
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+ ################################################################################
+ # plot the visibility amplitudes
+ page = 3
+ if ampplots:
+ print("===========================================")
+ print("plotting amplitudes")
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ plt.suptitle("Amplitude Plots", y=.9, va='center', fontsize=int(1.2*fontsize))
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+ switch = False
+
+ obs_model.data['sigma'] *= 0
+ amax = 1.1*np.max(np.abs(np.abs(obs_model.data['vis'])))
+ obs_all = [obs, obs_model]
+ for bl in uniquebl:
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'amp', rangey=[0, amax],
+ markersize=MARKERSIZE, debias=debias,
+ snrcut=snrcut_dict['amp'],
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype=ttype, show=False, ebar=ebar)
+ if ax is None:
+ continue
+ if switch:
+ i += 1
+ j = 0
+ switch = False
+ else:
+ j = 1
+ switch = True
+
+ ax.set_xlabel('')
+
+ if i == 3:
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+ switch = False
+
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+ ################################################################################
+ # plot the closure phases
+ if cphaseplots:
+ print("===========================================")
+ print("plotting closure phases")
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ plt.suptitle("Closure Phase Plots", y=.9, va='center', fontsize=int(1.2*fontsize))
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+ switch = False
+ obs_all = [obs, obs_model]
+ cphases_model['sigmacp'] *= 0
+ cphases_all = [cphases_obs, cphases_model]
+ for tri in uniqueclosure_tri:
+
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_cphase_obs_compare(obs_all, tri[0], tri[1], tri[2], rangey=[-185, 185],
+ cphases=cphases_all, markersize=MARKERSIZE,
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype=ttype, show=False, ebar=ebar)
+ if ax is None:
+ continue
+ if switch:
+ i += 1
+ j = 0
+ switch = False
+ else:
+ j = 1
+ switch = True
+
+ ax.set_xlabel('')
+
+ if i == 3:
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+ switch = False
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+ ################################################################################
+ # plot the log closure amps
+ if campplots:
+ print("===========================================")
+ print("plotting closure amplitudes")
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ plt.suptitle("Closure Amplitude Plots", y=.9, va='center', fontsize=int(1.2*fontsize))
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+ switch = False
+ obs_all = [obs, obs_model]
+ camps_model['sigmaca'] *= 0
+ camps_all = [camps_obs, camps_model]
+ cmax = 1.1*np.max(np.abs(camps_obs['camp']))
+ for quad in uniqueclosure_quad:
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_camp_obs_compare(obs_all, quad[0], quad[1], quad[2], quad[3],
+ markersize=MARKERSIZE,
+ ctype='logcamp', rangey=[-cmax, cmax], camps=camps_all,
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype=ttype, show=False, ebar=ebar)
+ if ax is None:
+ continue
+ if switch:
+ i += 1
+ j = 0
+ switch = False
+ else:
+ j = 1
+ switch = True
+
+ ax.set_xlabel('')
+
+ if i == 3:
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+ switch = False
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+
+def imgsum_pol(im, obs, obs_uncal, outname,
+ leakage_arr=False, nvec=False,
+ outdir='.', title='imgsum_pol', commentstr="",
+ fontsize=FONTSIZE, cfun='afmhot', snrcut=0.,
+ dtermplots=True, pplots=True, mplots=True, qplots=True, uplots=True, ebar=True,
+ sysnoise=0):
+ """Produce a polarimetric image summary plot for an image and uvfits file.
+
+ Args:
+ im (Image): an Image object
+ obs (Obsdata): the calibrated Obsdata object
+ obs_uncal (Obsdata): the original Obsdata object
+ outname (str): output pdf file name
+
+ leakage_arr (bool): array with calibrated d-terms
+ nvec (int): number of polarimetric vectors to plot in each direction
+
+ outdir (str): directory for output file
+ title (str): the pdf file title
+ commentstr (str): a comment for the top line of the pdf
+ fontsize (float): the font size for text in the sheet
+ cfun (float): matplotlib color function
+ snrcut (dict): a dictionary of snrcut values for each quantity
+
+ dtermplots (bool): plot the d-terms or not
+ mplots (bool): plot the fractional polarizations or not
+ pplots (bool): plot the P=RL polarization or not
+ mplots (bool): plot the Q data or not
+ pplots (bool): plot the U data or not
+
+ ebar (bool): include error bars or not
+ sysnoise (float): percent systematic noise added in quadrature
+
+ Returns:
+
+ """
+
+ # switch polreps and mask nan data
+ im = im.switch_polrep(polrep_out='stokes')
+ obs = obs.switch_polrep(polrep_out='stokes')
+ obs_uncal = obs_uncal.switch_polrep(polrep_out='stokes')
+
+ mask_nan = (np.isnan(obs_uncal.data['vis']) +
+ np.isnan(obs_uncal.data['qvis']) +
+ np.isnan(obs_uncal.data['uvis']) +
+ np.isnan(obs_uncal.data['vvis']))
+ obs_uncal.data = obs_uncal.data[~mask_nan]
+
+ mask_nan = (np.isnan(obs.data['vis']) +
+ np.isnan(obs.data['qvis']) +
+ np.isnan(obs.data['uvis']) +
+ np.isnan(obs.data['vvis']))
+ obs.data = obs.data[~mask_nan]
+
+ if len(im.qvec) == 0 or len(im.uvec) == 0:
+ raise Exception("the image isn't polarized!")
+
+ plt.close('all') # close conflicting plots
+ plt.rc('font', family='serif')
+ plt.rc('text', usetex=True)
+ plt.rc('font', size=FONTSIZE)
+ plt.rc('axes', titlesize=FONTSIZE)
+ plt.rc('axes', labelsize=FONTSIZE)
+ plt.rc('xtick', labelsize=FONTSIZE)
+ plt.rc('ytick', labelsize=FONTSIZE)
+ plt.rc('legend', fontsize=FONTSIZE)
+ plt.rc('figure', titlesize=FONTSIZE)
+
+ if fontsize == 0:
+ fontsize = FONTSIZE
+
+ snrcut_dict = {key: 0. for key in ['m', 'pvis', 'qvis', 'uvis']}
+
+ if type(snrcut) is dict:
+ for key in snrcut.keys():
+ snrcut_dict[key] = snrcut[key]
+ else:
+ for key in snrcut_dict.keys():
+ snrcut_dict[key] = snrcut
+
+ # TODO -- ok? prevent errors in divisition
+ if(np.any(im.ivec == 0)):
+ im.ivec += 1.e-50*np.max(im.ivec)
+
+ with PdfPages(outname) as pdf:
+ titlestr = 'Summary Sheet for %s on MJD %s' % (im.source, im.mjd)
+
+ # pdf metadata
+ d = pdf.infodict()
+ d['Title'] = title
+ d['Author'] = u'EHT Team 1'
+ d['Subject'] = titlestr
+ d['CreationDate'] = datetime.datetime.today()
+ d['ModDate'] = datetime.datetime.today()
+
+ # define the figure
+ fig = plt.figure(1, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+
+ # user comments
+ if len(commentstr) > 1:
+ titlestr = titlestr+'\n'+str(commentstr)
+
+ plt.suptitle(titlestr, y=.9, va='center', fontsize=int(1.2*fontsize))
+
+ ################################################################################
+ print("===========================================")
+ print("displaying the images")
+
+ # unblurred image IQU
+ ax = plt.subplot(gs[0:2, 0:2])
+ ax.set_title('I')
+
+ ax = _display_img_pol(im, axis=ax, show=False, has_title=False, cfun=cfun,
+ pol='I', polticks=True,
+ nvec=nvec, pcut=0.1, mcut=0.01, contour=False,
+ fontsize=fontsize)
+
+ ax = plt.subplot(gs[2:4, 0:2])
+ ax.set_title('Q')
+ ax = _display_img_pol(im, axis=ax, show=False, has_title=False, cfun=plt.get_cmap('bwr'),
+ pol='Q', polticks=False,
+ nvec=nvec, pcut=0.1, mcut=0.01, contour=True,
+ fontsize=fontsize)
+
+ ax = plt.subplot(gs[4:6, 0:2])
+ ax.set_title('U')
+ ax = _display_img_pol(im, axis=ax, show=False, has_title=False, cfun=plt.get_cmap('bwr'),
+ pol='U', polticks=False,
+ nvec=nvec, pcut=0.1, mcut=0.01, contour=True,
+ fontsize=fontsize)
+
+ # blurred image IQU
+ ax = plt.subplot(gs[0:2, 2:5])
+ beamparams = obs_uncal.fit_beam()
+ fwhm = np.min((np.abs(beamparams[0]), np.abs(beamparams[1])))
+ print("blur_FWHM: ", fwhm/ehc.RADPERUAS)
+
+ imblur = im.blur_gauss(beamparams, frac=1.0, frac_pol=1.)
+
+ ax = _display_img_pol(imblur, axis=ax, show=False, has_title=False, cfun=cfun,
+ pol='I', polticks=True, beamparams=beamparams,
+ nvec=nvec, pcut=0.1, mcut=0.01, contour=False,
+ fontsize=fontsize)
+
+ ax = plt.subplot(gs[2:4, 2:5])
+ ax = _display_img_pol(imblur, axis=ax, show=False, has_title=False,
+ cfun=plt.get_cmap('bwr'),
+ pol='Q', polticks=False,
+ nvec=nvec, pcut=0.1, mcut=0.01, contour=True,
+ fontsize=fontsize)
+
+ ax = plt.subplot(gs[4:6, 2:5])
+ ax = _display_img_pol(imblur, axis=ax, show=False, has_title=False,
+ cfun=plt.get_cmap('bwr'),
+ pol='U', polticks=False,
+ nvec=nvec, pcut=0.1, mcut=0.01, contour=True,
+ fontsize=fontsize)
+
+ print('saving pdf page 1')
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+ # unblurred image m chi
+ fig = plt.figure(2, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+
+ ax = plt.subplot(gs[0:2, 0:2])
+ ax.set_title('m')
+ ax = _display_img_pol(im, axis=ax, show=True, has_title=False,
+ cfun=plt.get_cmap('jet'),
+ pol='m', polticks=False,
+ nvec=nvec, pcut=0.1, mcut=0.01, contour=False,
+ fontsize=fontsize)
+
+ ax = plt.subplot(gs[2:4, 0:2])
+ ax.set_title('chi')
+ ax = _display_img_pol(im, axis=ax, show=False, has_title=False,
+ cfun=plt.get_cmap('jet'),
+ pol='chi', polticks=False,
+ nvec=nvec, pcut=0.1, mcut=0.01, contour=False,
+ fontsize=fontsize)
+
+ ax = plt.subplot(gs[0:2, 2:5])
+ ax = _display_img_pol(imblur, axis=ax, show=False, has_title=False,
+ cfun=plt.get_cmap('jet'),
+ pol='m', polticks=False,
+ nvec=nvec, pcut=0.1, mcut=0.01, contour=False,
+ fontsize=fontsize)
+
+ ax = plt.subplot(gs[2:4, 2:5])
+ ax = _display_img_pol(imblur, axis=ax, show=False, has_title=False,
+ cfun=plt.get_cmap('jet'),
+ pol='chi', polticks=False,
+ nvec=nvec, pcut=0.1, mcut=0.01, contour=False,
+ fontsize=fontsize)
+
+ ################################################################################
+ print("===========================================")
+ print("calculating statistics")
+ # display the overall chi2
+ ax = plt.subplot(gs[4, 0:2])
+ ax.set_title('Image statistics')
+ # ax.axis('off')
+ ax.set_yticks([])
+ ax.set_xticks([])
+
+ flux = im.total_flux()
+
+ # SNR ordering
+ # obs.reorder_tarr_snr()
+ # obs_uncal.reorder_tarr_snr()
+
+ # compute chi^2
+ chi2pvis = obs.polchisq(im, dtype='m', ttype='nfft',
+ systematic_noise=sysnoise, pol_trans=False)
+ chi2m = obs.polchisq(im, dtype='m', ttype='nfft',
+ systematic_noise=sysnoise, pol_trans=False)
+ chi2qvis = obs.chisq(im, dtype='vis', ttype='nfft',
+ systematic_noise=sysnoise, pol='Q')
+ chi2uvis = obs.chisq(im, dtype='vis', ttype='nfft',
+ systematic_noise=sysnoise, pol='U')
+
+ chi2pvis_uncal = obs_uncal.polchisq(im, dtype='m', ttype='nfft',
+ systematic_noise=sysnoise, pol_trans=False)
+ chi2m_uncal = obs_uncal.polchisq(im, dtype='m', ttype='nfft',
+ systematic_noise=sysnoise, pol_trans=False)
+ chi2qvis_uncal = obs_uncal.chisq(im, dtype='vis', ttype='nfft',
+ systematic_noise=sysnoise, pol='Q')
+ chi2uvis_uncal = obs_uncal.chisq(im, dtype='vis', ttype='nfft',
+ systematic_noise=sysnoise, pol='U')
+
+ print("chi^2 m: %0.2f %0.2f" % (chi2m, chi2m_uncal))
+ print("chi^2 pvis: %0.2f %0.2f" % (chi2pvis, chi2pvis_uncal))
+ print("chi^2 qvis: %0.2f %0.2f" % (chi2qvis, chi2qvis_uncal))
+ print("chi^2 uvis: %0.2f %0.2f" % (chi2uvis, chi2uvis_uncal))
+
+ fs = int(1*fontsize)
+ fs2 = int(.8*fontsize)
+ ax.text(.05, .9, "Source:", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.05, .7, "MJD:", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.05, .5, "FREQ:", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.05, .3, "FOV:", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.05, .1, "FLUX:", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+
+ ax.text(.23, .9, "%s" % im.source, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.23, .7, "%i" % im.mjd, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.23, .5, "%0.0f GHz" % (im.rf/1.e9), fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.23, .3, "%0.1f $\mu$as" % (im.fovx()/ehc.RADPERUAS), fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.23, .1, "%0.2f Jy" % flux, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+
+ ax.text(.5, .9, "$\chi^2_{m}$", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.5, .7, "$\chi^2_{P}$", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.5, .5, "$\chi^2_{Q}$", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.5, .3, "$\chi^2_{U}$", fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+
+ ax.text(.72, .9, "%0.2f" % chi2m, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.72, .7, "%0.2f" % chi2pvis, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.72, .5, "%0.2f" % chi2qvis, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.72, .3, "%0.2f" % chi2uvis, fontsize=fs,
+ ha='left', va='center', transform=ax.transAxes)
+
+ ax.text(.85, .9, "(%0.2f)" % chi2m_uncal, fontsize=fs2,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.85, .7, "(%0.2f)" % chi2pvis_uncal, fontsize=fs2,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.85, .5, "(%0.2f)" % chi2qvis_uncal, fontsize=fs2,
+ ha='left', va='center', transform=ax.transAxes)
+ ax.text(.85, .3, "(%0.2f)" % chi2uvis_uncal, fontsize=fs2,
+ ha='left', va='center', transform=ax.transAxes)
+
+ ################################################################################
+ # plot the D terms
+
+ if dtermplots:
+ print("===========================================")
+ print("plotting d terms")
+ ax = plt.subplot(gs[4:6, 2:5])
+
+ if leakage_arr:
+ obs_polcal = obs_uncal.copy()
+ obs_polcal.tarr = leakage_arr.tarr
+ else:
+ obs_polcal = leakage_cal(obs_uncal, im, leakage_tol=1e6, ttype='nfft')
+
+ ax = plot_leakage(obs_polcal, axis=ax, show=False,
+ rangex=[-20, 20], rangey=[-20, 20], markersize=5)
+
+ print('saving pdf page 2')
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+ # 3
+ # baseline amplitude chi2
+ fig = plt.figure(2, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+
+ print("===========================================")
+ print("baseline m&p chisq")
+
+ bl_unpk = obs.unpack(['t1', 't2'])
+ n_bl = len(bl_unpk)
+ allbl = [(str(bl['t1']), str(bl['t2'])) for bl in bl_unpk]
+ uniquebl = []
+ for bl in allbl:
+ if bl not in uniquebl:
+ uniquebl.append(bl)
+
+ # generate data
+ obs_model = im.observe_same(obs, add_th_noise=False, ttype='nfft')
+
+ # generate chi2 -- NO SYSTEMATIC NOISES
+ bl_chisq_data_m = []
+ bl_chisq_data_pvis = []
+ bl_chisq_data_qvis = []
+ bl_chisq_data_uvis = []
+
+ for ii in range(0, len(uniquebl)):
+ bl = uniquebl[ii]
+
+ m_bl = obs.unpack_bl(bl[0], bl[1], ['m', 'msigma'], debias=False)
+ pvis_bl = obs.unpack_bl(bl[0], bl[1], ['pvis', 'psigma'], debias=False)
+ qvis_bl = obs.unpack_bl(bl[0], bl[1], ['qvis', 'qsigma'], debias=False)
+ uvis_bl = obs.unpack_bl(bl[0], bl[1], ['uvis', 'usigma'], debias=False)
+
+ if len(m_bl) > 0:
+
+ m_bl_model = obs_model.unpack_bl(bl[0], bl[1], ['m', 'msigma'], debias=False)
+ pvis_bl_model = obs_model.unpack_bl(bl[0], bl[1], ['pvis', 'psigma'], debias=False)
+ qvis_bl_model = obs_model.unpack_bl(bl[0], bl[1], ['qvis', 'qsigma'], debias=False)
+ uvis_bl_model = obs_model.unpack_bl(bl[0], bl[1], ['uvis', 'usigma'], debias=False)
+
+ if snrcut_dict['m'] > 0:
+ amask = np.abs(m_bl['m'])/m_bl['msigma'] > snrcut_dict['m']
+ m_bl = m_bl[amask]
+ m_bl_model = m_bl_model[amask]
+ if snrcut_dict['pvis'] > 0:
+ amask = np.abs(pvis_bl['pvis'])/pvis_bl['psigma'] > snrcut_dict['pvis']
+ pvis_bl = pvis_bl[amask]
+ pvis_bl_model = pvis_bl_model[amask]
+ if snrcut_dict['qvis'] > 0:
+ amask = np.abs(qvis_bl['qvis'])/qvis_bl['qsigma'] > snrcut_dict['qvis']
+ qvis_bl = qvis_bl[amask]
+ qvis_bl_model = qvis_bl_model[amask]
+ if snrcut_dict['uvis'] > 0:
+ amask = np.abs(uvis_bl['uvis'])/uvis_bl['usigma'] > snrcut_dict['uvis']
+ uvis_bl = uvis_bl[amask]
+ uvis_bl_model = uvis_bl_model[amask]
+
+ chisq_m_bl = np.sum(np.abs((m_bl['m'] - m_bl_model['m'])/m_bl['msigma'])**2)
+ npts_m = len(m_bl_model)
+ data_m = (bl[0], bl[1], npts_m, chisq_m_bl)
+ bl_chisq_data_m.append(data_m)
+
+ chisq_pvis_bl = np.sum(
+ np.abs((pvis_bl['pvis'] - pvis_bl_model['pvis'])/pvis_bl['psigma'])**2)
+ npts_pvis = len(pvis_bl_model)
+ data_pvis = (bl[0], bl[1], npts_pvis, chisq_pvis_bl)
+ bl_chisq_data_pvis.append(data_pvis)
+
+ chisq_qvis_bl = np.sum(
+ np.abs((qvis_bl['qvis'] - qvis_bl_model['qvis'])/qvis_bl['qsigma'])**2)
+ npts_qvis = len(qvis_bl_model)
+ data_qvis = (bl[0], bl[1], npts_qvis, chisq_qvis_bl)
+ bl_chisq_data_qvis.append(data_qvis)
+
+ chisq_uvis_bl = np.sum(
+ np.abs((uvis_bl['uvis'] - uvis_bl_model['uvis'])/uvis_bl['usigma'])**2)
+ npts_uvis = len(uvis_bl_model)
+ data_uvis = (bl[0], bl[1], npts_uvis, chisq_uvis_bl)
+ bl_chisq_data_uvis.append(data_uvis)
+
+ # sort by decreasing chi^2
+ idx_m = np.argsort([data[-1] for data in bl_chisq_data_m])
+ idx_m = list(reversed(idx_m))
+ idx_p = np.argsort([data[-1] for data in bl_chisq_data_pvis])
+ idx_p = list(reversed(idx_p))
+ idx_q = np.argsort([data[-1] for data in bl_chisq_data_qvis])
+ idx_q = list(reversed(idx_q))
+ idx_u = np.argsort([data[-1] for data in bl_chisq_data_uvis])
+ idx_u = list(reversed(idx_u))
+
+ chisqtab_m = (r"\begin{tabular}{ l|l|l|l } \hline Baseline & $N_{m}$ & " +
+ r"$\chi^2_{m}/N_{m}$ & $\chi^2_{m}/N_{total}$ \\ " +
+ r"\hline \hline")
+ chisqtab_p = (r"\begin{tabular}{ l|l|l|l } \hline Baseline & $N_{p}$ & " +
+ r"$\chi^2_{p}/N_{p}$ & $\chi^2_{p}/N_{total}$ \\ " +
+ r"\hline \hline")
+ chisqtab_q = (r"\begin{tabular}{ l|l|l|l } \hline Baseline & $N_{Q}$ & " +
+ r"$\chi^2_{Q}/N_{Q}$ & $\chi^2_{Q}/N_{total}$ \\ " +
+ r"\hline \hline")
+ chisqtab_u = (r"\begin{tabular}{ l|l|l|l } \hline Baseline & $N_{U}$ & " +
+ r"$\chi^2_{U}/N_{U}$ & $\chi^2_{U}/N_{total}$ \\ " +
+ r"\hline \hline")
+
+ for i in range(len(bl_chisq_data_m)):
+ if i > 45:
+ break
+ data = bl_chisq_data_m[idx_m[i]]
+ tristr = r"%s-%s" % (data[0], data[1])
+ nstr = r"%i" % data[2]
+ chisqstr = r"%0.1f" % data[3]
+ rchisqstr = r"%0.1f" % (data[3]/float(data[2]))
+ rrchisqstr = r"%0.3f" % (data[3]/float(n_bl))
+ if i == 0:
+ chisqtab_m += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ else:
+ chisqtab_m += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ for i in range(len(bl_chisq_data_pvis)):
+ if i > 45:
+ break
+ data = bl_chisq_data_pvis[idx_p[i]]
+ tristr = r"%s-%s" % (data[0], data[1])
+ nstr = r"%i" % data[2]
+ rchisqstr = r"%0.1f" % (data[3]/float(data[2]))
+ rrchisqstr = r"%0.3f" % (data[3]/float(n_bl))
+ if i == 0:
+ chisqtab_p += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ else:
+ chisqtab_p += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ for i in range(len(bl_chisq_data_qvis)):
+ if i > 45:
+ break
+ data = bl_chisq_data_qvis[idx_q[i]]
+ tristr = r"%s-%s" % (data[0], data[1])
+ nstr = r"%i" % data[2]
+ rchisqstr = r"%0.1f" % (data[3]/float(data[2]))
+ rrchisqstr = r"%0.3f" % (data[3]/float(n_bl))
+ if i == 0:
+ chisqtab_q += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ else:
+ chisqtab_q += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ for i in range(len(bl_chisq_data_uvis)):
+ if i > 45:
+ break
+ data = bl_chisq_data_qvis[idx_u[i]]
+ tristr = r"%s-%s" % (data[0], data[1])
+ nstr = r"%i" % data[2]
+ rchisqstr = r"%0.1f" % (data[3]/float(data[2]))
+ rrchisqstr = r"%0.3f" % (data[3]/float(n_bl))
+ if i == 0:
+ chisqtab_u += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+ else:
+ chisqtab_u += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr
+
+ chisqtab_m += r" \end{tabular}"
+ chisqtab_p += r" \end{tabular}"
+ chisqtab_q += r" \end{tabular}"
+ chisqtab_u += r" \end{tabular}"
+
+ ax = plt.subplot(gs[0:3, 0:2])
+ ax.set_title('baseline m statistics')
+ ax.set_yticks([])
+ ax.set_xticks([])
+ ax.text(0.5, .975, chisqtab_m, ha="center", va="top", transform=ax.transAxes, size=fontsize)
+
+ ax = plt.subplot(gs[0:3, 2:5])
+ ax.set_title('baseline P statistics')
+ ax.set_yticks([])
+ ax.set_xticks([])
+ ax.text(0.5, .975, chisqtab_p, ha="center", va="top", transform=ax.transAxes, size=fontsize)
+
+ ax = plt.subplot(gs[3:6, 0:2])
+ ax.set_title('baseline Q statistics')
+ ax.set_yticks([])
+ ax.set_xticks([])
+ ax.text(0.5, .975, chisqtab_q, ha="center", va="top", transform=ax.transAxes, size=fontsize)
+
+ ax = plt.subplot(gs[3:6, 2:5])
+ ax.set_title('baseline U statistics')
+ ax.set_yticks([])
+ ax.set_xticks([])
+ ax.text(0.5, .975, chisqtab_u, ha="center", va="top", transform=ax.transAxes, size=fontsize)
+
+ # save the first page of the plot
+ print('saving pdf page 3')
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+ ################################################################################
+ # plot the baseline pol amps and phases
+ page = 4
+ if mplots:
+ print("===========================================")
+ print("plotting fractional polarizatons")
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ plt.suptitle("Fractional Polarization Plots", y=.9,
+ va='center', fontsize=int(1.2*fontsize))
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+
+ obs_model.data['sigma'] *= 0
+ obs_model.data['qsigma'] *= 0
+ obs_model.data['usigma'] *= 0
+ obs_model.data['vsigma'] *= 0
+
+ amax = 1.1*np.max(np.abs(np.abs(obs_model.unpack(['mamp'])['mamp'])))
+ obs_all = [obs, obs_model]
+ for nbl, bl in enumerate(uniquebl):
+ j = 0
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'mamp', rangey=[0, amax],
+ markersize=MARKERSIZE, debias=False,
+ snrcut=snrcut_dict['m'],
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype='nfft', show=False, ebar=ebar)
+ ax.set_xlabel('')
+ j = 1
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'mphase', rangey=[-180, 180],
+
+ markersize=MARKERSIZE, debias=False,
+ snrcut=snrcut_dict['m'],
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype='nfft', show=False, ebar=ebar)
+ i += 1
+ ax.set_xlabel('')
+
+ if ax is None:
+ continue
+
+ if i == 3:
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+
+ if nbl == len(uniquebl):
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+ if pplots:
+ print("===========================================")
+ print("plotting total polarizaton")
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ plt.suptitle("Total Polarization Plots", y=.9, va='center', fontsize=int(1.2*fontsize))
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+
+ obs_model.data['sigma'] *= 0
+ obs_model.data['qsigma'] *= 0
+ obs_model.data['usigma'] *= 0
+ obs_model.data['vsigma'] *= 0
+
+ amax = 1.1*np.max(np.abs(np.abs(obs_model.unpack(['pamp'])['pamp'])))
+ obs_all = [obs, obs_model]
+ for nbl, bl in enumerate(uniquebl):
+ j = 0
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'pamp', rangey=[0, amax],
+ markersize=MARKERSIZE, debias=False,
+ snrcut=snrcut_dict['pvis'],
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype='nfft', show=False, ebar=ebar)
+ ax.set_xlabel('')
+ j = 1
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'pphase', rangey=[-180, 180],
+ markersize=MARKERSIZE, debias=False,
+ snrcut=snrcut_dict['pvis'],
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype='nfft', show=False, ebar=ebar)
+ i += 1
+ ax.set_xlabel('')
+
+ if ax is None:
+ continue
+
+ if i == 3:
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+
+ if nbl == len(uniquebl):
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+ if qplots:
+ print("===========================================")
+ print("plotting Q fit")
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ plt.suptitle("Q Plots", y=.9, va='center', fontsize=int(1.2*fontsize))
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+
+ obs_model.data['sigma'] *= 0
+ obs_model.data['qsigma'] *= 0
+ obs_model.data['usigma'] *= 0
+ obs_model.data['vsigma'] *= 0
+
+ amax = 1.1*np.max(np.abs(np.abs(obs_model.unpack(['qamp'])['qamp'])))
+ obs_all = [obs, obs_model]
+ for nbl, bl in enumerate(uniquebl):
+ j = 0
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'qamp', rangey=[0, amax],
+ markersize=MARKERSIZE, debias=False,
+ snrcut=snrcut_dict['qvis'],
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype='nfft', show=False, ebar=ebar)
+ ax.set_xlabel('')
+ j = 1
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'qphase', rangey=[-180, 180],
+ markersize=MARKERSIZE, debias=False,
+ snrcut=snrcut_dict['qvis'],
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype='nfft', show=False, ebar=ebar)
+ i += 1
+ ax.set_xlabel('')
+
+ if ax is None:
+ continue
+
+ if i == 3:
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+
+ if nbl == len(uniquebl):
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+ if uplots:
+ print("===========================================")
+ print("plotting U fit")
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ plt.suptitle("U Plots", y=.9, va='center', fontsize=int(1.2*fontsize))
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+
+ obs_model.data['sigma'] *= 0
+ obs_model.data['qsigma'] *= 0
+ obs_model.data['usigma'] *= 0
+ obs_model.data['vsigma'] *= 0
+
+ amax = 1.1*np.max(np.abs(np.abs(obs_model.unpack(['uamp'])['uamp'])))
+ obs_all = [obs, obs_model]
+ for nbl, bl in enumerate(uniquebl):
+ j = 0
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'uamp', rangey=[0, amax],
+ markersize=MARKERSIZE, debias=False,
+ snrcut=snrcut_dict['uvis'],
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype='nfft', show=False, ebar=ebar)
+ ax.set_xlabel('')
+ j = 1
+ ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)])
+ ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'uphase', rangey=[-180, 180],
+ markersize=MARKERSIZE, debias=False,
+ snrcut=snrcut_dict['uvis'],
+ axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]],
+ ttype='nfft', show=False, ebar=ebar)
+ i += 1
+ ax.set_xlabel('')
+
+ if ax is None:
+ continue
+
+ if i == 3:
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+ fig = plt.figure(3, figsize=(18, 28), dpi=200)
+ gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE)
+ i = 0
+ j = 0
+
+ if nbl == len(uniquebl):
+ print('saving pdf page %i' % page)
+ page += 1
+ pdf.savefig(pad_inches=MARGINS, bbox_inches='tight')
+ plt.close()
+
+
+def _display_img(im, beamparams=None, scale='linear', gamma=0.5, cbar_lims=False,
+ has_cbar=True, has_title=True, cfun='afmhot', dynamic_range=100,
+ axis=False, show=False, fontsize=FONTSIZE):
+ """display the figure on a given axis
+ cannot use im.display because it makes a new figure
+ """
+
+ interp = 'gaussian'
+
+ if axis:
+ ax = axis
+ else:
+ fig = plt.figure()
+ ax = fig.add_subplot(1, 1, 1)
+
+ imvec = np.array(im.imvec).reshape(-1)
+ # flux unit is mJy/uas^2
+ imvec = imvec * 1.e3
+ fovfactor = im.xdim*im.psize*(1/ehc.RADPERUAS)
+ factor = (1./fovfactor)**2 / (1./im.xdim)**2
+ imvec = imvec * factor
+
+ imarr = (imvec).reshape(im.ydim, im.xdim)
+ unit = 'mJy/$\mu$ as$^2$'
+ if scale == 'log':
+ if (imarr < 0.0).any():
+ print('clipping values less than 0')
+ imarr[imarr < 0.0] = 0.0
+ imarr = np.log(imarr + np.max(imarr)/dynamic_range)
+ unit = 'log(' + unit + ')'
+
+ if scale == 'gamma':
+ if (imarr < 0.0).any():
+ print('clipping values less than 0')
+ imarr[imarr < 0.0] = 0.0
+ imarr = (imarr + np.max(imarr)/dynamic_range)**(gamma)
+ unit = '(' + unit + ')^gamma'
+
+ if cbar_lims:
+ imarr[imarr > cbar_lims[1]] = cbar_lims[1]
+ imarr[imarr < cbar_lims[0]] = cbar_lims[0]
+
+ if cbar_lims:
+ ax = ax.imshow(imarr, cmap=plt.get_cmap(cfun), interpolation=interp,
+ vmin=cbar_lims[0], vmax=cbar_lims[1])
+ else:
+ ax = ax.imshow(imarr, cmap=plt.get_cmap(cfun), interpolation=interp)
+
+ if has_cbar:
+ cbar = plt.colorbar(ax, fraction=0.046, pad=0.04, format='%1.2g')
+ cbar.set_label(unit, fontsize=fontsize)
+ cbar.ax.xaxis.set_label_position('top')
+ cbar.ax.tick_params(labelsize=16)
+ if cbar_lims:
+ plt.clim(cbar_lims[0], cbar_lims[1])
+
+ if not(beamparams is None):
+ beamparams = [beamparams[0], beamparams[1], beamparams[2],
+ -.35*im.fovx(), -.35*im.fovy()]
+ beamimage = im.copy()
+ beamimage.imvec *= 0
+ beamimage = beamimage.add_gauss(1, beamparams)
+ halflevel = 0.5*np.max(beamimage.imvec)
+ beamimarr = (beamimage.imvec).reshape(beamimage.ydim, beamimage.xdim)
+ plt.contour(beamimarr, levels=[halflevel], colors='w', linewidths=3)
+ ax = plt.gca()
+
+ plt.axis('off')
+ fov_uas = im.xdim * im.psize / ehc.RADPERUAS # get the fov in uas
+ roughfactor = 1./3. # make the bar about 1/3 the fov
+ fov_scale = 40
+ start = im.xdim * roughfactor / 3.0 # select the start location
+ end = start + fov_scale/fov_uas * im.xdim # determine the end location
+ plt.plot([start, end], [im.ydim-start, im.ydim-start], color="white", lw=1) # plot line
+ plt.text(x=(start+end)/2.0, y=im.ydim-start-im.ydim/20, s=str(fov_scale) + " $\mu$as",
+ color="white", ha="center", va="center",
+ fontsize=int(1.2*fontsize), fontweight='bold')
+
+ ax.axes.get_xaxis().set_visible(False)
+ ax.axes.get_yaxis().set_visible(False)
+
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ return ax
+
+
+def _display_img_pol(im, beamparams=None, scale='linear', gamma=0.5, cbar_lims=False,
+ has_cbar=True, has_title=True, cfun='afmhot', pol=None, polticks=False,
+ nvec=False, pcut=0.1, mcut=0.01, contour=False, dynamic_range=100,
+ axis=False, show=False, fontsize=FONTSIZE):
+ """display the polarimetric figure on a given axis
+ cannot use im.display because it makes a new figure
+ """
+
+ interp = 'gaussian'
+
+ if axis:
+ ax = axis
+ else:
+ fig = plt.figure()
+ ax = fig.add_subplot(1, 1, 1)
+
+ if pol == 'm':
+ imvec = im.mvec
+ unit = r'$|\breve{m}|$'
+ factor = 1
+ cbar_lims = [0, 1]
+ elif pol == 'chi':
+ imvec = im.chivec / ehc.DEGREE
+ unit = r'$\chi (^\circ)$'
+ factor = 1
+ cbar_lims = [0, 180]
+ else:
+ # flux unit is Tb
+ factor = 3.254e13/(im.rf**2 * im.psize**2)
+ unit = 'Tb (K)'
+ try:
+ imvec = np.array(im._imdict[pol]).reshape(-1)
+ except KeyError:
+ try:
+ if im.polrep == 'stokes':
+ im2 = im.switch_polrep('circ')
+ elif im.polrep == 'circ':
+ im2 = im.switch_polrep('stokes')
+ imvec = np.array(im2._imdict[pol]).reshape(-1)
+ except KeyError:
+ raise Exception("Cannot make pol %s image in display()!" % pol)
+
+ # flux unit is Tb
+ imvec = imvec * factor
+ imarr = (imvec).reshape(im.ydim, im.xdim)
+
+ if scale == 'log':
+ if (imarr < 0.0).any():
+ print('clipping values less than 0')
+ imarr[imarr < 0.0] = 0.0
+ imarr = np.log(imarr + np.max(imarr)/dynamic_range)
+ unit = 'log(' + unit + ')'
+
+ if scale == 'gamma':
+ if (imarr < 0.0).any():
+ print('clipping values less than 0')
+ imarr[imarr < 0.0] = 0.0
+ imarr = (imarr + np.max(imarr)/dynamic_range)**(gamma)
+ unit = '(' + unit + ')^gamma'
+
+ if cbar_lims:
+ imarr[imarr > cbar_lims[1]] = cbar_lims[1]
+ imarr[imarr < cbar_lims[0]] = cbar_lims[0]
+
+ if cbar_lims:
+ ax = ax.imshow(imarr, cmap=plt.get_cmap(cfun), interpolation=interp,
+ vmin=cbar_lims[0], vmax=cbar_lims[1])
+ else:
+ ax = ax.imshow(imarr, cmap=plt.get_cmap(cfun), interpolation=interp)
+
+ if contour:
+ plt.contour(imarr, colors='k', linewidths=.25)
+
+ if polticks:
+ im_stokes = im.switch_polrep(polrep_out='stokes')
+ ivec = np.array(im_stokes.imvec).reshape(-1)
+ qvec = np.array(im_stokes.qvec).reshape(-1)
+ uvec = np.array(im_stokes.uvec).reshape(-1)
+ vvec = np.array(im_stokes.vvec).reshape(-1)
+
+ if len(ivec) == 0:
+ ivec = np.zeros(im_stokes.ydim*im_stokes.xdim)
+ if len(qvec) == 0:
+ qvec = np.zeros(im_stokes.ydim*im_stokes.xdim)
+ if len(uvec) == 0:
+ uvec = np.zeros(im_stokes.ydim*im_stokes.xdim)
+ if len(vvec) == 0:
+ vvec = np.zeros(im_stokes.ydim*im_stokes.xdim)
+
+ if not nvec:
+ nvec = im.xdim // 2
+
+ thin = im.xdim//nvec
+ maska = (ivec).reshape(im.ydim, im.xdim) > pcut * np.max(ivec)
+ maskb = (np.abs(qvec + 1j*uvec)/ivec).reshape(im.ydim, im.xdim) > mcut
+ mask = maska * maskb
+ mask2 = mask[::thin, ::thin]
+ x = (np.array([[i for i in range(im.xdim)] for j in range(im.ydim)])[::thin, ::thin])
+ x = x[mask2]
+ y = (np.array([[j for i in range(im.xdim)] for j in range(im.ydim)])[::thin, ::thin])
+ y = y[mask2]
+ a = (-np.sin(np.angle(qvec+1j*uvec)/2).reshape(im.ydim, im.xdim)[::thin, ::thin])
+ a = a[mask2]
+ b = (np.cos(np.angle(qvec+1j*uvec)/2).reshape(im.ydim, im.xdim)[::thin, ::thin])
+ b = b[mask2]
+
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.01*im.xdim, units='x', pivot='mid', color='k', angles='uv',
+ scale=1.0/thin)
+ plt.quiver(x, y, a, b,
+ headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1,
+ width=.005*im.xdim, units='x', pivot='mid', color='w', angles='uv',
+ scale=1.1/thin)
+
+ if has_cbar:
+ cbar = plt.colorbar(ax, fraction=0.046, pad=0.04, format='%1.2g')
+ cbar.set_label(unit, fontsize=fontsize)
+ cbar.ax.xaxis.set_label_position('top')
+ cbar.ax.tick_params(labelsize=16)
+ if cbar_lims:
+ ax.set_clim(cbar_lims[0], cbar_lims[1])
+
+ if not(beamparams is None):
+ beamparams = [beamparams[0], beamparams[1], beamparams[2],
+ -.35*im.fovx(), -.35*im.fovy()]
+ beamimage = im.copy()
+ beamimage.imvec *= 0
+ beamimage = beamimage.add_gauss(1, beamparams)
+ halflevel = 0.5*np.max(beamimage.imvec)
+ beamimarr = (beamimage.imvec).reshape(beamimage.ydim, beamimage.xdim)
+ plt.contour(beamimarr, levels=[halflevel], colors='w', linewidths=3)
+ ax = plt.gca()
+
+ plt.axis('off')
+ if has_cbar:
+ fov_uas = im.xdim * im.psize / ehc.RADPERUAS # get the fov in uas
+ roughfactor = 1./3. # make the bar about 1/3 the fov
+ fov_scale = 40
+ # round around 1/3 the fov to nearest 10
+ start = im.xdim * roughfactor / 3.0 # select the start location
+ end = start + fov_scale/fov_uas * im.xdim
+ # determine the end location based on the size of the bar
+ plt.plot([start, end], [im.ydim-start, im.ydim-start], color="white", lw=1) # plot line
+ plt.text(x=(start+end)/2.0, y=im.ydim-start-im.ydim/20, s=str(fov_scale) + " $\mu$as",
+ color="white", ha="center", va="center",
+ fontsize=int(1.2*fontsize), fontweight='bold')
+
+ ax.axes.get_xaxis().set_visible(False)
+ ax.axes.get_yaxis().set_visible(False)
+
+ if show:
+ #plt.show(block=False)
+ ehc.show_noblock()
+
+ return ax
diff --git a/scattering/__init__.py b/scattering/__init__.py
new file mode 100644
index 00000000..3ad5e415
--- /dev/null
+++ b/scattering/__init__.py
@@ -0,0 +1,11 @@
+"""
+.. module:: ehtim.stochastic_optics
+ :platform: Unix
+ :synopsis: EHT Imaging Utilities: imaging functions
+
+.. moduleauthor:: Michael Johnson (mjohnson@cfa.harvard.edu)
+
+"""
+
+from .stochastic_optics import *
+from ..const_def import *
diff --git a/scattering/stochastic_optics.py b/scattering/stochastic_optics.py
new file mode 100644
index 00000000..d328d921
--- /dev/null
+++ b/scattering/stochastic_optics.py
@@ -0,0 +1,817 @@
+# Michael Johnson, 2/15/2017
+# See http://adsabs.harvard.edu/abs/2016ApJ...833...74J for details about this module
+
+from __future__ import print_function
+from builtins import range
+from builtins import object
+import numpy as np
+import scipy.signal
+import scipy.special as sps
+import scipy.integrate as integrate
+from scipy.optimize import minimize
+
+import matplotlib.pyplot as plt
+
+import ehtim.image as image
+import ehtim.movie as movie
+import ehtim.obsdata as obsdata
+from ehtim.observing.obs_helpers import *
+from ehtim.const_def import * #Note: C is m/s rather than cm/s.
+
+from multiprocessing import cpu_count
+from multiprocessing import Pool
+
+import math
+import cmath
+
+################################################################################
+# The class ScatteringModel enscompasses a generic scattering model, determined by the power spectrum Q and phase structure function Dphi
+################################################################################
+
+class ScatteringModel(object):
+ """A scattering model based on a thin-screen approximation.
+
+ Models include:
+ ('von_Mises', 'boxcar', 'dipole'): These scattering models are motivated by observations of Sgr A*.
+ Each gives a Gaussian at long wavelengths that matches the model defined
+ by {theta_maj_mas_ref, theta_min_mas_ref, POS_ANG} at the reference wavelength wavelength_reference_cm
+ with a lambda^2 scaling. The source sizes {theta_maj, theta_min} are the image FWHM in milliarcseconds
+ at the reference wavelength. Note that this may not match the ensemble-average kernel at the reference wavelength,
+ if the reference wavelength is short enough to be beyond the lambda^2 regime!
+ This model also includes an inner and outer scale and will thus transition to scattering with scatt_alpha at shorter wavelengths
+ Note: This model *requires* a finite inner scale
+ 'power-law': This scattering model gives a pure power law at all wavelengths. There is no inner scale, but there can be an outer scale.
+ The ensemble-average image is given by {theta_maj_mas_ref, theta_min_mas_ref, POS_ANG} at the reference wavelength wavelength_reference_cm.
+ The ensemble-average image size is proportional to wavelength^(1+2/scatt_alpha) = wavelength^(11/5) for Kolmogorov
+
+ Attributes:
+ model (string): The type of scattering model (determined by the power spectrum of phase fluctuations).
+ scatt_alpha (float): The power-law index of the phase fluctuations (Kolmogorov is 5/3).
+ observer_screen_distance (float): The distance from the observer to the scattering screen in cm.
+ source_screen_distance (float): The distance from the source to the scattering screen in cm.
+ theta_maj_mas_ref (float): FWHM in mas of the major axis angular broadening at the specified reference wavelength.
+ theta_min_mas_ref (float): FWHM in mas of the minor axis angular broadening at the specified reference wavelength.
+ POS_ANG (float): The position angle of the major axis of the scattering.
+ wavelength_reference_cm (float): The reference wavelength for the scattering model in cm.
+ r_in (float): The inner scale of the scattering screen in cm.
+ r_out (float): The outer scale of the scattering screen in cm.
+ rF (function): The Fresnel scale of the scattering screen at the specific wavelength.
+ """
+
+ def __init__(self, model = 'dipole', scatt_alpha = 1.38, observer_screen_distance = 2.82 * 3.086e21, source_screen_distance = 5.53 * 3.086e21, theta_maj_mas_ref = 1.380, theta_min_mas_ref = 0.703, POS_ANG = 81.9, wavelength_reference_cm = 1.0, r_in = 800e5, r_out = 1e20):
+ """To initialize the scattering model, specify:
+
+ Attributes:
+ model (string): The type of scattering model (determined by the power spectrum of phase fluctuations). Options are 'von_Mises', 'boxcar', 'dipole', and 'power-law'
+ scatt_alpha (float): The power-law index of the phase fluctuations (Kolmogorov is 5/3).
+ observer_screen_distance (float): The distance from the observer to the scattering screen in cm.
+ source_screen_distance (float): The distance from the source to the scattering screen in cm.
+ theta_maj_mas_ref (float): FWHM in mas of the major axis angular broadening at the specified reference wavelength.
+ theta_min_mas_ref (float): FWHM in mas of the minor axis angular broadening at the specified reference wavelength.
+ POS_ANG (float): The position angle of the major axis of the scattering.
+ wavelength_reference_cm (float): The reference wavelength for the scattering model in cm.
+ r_in (float): The inner scale of the scattering screen in cm.
+ r_out (float): The outer scale of the scattering screen in cm.
+ """
+
+ self.model = model
+ self.POS_ANG = POS_ANG #Major axis position angle [degrees, east of north]
+ self.observer_screen_distance = observer_screen_distance #cm
+ self.source_screen_distance = source_screen_distance #cm
+ M = observer_screen_distance/source_screen_distance
+ self.wavelength_reference = wavelength_reference_cm #Reference wavelength [cm]
+ self.r_in = r_in #inner scale [cm]
+ self.r_out = r_out #outer scale [cm]
+ self.scatt_alpha = scatt_alpha
+
+ FWHM_fac = (2.0 * np.log(2.0))**0.5/np.pi
+ self.Qbar = 2.0/sps.gamma((2.0 - self.scatt_alpha)/2.0) * (self.r_in**2*(1.0 + M)/(FWHM_fac*(self.wavelength_reference/(2.0*np.pi))**2) )**2 * ( (theta_maj_mas_ref**2 + theta_min_mas_ref**2)*(1.0/1000.0/3600.0*np.pi/180.0)**2)
+ self.C_scatt_0 = (self.wavelength_reference/(2.0*np.pi))**2 * self.Qbar*sps.gamma(1.0 - self.scatt_alpha/2.0)/(8.0*np.pi**2*self.r_in**2)
+ A = theta_maj_mas_ref/theta_min_mas_ref # Anisotropy, >=1, as lambda->infinity
+ self.phi0 = (90 - self.POS_ANG) * np.pi/180.0
+
+ # Parameters for the approximate phase structure function
+ theta_maj_rad_ref = theta_maj_mas_ref/1000.0/3600.0*np.pi/180.0
+ theta_min_rad_ref = theta_min_mas_ref/1000.0/3600.0*np.pi/180.0
+ self.Amaj_0 = ( self.r_in*(1.0 + M) * theta_maj_rad_ref/(FWHM_fac * (self.wavelength_reference/(2.0*np.pi)) * 2.0*np.pi ))**2
+ self.Amin_0 = ( self.r_in*(1.0 + M) * theta_min_rad_ref/(FWHM_fac * (self.wavelength_reference/(2.0*np.pi)) * 2.0*np.pi ))**2
+
+ if model == 'von_Mises':
+ def avM_Anisotropy(kzeta):
+ return np.abs( (kzeta*sps.i0(kzeta)/sps.i1(kzeta) - 1.0)**0.5 - A )
+
+ self.kzeta = minimize(avM_Anisotropy, A**2, method='nelder-mead', options={'xtol': 1e-8, 'disp': False}).x
+ self.P_phi_prefac = 1.0/(2.0*np.pi*sps.i0(self.kzeta))
+ elif model == 'boxcar':
+ def boxcar_Anisotropy(kzeta):
+ return np.abs( np.sin(np.pi/(1.0 + kzeta))/(np.pi/(1.0 + kzeta)) - (theta_maj_mas_ref**2 - theta_min_mas_ref**2)/(theta_maj_mas_ref**2 + theta_min_mas_ref**2) )
+
+ self.kzeta = minimize(boxcar_Anisotropy, A, method='nelder-mead', options={'xtol': 1e-8, 'disp': False}).x
+ self.P_phi_prefac = (1.0 + self.kzeta)/(2.0*np.pi)
+ elif model == 'dipole':
+ def dipole_Anisotropy(kzeta):
+ return np.abs( sps.hyp2f1((self.scatt_alpha + 2.0)/2.0, 0.5, 2.0, -kzeta)/sps.hyp2f1((self.scatt_alpha + 2.0)/2.0, 1.5, 2.0, -kzeta) - A**2 )
+
+ self.kzeta = minimize(dipole_Anisotropy, A, method='nelder-mead', options={'xtol': 1e-8, 'disp': False}).x
+ self.P_phi_prefac = 1.0/(2.0*np.pi*sps.hyp2f1((self.scatt_alpha + 2.0)/2.0, 0.5, 1.0, -self.kzeta))
+ else:
+ print("Scattering Model Not Recognized!")
+
+ # More parameters for the approximate phase structure function
+ int_maj = integrate.quad(lambda phi_q: np.abs( np.cos( self.phi0 - phi_q ) )**self.scatt_alpha * self.P_phi(phi_q), 0, 2.0*np.pi, limit=250)[0]
+ int_min = integrate.quad(lambda phi_q: np.abs( np.sin( self.phi0 - phi_q ) )**self.scatt_alpha * self.P_phi(phi_q), 0, 2.0*np.pi, limit=250)[0]
+ B_prefac = self.C_scatt_0 * 2.0**(2.0 - self.scatt_alpha) * np.pi**0.5/(self.scatt_alpha * sps.gamma((self.scatt_alpha + 1.0)/2.0))
+ self.Bmaj_0 = B_prefac*int_maj
+ self.Bmin_0 = B_prefac*int_min
+
+ #Check normalization:
+ #print("Checking Normalization:",integrate.quad(lambda phi_q: self.P_phi(phi_q), 0, 2.0*np.pi)[0])
+
+ return
+
+ def P_phi(self, phi):
+ if self.model == 'von_Mises':
+ return self.P_phi_prefac * np.cosh(self.kzeta*np.cos(phi - self.phi0))
+ elif self.model == 'boxcar':
+ return self.P_phi_prefac * (1.0 - ((np.pi/(2.0*(1.0 + self.kzeta)) < (phi - self.phi0) % np.pi) & ((phi - self.phi0) % np.pi < np.pi - np.pi/(2.0*(1.0 + self.kzeta)))))
+ elif self.model == 'dipole':
+ return self.P_phi_prefac * (1.0 + self.kzeta*np.sin(phi - self.phi0)**2)**(-(self.scatt_alpha + 2.0)/2.0)
+
+ def rF(self, wavelength):
+ """Returns the Fresnel scale [cm] of the scattering screen at the specified wavelength [cm].
+
+ Args:
+ wavelength (float): The desired wavelength [cm]
+
+ Returns:
+ rF (float): The Fresnel scale [cm]
+ """
+ return (self.source_screen_distance*self.observer_screen_distance/(self.source_screen_distance + self.observer_screen_distance)*wavelength/(2.0*np.pi))**0.5
+
+ def Mag(self):
+ """Returns the effective magnification the scattering screen: (observer-screen distance)/(source-screen distance).
+
+ Returns:
+ M (float): The effective magnification of the scattering screen.
+ """
+ return self.observer_screen_distance/self.source_screen_distance
+
+ def dDphi_dz(self, r, phi, phi_q, wavelength):
+ """differential contribution to the phase structure function
+ """
+ return 4.0 * (wavelength/self.wavelength_reference)**2 * self.C_scatt_0/self.scatt_alpha * (sps.hyp1f1(-self.scatt_alpha/2.0, 0.5, -r**2/(4.0*self.r_in**2)*np.cos(phi - phi_q)**2) - 1.0)
+
+ def Dphi_exact(self, x, y, wavelength_cm):
+ r = (x**2 + y**2)**0.5
+ phi = np.arctan2(y, x)
+
+ return integrate.quad(lambda phi_q: self.dDphi_dz(r, phi, phi_q, wavelength_cm)*self.P_phi(phi_q), 0, 2.0*np.pi)[0]
+
+ def Dmaj(self, r, wavelength_cm):
+ return (wavelength_cm/self.wavelength_reference)**2 * self.Bmaj_0 * (2.0 * self.Amaj_0/(self.scatt_alpha * self.Bmaj_0))**(-self.scatt_alpha/(2.0 - self.scatt_alpha)) * ((1.0 + (2.0*self.Amaj_0/(self.scatt_alpha * self.Bmaj_0))**(2.0/(2.0 - self.scatt_alpha)) * (r/self.r_in)**2 )**(self.scatt_alpha/2.0) - 1.0)
+
+ def Dmin(self, r, wavelength_cm):
+ return (wavelength_cm/self.wavelength_reference)**2 * self.Bmin_0 * (2.0 * self.Amin_0/(self.scatt_alpha * self.Bmin_0))**(-self.scatt_alpha/(2.0 - self.scatt_alpha)) * ((1.0 + (2.0*self.Amin_0/(self.scatt_alpha * self.Bmin_0))**(2.0/(2.0 - self.scatt_alpha)) * (r/self.r_in)**2 )**(self.scatt_alpha/2.0) - 1.0)
+
+ def Dphi_approx(self, x, y, wavelength_cm):
+ r = (x**2 + y**2)**0.5
+ phi = np.arctan2(y, x)
+
+ Dmaj_eval = self.Dmaj(r, wavelength_cm)
+ Dmin_eval = self.Dmin(r, wavelength_cm)
+
+ return (Dmaj_eval + Dmin_eval)/2.0 + (Dmaj_eval - Dmin_eval)/2.0*np.cos(2.0*(phi - self.phi0))
+
+ def Q(self, qx, qy):
+ """Computes the power spectrum of the scattering model at a wavenumber {qx,qy} (in 1/cm).
+ The power spectrum is part of what defines the scattering model (along with Dphi).
+ Q(qx,qy) is independent of the observing wavelength.
+
+ Args:
+ qx (float): x coordinate of the wavenumber in 1/cm.
+ qy (float): y coordinate of the wavenumber in 1/cm.
+ Returns:
+ (float): The power spectrum Q(qx,qy)
+ """
+
+ q = (qx**2 + qy**2)**0.5 + 1e-12/self.r_in #Add a small offset to avoid division by zero
+ phi_q = np.arctan2(qy, qx)
+
+ return self.Qbar * (q*self.r_in)**(-(self.scatt_alpha + 2.0)) * np.exp(-(q * self.r_in)**2) * self.P_phi(phi_q)
+
+
+ def sqrtQ_Matrix(self, Reference_Image, Vx_km_per_s=50.0, Vy_km_per_s=0.0, t_hr=0.0):
+ """Computes the square root of the power spectrum on a discrete grid. Because translation of the screen is done most conveniently in Fourier space, a screen translation can also be included.
+
+ Args:
+ Reference_Image (Image): Reference image to determine image and pixel dimensions and wavelength.
+ Vx_km_per_s (float): Velocity of the scattering screen in the x direction (toward East) in km/s.
+ Vy_km_per_s (float): Velocity of the scattering screen in the y direction (toward North) in km/s.
+ t_hr (float): The current time of the scattering in hours.
+ Returns:
+ sqrtQ (2D complex ndarray): The square root of the power spectrum of the screen with an additional phase for rotation of the screen.
+ """
+
+ #Derived parameters
+ FOV = Reference_Image.psize * Reference_Image.xdim * self.observer_screen_distance #Field of view, in cm, at the scattering screen
+ N = Reference_Image.xdim
+ dq = 2.0*np.pi/FOV #this is the spacing in wavenumber
+ screen_x_offset_pixels = (Vx_km_per_s * 1.e5) * (t_hr*3600.0) / (FOV/float(N))
+ screen_y_offset_pixels = (Vy_km_per_s * 1.e5) * (t_hr*3600.0) / (FOV/float(N))
+
+ s, t = np.meshgrid(np.fft.fftfreq(N, d=1.0/N), np.fft.fftfreq(N, d=1.0/N))
+ sqrtQ = np.sqrt(self.Q(dq*s, dq*t)) * np.exp(2.0*np.pi*1j*(s*screen_x_offset_pixels +
+ t*screen_y_offset_pixels)/float(N))
+ sqrtQ[0][0] = 0.0 #A DC offset doesn't affect scattering
+
+ return sqrtQ
+
+ def Ensemble_Average_Kernel(self, Reference_Image, wavelength_cm = None, use_approximate_form=True):
+ """The ensemble-average convolution kernel for images; returns a 2D array corresponding to the image dimensions of the reference image
+
+ Args:
+ Reference_Image (Image): Reference image to determine image and pixel dimensions and wavelength.
+ wavelength_cm (float): The observing wavelength for the scattering kernel in cm. If unspecified, this will default to the wavelength of the Reference image.
+
+ Returns:
+ ker (2D ndarray): The ensemble-average scattering kernel in the image domain.
+ """
+
+ if wavelength_cm == None:
+ wavelength_cm = C/Reference_Image.rf*100.0 #Observing wavelength [cm]
+
+ uvlist = np.fft.fftfreq(Reference_Image.xdim)/Reference_Image.psize # assume square kernel. FIXME: create ulist and vlist, and construct u_grid and v_grid with the correct dimension
+ if use_approximate_form == True:
+ u_grid, v_grid = np.meshgrid(uvlist, uvlist)
+ ker_uv = self.Ensemble_Average_Kernel_Visibility(u_grid, v_grid, wavelength_cm, use_approximate_form=use_approximate_form)
+ else:
+ ker_uv = np.array([[self.Ensemble_Average_Kernel_Visibility(u, v, wavelength_cm, use_approximate_form=use_approximate_form) for u in uvlist] for v in uvlist])
+
+ ker = np.real(np.fft.fftshift(np.fft.fft2(ker_uv)))
+ ker = ker / np.sum(ker) # normalize to 1
+ return ker
+
+ def Ensemble_Average_Kernel_Visibility(self, u, v, wavelength_cm, use_approximate_form=True):
+ """The ensemble-average multiplicative scattering kernel for visibilities at a particular {u,v} coordinate
+
+ Args:
+ u (float): u baseline coordinate (dimensionless)
+ v (float): v baseline coordinate (dimensionless)
+ wavelength_cm (float): The observing wavelength for the scattering kernel in cm.
+
+ Returns:
+ float: The ensemble-average kernel at the specified {u,v} point and observing wavelength.
+ """
+ if use_approximate_form == True:
+ return np.exp(-0.5*self.Dphi_approx(u*wavelength_cm/(1.0+self.Mag()), v*wavelength_cm/(1.0+self.Mag()), wavelength_cm))
+ else:
+ return np.exp(-0.5*self.Dphi_exact(u*wavelength_cm/(1.0+self.Mag()), v*wavelength_cm/(1.0+self.Mag()), wavelength_cm))
+
+ def Ensemble_Average_Blur(self, im, wavelength_cm = None, ker = None, use_approximate_form=True):
+ """Blurs an input Image with the ensemble-average scattering kernel.
+
+ Args:
+ im (Image): The unscattered image.
+ wavelength_cm (float): The observing wavelength for the scattering kernel in cm. If unspecified, this will default to the wavelength of the input image.
+ ker (2D ndarray): The user can optionally pass a pre-computed ensemble-average blurring kernel.
+
+ Returns:
+ out (Image): The ensemble-average scattered image.
+ """
+
+ # Inputs an unscattered image and an ensemble-average blurring kernel (2D array); returns the ensemble-average image
+ # The pre-computed kernel can optionally be specified (ker)
+
+ if wavelength_cm == None:
+ wavelength_cm = C/im.rf*100.0 #Observing wavelength [cm]
+
+ if ker is None:
+ ker = self.Ensemble_Average_Kernel(im, wavelength_cm, use_approximate_form)
+
+ Iim = Wrapped_Convolve((im.imvec).reshape(im.ydim, im.xdim), ker)
+ out = image.Image(Iim, im.psize, im.ra, im.dec, rf=C/(wavelength_cm/100.0), source=im.source, mjd=im.mjd, pulse=im.pulse)
+ if len(im.qvec):
+ Qim = Wrapped_Convolve((im.qvec).reshape(im.ydim, im.xdim), ker)
+ Uim = Wrapped_Convolve((im.uvec).reshape(im.ydim, im.xdim), ker)
+ out.add_qu(Qim, Uim)
+ if len(im.vvec):
+ Vim = Wrapped_Convolve((im.vvec).reshape(im.ydim, im.xdim), ker)
+ out.add_v(Vim)
+
+ return out
+
+ def Deblur_obs(self, obs, use_approximate_form=True):
+ """Deblurs the observation obs by dividing visibilities by the ensemble-average scattering kernel. See Fish et al. (2014): arXiv:1409.4690.
+
+ Args:
+ obs (Obsdata): The observervation data (including scattering).
+
+ Returns:
+ obsdeblur (Obsdata): The deblurred observation.
+ """
+
+ # make a copy of observation data
+ datatable = (obs.copy()).data
+
+ vis = datatable['vis']
+ qvis = datatable['qvis']
+ uvis = datatable['uvis']
+ vvis = datatable['vvis']
+ sigma = datatable['sigma']
+ qsigma = datatable['qsigma']
+ usigma = datatable['usigma']
+ vsigma = datatable['vsigma']
+ u = datatable['u']
+ v = datatable['v']
+
+ # divide visibilities by the scattering kernel
+ for i in range(len(vis)):
+ ker = self.Ensemble_Average_Kernel_Visibility(u[i], v[i], wavelength_cm = C/obs.rf*100.0, use_approximate_form=use_approximate_form)
+ vis[i] = vis[i] / ker
+ qvis[i] = qvis[i] / ker
+ uvis[i] = uvis[i] / ker
+ vvis[i] = vvis[i] / ker
+ sigma[i] = sigma[i] / ker
+ qsigma[i] = qsigma[i] / ker
+ usigma[i] = usigma[i] / ker
+ vsigma[i] = vsigma[i] / ker
+
+ datatable['vis'] = vis
+ datatable['qvis'] = qvis
+ datatable['uvis'] = uvis
+ datatable['vvis'] = vvis
+ datatable['sigma'] = sigma
+ datatable['qsigma'] = qsigma
+ datatable['usigma'] = usigma
+ datatable['vsigma'] = vsigma
+
+ obsdeblur = obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, datatable, obs.tarr, source=obs.source, mjd=obs.mjd,
+ ampcal=obs.ampcal, phasecal=obs.phasecal, opacitycal=obs.opacitycal, dcal=obs.dcal, frcal=obs.frcal)
+ return obsdeblur
+
+ def MakePhaseScreen(self, EpsilonScreen, Reference_Image, obs_frequency_Hz=0.0, Vx_km_per_s=50.0, Vy_km_per_s=0.0, t_hr=0.0, sqrtQ_init=None):
+ """Create a refractive phase screen from standardized Fourier components (the EpsilonScreen).
+ All lengths should be specified in centimeters
+ If the observing frequency (obs_frequency_Hz) is not specified, then it will be taken to be equal to the frequency of the Reference_Image
+ Note: an odd image dimension is required!
+
+ Args:
+ EpsilonScreen (2D ndarray): Optionally, the scattering screen can be specified. If none is given, a random one will be generated.
+ Reference_Image (Image): The reference image.
+ obs_frequency_Hz (float): The observing frequency, in Hz. By default, it will be taken to be equal to the frequency of the Unscattered_Image.
+ Vx_km_per_s (float): Velocity of the scattering screen in the x direction (toward East) in km/s.
+ Vy_km_per_s (float): Velocity of the scattering screen in the y direction (toward North) in km/s.
+ t_hr (float): The current time of the scattering in hours.
+ ea_ker (2D ndarray): The used can optionally pass a precomputed array of the ensemble-average blurring kernel.
+ sqrtQ_init (2D ndarray): The used can optionally pass a precomputed array of the square root of the power spectrum.
+
+ Returns:
+ phi_Image (Image): The phase screen.
+ """
+
+ #Observing wavelength
+ if obs_frequency_Hz == 0.0:
+ obs_frequency_Hz = Reference_Image.rf
+
+ wavelength = C/obs_frequency_Hz*100.0 #Observing wavelength [cm]
+ wavelengthbar = wavelength/(2.0*np.pi) #lambda/(2pi) [cm]
+
+ #Derived parameters
+ FOV = Reference_Image.psize * Reference_Image.xdim * self.observer_screen_distance #Field of view, in cm, at the scattering screen
+ rF = self.rF(wavelength)
+ Nx = EpsilonScreen.shape[1]
+ Ny = EpsilonScreen.shape[0]
+
+# if Nx%2 == 0:
+# print("The image dimension should really be odd...")
+
+ #Now we'll calculate the power spectrum for each pixel in Fourier space
+ screen_x_offset_pixels = (Vx_km_per_s*1.e5) * (t_hr*3600.0) / (FOV/float(Nx))
+ screen_y_offset_pixels = (Vy_km_per_s*1.e5) * (t_hr*3600.0) / (FOV/float(Nx))
+
+ if sqrtQ_init is None:
+ sqrtQ = self.sqrtQ_Matrix(Reference_Image, Vx_km_per_s=Vx_km_per_s, Vy_km_per_s=Vy_km_per_s, t_hr=t_hr)
+ else:
+ #If a matrix for sqrtQ_init is passed, we still potentially need to rotate it
+
+ if screen_x_offset_pixels != 0.0 or screen_y_offset_pixels != 0.0:
+ s, t = np.meshgrid(np.fft.fftfreq(Nx, d=1.0/Nx), np.fft.fftfreq(Ny, d=1.0/Ny))
+ sqrtQ = sqrtQ_init * np.exp(2.0*np.pi*1j*(s*screen_x_offset_pixels +
+ t*screen_y_offset_pixels)/float(Nx))
+ else:
+ sqrtQ = sqrtQ_init
+
+ #Now calculate the phase screen
+ phi = np.real(wavelengthbar/FOV*EpsilonScreen.shape[0]*EpsilonScreen.shape[1]*np.fft.ifft2(sqrtQ*EpsilonScreen))
+ phi_Image = image.Image(phi, Reference_Image.psize, Reference_Image.ra, Reference_Image.dec, rf=Reference_Image.rf, source=Reference_Image.source, mjd=Reference_Image.mjd)
+
+ return phi_Image
+
+ def Scatter2(self, args, kwargs):
+ """Call self.Scatter with expanded args and kwargs."""
+ return self.Scatter(*args, **kwargs)
+
+ def Scatter(self, Unscattered_Image, Epsilon_Screen=np.array([]), obs_frequency_Hz=0.0, Vx_km_per_s=50.0, Vy_km_per_s=0.0, t_hr=0.0, ea_ker=None, sqrtQ=None, Linearized_Approximation=False, DisplayImage=False, Force_Positivity=False, use_approximate_form=True):
+ """Scatter an image using the specified epsilon screen.
+ All lengths should be specified in centimeters
+ If the observing frequency (obs_frequency_Hz) is not specified, then it will be taken to be equal to the frequency of the Unscattered_Image
+ Note: an odd image dimension is required!
+
+ Args:
+ Unscattered_Image (Image): The unscattered image.
+ Epsilon_Screen (2D ndarray): Optionally, the scattering screen can be specified. If none is given, a random one will be generated.
+ obs_frequency_Hz (float): The observing frequency, in Hz. By default, it will be taken to be equal to the frequency of the Unscattered_Image.
+ Vx_km_per_s (float): Velocity of the scattering screen in the x direction (toward East) in km/s.
+ Vy_km_per_s (float): Velocity of the scattering screen in the y direction (toward North) in km/s.
+ t_hr (float): The current time of the scattering in hours.
+ ea_ker (2D ndarray): The used can optionally pass a precomputed array of the ensemble-average blurring kernel.
+ sqrtQ (2D ndarray): The used can optionally pass a precomputed array of the square root of the power spectrum.
+ Linearized_Approximation (bool): If True, uses a linearized approximation for the scattering (Eq. 10 of Johnson & Narayan 2016). If False, uses Eq. 9 of that paper.
+ DisplayImage (bool): If True, show a plot of the unscattered, ensemble-average, and scattered images as well as the phase screen.
+ Force_Positivity (bool): If True, eliminates negative flux from the scattered image from the linearized approximation.
+ Return_Image_List (bool): If True, returns a list of the scattered frames. If False, returns a movie object.
+
+ Returns:
+ AI_Image (Image): The scattered image.
+ """
+
+ #Observing wavelength
+ if obs_frequency_Hz == 0.0:
+ obs_frequency_Hz = Unscattered_Image.rf
+
+ wavelength = C/obs_frequency_Hz*100.0 #Observing wavelength [cm]
+ wavelengthbar = wavelength/(2.0*np.pi) #lambda/(2pi) [cm]
+
+ #Derived parameters
+ FOV = Unscattered_Image.psize * Unscattered_Image.xdim * self.observer_screen_distance #Field of view, in cm, at the scattering screen
+ rF = self.rF(wavelength)
+ Nx = Unscattered_Image.xdim
+ Ny = Unscattered_Image.ydim
+
+ #First we need to calculate the ensemble-average image by blurring the unscattered image with the correct kernel
+ EA_Image = self.Ensemble_Average_Blur(Unscattered_Image, wavelength, ker = ea_ker, use_approximate_form=use_approximate_form)
+
+ # If no epsilon screen is specified, then generate a random realization
+ if Epsilon_Screen.shape[0] == 0:
+ Epsilon_Screen = MakeEpsilonScreen(Nx, Ny)
+
+ #We'll now calculate the phase screen.
+ phi_Image = self.MakePhaseScreen(Epsilon_Screen, Unscattered_Image, obs_frequency_Hz, Vx_km_per_s=Vx_km_per_s, Vy_km_per_s=Vy_km_per_s, t_hr=t_hr, sqrtQ_init=sqrtQ)
+ phi = phi_Image.imvec.reshape(Ny,Nx)
+
+ #Next, we need the gradient of the ensemble-average image
+ phi_Gradient = Wrapped_Gradient(phi/(FOV/Nx))
+ #The gradient signs don't actually matter, but let's make them match intuition (i.e., right to left, bottom to top)
+ phi_Gradient_x = -phi_Gradient[1]
+ phi_Gradient_y = -phi_Gradient[0]
+
+ if Linearized_Approximation == True: #Use Equation 10 of Johnson & Narayan (2016)
+ #Calculate the gradient of the ensemble-average image
+ EA_Gradient = Wrapped_Gradient((EA_Image.imvec/(FOV/Nx)).reshape(EA_Image.ydim, EA_Image.xdim))
+ #The gradient signs don't actually matter, but let's make them match intuition (i.e., right to left, bottom to top)
+ EA_Gradient_x = -EA_Gradient[1]
+ EA_Gradient_y = -EA_Gradient[0]
+ #Now we can patch together the average image
+ AI = (EA_Image.imvec).reshape(Ny,Nx) + rF**2.0 * ( EA_Gradient_x*phi_Gradient_x + EA_Gradient_y*phi_Gradient_y )
+ if len(Unscattered_Image.qvec):
+ # Scatter the Q image
+ EA_Gradient = Wrapped_Gradient((EA_Image.qvec/(FOV/Nx)).reshape(EA_Image.ydim, EA_Image.xdim))
+ EA_Gradient_x = -EA_Gradient[1]
+ EA_Gradient_y = -EA_Gradient[0]
+ AI_Q = (EA_Image.qvec).reshape(Ny,Nx) + rF**2.0 * ( EA_Gradient_x*phi_Gradient_x + EA_Gradient_y*phi_Gradient_y )
+ # Scatter the U image
+ EA_Gradient = Wrapped_Gradient((EA_Image.uvec/(FOV/Nx)).reshape(EA_Image.ydim, EA_Image.xdim))
+ EA_Gradient_x = -EA_Gradient[1]
+ EA_Gradient_y = -EA_Gradient[0]
+ AI_U = (EA_Image.uvec).reshape(Ny,Nx) + rF**2.0 * ( EA_Gradient_x*phi_Gradient_x + EA_Gradient_y*phi_Gradient_y )
+ if len(Unscattered_Image.vvec):
+ # Scatter the V image
+ EA_Gradient = Wrapped_Gradient((EA_Image.vvec/(FOV/Nx)).reshape(EA_Image.ydim, EA_Image.xdim))
+ EA_Gradient_x = -EA_Gradient[1]
+ EA_Gradient_y = -EA_Gradient[0]
+ AI_V = (EA_Image.vvec).reshape(Ny,Nx) + rF**2.0 * ( EA_Gradient_x*phi_Gradient_x + EA_Gradient_y*phi_Gradient_y )
+ else: #Use Equation 9 of Johnson & Narayan (2016)
+ EA_im = (EA_Image.imvec).reshape(Ny,Nx)
+ AI = np.copy((EA_Image.imvec).reshape(Ny,Nx))
+ if len(Unscattered_Image.qvec):
+ AI_Q = np.copy((EA_Image.imvec).reshape(Ny,Nx))
+ AI_U = np.copy((EA_Image.imvec).reshape(Ny,Nx))
+ EA_im_Q = (EA_Image.qvec).reshape(Ny,Nx)
+ EA_im_U = (EA_Image.uvec).reshape(Ny,Nx)
+ if len(Unscattered_Image.vvec):
+ AI_V = np.copy((EA_Image.imvec).reshape(Ny,Nx))
+ EA_im_V = (EA_Image.vvec).reshape(Ny,Nx)
+ for rx in range(Nx):
+ for ry in range(Ny):
+ # Annoyingly, the signs here must be negative to match the other approximation. I'm not sure which is correct, but it really shouldn't matter anyway because -phi has the same power spectrum as phi. However, getting the *relative* sign for the x- and y-directions correct is important.
+ rxp = int(np.round(rx - rF**2.0 * phi_Gradient_x[ry,rx]/self.observer_screen_distance/Unscattered_Image.psize))%Nx
+ ryp = int(np.round(ry - rF**2.0 * phi_Gradient_y[ry,rx]/self.observer_screen_distance/Unscattered_Image.psize))%Ny
+ AI[ry,rx] = EA_im[ryp,rxp]
+ if len(Unscattered_Image.qvec):
+ AI_Q[ry,rx] = EA_im_Q[ryp,rxp]
+ AI_U[ry,rx] = EA_im_U[ryp,rxp]
+ if len(Unscattered_Image.vvec):
+ AI_V[ry,rx] = EA_im_V[ryp,rxp]
+
+ #Optional: eliminate negative flux
+ if Force_Positivity == True:
+ AI = abs(AI)
+
+ #Make it into a proper image format
+ AI_Image = image.Image(AI, EA_Image.psize, EA_Image.ra, EA_Image.dec, rf=EA_Image.rf, source=EA_Image.source, mjd=EA_Image.mjd)
+ if len(Unscattered_Image.qvec):
+ AI_Image.add_qu(AI_Q, AI_U)
+ if len(Unscattered_Image.vvec):
+ AI_Image.add_v(AI_V)
+
+ if DisplayImage:
+ plot_scatt(Unscattered_Image.imvec, EA_Image.imvec, AI_Image.imvec, phi_Image.imvec, Unscattered_Image, 0, 0, ipynb=False)
+
+ return AI_Image
+
+ def Scatter_Movie(self, Unscattered_Movie, Epsilon_Screen=np.array([]), obs_frequency_Hz=0.0, Vx_km_per_s=50.0, Vy_km_per_s=0.0, framedur_sec=None, N_frames = None, ea_ker=None, sqrtQ=None, Linearized_Approximation=False, Force_Positivity=False, Return_Image_List=False, processes=0):
+ """Scatter a movie using the specified epsilon screen. The movie can either be a movie object, an image list, or a static image
+ If scattering a list of images or static image, the frame duration in seconds (framedur_sec) must be specified
+ If scattering a static image, the total number of frames must be specified (N_frames)
+ All lengths should be specified in centimeters
+ If the observing frequency (obs_frequency_Hz) is not specified, then it will be taken to be equal to the frequency of the Unscattered_Movie
+ Note: an odd image dimension is required!
+
+ Args:
+ Unscattered_Movie: This can be a movie object, an image list, or a static image
+ Epsilon_Screen (2D ndarray): Optionally, the scattering screen can be specified. If none is given, a random one will be generated.
+ obs_frequency_Hz (float): The observing frequency, in Hz. By default, it will be taken to be equal to the frequency of the Unscattered_Movie.
+ Vx_km_per_s (float): Velocity of the scattering screen in the x direction (toward East) in km/s.
+ Vy_km_per_s (float): Velocity of the scattering screen in the y direction (toward North) in km/s.
+ framedur_sec (float): Duration of each frame, in seconds. Only needed if Unscattered_Movie is not a movie object.
+ N_frames (int): Total number of frames. Only needed if Unscattered_Movie is a static image object.
+ ea_ker (2D ndarray): The used can optionally pass a precomputed array of the ensemble-average blurring kernel.
+ sqrtQ (2D ndarray): The used can optionally pass a precomputed array of the square root of the power spectrum.
+ Linearized_Approximation (bool): If True, uses a linearized approximation for the scattering (Eq. 10 of Johnson & Narayan 2016). If False, uses Eq. 9 of that paper.
+ Force_Positivity (bool): If True, eliminates negative flux from the scattered image from the linearized approximation.
+ Return_Image_List (bool): If True, returns a list of the scattered frames. If False, returns a movie object.
+ processes (int): Number of cores to use in multiprocessing. Default value (0) means no multiprocessing. Uses all available cores if processes < 0.
+
+ Returns:
+ Scattered_Movie: Either a movie object or a list of images, depending on the flag Return_Image_List.
+ """
+
+ print("Warning!! assuming a constant frame duration, but Movie objects now support unequally spaced frames!")
+
+ if type(Unscattered_Movie) != movie.Movie and framedur_sec is None:
+ print("If scattering a list of images or static image, the framedur must be specified!")
+ return
+
+ if type(Unscattered_Movie) == image.Image and N_frames is None:
+ print("If scattering a static image, the total number of frames must be specified (N_frames)!")
+ return
+
+ # time list in hr
+ if hasattr(Unscattered_Movie, 'times'):
+ tlist_hr = Unscattered_Movie.times
+ else:
+ tlist_hr = [framedur_sec/3600.0*j for j in range(N_frames)]
+
+ if type(Unscattered_Movie) == movie.Movie:
+ N = Unscattered_Movie.xdim
+ N_frames = len(Unscattered_Movie.frames)
+ psize = Unscattered_Movie.psize
+ ra = Unscattered_Movie.ra
+ dec = Unscattered_Movie.dec
+ rf = Unscattered_Movie.rf
+ pulse=Unscattered_Movie.pulse
+ source=Unscattered_Movie.source
+ mjd=Unscattered_Movie.mjd
+ start_hr=Unscattered_Movie.start_hr
+ has_pol = len(Unscattered_Movie.qframes)
+ has_circ_pol = len(Unscattered_Movie.vframes)
+ elif type(Unscattered_Movie) == list:
+ N = Unscattered_Movie[0].xdim
+ N_frames = len(Unscattered_Movie)
+ psize = Unscattered_Movie[0].psize
+ ra = Unscattered_Movie[0].ra
+ dec = Unscattered_Movie[0].dec
+ rf = Unscattered_Movie[0].rf
+ pulse=Unscattered_Movie[0].pulse
+ source=Unscattered_Movie[0].source
+ mjd=Unscattered_Movie[0].mjd
+ start_hr=0.0
+ has_pol = len(Unscattered_Movie[0].qvec)
+ has_circ_pol = len(Unscattered_Movie[0].vvec)
+ else:
+ N = Unscattered_Movie.xdim
+ psize = Unscattered_Movie.psize
+ ra = Unscattered_Movie.ra
+ dec = Unscattered_Movie.dec
+ rf = Unscattered_Movie.rf
+ pulse=Unscattered_Movie.pulse
+ source=Unscattered_Movie.source
+ mjd=Unscattered_Movie.mjd
+ start_hr=0.0
+ has_pol = len(Unscattered_Movie.qvec)
+ has_circ_pol = len(Unscattered_Movie.vvec)
+
+ def get_frame(j):
+ if type(Unscattered_Movie) == movie.Movie:
+ im = image.Image(Unscattered_Movie.frames[j].reshape((N,N)), psize=psize, ra=ra, dec=dec, rf=rf, pulse=pulse, source=source, mjd=mjd)
+ if len(Unscattered_Movie.qframes) > 0:
+ im.add_qu(Unscattered_Movie.qframes[j].reshape((N,N)), Unscattered_Movie.uframes[j].reshape((N,N)))
+ if len(Unscattered_Movie.vframes) > 0:
+ im.add_v(Unscattered_Movie.vframes[j].reshape((N,N)))
+ return im
+ elif type(Unscattered_Movie) == list:
+ return Unscattered_Movie[j]
+ else:
+ return Unscattered_Movie
+
+ #If it isn't specified, calculate the matrix sqrtQ for efficiency
+ if sqrtQ is None:
+ sqrtQ = self.sqrtQ_Matrix(get_frame(0))
+
+ # If no epsilon screen is specified, then generate a random realization
+ if Epsilon_Screen.shape[0] == 0:
+ Epsilon_Screen = MakeEpsilonScreen(N, N)
+
+ # possibly parallelize
+ if processes < 0:
+ processes = cpu_count()
+ processes = min(processes, N_frames)
+
+ # generate scattered images
+ if processes > 0:
+ pool = Pool(processes=processes)
+ args = [
+ (
+ [get_frame(j), Epsilon_Screen],
+ dict(obs_frequency_Hz = obs_frequency_Hz, Vx_km_per_s = Vx_km_per_s, Vy_km_per_s = Vy_km_per_s, t_hr=tlist_hr[j], sqrtQ=sqrtQ, Linearized_Approximation=Linearized_Approximation, Force_Positivity=Force_Positivity)
+ ) for j in range(N_frames)
+ ]
+ scattered_im_List = pool.starmap(self.Scatter2, args)
+ pool.close()
+ pool.join()
+ else:
+ scattered_im_List = [self.Scatter(get_frame(j), Epsilon_Screen, obs_frequency_Hz = obs_frequency_Hz, Vx_km_per_s = Vx_km_per_s, Vy_km_per_s = Vy_km_per_s, t_hr=tlist_hr[j], ea_ker=ea_ker, sqrtQ=sqrtQ, Linearized_Approximation=Linearized_Approximation, Force_Positivity=Force_Positivity) for j in range(N_frames)]
+
+ if Return_Image_List == True:
+ return scattered_im_List
+
+ Scattered_Movie = movie.Movie( [im.imvec.reshape((im.xdim,im.ydim)) for im in scattered_im_List],
+ times=tlist_hr, psize = psize, ra = ra, dec = dec, rf=rf, pulse=pulse, source=source, mjd=mjd)
+
+ if has_pol:
+ Scattered_Movie_Q = [im.qvec.reshape((im.xdim,im.ydim)) for im in scattered_im_List]
+ Scattered_Movie_U = [im.uvec.reshape((im.xdim,im.ydim)) for im in scattered_im_List]
+ Scattered_Movie.add_qu(Scattered_Movie_Q, Scattered_Movie_U)
+ if has_circ_pol:
+ Scattered_Movie_V = [im.vvec.reshape((im.xdim,im.ydim)) for im in scattered_im_List]
+ Scattered_Movie.add_v(Scattered_Movie_V)
+ return Scattered_Movie
+
+################################################################################
+# These are helper functions
+################################################################################
+
+def Wrapped_Convolve(sig,ker):
+ N = sig.shape[0]
+ return scipy.signal.fftconvolve(np.pad(sig,((N, N), (N, N)), 'wrap'), np.pad(ker,((N, N), (N, N)), 'constant'),mode='same')[N:(2*N),N:(2*N)]
+
+def Wrapped_Gradient(M):
+ G = np.gradient(np.pad(M,((1, 1), (1, 1)), 'wrap'))
+ Gx = G[0][1:-1,1:-1]
+ Gy = G[1][1:-1,1:-1]
+ return (Gx, Gy)
+
+def MakeEpsilonScreenFromList(EpsilonList, N):
+ epsilon = np.zeros((N,N),dtype=np.complex)
+ #If N is odd: there are (N^2-1)/2 real elements followed by their corresponding (N^2-1)/2 imaginary elements
+ #If N is even: there are (N^2+2)/2 of each, although 3 of these must be purely real, also giving a total of N^2-1 degrees of freedom
+ #This is because of conjugation symmetry in Fourier space to ensure a real Fourier transform
+
+ #The first (N-1)/2 are the top row
+ N_re = (N*N-1)//2 # FIXME: check logic if N is even
+ i = 0
+ for x in range(1,(N+1)//2): # FIXME: check logic if N is even
+ epsilon[0][x] = EpsilonList[i] + 1j * EpsilonList[i+N_re]
+ epsilon[0][N-x] = np.conjugate(epsilon[0][x])
+ i=i+1
+
+ #The next N(N-1)/2 are filling the next N rows
+ for y in range(1,(N+1)//2): # FIXME: check logic if N is even
+ for x in range(N):
+ epsilon[y][x] = EpsilonList[i] + 1j * EpsilonList[i+N_re]
+
+ x2 = N - x
+ y2 = N - y
+ if x2 == N:
+ x2 = 0
+ if y2 == N:
+ y2 = 0
+
+ epsilon[y2][x2] = np.conjugate(epsilon[y][x])
+ i=i+1
+
+ return epsilon
+
+def MakeEpsilonScreen(Nx, Ny, rngseed = 0):
+ """Create a standardized Fourier representation of a scattering screen
+
+ Args:
+ Nx (int): Number of pixels in the x direction
+ Ny (int): Number of pixels in the y direction
+ rngseed (int): Seed for the random number generator
+
+ Returns:
+ epsilon: A 2D numpy ndarray.
+ """
+
+ if rngseed != 0:
+ np.random.seed( rngseed )
+
+ epsilon = np.random.normal(loc=0.0, scale=1.0/math.sqrt(2), size=(Ny,Nx)) + 1j * np.random.normal(loc=0.0, scale=1.0/math.sqrt(2), size=(Ny,Nx))
+
+ # The zero frequency doesn't affect scattering
+ epsilon[0][0] = 0.0
+
+ #Now let's ensure that it has the necessary conjugation symmetry
+ if Nx%2 == 0:
+ epsilon[0][Nx//2] = np.real(epsilon[0][Nx//2])
+ if Ny%2 == 0:
+ epsilon[Ny//2][0] = np.real(epsilon[Ny//2][0])
+ if Nx%2 == 0 and Ny%2 == 0:
+ epsilon[Ny//2][Nx//2] = np.real(epsilon[Ny//2][Nx//2])
+
+ for x in range(Nx):
+ if x > (Nx-1)//2:
+ epsilon[0][x] = np.conjugate(epsilon[0][Nx-x])
+ for y in range((Ny-1)//2, Ny):
+ x2 = Nx - x
+ y2 = Ny - y
+ if x2 == Nx:
+ x2 = 0
+ if y2 == Ny:
+ y2 = 0
+ epsilon[y][x] = np.conjugate(epsilon[y2][x2])
+
+ return epsilon
+
+##################################################################################################
+# Plotting Functions
+##################################################################################################
+
+def plot_scatt(im_unscatt, im_ea, im_scatt, im_phase, Prior, nit, chi2, ipynb=False):
+ # Get vectors and ratio from current image
+ x = np.array([[i for i in range(Prior.xdim)] for j in range(Prior.ydim)])
+ y = np.array([[j for i in range(Prior.xdim)] for j in range(Prior.ydim)])
+
+ # Create figure and title
+ plt.ion()
+ plt.clf()
+ if chi2 > 0.0:
+ plt.suptitle("step: %i $\chi^2$: %f " % (nit, chi2), fontsize=20)
+
+ # Unscattered Image
+ plt.subplot(141)
+ plt.imshow(im_unscatt.reshape(Prior.ydim, Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian', vmin=0)
+ xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6)
+ yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ plt.title('Unscattered')
+
+ # Ensemble Average
+ plt.subplot(142)
+ plt.imshow(im_ea.reshape(Prior.ydim, Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian', vmin=0)
+ xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6)
+ yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ plt.title('Ensemble Average')
+
+ # Scattered
+ plt.subplot(143)
+ plt.imshow(im_scatt.reshape(Prior.ydim, Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian', vmin=0)
+ xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6)
+ yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ plt.title('Average Image')
+
+ # Phase
+ plt.subplot(144)
+ plt.imshow(im_phase.reshape(Prior.ydim, Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian')
+ xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6)
+ yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6)
+ plt.xticks(xticks[0], xticks[1])
+ plt.yticks(yticks[0], yticks[1])
+ plt.xlabel('Relative RA ($\mu$as)')
+ plt.ylabel('Relative Dec ($\mu$as)')
+ plt.title('Phase Screen')
+
+ # Display
+ plt.draw()
diff --git a/statistics/__init__.py b/statistics/__init__.py
new file mode 100644
index 00000000..aa83fb1a
--- /dev/null
+++ b/statistics/__init__.py
@@ -0,0 +1,12 @@
+"""
+.. module:: ehtim.stats
+ :platform: Unix
+ :synopsis: EHT Imaging Utilities: statistics and DataFrame format
+
+.. moduleauthor:: Maciek Wielgus (mwielgus@cfa.harvard.edu)
+
+"""
+from . import stats
+from . import dataframes
+
+
diff --git a/statistics/dataframes.py b/statistics/dataframes.py
new file mode 100644
index 00000000..81781872
--- /dev/null
+++ b/statistics/dataframes.py
@@ -0,0 +1,1044 @@
+# DataFrames.py
+# variety of statistical functions useful for
+#
+# Copyright (C) 2018 Maciek Wielgus
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+from builtins import str
+from builtins import map
+from builtins import range
+
+import numpy as np
+
+try:
+ import pandas as pd
+except ImportError:
+ print("Warning: pandas not installed!")
+ print("Please install pandas to use statistics package!")
+
+import datetime as datetime
+from astropy.time import Time
+from ehtim.statistics.stats import *
+
+def make_df(obs,polarization='unknown',band='unknown',round_s=0.1):
+
+ """converts visibilities from obs.data to DataFrame format
+
+ Args:
+ obs: ObsData object
+ round_s: accuracy of datetime object in seconds
+
+ Returns:
+ df: observation visibility data in DataFrame format
+ """
+ sour=obs.source
+ df = pd.DataFrame(data=obs.data)
+ df['fmjd'] = df['time']/24.
+ df['mjd'] = obs.mjd + df['fmjd']
+ telescopes = list(zip(df['t1'],df['t2']))
+ telescopes = [(x[0],x[1]) for x in telescopes]
+ df['baseline'] = [x[0]+'-'+x[1] for x in telescopes]
+ if obs.polrep=='stokes':
+ vis1='vis'; sig1='sigma'
+ elif obs.polrep=='circ':
+ vis1='rrvis'; sig1='rrsigma'
+ df['vis']=df[vis1]
+ df['sigma']=df[sig1]
+ df['rramp']=np.abs(df['rrvis'])
+ df['llamp']=np.abs(df['llvis'])
+ df['rlamp']=np.abs(df['rlvis'])
+ df['lramp']=np.abs(df['lrvis'])
+ df['rrsnr']=df['rramp']/df['rrsigma']
+ df['llsnr']=df['llamp']/df['llsigma']
+ df['rlsnr']=df['rlamp']/df['rlsigma']
+ df['lrsnr']=df['lramp']/df['lrsigma']
+ #df = df.dropna(subset=['rrvis', 'llvis','rrsigma','llsigma'])
+ df['amp'] = list(map(np.abs,df[vis1]))
+ df['phase'] = list(map(lambda x: (180./np.pi)*np.angle(x),df[vis1]))
+ df['snr'] = df['amp']/df[sig1]
+ df['datetime'] = Time(df['mjd'], format='mjd').datetime
+ df['datetime'] =list(map(lambda x: round_time(x,round_s=round_s),df['datetime']))
+ df['jd'] = Time(df['mjd'], format='mjd').jd
+ df['polarization'] = polarization
+ df['band'] = band
+ df['source'] = sour
+ df['baselength'] = np.sqrt(np.asarray(df.u)**2+np.asarray(df.v)**2)
+ return df
+
+
+def make_amp(obs,debias=True,polarization='unknown',band='unknown',round_s=0.1):
+
+ """converts visibilities from obs.data to amplitudes inDataFrame format
+
+ Args:
+ obs: ObsData object
+ debias (str): whether to debias the amplitudes
+ round_s: accuracy of datetime object in seconds
+
+ Returns:
+ df: observation visibility data in DataFrame format
+ """
+ sour=obs.source
+ df = pd.DataFrame(data=obs.data)
+ df['fmjd'] = df['time']/24.
+ df['mjd'] = obs.mjd + df['fmjd']
+ telescopes = list(zip(df['t1'],df['t2']))
+ telescopes = [(x[0],x[1]) for x in telescopes]
+ df['baseline'] = [x[0]+'-'+x[1] for x in telescopes]
+ df['amp'] = list(map(np.abs,df['vis']))
+ if debias==True:
+ amp2 = np.maximum(np.asarray(df['amp'])**2-np.asarray(df['sigma'])**2,np.asarray(df['sigma'])**2)
+ df['amp'] = np.sqrt(amp2)
+ df['phase'] = list(map(lambda x: (180./np.pi)*np.angle(x),df['vis']))
+ df['datetime'] = Time(df['mjd'], format='mjd').datetime
+ df['datetime'] =list(map(lambda x: round_time(x,round_s=round_s),df['datetime']))
+ df['jd'] = Time(df['mjd'], format='mjd').jd
+ df['polarization'] = polarization
+ df['band'] = band
+ df['snr'] = df['amp']/df['sigma']
+ df['source'] = sour
+ df['baselength'] = np.sqrt(np.asarray(df.u)**2+np.asarray(df.v)**2)
+ return df
+
+def coh_avg_vis(obs,dt=0,scan_avg=False,return_type='rec',err_type='predicted',num_samples=int(1e3)):
+ """coherently averages visibilities
+ Args:
+ obs: ObsData object
+ dt (float): integration time in seconds
+ return_type (str): 'rec' for numpy record array (as used by ehtim), 'df' for data frame
+ err_type (str): 'predicted' for modeled error, 'measured' for bootstrap empirical variability estimator
+ num_samples: 'bootstrap' resample set size for measured error
+ scan_avg (bool): should scan-long averaging be performed. If True, overrides dt
+ Returns:
+ vis_avg: coherently averaged visibilities
+ """
+ if (dt<=0)&(scan_avg==False):
+ return obs.data
+ else:
+ vis = make_df(obs)
+ if scan_avg==False:
+ #TODO
+ #we don't have to work on datetime products at all
+ #change it to only use 'time' in mjd
+ t0 = datetime.datetime(1960,1,1)
+ vis['round_time'] = list(map(lambda x: np.floor((x- t0).total_seconds()/float(dt)),vis.datetime))
+ grouping=['tau1','tau2','polarization','band','baseline','t1','t2','round_time']
+ else:
+ bins, labs = get_bins_labels(obs.scans)
+ vis['scan'] = list(pd.cut(vis.time, bins,labels=labs))
+ grouping=['tau1','tau2','polarization','band','baseline','t1','t2','scan']
+ #column just for counting the elements
+ vis['number'] = 1
+ aggregated = {'datetime': np.min, 'time': np.min,
+ 'number': lambda x: len(x), 'u':np.mean, 'v':np.mean,'tint': np.sum}
+
+ if err_type not in ['measured', 'predicted']:
+ print("Error type can only be 'predicted' or 'measured'! Assuming 'predicted'.")
+ err_type='predicted'
+
+ if obs.polrep=='stokes':
+ vis1='vis'; vis2='qvis'; vis3='uvis'; vis4='vvis'
+ sig1='sigma'; sig2='qsigma'; sig3='usigma'; sig4='vsigma'
+ elif obs.polrep=='circ':
+ vis1='rrvis'; vis2='llvis'; vis3='rlvis'; vis4='lrvis'
+ sig1='rrsigma'; sig2='llsigma'; sig3='rlsigma'; sig4='lrsigma'
+
+ #AVERAGING-------------------------------
+ if err_type=='measured':
+ vis['dummy'] = vis[vis1]
+ vis['qdummy'] = vis[vis2]
+ vis['udummy'] = vis[vis3]
+ vis['vdummy'] = vis[vis4]
+ meanF = lambda x: np.nanmean(np.asarray(x))
+ meanerrF = lambda x: bootstrap(np.abs(x), np.mean, num_samples=num_samples,wrapping_variable=False)
+ aggregated[vis1] = meanF
+ aggregated[vis2] = meanF
+ aggregated[vis3] = meanF
+ aggregated[vis4] = meanF
+ aggregated['dummy'] = meanerrF
+ aggregated['udummy'] = meanerrF
+ aggregated['vdummy'] = meanerrF
+ aggregated['qdummy'] = meanerrF
+
+ elif err_type=='predicted':
+ meanF = lambda x: np.nanmean(np.asarray(x))
+ #meanerrF = lambda x: bootstrap(np.abs(x), np.mean, num_samples=num_samples,wrapping_variable=False)
+ def meanerrF(x):
+ x = np.asarray(x)
+ x = x[x==x]
+
+ if len(x)>0: ret = np.sqrt(np.sum(x**2)/len(x)**2)
+ else: ret = np.nan +1j*np.nan
+ return ret
+
+ aggregated[vis1] = meanF
+ aggregated[vis2] = meanF
+ aggregated[vis3] = meanF
+ aggregated[vis4] = meanF
+ aggregated[sig1] = meanerrF
+ aggregated[sig2] = meanerrF
+ aggregated[sig3] = meanerrF
+ aggregated[sig4] = meanerrF
+
+ #ACTUAL AVERAGING
+ vis_avg = vis.groupby(grouping).agg(aggregated).reset_index()
+
+ if err_type=='measured':
+ vis_avg[sig1] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['dummy'])]
+ vis_avg[sig2] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['qdummy'])]
+ vis_avg[sig3] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['udummy'])]
+ vis_avg[sig4] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['vdummy'])]
+
+ vis_avg['amp'] = list(map(np.abs,vis_avg[vis1]))
+ vis_avg['phase'] = list(map(lambda x: (180./np.pi)*np.angle(x),vis_avg[vis1]))
+ vis_avg['snr'] = vis_avg['amp']/vis_avg[sig1]
+
+ if scan_avg==False:
+ #round datetime and time to the begining of the bucket and add half of a bucket time
+ half_bucket = dt/2.
+ vis_avg['datetime'] = list(map(lambda x: t0 + datetime.timedelta(seconds= int(dt*x) + half_bucket), vis_avg['round_time']))
+ vis_avg['time'] = list(map(lambda x: (Time(x).mjd-obs.mjd)*24., vis_avg['datetime']))
+ else:
+ #drop values that couldn't be matched to any scan
+ vis_avg.drop(list(vis_avg[vis_avg.scan<0].index.values),inplace=True)
+ if err_type=='measured':
+ vis_avg.drop(labels=['udummy','vdummy','qdummy','dummy'],axis='columns',inplace=True)
+ if return_type=='rec':
+ if obs.polrep=='stokes':
+ return df_to_rec(vis_avg,'vis')
+ elif obs.polrep=='circ':
+ return df_to_rec(vis_avg,'vis_circ')
+ elif return_type=='df':
+ return vis_avg
+
+
+def coh_moving_avg_vis(obs,dt=50,return_type='rec'):
+ """coherently averages visibilities with moving window
+ Args:
+ obs: ObsData object
+ dt (float): averaging window size in seconds
+ return_type (str): 'rec' for numpy record array (as used by ehtim), 'df' for data frame
+ Returns:
+ vis: coherently averaged visibilities on same grid
+ """
+ min_periods=1
+ if dt <= 0:
+ raise Exception('Time dt must be positive!')
+ if obs.polrep=='stokes':
+ vis1='vis'; vis2='qvis'; vis3='uvis'; vis4='vvis'
+ sig1='sigma'; sig2='qsigma'; sig3='usigma'; sig4='vsigma'
+ elif obs.polrep=='circ':
+ vis1='rrvis'; vis2='llvis'; vis3='rlvis'; vis4='lrvis'
+ sig1='rrsigma'; sig2='llsigma'; sig3='rlsigma'; sig4='lrsigma'
+
+ vis = make_df(obs)
+ vis = vis.sort_values(['baseline','datetime']).reset_index().copy()
+ #vis['total_seconds'] = list(map(lambda x: int(x.total_seconds()), vis['datetime'] - vis['datetime'].min()))
+ vis['total_seconds'] = [pd.Timestamp(x) for x in vis.datetime]
+ vis['roll_vis'] = list(zip(vis['total_seconds'],vis[vis1],vis[vis2],vis[vis3],vis[vis4],vis['datetime']))
+ vis['roll_sig'] = list(zip(vis['total_seconds'],vis[sig1],vis[sig2],vis[sig3],vis[sig4],vis['datetime']))
+
+ roll_vis_local = lambda x: roll_vis(x,dt=str(int(dt))+'s',min_periods=min_periods)
+ roll_sig_local = lambda x: roll_sig(x,dt=str(int(dt))+'s',min_periods=min_periods)
+ vis_avg_roll_vis = vis[['baseline','roll_vis']].groupby('baseline').transform(roll_vis_local)['roll_vis'].copy()
+ vis_avg_roll_sig = vis[['baseline','roll_sig']].groupby('baseline').transform(roll_sig_local)['roll_sig'].copy()
+
+ for cou,col in enumerate([vis1,vis2,vis3,vis4]):
+ vis[col] = [x[2*cou] + 1j*x[2*cou+1] for x in vis_avg_roll_vis]
+ for cou,col in enumerate([sig1,sig2,sig3,sig4]):
+ vis[col] = [x[cou] for x in vis_avg_roll_sig]
+ #shift to match with original data
+ vis.datetime = vis.datetime.apply(lambda x: x-datetime.timedelta(seconds=dt))
+ vis.time = vis.time - dt/2.0/3600.
+ if return_type=='rec':
+ if obs.polrep=='stokes':
+ return df_to_rec(vis.copy(),'vis')
+ elif obs.polrep=='circ':
+ return df_to_rec(vis.copy(),'vis_circ')
+ elif return_type=='df':
+ return vis.copy()
+
+
+def roll_vis(ser,dt='1s',min_periods=1):
+ """functtion helper for coh_moving_avg_vis
+ """
+ foo = pd.DataFrame({'REvis1': [np.real(x[1]) for x in ser],'IMvis1': [np.imag(x[1]) for x in ser],
+ 'REvis2': [np.real(x[2]) for x in ser],'IMvis2': [np.imag(x[2]) for x in ser],
+ 'REvis3': [np.real(x[3]) for x in ser],'IMvis3': [np.imag(x[3]) for x in ser],
+ 'REvis4': [np.real(x[4]) for x in ser],'IMvis4': [np.imag(x[4]) for x in ser]},
+ index=[x[0] for x in ser])
+ avg = foo.rolling(dt, min_periods=min_periods).mean()
+ avg_list = list(zip(avg['REvis1'],avg['IMvis1'],avg['REvis2'],avg['IMvis2'],avg['REvis3'],avg['IMvis3'],avg['REvis4'],avg['IMvis4'],[x[5] for x in ser]))
+ return avg_list
+
+def roll_sig(ser,dt='1s',min_periods=1):
+ """functtion helper for coh_moving_avg_vis
+ """
+ foo = pd.DataFrame({'sig1': [x[1]**2 for x in ser],'sig2': [x[2]**2 for x in ser],
+ 'sig3': [x[3]**2 for x in ser],'sig4': [x[4]**2 for x in ser]},
+ index=[x[0] for x in ser])
+ avg0 = foo.rolling(dt, min_periods=min_periods).mean()
+ sumSq = foo.rolling(dt, min_periods=min_periods).sum()
+ avg = pd.DataFrame({},index=[x[0] for x in ser])
+ avg['sig1'] = (avg0['sig1']**1.0)/(sumSq['sig1']**0.5)
+ avg['sig2'] = (avg0['sig2']**1.0)/(sumSq['sig2']**0.5)
+ avg['sig3'] = (avg0['sig3']**1.0)/(sumSq['sig3']**0.5)
+ avg['sig4'] = (avg0['sig4']**1.0)/(sumSq['sig4']**0.5)
+ avg_list = list(zip(avg['sig1'],avg['sig2'],avg['sig3'],avg['sig4'],[x[5] for x in ser]))
+ return avg_list
+
+def incoh_avg_vis(obs,dt=0,debias=True,scan_avg=False,return_type='rec',rec_type='vis',err_type='predicted',num_samples=int(1e3)):
+ """incoherently averages visibilities
+ Args:
+ obs: ObsData object
+ dt (float): integration time in seconds
+ return_type (str): 'rec' for numpy record array (as used by ehtim), 'df' for data frame
+ rec_type (str): 'vis' for DTPOL and 'amp' for DTAMP
+ err_type (str): 'predicted' for modeled error, 'measured' for bootstrap empirical variability estimator
+ num_samples: 'bootstrap' resample set size for measured error
+ scan_avg (bool): should scan-long averaging be performed. If True, overrides dt
+
+ Returns:
+ vis_avg: coherently averaged visibilities
+ """
+ if (dt<=0)&(scan_avg==False):
+ print('Either averaging time must be positive, or scan_avg option should be selected!')
+ return obs.data
+ else:
+ vis = make_df(obs)
+ if scan_avg==False:
+ #TODO
+ #we don't have to work on datetime products at all
+ #change it to only use 'time' in mjd
+ t0 = datetime.datetime(1960,1,1)
+ vis['round_time'] = list(map(lambda x: np.floor((x- t0).total_seconds()/float(dt)),vis.datetime))
+ grouping=['tau1','tau2','polarization','band','baseline','t1','t2','round_time']
+ else:
+ bins, labs = get_bins_labels(obs.scans)
+ vis['scan'] = list(pd.cut(vis.time, bins,labels=labs))
+ grouping=['tau1','tau2','polarization','band','baseline','t1','t2','scan']
+ #column just for counting the elements
+ vis['number'] = 1
+ aggregated = {'datetime': np.min, 'time': np.min,
+ 'number': lambda x: len(x), 'u':np.mean, 'v':np.mean,'tint': np.sum}
+
+ if err_type not in ['measured', 'predicted']:
+ print("Error type can only be 'predicted' or 'measured'! Assuming 'predicted'.")
+ err_type='predicted'
+
+ #AVERAGING-------------------------------
+ vis['dummy'] = list(zip(np.abs(vis['vis']),vis['sigma']))
+ vis['udummy'] = list(zip(np.abs(vis['uvis']),vis['usigma']))
+ vis['vdummy'] = list(zip(np.abs(vis['vvis']),vis['vsigma']))
+ vis['qdummy'] = list(zip(np.abs(vis['qvis']),vis['qsigma']))
+
+ if err_type=='predicted':
+ aggregated['dummy'] = lambda x: mean_incoh_avg(x,debias=debias)
+ aggregated['udummy'] = lambda x: mean_incoh_avg(x,debias=debias)
+ aggregated['vdummy'] = lambda x: mean_incoh_avg(x,debias=debias)
+ aggregated['qdummy'] = lambda x: mean_incoh_avg(x,debias=debias)
+
+ elif err_type=='measured':
+ aggregated['dummy'] = lambda x: bootstrap(np.abs(np.asarray([y[0] for y in x])), np.mean, num_samples=num_samples,wrapping_variable=False)
+ aggregated['udummy'] = lambda x: bootstrap(np.abs(np.asarray([y[0] for y in x])), np.mean, num_samples=num_samples,wrapping_variable=False)
+ aggregated['vdummy'] = lambda x: bootstrap(np.abs(np.asarray([y[0] for y in x])), np.mean, num_samples=num_samples,wrapping_variable=False)
+ aggregated['qdummy'] = lambda x: bootstrap(np.abs(np.asarray([y[0] for y in x])), np.mean, num_samples=num_samples,wrapping_variable=False)
+
+ #ACTUAL AVERAGING
+ vis_avg = vis.groupby(grouping).agg(aggregated).reset_index()
+
+ if err_type=='predicted':
+ vis_avg['vis'] = [x[0] for x in list(vis_avg['dummy'])]
+ vis_avg['uvis'] = [x[0] for x in list(vis_avg['udummy'])]
+ vis_avg['qvis'] = [x[0] for x in list(vis_avg['qdummy'])]
+ vis_avg['vvis'] = [x[0] for x in list(vis_avg['vdummy'])]
+ vis_avg['sigma'] = [x[1] for x in list(vis_avg['dummy'])]
+ vis_avg['usigma'] = [x[1] for x in list(vis_avg['udummy'])]
+ vis_avg['qsigma'] = [x[1] for x in list(vis_avg['qdummy'])]
+ vis_avg['vsigma'] = [x[1] for x in list(vis_avg['vdummy'])]
+
+ elif err_type=='measured':
+ vis_avg['vis'] = [x[0] for x in list(vis_avg['dummy'])]
+ vis_avg['uvis'] = [x[0] for x in list(vis_avg['udummy'])]
+ vis_avg['qvis'] = [x[0] for x in list(vis_avg['qdummy'])]
+ vis_avg['vvis'] = [x[0] for x in list(vis_avg['vdummy'])]
+ vis_avg['sigma'] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['dummy'])]
+ vis_avg['usigma'] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['udummy'])]
+ vis_avg['qsigma'] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['qdummy'])]
+ vis_avg['vsigma'] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['vdummy'])]
+
+ vis_avg['amp'] = list(map(np.abs,vis_avg['vis']))
+ vis_avg['phase'] = 0
+ vis_avg['snr'] = vis_avg['amp']/vis_avg['sigma']
+ if scan_avg==False:
+ #round datetime and time to the begining of the bucket and add half of a bucket time
+ half_bucket = dt/2.
+ vis_avg['datetime'] = list(map(lambda x: t0 + datetime.timedelta(seconds= int(dt*x) + half_bucket), vis_avg['round_time']))
+ vis_avg['time'] = list(map(lambda x: (Time(x).mjd-obs.mjd)*24., vis_avg['datetime']))
+ else:
+ #drop values that couldn't be matched to any scan
+ vis_avg.drop(list(vis_avg[vis_avg.scan<0].index.values),inplace=True)
+
+ vis_avg.drop(labels=['udummy','vdummy','qdummy','dummy'],axis='columns',inplace=True)
+ if return_type=='rec':
+ return df_to_rec(vis_avg,rec_type)
+ elif return_type=='df':
+ return vis_avg
+
+
+def make_cphase_df(obs,band='unknown',polarization='unknown',mode='all',count='max',round_s=0.1,snrcut=0.,uv_min=False):
+
+ """generate DataFrame of closure phases
+
+ Args:
+ obs: ObsData object
+ round_s: accuracy of datetime object in seconds
+
+ Returns:
+ df: closure phase data in DataFrame format
+ """
+
+ data=obs.c_phases(mode=mode,count=count,snrcut=snrcut,uv_min=uv_min)
+ sour=obs.source
+ df = pd.DataFrame(data=data).copy()
+ df['fmjd'] = df['time']/24.
+ df['mjd'] = obs.mjd + df['fmjd']
+ df['triangle'] = list(map(lambda x: x[0]+'-'+x[1]+'-'+x[2],zip(df['t1'],df['t2'],df['t3'])))
+ df['datetime'] = Time(df['mjd'], format='mjd').datetime
+ df['datetime'] =list(map(lambda x: round_time(x,round_s=round_s),df['datetime']))
+ df['jd'] = Time(df['mjd'], format='mjd').jd
+ df['polarization'] = polarization
+ df['band'] = band
+ df['source'] = sour
+ return df
+
+def make_cphase_diag_df(obs,vtype='vis',band='unknown',polarization='unknown',count='min',round_s=0.1,snrcut=0.,uv_min=False):
+
+ """generate DataFrame of diagonalized closure phases
+
+ Args:
+ obs: ObsData object
+ round_s: accuracy of datetime object in seconds
+
+ Returns:
+ df: diagonalized closure phase data in DataFrame format
+ """
+
+ data=obs.c_phases_diag(vtype=vtype,count=count,snrcut=snrcut,uv_min=uv_min)
+ sour=obs.source
+
+ tarr = []
+ dcparr = []
+ dcperrarr = []
+ triangle_arr = []
+ u_arr = []
+ v_arr = []
+ tform_arr = []
+ for d in data:
+ tarr.append(list(d[0]['time']))
+ dcparr.append(list(d[0]['cphase']))
+ dcperrarr.append(list(d[0]['sigmacp']))
+
+ triarr = []
+ u = []
+ v = []
+ for ant in d[1]:
+ triarr.append(ant[0][0]+'-'+ant[1][0]+'-'+ant[2][0])
+ for iu in d[2]:
+ u.append((iu[0][0],iu[1][0],iu[2][0]))
+ for iv in d[3]:
+ v.append((iv[0][0],iv[1][0],iv[2][0]))
+
+ for i in range(len(d[0])):
+ triangle_arr.append(triarr)
+ u_arr.append(u)
+ v_arr.append(v)
+ tform_arr.append(list(d[4].view('f8')))
+
+ df = pd.DataFrame()
+ df['time'] = np.concatenate(tarr)
+ df['cphase'] = np.concatenate(dcparr)
+ df['sigmacp'] = np.concatenate(dcperrarr)
+ df['triangles'] = triangle_arr
+ df['u'] = u_arr
+ df['v'] = v_arr
+ df['tform_matrix'] = tform_arr
+
+ df['fmjd'] = df['time']/24.
+ df['mjd'] = obs.mjd + df['fmjd']
+ df['datetime'] = Time(df['mjd'], format='mjd').datetime
+ df['datetime'] = list(map(lambda x: round_time(x,round_s=round_s),df['datetime']))
+ df['jd'] = Time(df['mjd'], format='mjd').jd
+ df['polarization'] = polarization
+ df['band'] = band
+ df['source'] = sour
+ return df
+
+def make_camp_df(obs,ctype='logcamp',debias=False,band='unknown',polarization='unknown',mode='all',count='max',round_s=0.1,snrcut=0.):
+
+ """generate DataFrame of closure amplitudes
+
+ Args:
+ obs: ObsData object
+ round_s: accuracy of datetime object in seconds
+
+ Returns:
+ df: closure amplitude data in DataFrame format
+ """
+
+ data = obs.c_amplitudes(mode=mode,count=count,debias=debias,ctype=ctype,snrcut=snrcut)
+ sour=obs.source
+ df = pd.DataFrame(data=data).copy()
+ df['fmjd'] = df['time']/24.
+ df['mjd'] = obs.mjd + df['fmjd']
+ df['quadrangle'] = list(map(lambda x: x[0]+'-'+x[1]+'-'+x[2]+'-'+x[3],zip(df['t1'],df['t2'],df['t3'],df['t4'])))
+ df['datetime'] = Time(df['mjd'], format='mjd').datetime
+ df['datetime'] =list(map(lambda x: round_time(x,round_s=round_s),df['datetime']))
+ df['jd'] = Time(df['mjd'], format='mjd').jd
+ df['polarization'] = polarization
+ df['band'] = band
+ df['source'] = sour
+ df['catype'] = ctype
+ return df
+
+def make_logcamp_diag_df(obs,debias=True,band='unknown',polarization='unknown',mode='all',count='min',round_s=0.1,snrcut=0.):
+
+ """generate DataFrame of closure amplitudes
+
+ Args:
+ obs: ObsData object
+ round_s: accuracy of datetime object in seconds
+
+ Returns:
+ df: closure amplitude data in DataFrame format
+ """
+
+ data = obs.c_log_amplitudes_diag(mode=mode,count=count,debias=debias,snrcut=snrcut)
+ sour=obs.source
+
+ tarr = []
+ dlcaarr = []
+ dlcaerrarr = []
+ quadrangle_arr = []
+ u_arr = []
+ v_arr = []
+ tform_arr = []
+ for d in data:
+ tarr.append(list(d[0]['time']))
+ dlcaarr.append(list(d[0]['camp']))
+ dlcaerrarr.append(list(d[0]['sigmaca']))
+
+ quadarr = []
+ u = []
+ v = []
+ for ant in d[1]:
+ quadarr.append(ant[0][0]+'-'+ant[1][0]+'-'+ant[2][0]+'-'+ant[3][0])
+ for iu in d[2]:
+ u.append((iu[0][0],iu[1][0],iu[2][0],iu[3][0]))
+ for iv in d[3]:
+ v.append((iv[0][0],iv[1][0],iv[2][0],iv[3][0]))
+
+ for i in range(len(d[0])):
+ quadrangle_arr.append(quadarr)
+ u_arr.append(u)
+ v_arr.append(v)
+ tform_arr.append(list(d[4].view('f8')))
+
+ df = pd.DataFrame()
+ df['time'] = np.concatenate(tarr)
+ df['camp'] = np.concatenate(dlcaarr)
+ df['sigmaca'] = np.concatenate(dlcaerrarr)
+ df['quadrangles'] = quadrangle_arr
+ df['u'] = u_arr
+ df['v'] = v_arr
+ df['tform_matrix'] = tform_arr
+
+ df['fmjd'] = df['time']/24.
+ df['mjd'] = obs.mjd + df['fmjd']
+ df['datetime'] = Time(df['mjd'], format='mjd').datetime
+ df['datetime'] = list(map(lambda x: round_time(x,round_s=round_s),df['datetime']))
+ df['jd'] = Time(df['mjd'], format='mjd').jd
+ df['polarization'] = polarization
+ df['band'] = band
+ df['source'] = sour
+ return df
+
+def make_bsp_df(obs,band='unknown',polarization='unknown',mode='all',count='min',round_s=0.1,snrcut=0., uv_min=False):
+
+ """generate DataFrame of bispectra
+
+ Args:
+ obs: ObsData object
+ round_s: accuracy of datetime object in seconds
+
+ Returns:
+ df: bispectra data in DataFrame format
+ """
+
+ data = obs.bispectra(mode=mode,count=count,snrcut=snrcut,uv_min=uv_min)
+ sour=obs.source
+ df = pd.DataFrame(data=data).copy()
+ df['fmjd'] = df['time']/24.
+ df['mjd'] = obs.mjd + df['fmjd']
+ df['triangle'] = list(map(lambda x: x[0]+'-'+x[1]+'-'+x[2],zip(df['t1'],df['t2'],df['t3'])))
+ df['datetime'] = Time(df['mjd'], format='mjd').datetime
+ df['datetime'] =list(map(lambda x: round_time(x,round_s=round_s),df['datetime']))
+ df['jd'] = Time(df['mjd'], format='mjd').jd
+ df['polarization'] = polarization
+ df['band'] = band
+ df['source'] = sour
+ return df
+
+def average_cphases(cdf,dt,return_type='rec',err_type='predicted',num_samples=1000,snrcut=0.):
+
+ """averages DataFrame of cphases
+
+ Args:
+ cdf: data frame of closure phases
+ dt: integration time in seconds
+ return_type: 'rec' for numpy record array (as used by ehtim), 'df' for data frame
+ err_type: 'predicted' for modeled error, 'measured' for bootstrap empirical variability estimator
+
+ Returns:
+ cdf2: averaged closure phases
+ """
+
+ cdf2 = cdf.copy()
+ t0 = datetime.datetime(1960,1,1)
+ cdf2['round_time'] = list(map(lambda x: np.round((x- t0).total_seconds()/float(dt)),cdf2.datetime))
+ grouping=['polarization','band','triangle','t1','t2','t3','round_time']
+ #column just for counting the elements
+ cdf2['number'] = 1
+ aggregated = {'datetime': np.min, 'time': np.mean,
+ 'number': lambda x: len(x), 'u1':np.mean, 'u2': np.mean, 'u3':np.mean,'v1':np.mean, 'v2': np.mean, 'v3':np.mean}
+
+ #AVERAGING-------------------------------
+ if err_type=='measured':
+ cdf2['dummy'] = cdf2['cphase']
+ aggregated['dummy'] = lambda x: bootstrap(x, circular_mean, num_samples=num_samples,wrapping_variable=True)
+ elif err_type=='predicted':
+ aggregated['cphase'] = circular_mean
+ aggregated['sigmacp'] = lambda x: np.sqrt(np.sum(x**2)/len(x)**2)
+ else:
+ print("Error type can only be 'predicted' or 'measured'! Assuming 'predicted'.")
+ aggregated['cphase'] = circular_mean
+ aggregated['sigmacp'] = lambda x: np.sqrt(np.sum(x**2)/len(x)**2)
+
+ #ACTUAL AVERAGING
+ cdf2 = cdf2.groupby(grouping).agg(aggregated).reset_index()
+
+ if err_type=='measured':
+ cdf2['cphase'] = [x[0] for x in list(cdf2['dummy'])]
+ cdf2['sigmacp'] = [0.5*(x[1][1]-x[1][0]) for x in list(cdf2['dummy'])]
+
+ # snrcut
+ # CHECK
+ if snrcut==0:
+ snrcut=EP
+ cdf2 = cdf2[cdf2['sigmacp'] < 180./np.pi/snrcut].copy() # TODO CHECK
+
+ #round datetime
+ cdf2['datetime'] = list(map(lambda x: t0 + datetime.timedelta(seconds= int(dt*x)), cdf2['round_time']))
+
+ #ANDREW TODO-- this can lead to big problems!!
+ #drop values averaged from less than 3 datapoints
+ #cdf2.drop(cdf2[cdf2.number < 3.].index, inplace=True)
+ if return_type=='rec':
+ return df_to_rec(cdf2,'cphase')
+ elif return_type=='df':
+ return cdf2
+
+
+def average_bispectra(cdf,dt,return_type='rec',num_samples=int(1e3), snrcut=0.):
+
+ """averages DataFrame of bispectra
+
+ Args:
+ cdf: data frame of bispectra
+ dt: integration time in seconds
+ return_type: 'rec' for numpy record array (as used by ehtim), 'df' for data frame
+
+ Returns:
+ cdf2: averaged bispectra
+ """
+
+ cdf2 = cdf.copy()
+ t0 = datetime.datetime(1960,1,1)
+ cdf2['round_time'] = list(map(lambda x: np.round((x- t0).total_seconds()/float(dt)),cdf2.datetime))
+ grouping=['polarization','band','triangle','t1','t2','t3','round_time']
+ #column just for counting the elements
+ cdf2['number'] = 1
+ aggregated = {'datetime': np.min, 'time': np.mean,
+ 'number': lambda x: len(x), 'u1':np.mean, 'u2': np.mean, 'u3':np.mean,'v1':np.mean, 'v2': np.mean, 'v3':np.mean}
+
+ #AVERAGING-------------------------------
+ aggregated['bispec'] = np.mean
+ aggregated['sigmab'] = lambda x: np.sqrt(np.sum(x**2)/len(x)**2)
+
+ #ACTUAL AVERAGING
+ cdf2 = cdf2.groupby(grouping).agg(aggregated).reset_index()
+
+ # snrcut
+ cdf2 = cdf2[np.abs(cdf2['bispec']/cdf2['sigmab']) > snrcut].copy() # TODO CHECK
+
+ #round datetime
+ cdf2['datetime'] = list(map(lambda x: t0 + datetime.timedelta(seconds= int(dt*x)), cdf2['round_time']))
+
+ #ANDREW TODO -- this can lead to big problems!!
+ #drop values averaged from less than 3 datapoints
+ #cdf2.drop(cdf2[cdf2.number < 3.].index, inplace=True)
+ if return_type=='rec':
+ return df_to_rec(cdf2,'bispec')
+ elif return_type=='df':
+ return cdf2
+
+
+def average_camp(cdf,dt,return_type='rec',err_type='predicted',num_samples=int(1e3)):
+ #TODO: SNRCUT?
+ """averages DataFrame of closure amplitudes
+
+ Args:
+ cdf: data frame of closure amplitudes
+ dt: integration time in seconds
+ return_type: 'rec' for numpy record array (as used by ehtim), 'df' for data frame
+ err_type: 'predicted' for modeled error, 'measured' for bootstrap empirical variability estimator
+
+ Returns:
+ cdf2: averaged closure amplitudes
+ """
+
+ cdf2 = cdf.copy()
+ t0 = datetime.datetime(1960,1,1)
+ cdf2['round_time'] = list(map(lambda x: np.round((x- t0).total_seconds()/float(dt)),cdf2.datetime))
+ grouping=['polarization','band','quadrangle','t1','t2','t3','t4','round_time']
+ #column just for counting the elements
+ cdf2['number'] = 1
+ aggregated = {'datetime': np.min, 'time': np.mean,
+ 'number': lambda x: len(x), 'u1':np.mean, 'u2': np.mean, 'u3':np.mean, 'u4': np.mean, 'v1':np.mean, 'v2': np.mean, 'v3':np.mean,'v4':np.mean}
+
+ #AVERAGING-------------------------------
+ if err_type=='measured':
+ cdf2['dummy'] = cdf2['camp']
+ aggregated['dummy'] = lambda x: bootstrap(x, np.mean, num_samples=num_samples,wrapping_variable=False)
+ elif err_type=='predicted':
+ aggregated['camp'] = np.mean
+ aggregated['sigmaca'] = lambda x: np.sqrt(np.sum(x**2)/len(x)**2)
+ else:
+ print("Error type can only be 'predicted' or 'measured'! Assuming 'predicted'.")
+ aggregated['camp'] = np.mean
+ aggregated['sigmaca'] = lambda x: np.sqrt(np.sum(x**2)/len(x)**2)
+
+ #ACTUAL AVERAGING
+ cdf2 = cdf2.groupby(grouping).agg(aggregated).reset_index()
+
+ if err_type=='measured':
+ cdf2['camp'] = [x[0] for x in list(cdf2['dummy'])]
+ cdf2['sigmaca'] = [0.5*(x[1][1]-x[1][0]) for x in list(cdf2['dummy'])]
+
+ #round datetime
+ cdf2['datetime'] = list(map(lambda x: t0 + datetime.timedelta(seconds= int(dt*x)), cdf2['round_time']))
+
+ #ANDREW TODO -- this can lead to big problems!!
+ #drop values averaged from less than 3 datapoints
+ #cdf2.drop(cdf2[cdf2.number < 3.].index, inplace=True)
+ if return_type=='rec':
+ return df_to_rec(cdf2,'camp')
+ elif return_type=='df':
+ return cdf2
+
+def df_to_rec(df,product_type):
+
+ """converts DataFrame to numpy recarray used by ehtim
+
+ Args:
+ df: DataFrame to convert
+ product_type: vis, cphase, camp, amp, bispec, cphase_diag, logcamp_diag
+ """
+ if product_type=='cphase':
+ out= df[['time','t1','t2','t3','u1','v1','u2','v2','u3','v3','cphase','sigmacp']].to_records(index=False)
+ return np.array(out,dtype=DTCPHASE)
+ elif product_type=='camp':
+ out= df[['time','t1','t2','t3','t4','u1','v1','u2','v2','u3','v3','u4','v4','camp','sigmaca']].to_records(index=False)
+ return np.array(out,dtype=DTCAMP)
+ elif product_type=='vis':
+ out= df[['time','tint','t1','t2','tau1','tau2','u','v','vis','qvis','uvis','vvis','sigma','qsigma','usigma','vsigma']].to_records(index=False)
+ return np.array(out,dtype=DTPOL_STOKES)
+ elif product_type=='vis_circ':
+ out= df[['time','tint','t1','t2','tau1','tau2','u','v','rrvis','llvis','rlvis','lrvis','rrsigma','llsigma','rlsigma','lrsigma']].to_records(index=False)
+ return np.array(out,dtype=DTPOL_CIRC)
+ elif product_type=='amp':
+ out= df[['time','tint','t1','t2','u','v','amp','sigma']].to_records(index=False)
+ return np.array(out,dtype=DTAMP)
+ elif product_type=='bispec':
+ out= df[['time','t1','t2','t3','u1','v1','u2','v2','u3','v3','bispec','sigmab']].to_records(index=False)
+ return np.array(out,dtype=DTBIS)
+ elif product_type=='cphase_diag':
+ out= df[['time','cphase','sigmacp','triangles','u','v','tform_matrix']].to_records(index=False)
+ return np.array(out,dtype=DTCPHASEDIAG)
+ elif product_type=='logcamp_diag':
+ out= df[['time','camp','sigmaca','quadrangles','u','v','tform_matrix']].to_records(index=False)
+ return np.array(out,dtype=DTLOGCAMPDIAG)
+
+
+def round_time(t,round_s=0.1):
+
+ """rounding time to given accuracy
+
+ Args:
+ t: time
+ round_s: delta time to round to in seconds
+
+ Returns:
+ round_t: rounded time
+ """
+ t0 = datetime.datetime(t.year,1,1)
+ foo = t - t0
+ foo_s = foo.days*24*3600 + foo.seconds + foo.microseconds*(1e-6)
+ foo_s = np.round(foo_s/round_s)*round_s
+ days = np.floor(foo_s/24/3600)
+ seconds = np.floor(foo_s - 24*3600*days)
+ microseconds = int(1e6*(foo_s - days*3600*24 - seconds))
+ round_t = t0+datetime.timedelta(days,seconds,microseconds)
+ return round_t
+
+def get_bins_labels(intervals,dt=0.00001):
+ '''gets bins and labels necessary to perform averaging by scan
+ Args:
+ intervals:
+ dt (float): time margin to add to the scan limits
+ '''
+
+ def fix_midnight_overlap(x):
+ if x[1] < x[0]:
+ x[1]+= 24.
+ return x
+
+ def is_overlapping(interval0,interval1):
+ if ((interval1[0]<=interval0[0])&(interval1[1]>=interval0[0]))|((interval1[0]<=interval0[1])&(interval1[1]>=interval0[1])):
+ return True
+ else: return False
+
+ def merge_overlapping_intervals(intervals):
+ return (np.min([x[0] for x in intervals]),np.max([x[1] for x in intervals]))
+
+ def replace_overlapping_intervals(intervals,element_ind):
+ indic_not_overlap=[not is_overlapping(x,intervals[element_ind]) for x in intervals]
+ indic_overlap=[is_overlapping(x,intervals[element_ind]) for x in intervals]
+ fooarr=np.asarray(intervals)
+ return sorted([tuple(x) for x in fooarr[indic_not_overlap]]+[merge_overlapping_intervals(list(fooarr[indic_overlap]))])
+
+ intervals = sorted(list(set(zip(intervals[:,0],intervals[:,1]))))
+ intervals = [fix_midnight_overlap(x) for x in intervals]
+ cou=0
+ while cou < len(intervals):
+ intervals = replace_overlapping_intervals(intervals,cou)
+ cou+=1
+
+ binsT=[None]*(2*np.shape(intervals)[0])
+ binsT[::2] = [x[0]-dt for x in intervals]
+ binsT[1::2] = [x[1]+dt for x in intervals]
+ labels=[None]*(2*np.shape(intervals)[0]-1)
+ labels[::2] = [cou for cou in range(1,len(intervals)+1)]
+ labels[1::2] = [-cou for cou in range(1,len(intervals))]
+
+ return binsT, labels
+
+def common_set(obs1, obs2, tolerance = 0,uniquely=False, by_what='uvdist'):
+ '''
+ Selects common subset of obs1, obs2 data
+ tolerance: time tolerance to accept common subsets [s] if by_what = 'ut' or in [h] if 'gmst'
+ or u,v tolerance in lambdas if by_what='uvdist'
+ uniquely: whether matching single value to single value
+ by_what: matching common sets by ut time 'ut' or by uvdistance 'uvdist' or by 'gmst'
+ '''
+ if obs1.polrep!=obs2.polrep:
+ raise ValueError('Observations must be in the same polrep!')
+ #make a dataframe with visibilities
+ #tolerance in seconds
+ df1 = make_df(obs1)
+ df2 = make_df(obs2)
+
+ #we need this to match baselines with possibly different station orders between the pipelines
+ df1['ta'] = list(map(lambda x: sorted(x)[0],zip(df1.t1,df1.t2)))
+ df1['tb'] = list(map(lambda x: sorted(x)[1],zip(df1.t1,df1.t2)))
+ df2['ta'] = list(map(lambda x: sorted(x)[0],zip(df2.t1,df2.t2)))
+ df2['tb'] = list(map(lambda x: sorted(x)[1],zip(df2.t1,df2.t2)))
+
+ if by_what=='ut':
+ if tolerance>0:
+ d_mjd = tolerance/24.0/60.0/60.0
+ df1['roundtime']=np.round(df1.mjd/d_mjd)
+ df2['roundtime']=np.round(df2.mjd/d_mjd)
+ else:
+ df1['roundtime'] = df1['time']
+ df2['roundtime'] = df2['time']
+ #matching data
+ df1,df2 = match_multiple_frames([df1.copy(),df2.copy()],['ta','tb','roundtime'],uniquely=uniquely)
+
+ elif by_what=='gmst':
+ df1 = add_gmst(df1)
+ df2 = add_gmst(df2)
+ if tolerance>0:
+ d_gmst = tolerance
+ df1['roundgmst']=np.round(df1.gmst/d_gmst)
+ df2['roundgmst']=np.round(df2.gmst/d_gmst)
+ else:
+ df1['roundgmst'] = df1['gmst']
+ df2['roundgmst'] = df2['gmst']
+ #matching data
+ df1,df2 = match_multiple_frames([df1.copy(),df2.copy()],['ta','tb','roundgmst'],uniquely=uniquely)
+
+ elif by_what=='uvdist':
+ if tolerance>0:
+ d_lambda = tolerance
+ df1['roundu'] = np.round(df1.u/d_lambda)
+ df1['roundv'] = np.round(df1.v/d_lambda)
+ df2['roundu'] = np.round(df2.u/d_lambda)
+ df2['roundv'] = np.round(df2.v/d_lambda)
+ else:
+ df1['roundu'] = df1['u']
+ df1['roundv'] = df1['v']
+ df2['roundu'] = df2['u']
+ df2['roundv'] = df2['v']
+ #matching data
+ df1,df2 = match_multiple_frames([df1.copy(),df2.copy()],['ta','tb','roundu','roundv'],uniquely=uniquely)
+
+ #replace visibility data with common subset
+ obs1cut = obs1.copy()
+ obs2cut = obs2.copy()
+ if obs1.polrep=='stokes':
+ obs1cut.data = df_to_rec(df1,'vis')
+ obs2cut.data = df_to_rec(df2,'vis')
+ elif obs1.polrep=='circ':
+ obs1cut.data = df_to_rec(df1,'vis_circ')
+ obs2cut.data = df_to_rec(df2,'vis_circ')
+
+ return obs1cut,obs2cut
+
+"""
+def common_multiple_sets(obsL, tolerance = 0,uniquely=False, by_what='uvdist'):
+ '''
+ Selects common subset of obs1, obs2 data
+ tolerance: time tolerance to accept common subsets [s] if by_what = 'ut' or 'gmst'
+ or u,v tolerance in lambdas if by_what='uvdist'
+ uniquely: whether matching single value to single value
+ by_what: matching common sets by ut time 'ut' or by uvdistance 'uvdist' or by 'gmst'
+ '''
+ polrepL = list(set([obs.polrep for obs in obsL]))
+ if len(polrepL)>1:
+ raise ValueError('Observations must be in the same polrep!')
+ #make a dataframe with visibilities
+ #tolerance in seconds
+ dfL = [make_df(obs) for obs in obsL]
+
+ #we need this to match baselines with possibly different station orders between the pipelines
+ for df in dfL:
+ df['ta'] = list(map(lambda x: sorted(x)[0],zip(df.t1,df.t2)))
+ df['tb'] = list(map(lambda x: sorted(x)[1],zip(df.t1,df.t2)))
+
+ if by_what=='ut':
+ if tolerance>0:
+ d_mjd = tolerance/24.0/60.0/60.0
+ for df in dfL: df['roundtime']=np.round(df.mjd/d_mjd)
+ else:
+ for df in dfL: df['roundtime']=df['time']
+ #matching data
+ dfcout = match_multiple_frames(dfL,['ta','tb','roundtime'],uniquely=uniquely)
+
+ elif by_what=='gmst':
+ dfL = [add_gmst(df) for df in dfL]
+ if tolerance>0:
+ d_gmst = tolerance
+ for df in dfL: df['roundgmst']=np.round(df.gmst/d_gmst)
+ else:
+ for df in dfL: df['roundgmst']= df['gmst']
+ #matching data
+ dfcut = match_multiple_frames([df1.copy(),df2.copy()],['ta','tb','roundgmst'],uniquely=uniquely)
+
+
+ elif by_what=='uvdist':
+ if tolerance>0:
+ d_lambda = tolerance
+ for df in dfL: df['roundu'] = np.round(df.u/d_lambda)
+ for df in dfL: df['roundv'] = np.round(df.v/d_lambda)
+ else:
+ for df in dfL: df['roundu'] = df['u']
+ for df in dfL: df['roundv'] = df['v']
+ #matching data
+ dfcut = match_multiple_frames([df1.copy(),df2.copy()],['ta','tb','roundu','roundv'],uniquely=uniquely)
+
+ #replace visibility data with common subset
+ obscutL = [obs.copy() for obs in obsL]
+
+ if obs1.polrep=='stokes':
+
+ for obscut in obscutL: obscut = df_to_rec(df1,'vis')
+ obs2cut.data = df_to_rec(df2,'vis')
+ elif obs1.polrep=='circ':
+ obs1cut.data = df_to_rec(df1,'vis_circ')
+ obs2cut.data = df_to_rec(df2,'vis_circ')
+
+ return obscut_list
+"""
+
+def match_multiple_frames(frames, what_is_same, dt = 0,uniquely=True):
+
+ if dt > 0:
+ for frame in frames:
+ frame['round_time'] = list(map(lambda x: np.round((x- datetime.datetime(2017,4,4)).total_seconds()/dt),frame['datetime']))
+ what_is_same += ['round_time']
+
+ frames_common = {}
+ for frame in frames:
+ frame['all_ind'] = list(zip(*[frame[x] for x in what_is_same]))
+ if frames_common != {}:
+ frames_common = frames_common&set(frame['all_ind'])
+ else:
+ frames_common = set(frame['all_ind'])
+
+ frames_out = []
+ for frame in frames:
+ frame = frame[list(map(lambda x: x in frames_common, frame.all_ind))].copy()
+ if uniquely:
+ frame.drop_duplicates(subset=['all_ind'], keep='first', inplace=True)
+
+ frame = frame.sort_values('all_ind').reset_index(drop=True)
+ frame.drop('all_ind', axis=1,inplace=True)
+ frames_out.append(frame.copy())
+ return frames_out
+
+
+def add_gmst(df):
+ #Lindy Blackburn's work borrowed from eat
+ """add *gmst* column to data frame with *datetime* field using astropy for conversion"""
+ from astropy import time
+ g = df.groupby('datetime')
+ (timestamps, indices) = list(zip(*iter(g.groups.items())))
+ # this broke in pandas 0.9 with API changes
+ if type(timestamps[0]) is np.datetime64: # pandas < 0.9
+ times_unix = 1e-9*np.array(
+ timestamps).astype('float') # note one float64 is not [ns] precision
+ elif type(timestamps[0]) is pd.Timestamp:
+ times_unix = np.array([1e-9 * t.value for t in timestamps]) # will be int64's
+ else:
+ raise Exception("do not know how to convert timestamp of type " + repr(type(timestamps[0])))
+ times_gmst = time.Time(
+ times_unix, format='unix').sidereal_time('mean', 'greenwich').hour # vectorized
+ df['gmst'] = 0. # initialize new column
+ for (gmst, idx) in zip(times_gmst, indices):
+ df.ix[idx, 'gmst'] = gmst
+ return df
diff --git a/statistics/stats.py b/statistics/stats.py
new file mode 100644
index 00000000..708b53e2
--- /dev/null
+++ b/statistics/stats.py
@@ -0,0 +1,336 @@
+# stats.py
+# variety of statistical functions useful for
+#
+# Copyright (C) 2018 Maciek Wielgus
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from __future__ import division
+from __future__ import print_function
+from builtins import str
+from builtins import map
+from builtins import range
+
+import numpy as np
+import numpy.random as npr
+import sys
+
+from ehtim.const_def import *
+
+def circular_mean(theta, unit='deg'):
+ '''circular mean for averaging angular quantities
+ Args:
+ theta: list/vector of angles to average
+ unit: degrees ('deg') or radians (any other string)
+
+ Returns:
+ circular mean
+ '''
+ theta = np.asarray(theta, dtype=np.float32)
+ theta= theta.flatten()
+ theta = theta[theta==theta]
+ if unit=='deg':
+ theta *= np.pi/180.
+ if len(theta)==0:
+ return None
+ else:
+ C = np.mean(np.cos(theta))
+ S = np.mean(np.sin(theta))
+ circ_mean = np.arctan2(S,C)
+ if unit=='deg':
+ circ_mean *= 180./np.pi
+ return np.mod(circ_mean+180.,360.)-180.
+ else:
+ return np.mod(circ_mean+np.pi,2.*np.pi)-np.pi
+
+def circular_std(theta,unit='deg'):
+ '''standard deviation of a circular distribution
+ Args:
+ theta: list/vector of angles
+ unit: degrees ('deg') or radians (any other string)
+
+ Returns:
+ circular standard deviation
+ '''
+ theta = np.asarray(theta, dtype=np.float32)
+ theta= theta.flatten()
+ theta = theta[theta==theta]
+ if unit=='deg':
+ theta *= np.pi/180.
+ if len(theta)<2:
+ return None
+ else:
+ C = np.mean(np.cos(theta))
+ S = np.mean(np.sin(theta))
+ circ_std = np.sqrt(-2.*np.log(np.sqrt(C**2+S**2)))
+ if unit=='deg':
+ circ_std *= 180./np.pi
+ return circ_std
+
+def circular_std_of_mean(theta,unit='deg'):
+ '''standard deviation of mean for a circular distribution
+ Args:
+ theta: list/vector of angles
+ unit: degrees ('deg') or radians (any other string)
+
+ Returns:
+ circular standard deviation of mean
+ '''
+ theta = np.asarray(theta, dtype=np.float32)
+ theta= theta.flatten()
+ theta = theta[theta==theta]
+ return circular_std(theta,unit)/np.sqrt(len(theta))
+
+def mean_incoh_amp(amp,sigma,debias=True,err_type='predicted',num_samples=int(1e3)):
+ """amplitude from ensemble of Rice-distributed measurements with debiasing
+ Args:
+ amp: vector of (biased) amplitudes
+ sigma: vector of errors
+ debias: whether debiasing is applied
+ Returns:
+ amp0: estimator of unbiased amplitude
+ """
+ if (not hasattr(amp, "__len__")):
+ amp = [amp]
+ amp = np.asarray(amp, dtype=np.float32)
+ N = len(amp)
+ if (not hasattr(sigma, "__len__")):
+ sigma = sigma*np.ones(N)
+ elif len(sigma)==1:
+ sigma = sigma*np.ones(N)
+ sigma = np.asarray(sigma, dtype=np.float32)
+ if len(sigma)!=len(amp):
+ print('Inconsistent length of amp and sigma')
+ return None
+ else:
+ amp_clean=amp[(amp==amp)&(sigma==sigma)&(sigma>0)&(amp>0)]
+ sigma_clean=sigma[(amp==amp)&(sigma==sigma)&(sigma>0)&(amp>0)]
+ #eq. 9.86 from Thompson et al.
+ if debias==True:
+ amp0sq = ( np.mean(amp_clean**2 - (2. - 1./N)*sigma_clean**2) )
+ else: amp0sq = np.mean(amp_clean**2)
+ amp0sq = np.maximum(amp0sq,0.)
+ amp0 = np.sqrt(amp0sq)
+
+ #getting errors
+ if err_type=='predicted':
+ sigma0 = np.sqrt(np.sum(sigma_clean**2)/len(sigma_clean)**2)
+ elif err_type=='measured':
+ ampfoo, ci = bootstrap(amp_clean, np.mean, num_samples=num_samples,wrapping_variable=False)
+ sigma0 = 0.5*(ci[1]-ci[0])
+ return amp0,sigma0
+
+def mean_incoh_amp_from_vis(vis,sigma,debias=True,err_type='predicted',num_samples=int(1e3)):
+ """Amplitude from ensemble of visibility measurements with debiasing
+ Args:
+ amp: vector of (biased) amplitudes
+ sigma: vector of errors
+ debias: whether debiasing is applied
+ Returns:
+ amp0: estimator of unbiased amplitude
+ """
+ if (not hasattr(vis, "__len__")):
+ vis = [vis]
+ vis= np.asarray(vis)
+ vis= vis[vis==vis]
+ amp=np.abs(vis)
+
+ N = len(amp)
+ if (not hasattr(sigma, "__len__")):
+ sigma = sigma*np.ones(N)
+ elif len(sigma)==1:
+ sigma = sigma*np.ones(N)
+ sigma = np.asarray(sigma, dtype=np.float32)
+ if len(sigma)!=len(amp):
+ print('Inconsistent length of amp and sigma')
+ return None, None
+ else:
+ amp_clean=amp[(amp==amp)&(sigma==sigma)&(sigma>=0)&(amp>=0)]
+ sigma_clean=sigma[(amp==amp)&(sigma==sigma)&(sigma>=0)&(amp>=0)]
+ Nc=len(amp_clean)
+ if Nc<1:
+ return None, None
+ else:
+ #eq. 9.86 from Thompson et al.
+ if debias==True:
+ amp0sq = ( np.mean(amp_clean**2 - (2. - 1./Nc)*sigma_clean**2) )
+ else: amp0sq = np.mean(amp_clean**2)
+ if (amp0sq!=amp0sq): amp0sq=0.
+ amp0sq = np.maximum(amp0sq,0.)
+ amp0 = np.sqrt(amp0sq)
+ #getting errors
+ if err_type=='predicted':
+ #sigma0 = np.sqrt(np.sum(sigma_clean**2)/Nc**2)
+ #Esigma = np.median(sigma_clean)
+ #snr0 = amp0/Esigma
+ #snrA = 1./(np.sqrt(1. + 2./np.sqrt(Nc)*(1./snr0)*np.sqrt(1.+1./snr0**2)) - 1.)
+ #sigma0=amp0/snrA
+ sigma0 = np.sqrt(np.sum(sigma_clean**2)/Nc**2)
+
+ elif err_type=='measured':
+ ampfoo, ci = bootstrap(amp_clean, np.mean, num_samples=num_samples,wrapping_variable=False,alpha='1sig')
+ sigma0 = 0.5*(ci[1]-ci[0])
+ return amp0,sigma0
+
+def bootstrap(data, statistic, num_samples=int(1e3), alpha='1sig',wrapping_variable=False):
+ """bootstrap estimate of 100.0*(1-alpha) confidence interval for a given statistic
+ Args:
+ data: vector of data to estimate bootstrap statistic on
+ statistic: function representing the statistic to be evaluated
+ num_samples: number of bootstrap (re)samples
+ alpha: parameter of the confidence interval, '1s' gets an analog of 1 sigma confidence for a normal variable
+ wrapping_variable: True for circular variables, attempts to avoid problem related to estimating variability of wrapping variable
+
+ Returns:
+ bootstrap_value: bootstrap-estimated value of the statistic
+ bootstrap_CI: bootstrap-estimated confidence interval
+ """
+ if alpha=='1sig':
+ alpha=0.3173
+ elif alpha=='2sig':
+ alpha=0.0455
+ elif alpha=='3sig':
+ alpha=0.0027
+ stat = np.zeros(num_samples)
+ data = np.asarray(data)
+ if wrapping_variable==True:
+ m=statistic(data)
+ else:
+ m=0
+ data = data-m
+ n = len(data)
+ idx = npr.randint(0, n, (num_samples, n))
+ samples = data[idx]
+ for cou in range(num_samples):
+ stat[cou] = statistic(samples[cou,:])
+ stat = np.sort(stat)
+ bootstrap_value = np.median(stat)+m
+ bootstrap_CI = [stat[int((alpha/2.0)*num_samples)]+m, stat[int((1-alpha/2.0)*num_samples)]+m]
+ return bootstrap_value, bootstrap_CI
+
+def mean_incoh_avg(x,debias=True):
+ amp = np.abs(np.asarray([y[0] for y in x]))
+ sig = np.asarray([y[1] for y in x])
+ ampN = amp[(amp==amp)&(amp>=0)&(sig==sig)&(sig>=0)]
+ sigN = sig[(amp==amp)&(amp>=0)&(sig==sig)&(sig>=0)]
+ amp = ampN
+ sig = sigN
+ Nc = len(sig)
+ if Nc==0:
+ amp0 = -1
+ sig0 = -1
+ elif Nc==1:
+ amp0 = amp[0]
+ sig0 = sig[0]
+ else:
+ if debias==True:
+ amp0 = deb_amp(amp,sig)
+ else:
+ amp0= np.sqrt(np.maximum(np.mean(amp**2),0.))
+ sig0 = inc_sig(amp,sig)
+ #sig0 = coh_sig(amp,sig)
+ return amp0,sig0
+
+def deb_amp(amp,sig):
+ #eq. 9.86 from Thompson et al.
+ amp = np.abs(np.asarray(amp))
+ sig = np.asarray(sig)
+ Nc = len(amp)
+ amp0sq = ( np.mean(amp**2 - (2. - 1./Nc)*sig**2) )
+ amp0sq = np.maximum(amp0sq,0.)
+ amp0 = np.sqrt(amp0sq)
+ return amp0
+
+def inc_sig(amp,sig):
+ amp = np.abs(np.asarray(amp))
+ sig = np.asarray(sig)
+ Nc = len(amp)
+ amp0 = deb_amp(amp,sig)
+ Esigma = np.median(sig)
+ snr0 = amp0/Esigma
+ snrA = 1./(np.sqrt(1. + 2./np.sqrt(Nc)*(1./snr0)*np.sqrt(1.+1./snr0**2)) - 1.)
+ if snrA>0:
+ sigma0=amp0/snrA
+ else: sigma0=coh_sig(amp,sig)
+ return sigma0
+
+def coh_sig(amp,sig):
+ amp = np.abs(np.asarray(amp))
+ sig = np.asarray(sig)
+ Nc = len(amp)
+ sigma0 = np.sqrt(np.sum(sig**2)/Nc**2)
+ return sigma0
+
+
+def dicts_TV_report(obs,snr_cut=2.):
+ """Computes mean total variation reports
+ Args:
+ obs: ObsData object
+ snr_cut: threshold for data snr
+ Returns:
+ amptv: dictionary of baseline mean TV
+ cptv: dictionary of triangle mean TV
+ lcatv: dictionary of quadrangle mean TV
+ """
+ amp = obs.data
+ baselines = list(set([(x[0],x[1]) for x in lca[['t1','t2']]]))
+ amptv = {}
+ for cou,quad in enumerate(baselines):
+ amptv[quad] = np.mean(np.abs(np.diff(np.abs(amp[(amp['t1']==baselines[cou][0])&(amp['t2']==baselines[cou][1])]['vis']))))
+
+ cp = obs.c_phases()
+ obs = obs.flag_low_snr(snr_cut=snr_cut)
+ triangles = list(set([(x[0],x[1],x[2]) for x in cp[['t1','t2','t3']]]))
+ cptv = {}
+ for cou,tri in enumerate(triangles):
+ cptv[tri] = np.mean(np.abs(np.diff(cp[(cp['t1']==triangles[cou][0])&(cp['t2']==triangles[cou][1])&(cp['t3']==triangles[cou][2])]['cphase'])))
+
+ lca = obs.c_amplitudes(ctype='logcamp')
+ quadrangles = list(set([(x[0],x[1],x[2],x[3]) for x in lca[['t1','t2','t3','t4']]]))
+ lcatv = {}
+ for cou,quad in enumerate(quadrangles):
+ lcatv[quad] = np.mean(np.abs(np.diff(lca[(lca['t1']==quadrangles[cou][0])&(lca['t2']==quadrangles[cou][1])&(lca['t3']==quadrangles[cou][2])&(lca['t4']==quadrangles[cou][3])]['camp'])))
+
+ return amptv, cptv, lcatv
+
+def compare_TV(obs,obsref,snr_cut=2.,output=''):
+ """Computes mean total variation reports
+ Args:
+ obs: ObsData object
+ obref: ObsData object to use as a reference
+ snr_cut: threshold for data snr
+ output (str): if empty, returns median relative TV across all (baselines, triangles, quadrangles)
+ if 'Full', returns full dictionaries comparing all (baselines, triangles, quadrangles)
+ Returns:
+ amprel / ampmed: dictionary of baseline relative differences in mean TV / median of baseline relative differences in mean TV
+ cprel / cpmed: dictionary of triangle relative differences in mean TV / median of triangle relative differences in mean TV
+ lcarel / lcamed: dictionary of quadrangle relative differences in mean TV / median of quadrangle relative differences in mean TV
+ """
+ amptv, cptv, lcatv = dicts_TV(obs,snr_cut=snr_cut)
+ ampref, cpref, lcaref = dicts_TV(obsref,snr_cut=snr_cut)
+
+ cprel = {key: (cptv[key] - cpref[key])/cpref[key] for key in cptv.keys() if key in set(cpref.keys())}
+ amprel = {key: (amptv[key] - ampref[key])/ampref[key] for key in amptv.keys() if key in set(ampref.keys())}
+ lcarel = {key: (lcatv[key] - lcaref[key])/lcaref[key] for key in lcatv.keys() if key in set(lcaref.keys())}
+ amprel = {key:amprel[key] for key in amprel.keys() if amprel[key]==amprel[key]}
+ cprel = {key:cprel[key] for key in cprel.keys() if cprel[key]==cprel[key]}
+ lcarel = {key:lcarel[key] for key in lcarel.keys() if lcarel[key]==lcarel[key]}
+
+ if output=='Full':
+ return amprel, cprel, lcarel
+ else:
+ ampmed =np.median([amprel[key] for key in amprel.keys() if amprel[key]==amprel[key]])
+ cpmed =np.median([cprel[key] for key in cprel.keys() if cprel[key]==cprel[key]])
+ lcamed =np.median([lcarel[key] for key in lcarel.keys() if lcarel[key]==lcarel[key]])
+ return ampmed,cpmed,lcamed
\ No newline at end of file
diff --git a/survey.py b/survey.py
new file mode 100644
index 00000000..fc7c2b0a
--- /dev/null
+++ b/survey.py
@@ -0,0 +1,626 @@
+# survey.py
+# a parameter survey class
+#
+# Copyright (C) 2018 Andrew Chael
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import os
+
+import ehtim as eh
+import paramsurvey
+import paramsurvey.params
+
+import warnings
+warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
+
+
+##################################################################################################
+# ParameterSet object
+##################################################################################################
+
+class ParameterSet:
+
+ """
+ Attributes:
+ paramset (dict): dict containing single parameter set. Each key in dict becomes its own attribute
+ params_fixed (dict): dict containing non-varying parameters. Each key in dict becomes its own attribute
+ outfile (str): base of outfile with 7 digit index of pset added to it
+ obs (Obsdata): observation object from input uvfits file, generated by load_data()
+ zbl_tot (float): flux on ALMA-APEX baseline, generated in preimcal()
+ obs_sc_init (Obsdata): original observation with systematic noise added, reverse taper applied, and extended flux removed
+ obs_sc (Obsdata): obs_sc_init after all rounds of self-calibration
+ initimg (Image): prior/initial image
+ im_out (Image): final reconstructed image
+ im_addcmp (Image): final reconstructed image with extended flux added back in
+ obs_sc_addcmp (Obsdata): original observation self-calibrated with im_addcmp
+ caltab (Caltable): final calibration table associated with im_out
+
+ """
+
+ def __init__(self, paramset, params_fixed={}):
+
+ """
+ An object for one parameter set
+
+ Args:
+ paramset (dict): A dict containing single parameter set
+ params_fixed (dict): A dict containing non-varying parameters
+
+ Returns:
+ a ParameterSet object
+ """
+
+ # set each item in paramset dict to an attribute
+ for param in paramset:
+ setattr(self, param, paramset[param])
+
+ if len(params_fixed) > 0:
+ for param in params_fixed:
+ setattr(self, param, params_fixed[param])
+
+ os.makedirs(self.outpath, exist_ok=True)
+ self.paramset = paramset
+ self.params_fixed = params_fixed
+ self.outfile = '%s_%0.7i' % (self.outfile_base, self.i)
+
+ self.fov *= eh.RADPERUAS
+ self.reverse_taper_uas *= eh.RADPERUAS
+ self.prior_fwhm *= eh.RADPERUAS
+
+ def load_data(self):
+ """
+ Loads in uvfits file from self.infile into eht-imaging obs object and averages using self.avg_time
+ Creates self.obs
+
+ Args:
+
+ Returns:
+
+ """
+
+ # load the uvfits file
+ self.obs = eh.obsdata.load_uvfits(self.infile)
+
+ # identify the scans (times of continuous observation) in the data
+ self.obs.add_scans()
+
+ # coherently average
+ if self.avg_time == 'scan':
+ self.obs = self.obs.avg_coherent(0., scan_avg=True)
+ else:
+ self.obs = self.obs.avg_coherent(self.avg_time)
+
+ def preimcal(self):
+
+ """
+ Applies pre-imaging calibration to self.obs. This includes flagging sites with no measurements, rescaling
+ short baselines so only compact flux is being imaged, applying a u-v taper, and adding extra sys noise. Creates
+ self.obs_sc which will have self-cal applied in future steps and self.obs_sc_init which is a copy of the initial
+ preimcal observation
+
+ Args:
+
+ Returns:
+
+ """
+
+ # handle site name change of APEX between 2017 (AP) to 2018 (AX)
+ sites = list(self.obs.tkey.keys())
+
+ alma_options = ['ALMA', 'AA']
+ for opt in alma_options:
+ if opt in sites:
+ alma = opt
+
+ apex_options = ['APEX', 'AP', 'AX']
+ for opt in apex_options:
+ if opt in sites:
+ apex = opt
+
+ self.zbl_tot = np.median(self.obs.unpack_bl(alma, apex, 'amp')['amp'])
+
+
+ # Flag out sites in the obs.tarr table with no measurements
+ allsites = set(self.obs.unpack(['t1'])['t1']) | set(self.obs.unpack(['t2'])['t2'])
+ self.obs.tarr = self.obs.tarr[[o in allsites for o in self.obs.tarr['site']]]
+ self.obs = eh.obsdata.Obsdata(self.obs.ra, self.obs.dec, self.obs.rf, self.obs.bw, self.obs.data, self.obs.tarr,
+ source=self.obs.source, mjd=self.obs.mjd,
+ ampcal=self.obs.ampcal, phasecal=self.obs.phasecal)
+
+ self.obs_orig = self.obs.copy()
+
+ # Rescale short baselines to excise contributions from extended flux.
+ # setting zbl < zbl_tot assumes there is an extended constant flux component of zbl_tot-zbl Jy
+
+ if self.zbl != self.zbl_tot:
+ for j in range(len(self.obs.data)):
+ if (self.obs.data['u'][j] ** 2 + self.obs.data['v'][j] ** 2) ** 0.5 >= self.uv_zblcut: continue
+ for field in ['vis', 'qvis', 'uvis', 'vvis', 'sigma', 'qsigma', 'usigma', 'vsigma']:
+ self.obs.data[field][j] *= self.zbl / self.zbl_tot
+
+ self.obs.reorder_tarr_snr()
+
+ self.obs_sc = self.obs.copy()
+ # Reverse taper the observation: this enforces a maximum resolution on reconstructed features
+ if self.reverse_taper_uas > 0:
+ self.obs_sc = self.obs_sc.reverse_taper(self.reverse_taper_uas)
+
+ # Add non-closing systematic noise to the observation
+ self.obs_sc = self.obs_sc.add_fractional_noise(self.sys_noise)
+
+ # Make a copy of the initial data (before any self-calibration but after the taper)
+ self.obs_sc_init = self.obs_sc.copy()
+
+ def init_img(self):
+
+ """
+ Creates initial/prior image. Only gaussian prior option at present, but creates prior attritubute self.initimg
+ using self.zbl and self.prior_fwhm
+
+ Args:
+
+ Returns:
+
+ """
+ # create guassian prior/inital image
+ emptyprior = eh.image.make_square(self.obs_sc, self.npixels, self.fov)
+
+ gaussprior = emptyprior.add_gauss(self.zbl, (self.prior_fwhm, self.prior_fwhm, 0, 0, 0))
+ # To avoid gradient singularities in the first step, add an additional small Gaussian
+ gaussprior = gaussprior.add_gauss(self.zbl * 1e-3, (self.prior_fwhm, self.prior_fwhm, 0,
+ self.prior_fwhm, self.prior_fwhm))
+ self.initimg = gaussprior.copy()
+
+ def make_img(self):
+
+ """
+ Reconstructs image with specified parameters (data weights, amount of self-cal, etc) described in paramset dict
+ Creates attributes self.im_out containing the final image and self.caltab containing a corresponding calibration
+ table object for the final image
+
+ Args:
+
+ Returns:
+
+ """
+
+ # specify data terms
+ data_term = {}
+ if hasattr(self, 'vis') and self.vis != 0.:
+ data_term['vis'] = self.vis
+ if hasattr(self, 'amp') and self.amp != 0.:
+ data_term['amp'] = self.amp
+ if hasattr(self, 'diag_closure') and self.diag_closure is True:
+ if hasattr(self, 'logcamp_diag') and self.logcamp_diag != 0.:
+ data_term['logcamp_diag'] = self.logcamp
+ if hasattr(self, 'cphase_diag') and self.cphase_diag != 0.:
+ data_term['cphase_diag'] = self.cphase
+ else:
+ if hasattr(self, 'logcamp') and self.logcamp != 0.:
+ data_term['logcamp'] = self.logcamp
+ if hasattr(self, 'cphase') and self.cphase != 0.:
+ data_term['cphase'] = self.cphase
+
+ # specify regularizer terms
+ reg_term = {}
+ if hasattr(self, 'simple') and self.simple != 0.:
+ reg_term['simple'] = self.simple
+ if hasattr(self, 'tv2') and self.tv2 != 0.:
+ reg_term['tv2'] = self.tv2
+ if hasattr(self, 'tv') and self.tv != 0.:
+ reg_term['tv'] = self.tv
+ if hasattr(self, 'l1') and self.l1 != 0.:
+ reg_term['l1'] = self.l1
+ if hasattr(self, 'flux') and self.flux != 0.:
+ reg_term['flux'] = self.flux
+ if hasattr(self, 'rgauss') and self.rgauss != 0.:
+ reg_term['rgauss'] = self.rgauss
+
+ ### How to make this more general? ###
+ # Add systematic noise tolerance for amplitude a-priori calibration errors
+ # Start with the SEFD noise (but need sqrt)
+ # then rescale to ensure that final results respect the stated error budget
+ systematic_noise = self.SEFD_error_budget.copy()
+ for key in systematic_noise.keys():
+ systematic_noise[key] = ((1.0 + systematic_noise[key]) ** 0.5 - 1.0) * 0.25
+
+ # set up imager
+ imgr = eh.imager.Imager(self.obs_sc, self.initimg, prior_im=self.initimg, flux=self.zbl,
+ data_term=data_term, maxit=self.maxit, norm_reg=True, systematic_noise=systematic_noise,
+ reg_term=reg_term, ttype=self.ttype, cp_uv_min=self.uv_zblcut, stop=self.stop)
+
+ res = self.obs.res()
+
+ imgr.make_image_I(show_updates=False, niter=self.niter_static, blur_frac=self.blurfrac)
+
+ if self.selfcal:
+ # Self-calibrate to the previous model (phase-only);
+ # The solution_interval is 0 to align phases from high and low bands if needed
+ self.obs_sc = eh.selfcal(self.obs_sc, imgr.out_last(), method='phase', ttype=self.ttype,
+ solution_interval=0.0, processes=-1)
+
+ sc_p_idx = 0
+ dterms = data_term.keys()
+ while sc_p_idx < self.sc_phase:
+
+ # Blur the previous reconstruction to the intrinsic resolution
+ init = imgr.out_last().blur_circ(res)
+
+ # Increase the data weights and reinitialize imaging
+ if sc_p_idx == 0:
+ for key in dterms:
+ data_term[key] *= self.xdw_phase
+
+ # set up imager
+ imgr = eh.imager.Imager(self.obs_sc, init, prior_im=self.initimg, flux=self.zbl,
+ data_term=data_term, maxit=self.maxit, norm_reg=True,
+ systematic_noise=systematic_noise,
+ reg_term=reg_term, ttype=self.ttype, cp_uv_min=self.uv_zblcut, stop=self.stop)
+
+ # Imaging
+ imgr.make_image_I(show_updates=False, niter=self.niter_static, blur_frac=self.blurfrac)
+
+ # apply self-calibration to original calibrated data
+ self.obs_sc = eh.selfcal(self.obs_sc_init, imgr.out_last(), method='phase', ttype=self.ttype)
+
+ sc_p_idx += 1
+
+ # repeat amp+phase self-calibration
+ sc_ap_idx = 0
+
+ while sc_ap_idx < self.sc_ap:
+
+ # Blur the previous reconstruction to the intrinsic resolution
+ init = imgr.out_last().blur_circ(res)
+
+ # Increase the data weights and reinitialize imaging
+ if sc_p_idx == 0:
+ for key in dterms:
+ data_term[key] *= self.xdw_ap
+
+ # set up imager
+ imgr = eh.imager.Imager(self.obs_sc, init, prior_im=self.initimg, flux=self.zbl,
+ data_term=data_term, maxit=self.maxit, norm_reg=True,
+ systematic_noise=systematic_noise,
+ reg_term=reg_term, ttype=self.ttype, cp_uv_min=self.uv_zblcut, stop=self.stop)
+
+ # Imaging
+ imgr.make_image_I(show_updates=False, niter=self.niter_static, blur_frac=self.blurfrac)
+
+ caltab = eh.selfcal(self.obs_sc_init, imgr.out_last(), method='both',
+ ttype=self.ttype, gain_tol=self.gaintol, caltable=True, processes=-1)
+ self.obs_sc = caltab.applycal(self.obs_sc_init, interp='nearest', extrapolate=True)
+
+ sc_ap_idx += 1
+
+ self.im_out = imgr.out_last().copy()
+
+ # if no self-cal, no caltabs will be saved
+ if self.sc_phase == 0 and self.sc_ap == 0:
+ self.save_caltab = False
+
+ else:
+ self.caltab = caltab
+
+ def output_results(self):
+
+ """
+ Outputs all requested files pertaining to final image
+
+ Args:
+
+ Returns:
+
+ """
+
+ # Add a large gaussian component to account for the missing flux
+ # so the final image can be compared with the original data
+ self.im_addcmp = self.im_out.add_zblterm(self.obs_orig, self.uv_zblcut, debias=True)
+ self.obs_sc_addcmp = eh.selfcal(self.obs_orig, self.im_addcmp, method='both', ttype=self.ttype)
+
+ # If an inverse taper was used, restore the final image
+ # to be consistent with the original data
+ if self.reverse_taper_uas > 0.0:
+ self.im_out = self.im_out.blur_circ(self.reverse_taper_uas)
+
+ # Save the final image
+ outfits = os.path.join(self.outpath, '%s.fits' % (self.outfile))
+ self.im_out.save_fits(outfits)
+
+ # Save caltab
+ if hasattr(self, 'save_caltab') and self.save_caltab == True:
+ outcal = os.path.join(self.outpath, '%s/' % (self.outfile))
+ eh.caltable.save_caltable(self.caltab, self.obs_sc_init, outcal)
+
+ # Save self-calibrated uvfits
+ if self.save_uvfits:
+ outuvfits = os.path.join(self.outpath, '%s.uvfits' % (self.outfile))
+ self.obs_sc_addcmp.save_uvfits(outuvfits)
+
+ # Save pdf of final image
+ if self.save_pdf:
+ outpdf = os.path.join(self.outpath, '%s.pdf' % (self.outfile))
+ self.im_out.display(cbar_unit=['Tb'], label_type='scale', export_pdf=outpdf)
+
+ # Save pdf of image summary
+ if self.save_imgsums:
+ # Save an image summary sheet
+ plt.close('all')
+ outimgsum = os.path.join(self.outpath, '%s_imgsum.pdf' % (self.outfile))
+ eh.imgsum(self.im_addcmp, self.obs_sc_addcmp, self.obs_orig, outimgsum, cp_uv_min=self.uv_zblcut,
+ processes=-1)
+
+ def save_statistics(self):
+
+ """
+ Saves a csv file with the following statistics:
+ chi^2 closure phase, logcamp, vis wrt the original observation
+ chi^2 vis wrt to original observation with self-cal to final image applied
+ chi^2 closure phase, logcamp, vis wrt the original observation with sys noise and self-cal applied
+
+ Args:
+
+ Returns:
+
+ """
+ stats_dict = {}
+ stats_dict['i'] = [self.i]
+
+ outstats = os.path.join(self.outpath, '%s_stats.csv' % (self.outfile))
+
+ # if ground truth image available, compute nxcorr
+ if self.ground_truth_img != 'None':
+ gt_im = eh.image.load_fits(self.ground_truth_img)
+
+ fov_ = 200 * eh.RADPERUAS
+ psize_ = fov_ / 256
+ nxcorr_, _, _ = gt_im.compare_images(self.im_addcmp, metric='nxcorr', target_fov=fov_, psize=psize_)
+ nxcorr = nxcorr_[0]
+
+ stats_dict['nxcorr'] = [nxcorr]
+
+ # chi^2 for closure phase (cp) and log camp (lc)
+ # original uv data
+ obs_ref = self.obs_orig
+ chi2_cp_ref = obs_ref.chisq(self.im_addcmp, dtype='cphase',
+ ttype=self.ttype, systematic_noise=0.,
+ systematic_cphase_noise=0, maxset=False,
+ cp_uv_min=self.uv_zblcut)
+ chi2_lc_ref = obs_ref.chisq(self.im_addcmp, dtype='logcamp',
+ ttype=self.ttype, systematic_noise=0.,
+ snrcut=1.0, maxset=False,
+ cp_uv_min=self.uv_zblcut) # snrcut to remove large chi2 point
+ chi2_vis_ref = obs_ref.chisq(self.im_addcmp, dtype='vis',
+ ttype=self.ttype, systematic_noise=0.,
+ snrcut=1.0, maxset=False,
+ cp_uv_min=self.uv_zblcut) # snrcut to remove large chi2 point
+
+ stats_dict['chi2_cp_ref'] = [chi2_cp_ref]
+ stats_dict['chi2_lc_ref'] = [chi2_lc_ref]
+ stats_dict['chi2_vis_ref'] = [chi2_vis_ref]
+
+ # orig data self-cal to final image
+ obs_sub = self.obs_sc_addcmp
+ chi2_vis_sub = obs_sub.chisq(self.im_addcmp, dtype='vis',
+ ttype=self.ttype, systematic_noise=0.,
+ snrcut=1.0, maxset=False,
+ cp_uv_min=self.uv_zblcut) # snrcut to remove large chi2 point
+
+ stats_dict['chi2_vis_sub'] = [chi2_vis_sub]
+
+ # orig data with sys noise added + self-cal to final image
+ self.obs_sc_addcmp_sys = eh.selfcal(self.obs_sc_init, self.im_addcmp, method='both', ttype=self.ttype)
+ obs_sys = self.obs_sc_addcmp_sys
+ chi2_cp_sys = obs_sys.chisq(self.im_addcmp, dtype='cphase',
+ ttype=self.ttype, systematic_noise=0.,
+ systematic_cphase_noise=0, maxset=False,
+ cp_uv_min=self.uv_zblcut)
+ chi2_lc_sys = obs_sys.chisq(self.im_addcmp, dtype='logcamp',
+ ttype=self.ttype, systematic_noise=0.,
+ snrcut=1.0, maxset=False,
+ cp_uv_min=self.uv_zblcut) # snrcut to remove large chi2 point
+ chi2_vis_sys = obs_sys.chisq(self.im_addcmp, dtype='vis',
+ ttype=self.ttype, systematic_noise=0.,
+ snrcut=1.0, maxset=False,
+ cp_uv_min=self.uv_zblcut) # snrcut to remove large chi2 point
+
+ stats_dict['chi2_cp_sys'] = [chi2_cp_sys]
+ stats_dict['chi2_lc_sys'] = [chi2_lc_sys]
+ stats_dict['chi2_vis_sys'] = [chi2_vis_sys]
+
+ df = pd.DataFrame.from_dict(stats_dict)
+ df.to_csv(outstats)
+
+ def save_params(self):
+ """
+ Saves a csv file with parameter set details
+
+ Args:
+
+ Returns:
+
+ """
+ self.paramset['fov'] = self.fov
+ self.paramset['zbl_tot'] = self.zbl_tot
+
+ df = pd.DataFrame.from_dict([self.paramset])
+
+ outparams = os.path.join(self.outpath, '%s_params.csv' % (self.outfile))
+ df.to_csv(outparams)
+
+ def run(self):
+
+ """
+ Run imaging pipeline for one parameter set.
+
+ Args:
+
+ Returns:
+
+ """
+
+ # if a *_params.csv file exists, it means this parameter has already been run and can be skipped
+ # useful in case survey with multiple parameter sets gets interrupted
+ outcsv = os.path.join(self.outpath, '%s_params.csv' % (self.outfile))
+
+ if not self.overwrite:
+ if os.path.exists(outcsv):
+ pass
+
+ else:
+
+ # load in data
+ self.load_data()
+
+ # do pre-imaging calibration
+ self.preimcal()
+
+ # create initial/prior image
+ self.init_img()
+
+ # run imaging step
+ self.make_img()
+
+ # output the results
+ self.output_results()
+
+ # save params to text
+ self.save_params()
+
+ if self.save_stats:
+ self.save_statistics()
+
+
+def run_pset(pset, system_kwargs, params_fixed):
+ """
+ Run imaging for one parameter set. Not to be used individually, but called in map function for run_survey
+
+ Args:
+ pset (dict): A dict containing single parameter set
+ params_fixed (dict): A dict containing non-varying parameters
+
+ Returns:
+
+ """
+
+ PSet = ParameterSet(pset, params_fixed)
+ PSet.run()
+
+def run_survey(psets, params_fixed):
+ """Run survey for all parameter sets using paramsurvey
+
+ Args:
+ psets (DataFrame): A pandas DataFrame containing all parameter sets
+ params_fixed (dict): A dict containing non-varying parameters
+
+ Returns:
+
+ """
+ # run whole survey using map function
+ paramsurvey.init(backend=params_fixed['backend'], ncores=params_fixed['nproc'],
+ verbose=0,vstats=0)
+ paramsurvey.map(run_pset, psets, user_kwargs=params_fixed, verbose=0)
+
+def create_params_fixed(infile, outfile_base, outpath, ground_truth_img='None',
+ save_imgsums=False, save_uvfits=True, save_pdf=False, save_stats=True, save_caltab=True,
+ nproc=1, backend='multiprocessing', ttype='nfft', overwrite=False,
+ selfcal=True, gaintol=[0.02,0.2], niter_static=3, blurfrac=1,
+ maxit=100, stop=1e-4, fov=128, npixels=64, reverse_taper_uas=5, uv_zblcut=0.1e9,
+ SEFD_error_budget={'AA':0.1,'AX':0.1,'GL':0.1,'LM':0.1,'MG':0.1,'MM':0.1,'PV':0.1,'SW':0.1}):
+ """Create a dict for all non-varying survery parameters
+
+ Args:
+ infile (str): path to input uvfits observation file
+ outfile_base (str): name of base filename for all outputs
+ outpath (str): path to directory where all outputs should be stored
+ ground_truth_img (str): if applicable, path to ground truth fits file
+ save_imgsums (bool): save summary pdf for each image
+ save_uvfits (bool): save final self-cal observation to uvfits file
+ save_pdf (bool): save pdf of each image
+ save_stats (bool): save csv file containing statistics for each image
+ save_caltab (bool): save a calibration table for each image
+ nproc (int): number of parallel processes
+ backend (str): either 'multiprocessing' or 'ray'
+ ttype (str): “fast” or “nfft” or “direct”
+ overwrite (bool): if True, write over existing files with same names - else, skip parameter set
+ selfcal (bool): perform self-calibration steps during imaging
+ gaintol (array): tolerance for gains to be under/over unity respectively during self-cal
+ niter_static (int): number of iterations for each imaging step
+ blurfrac (int): factor to blur initial image between iterations
+ maxit (int): maximum number of iterations if image does not converge
+ stop (float): convergence criterion for imaging
+ fov (int): image field of view in uas
+ npixels (int): number of image pixels
+ reverse_taper_uas (int): fwhm of gassuain in uas to reverse taper observation
+ uv_zblcut (float): maximum uv-distance to which is considered short baseline flux
+ SEFD_error_budget (dict): SEFD percentage error for each station
+
+
+ Returns:
+ (dict): dict containing all non-varying survery parameters
+ """
+
+ # take all arguments and put them in a dict
+ args = list(locals().keys())
+ params_fixed = {}
+ for arg in args:
+ params_fixed[arg] = locals().get(arg)
+
+ return params_fixed
+
+def create_survey_psets(zbl=[0.6], sys_noise=[0.02], avg_time=['scan'], prior_fwhm=[40],
+ sc_phase=[0], xdw_phase=[10], sc_ap=[0], xdw_ap=[1], amp=[0.2], cphase=[1], logcamp=[1],
+ simple=[1], l1=[1], tv=[1], tv2=[1], flux=[1], epsilon_tv=[1e-10]):
+ """Create a dataframe given all survey parameters. Default values will create an example dataframe but these values should be adjusted for each specific observation
+
+ Args:
+ zbl (array): compact flux value (Jy)
+ sys_noise (array): percent addition of systematic noise
+ avg_time (array): in seconds or 'scan' for scan averaging
+ prior_fwhm (array): fwhm of gaussian prior in uas
+ sc_phase (array): number of iterations to perform phase-only self-cal
+ xdw_phase (array): multiplicative factor for data weights after one round of phase-only self-cal
+ sc_ap (array): number of iterations to perform amp+phase self-cal
+ xdw_ap (array): multiplicative factor for data weights after one round of amp+phase self-cal
+ amp (array): data weight to be placed on amplitudes
+ cphase (array): data weight to be placed on closure phases
+ logcamp (array): data weight to be placed on log closure amplitudes
+ simple (array): regularizer weight for relative entropy, favoring similarity to prior image
+ l1 (array): regularizer weight for l1 norm, favoring image sparsity
+ tv (array): regularizer weight for total variation, favoring sharp edges
+ tv2 (array): regularizer weight for total squared variation, favoring smooth edges
+ flux (array): regularizer weight for total flux density, favoring final images with flux close to zbl
+ epsilon_tv (array): epsilon value used in definition of total variation - rarely need to change this
+
+ Returns:
+ (DataFrame): pandas DataFrame containing all combination of parameter sets along with index values
+ """
+
+ # take all arguments and put them in a dict
+ args = list(locals().keys())
+ params = {}
+ for arg in args:
+ params[arg] = locals().get(arg)
+
+ # pandas dataframe containing all combinations of parameters to survey
+ psets = paramsurvey.params.product(params)
+
+ # add pset index number to each row of dataframe
+ psets['i'] = np.array(range(len(psets)))
+
+ return psets
diff --git a/vex.py b/vex.py
new file mode 100644
index 00000000..38a42e43
--- /dev/null
+++ b/vex.py
@@ -0,0 +1,332 @@
+# vex.py
+# a interferometric array vex schedule class
+#
+# Copyright (C) 2018 Hotaka Shiokawa
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+
+from __future__ import division
+from __future__ import print_function
+
+from builtins import str
+from builtins import range
+from builtins import object
+
+import numpy as np
+import re
+
+from astropy.time import Time
+import os
+import ehtim.array
+import ehtim.const_def as ehc
+
+###################################################################################################
+# Vex object
+###################################################################################################
+
+
+class Vex(object):
+ """Read in observing schedule data from .vex files.
+ Assumes there is only 1 MODE in vex file
+
+ Attributes:
+ filename (str): The .vex filename.
+ source (str): The source name.
+ metalist (list): The observation information.
+ sched (list): The schedule information.
+ array (Array): an Array object of sites.
+ """
+
+ def __init__(self, filename):
+
+ f = open(filename)
+ raw = f.readlines()
+ f.close()
+
+ self.filename = filename
+
+ # Divide 'raw' data into sectors of '$' marks
+ # ASSUMING '$' is the very first character in a line (no space in front)
+ metalist = [] # meaning list of metadata
+
+ for i in range(len(raw)):
+ if raw[i][0] == '$':
+ temp = [raw[i]]
+ break
+
+ for j in range(i + 1, len(raw)):
+ if raw[j][0] != '$':
+ temp.append(raw[j])
+ elif raw[j][0] == '$':
+ metalist.append(temp)
+ temp = [raw[j]]
+ else:
+ print('Something is wrong.')
+ metalist.append(temp) # don't forget to add the final one
+ self.metalist = metalist
+
+ # Extract desired information
+ # SOURCE ========================================================
+ SOURCE = self.get_sector('SOURCE')
+ source = []
+ indef = False
+
+ for i in range(len(SOURCE)):
+
+ line = SOURCE[i]
+ if line[0:3] == "def":
+ indef = True
+
+ if indef:
+ ret = self.get_variable("source_name", line)
+ if len(ret) > 0:
+ source_name = ret
+ ret = self.get_variable("ra", line)
+ if len(ret) > 0:
+ ra = ret
+ ret = self.get_variable("dec", line)
+ if len(ret) > 0:
+ dec = ret
+ ret = self.get_variable("ref_coord_frame", line)
+ if len(ret) > 0:
+ ref_coord_frame = ret
+
+ if line[0:6] == "enddef":
+ source.append({'source': source_name, 'ra': ra, 'dec': dec,
+ 'ref_coord_frame': ref_coord_frame})
+ indef = False
+
+ self.source = source
+
+ # FREQ ==========================================================
+ FREQ = self.get_sector('FREQ')
+ indef = False
+ nfreq = 0
+ for i in range(len(FREQ)):
+
+ line = FREQ[i]
+ if line[0:3] == "def":
+ if nfreq > 0:
+ print("Not implemented yet.")
+ nfreq += 1
+ indef = True
+
+ if indef:
+ idx = line.find('chan_def')
+ if idx >= 0 and line[0] != '*':
+ chan_def = re.findall(r"[-+]?\d+[\.]?\d*", line)
+ self.freq = float(chan_def[0]) * 1.e6
+ self.bw_hz = float(chan_def[1]) * 1.e6
+
+ if line[0:6] == "enddef":
+ indef = False
+
+ # SITE ==========================================================
+ SITE = self.get_sector('SITE')
+ sites = []
+ site_ID_dict = {}
+ indef = False
+
+ for i in range(len(SITE)):
+
+ line = SITE[i]
+ if line[0:3] == "def":
+ indef = True
+
+ if indef:
+ # get site_name and SEFD
+ ret = self.get_variable("site_name", line)
+ if len(ret) > 0:
+ site_name = ret
+ SEFD = self.get_SEFD(site_name)
+
+ # making dictionary of site_ID:site_name
+ ret = self.get_variable("site_ID", line)
+ if len(ret) > 0:
+ site_ID_dict[ret] = site_name
+
+ # get site_position
+ ret = self.get_variable("site_position", line)
+ if len(ret) > 0:
+ site_position = re.findall(r"[-+]?\d+[\.]?\d*", line)
+
+ # same format as Andrew's array tables
+ if line[0:6] == "enddef":
+ sites.append([site_name, site_position[0],
+ site_position[1], site_position[2], SEFD])
+ indef = False
+
+ # Construct Array() object of Andrew's format
+ # mimic the function "load_array(filename)"
+ # TODO this does not store d-term and pol cal. information!
+ tdataout = [np.array((x[0], float(x[1]), float(x[2]), float(x[3]), float(x[4]), float(x[4]),
+ 0.0, 0.0, 0.0, 0.0, 0.0),
+ dtype=ehc.DTARR) for x in sites]
+
+ tdataout = np.array(tdataout)
+ self.array = ehtim.array.Array(tdataout)
+
+ # SCHED =========================================================
+ SCHED = self.get_sector('SCHED')
+ sched = []
+ inscan = False
+
+ for i in range(len(SCHED)):
+
+ line = SCHED[i]
+ if line[0:4] == "scan":
+ inscan = True
+ temp = {}
+ temp['scan'] = {}
+ cnt = 0
+
+ if inscan:
+ ret = self.get_variable("start", line)
+ if len(ret) > 0:
+ mjd, hr = vexdate_to_MJD_hr(ret) # convert vex time format to mjd and hour
+ temp['mjd_floor'] = mjd
+ temp['start_hr'] = hr
+
+ ret = self.get_variable("mode", line)
+ if len(ret) > 0:
+ temp['mode'] = ret
+
+ ret = self.get_variable("source", line)
+ if len(ret) > 0:
+ temp['source'] = ret
+
+ ret = self.get_variable("station", line)
+ if len(ret) > 0:
+ site_ID = ret
+ site_name = site_ID_dict[site_ID] # convert to more familier site name
+ sdur = re.findall(r"[-+]?\d+[\.]?\d*", line)
+ s_st = float(sdur[0]) # start time in sec
+ s_en = float(sdur[1]) # end time in sec
+ d_size = float(sdur[2]) # data size(?) in GB
+ temp['scan'][cnt] = {'site': site_name, 'scan_sec_start': s_st,
+ 'scan_sec': s_en, 'data_size': d_size}
+ cnt += 1
+
+ if line[0:7] == "endscan":
+ sched.append(temp)
+ inscan = False
+
+ self.sched = sched
+
+ # Function to obtain a desired sector from 'metalist'
+
+ def get_sector(self, sname):
+ """Obtain a desired sector from 'metalist'.
+ """
+
+ for i in range(len(self.metalist)):
+ if sname in self.metalist[i][0]:
+ return self.metalist[i]
+ print('No sector named %s' % sname)
+ return False
+
+ # Function to get a value of 'vname' in a line which has format of
+ # 'vname' = value ;(or :)
+ def get_variable(self, vname, line):
+ """Function to get a value of 'vname' in a line.
+ """
+
+ idx = self.find_variable(vname, line)
+ name = ''
+ if idx >= 0:
+ start = False
+ for i in range(idx + len(vname), len(line)):
+ if start is True:
+ if line[i] == ';' or line[i] == ':':
+ break
+ elif line[i] != ' ':
+ name += line[i]
+ if start is False and line[i] != ' ' and line[i] != '=':
+ break
+ if line[i] == '=':
+ start = True
+ return name
+
+ # check if a variable 'vname' exists by itself in a line.
+ # returns index of vname[0] in a line, or -1
+ def find_variable(self, vname, line):
+ """Function to find a variable 'vname' in a line.
+ """
+ idx = line.find(vname)
+ if ((idx > 0 and line[idx - 1] == ' ') or idx == 0) and line[0] != '*':
+ if idx + len(vname) == len(line):
+ return idx
+ if (line[idx + len(vname)] == '=' or
+ line[idx + len(vname)] == ' ' or
+ line[idx + len(vname)] == ':' or
+ line[idx + len(vname)] == ';'):
+ return idx
+ return -1
+
+ # Find SEFD for a given station name.
+ # For now look for it in Andrew's tables
+ # Vex files could have SEFD sector.
+ def get_SEFD(self, station):
+ """Find SEFD for a given station.
+ """
+ f = open(os.path.dirname(os.path.abspath(__file__)) + "/../arrays/SITES.txt")
+ sites = f.readlines()
+ f.close()
+ for i in range(len(sites)):
+ if sites[i].split()[0] == station:
+ return float(re.findall(r"[-+]?\d+[\.]?\d*", sites[i])[3])
+ print('No station named %s' % station)
+ return 10000. # some arbitrary value
+
+ # Find the time that any station starts observing the source in MJD.
+ # Find the time that the last station stops observing the source in MHD.
+ def get_obs_timerange(self, source):
+ """Find the time that any station starts observing the source in MJD,
+ and the time that the last station stops observing the source.
+ """
+
+ sched = self.sched
+ first = True
+ for i_scan in range(len(sched)):
+ if sched[i_scan]['source'] == source and first is True:
+ Tstart_hr = sched[i_scan]['start_hr']
+ mjd_s = sched[i_scan]['mjd_floor'] + Tstart_hr / 24.
+ first = False
+ if sched[i_scan]['source'] == source and first is False:
+ Tstop_hr = sched[i_scan]['start_hr'] + sched[i_scan]['scan'][0]['scan_sec'] / 3600.
+ mjd_e = sched[i_scan]['mjd_floor'] + Tstop_hr / 24.
+
+ return mjd_s, mjd_e
+
+# =================================================================
+# =================================================================
+
+# Function to find MJD (int!) and hour in UT from vex format,
+# e.g, 2016y099d05h00m00s
+
+
+def vexdate_to_MJD_hr(vexdate):
+ """Find the integer MJD and UT hour from vex format date.
+ """
+
+ time = re.findall(r"[-+]?\d+[\.]?\d*", vexdate)
+ year = int(time[0])
+ date = int(time[1])
+ yeardatetime = ("%04i" % year) + ':' + ("%03i" % date) + ":00:00:00.000"
+ t = Time(yeardatetime, format='yday')
+ mjd = t.mjd
+ hour = int(time[2]) + float(time[3]) / 60. + float(time[4]) / 60. / 60.
+
+ return mjd, hour