diff --git a/__init__.py b/__init__.py new file mode 100644 index 00000000..c560c12d --- /dev/null +++ b/__init__.py @@ -0,0 +1,72 @@ +""" +.. module:: ehtim + :platform: Unix + :synopsis: EHT Imaging Utilities + +.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu) + +""" +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import ehtim.observing +from ehtim.const_def import * +from ehtim.imaging.imager_utils import imager_func +from ehtim.modeling.modeling_utils import modeler_func +import ehtim.imaging +from ehtim.features import rex +import ehtim.features +from ehtim.plotting.summary_plots import * +from ehtim.plotting.comparisons import * +from ehtim.plotting.comp_plots import * +from ehtim.plotting import comparisons +from ehtim.plotting import comp_plots +import ehtim.plotting +from ehtim.calibrating.network_cal import network_cal as netcal +from ehtim.calibrating.self_cal import self_cal as selfcal +from ehtim.calibrating.pol_cal import * +from ehtim.calibrating.pol_cal_new import * +from ehtim.calibrating import pol_cal +from ehtim.calibrating import network_cal +from ehtim.calibrating import self_cal +import ehtim.calibrating +import ehtim.parloop +import ehtim.caltable +import ehtim.vex +import ehtim.imager +import ehtim.obsdata +import ehtim.array +import ehtim.movie +import ehtim.image +import ehtim.model +import ehtim.survey + + +import warnings +warnings.filterwarnings( + "ignore", message="numpy.dtype size changed, may indicate binary incompatibility.") + +# necessary to prevent hangs from astropy iers bug in astropy v 2.0.8 +#from astropy.utils import iers +#iers.conf.auto_download = False + +try: + import pkg_resources + version = pkg_resources.get_distribution("ehtim").version + print("Welcome to eht-imaging! v", version,'\n') +except: + print("Welcome to eht-imaging!\n") + + +def logo(): + for line in BHIMAGE: + print(line) + + +def eht(): + for line in EHTIMAGE: + print(line) diff --git a/array.py b/array.py new file mode 100644 index 00000000..8415f22d --- /dev/null +++ b/array.py @@ -0,0 +1,388 @@ +# array.py +# a interferometric telescope array class +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import copy +from astropy.time import Time +import matplotlib.pyplot as plt +import matplotlib +import ehtim.observing.obs_simulate as simobs +import ehtim.observing.obs_helpers as obsh +import ehtim.io.save +import ehtim.io.load +import ehtim.const_def as ehc +from ehtim.caltable import plot_tarr_dterms + + +################################################################################################### +# Array object +################################################################################################### + + +class Array(object): + + """A VLBI array of telescopes with site locations, SEFDs, and other data. + + Attributes: + tarr (numpy.recarray): The array of telescope data with datatype DTARR + tkey (dict): A dictionary of rows in the tarr for each site name + ephem (dict): A dictionary of TLEs for each space antenna, + Space antennas have x=y=z=0 in the tarr + """ + + def __init__(self, tarr, ephem={}): + self.tarr = tarr + self.ephem = ephem + + # check to see if ephemeris is correct + for line in self.tarr: + if np.any(np.isnan([line['x'], line['y'], line['z']])): + sitename = str(line['site']) + try: + elen = len(ephem[sitename]) + except NameError: + raise Exception('no ephemeris for site %s !' % sitename) + if elen != 3: + raise Exception('wrong ephemeris format for site %s !' % sitename) + + # Dictionary of array indices for site names + self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))} + + @property + def tarr(self): + return self._tarr + + @tarr.setter + def tarr(self, tarr): + self._tarr = tarr + self.tkey = {tarr[i]['site']: i for i in range(len(tarr))} + + def copy(self): + """Copy the array object. + + Args: + + Returns: + (Array): a copy of the Array object. + """ + + newarr = copy.deepcopy(self) + return newarr + + + def listbls(self): + """List all baselines. + + Args: + Returns: + numpy.array : array of baselines + """ + + bls = [] + for i1 in sorted(self.tarr['site']): + for i2 in sorted(self.tarr['site']): + if not ([i1, i2] in bls) and not ([i2, i1] in bls) and i1 != i2: + bls.append([i1, i2]) + bls = np.array(bls) + + return bls + + def obsdata(self, ra, dec, rf, bw, tint, tadv, tstart, tstop, + mjd=ehc.MJD_DEFAULT, timetype='UTC', polrep='stokes', + elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH, + no_elevcut_space=False, + tau=ehc.TAUDEF, fix_theta_GMST=False, reorder=True): + """Generate u,v points and baseline uncertainties. + + Args: + ra (float): the source right ascension in fractional hours + dec (float): the source declination in fractional degrees + tint (float): the scan integration time in seconds + tadv (float): the uniform cadence between scans in seconds + tstart (float): the start time of the observation in hours + tstop (float): the end time of the observation in hours + + mjd (int): the mjd of the observation + timetype (str): how to interpret tstart and tstop; either 'GMST' or 'UTC' + polrep (str): polarization representation, either 'stokes' or 'circ' + elevmin (float): station minimum elevation in degrees + elevmax (float): station maximum elevation in degrees + no_elevcut_space (bool): if True, do not apply elevation cut to orbiters + tau (float): the base opacity at all sites, or a dict giving one opacity per site + fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v points + + Returns: + Obsdata: an observation object with no data + + """ + + obsarr = simobs.make_uvpoints(self, ra, dec, rf, bw, + tint, tadv, tstart, tstop, + mjd=mjd, polrep=polrep, tau=tau, + elevmin=elevmin, elevmax=elevmax, + no_elevcut_space=no_elevcut_space, + timetype=timetype, fix_theta_GMST=fix_theta_GMST) + + uniquetimes = np.sort(np.unique(obsarr['time'])) + scans = np.array([[time - 0.5 * tadv, time + 0.5 * tadv] for time in uniquetimes]) + source = str(ra) + ":" + str(dec) + obs = ehtim.obsdata.Obsdata(ra, dec, rf, bw, obsarr, self.tarr, + source=source, mjd=mjd, timetype=timetype, polrep=polrep, + ampcal=True, phasecal=True, opacitycal=True, + dcal=True, frcal=True, + scantable=scans, reorder=reorder) + + return obs + + def make_subarray(self, sites): + """Make a subarray from the Array object array that only includes the sites listed. + + Args: + sites (list) : list of sites in the subarray + Returns: + Array: an Array object with specified sites and metadata + """ + all_sites = [t[0] for t in self.tarr] + mask = np.array([t in sites for t in all_sites]) + subarr = Array(self.tarr[mask], ephem=self.ephem) + return subarr + + def save_txt(self, fname): + """Save the array data in a text file. + + Args: + fname (str) : path to output array file + """ + ehtim.io.save.save_array_txt(self, fname) + return + + def plot_dterms(self, sites='all', label=None, legend=True, clist=ehc.SCOLORS, + rangex=False, rangey=False, markersize=2 * ehc.MARKERSIZE, + show=True, grid=True, export_pdf=""): + """Make a plot of the D-terms. + + Args: + sites (list) : list of sites to plot + label (str) : title for plot + legend (bool) : add telescope legend or not + clist (list) : list of colors for different stations + rangex (list) : lower and upper x-axis limits + rangey (list) : lower and upper y-axis limits + markersize (float) : marker size + show (bool) : display the plot or not + grid (bool) : add a grid to the plot or not + export_pdf (str) : save a pdf file to this path + + Returns: + matplotlib.axes + """ + # sites + if sites in ['all' or 'All'] or sites == []: + sites = list(self.tkey.keys()) + + if not isinstance(sites, list): + sites = [sites] + + keys = [self.tkey[site] for site in sites] + + axes = plot_tarr_dterms(self.tarr, keys=keys, label=label, legend=legend, clist=clist, + rangex=rangex, rangey=rangey, markersize=markersize, + show=show, grid=grid, export_pdf=export_pdf) + + return axes + + def add_site(self, site, coords, sefd=10000, + fr_par=0, fr_elev=0, fr_off=0, + dr=0.+0.j, dl=0.+0.j): + + """Add a ground station to the array + + """ + tarr_old = self.tarr.copy() + ephem_old = self.ephem.copy() + + + tarr_newline = np.array((str(site), float(coords[0]), float(coords[1]), float(coords[2]), + float(sefd), float(sefd), + dr, dl, + float(fr_par), float(fr_elev), float(fr_off)), dtype=ehc.DTARR) + tarr_new = np.append(tarr_old, tarr_newline) + + arr_out = Array(tarr_new, ephem_old) + return arr_out + + def remove_site(self, site): + """Remove a site from the array + + """ + tarr_old = self.tarr.copy() + ephem_old = self.ephem.copy() + ephem_new = ephem_old.copy() + + try: + tarr_new = np.delete(tarr_old.copy(), self.tkey[site]) + if site in ephem_old.keys(): + ephem_new.pop(site) + except: + raise Exception("could not find site %s to delete from Array!"%site) + + arr_out = Array(tarr_new, ephem_new) + return arr_out + + def add_satellite_tle(self, tlelist, sefd=10000): + + """Add an earth-orbiting satellite to the array from a TLE + + Args: + tlearr (str) : 3 element list with [name, tle line 1, tle line 2] as strings + sefd (float) : assumed sefd for the array file (assumes sefdl = sefdr) + """ + satname = tlearr[0] + tarr_new = self.tarr.copy() + ephem_new = self.ephem.copy() + + tarr_newline = np.array((str(satname), 0., 0., 0., + float(sefd), float(sefd), + 0., 0., 0., 0., 0.), dtype=ehc.DTARR) + tarr_new = np.append(tarr_new, tarr_newline) + ephem_new[satname] = tlearr + arr_out = Array(tarr_new, ephem_new) + + return arr_out + + def add_satellite_elements(self, satname, + perigee_mjd=Time.now().mjd, + period_days=1., eccentricity=0., + inclination=0., arg_perigee=0., long_ascending=0., + sefd=10000): + """Add an earth-orbiting satellite to the array from simple keplerian elements + perfect keplerian orbit is assumed, no derivatives + + Args: + perigee time given in mjd + period given in days + inclination, arg_perigee, long_ascending given in degrees + """ + + tarr_new = self.tarr.copy() + ephem_new = self.ephem.copy() + + tarr_newline = np.array((str(satname), 0., 0., 0., + float(sefd), float(sefd), + 0., 0., 0., 0., 0.), dtype=ehc.DTARR) + tarr_new = np.append(tarr_new, tarr_newline) + + ephem_new[satname] = [perigee_mjd, period_days, eccentricity, inclination, arg_perigee, long_ascending] + arr_out = Array(tarr_new, ephem_new) + + return arr_out + + def plot_satellite_orbits(self, tstart_mjd=Time.now().mjd, tstop_mjd=Time.now().mjd+1, npoints=1000): + earth_radius_polar = 6357. #km + earth_radius_eq = 6378. + + fig = plt.figure(figsize=(18,6)) + gs = matplotlib.gridspec.GridSpec(1,3,width_ratios=[1,1,1]) + + satellites = self.ephem.keys() + for i,satellite in enumerate(satellites): + + if i==0: color='k' + else: color=ehc.SCOLORS[i-1] + + # get skyfield satelllite object + if len(self.ephem[satellite])==3: # TLE + line1 = self.ephem[satellite][1] + line2 = self.ephem[satellite][2] + sat = obsh.sat_skyfield_from_tle(satellite, line1, line2) + elif len(self.ephem[satellite])==6: #keplerian elements + elements = self.ephem[satellite] + sat = obsh.sat_skyfield_from_elements(satellite, tstart_mjd, + elements[0],elements[1],elements[2],elements[3],elements[4],elements[5]) + else: + raise Exception("ephemeris format not recognized for %s"%satellite) + + # get GCRS positions + fracmjds = np.linspace(tstart_mjd, tstop_mjd, npoints) + positions = obsh.orbit_skyfield(sat, fracmjds, whichout='gcrs') + positions *= 1.e-3 # convert to km + distances = np.sqrt(positions[0]**2 + positions[1]**2 + positions[2]**2) + maxdist = np.max(distances) + + ax1 = fig.add_subplot(gs[0]) + ax1.set_aspect(1) + plt.plot(positions[0], positions[1], color=color, marker='.',ls='None') + circle1 = matplotlib.patches.Circle((0, 0), earth_radius_eq, color='b') + plt.gca().add_patch(circle1) + plt.xlabel('x (km)') + plt.ylabel('y (km)') + plt.xlim(-1.1*maxdist, 1.1*maxdist) + plt.ylim(-1.1*maxdist, 1.1*maxdist) + plt.grid() + + ax2 = fig.add_subplot(gs[1]) + ax2.set_aspect(1) + plt.plot(positions[1], positions[2], color=color, marker='.',ls='None') + circle1 = matplotlib.patches.Ellipse((0, 0), 2*earth_radius_eq, 2*earth_radius_polar, color='b') + plt.gca().add_patch(circle1) + plt.xlabel('y (km)') + plt.ylabel('z (km)') + plt.xlim(-1.1*maxdist, 1.1*maxdist) + plt.ylim(-1.1*maxdist, 1.1*maxdist) + plt.grid() + + ax3 = fig.add_subplot(gs[2]) + ax3.set_aspect(1) + plt.plot(positions[0], positions[2], color=color, marker='.',ls='None', label=satellite) + circle1 = matplotlib.patches.Ellipse((0, 0), 2*earth_radius_eq, 2*earth_radius_polar, color='b') + plt.gca().add_patch(circle1) + plt.xlabel('x (km)') + plt.ylabel('z (km)') + plt.xlim(-1.1*maxdist, 1.1*maxdist) + plt.ylim(-1.1*maxdist, 1.1*maxdist) + plt.legend(frameon=False,loc='center left', bbox_to_anchor=(1, 0.5)) + plt.grid() + + plt.subplots_adjust(wspace=1) + ehc.show_noblock() + return + +########################################################################## +# Array creation functions +########################################################################## + + +def load_txt(fname, ephemdir='ephemeris'): + """Read an array from a text file. + Sites with x=y=z=0 are spacecraft, TLE ephemerides read from ephemdir. + + Args: + fname (str) : path to input array file + ephemdir (str) : path to directory with TLE ephemerides for spacecraft + Returns: + Array: an Array object loaded from file + """ + + return ehtim.io.load.load_array_txt(fname, ephemdir=ephemdir) diff --git a/calibrating/__init__.py b/calibrating/__init__.py new file mode 100644 index 00000000..ede0bd27 --- /dev/null +++ b/calibrating/__init__.py @@ -0,0 +1,15 @@ +""" +.. module:: ehtim.calibrating + :platform: Unix + :synopsis: EHT Imaging Utilities: calibration functions + +.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu) + +""" +from . import self_cal +from . import network_cal +from . import pol_cal +from . import pol_cal_new +from . import polgains_cal + +from ..const_def import * diff --git a/calibrating/cal_helpers.py b/calibrating/cal_helpers.py new file mode 100644 index 00000000..c3976400 --- /dev/null +++ b/calibrating/cal_helpers.py @@ -0,0 +1,74 @@ +# cal_helpers.py +# helper functions for calibration +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import itertools as it + +import ehtim.const_def as ehc + +ZBLCUTOFF = 1.e7 + + +def make_cluster_data(obs, zbl_uvdist_max=ZBLCUTOFF): + """Cluster sites in an observation into groups + with intra-group basline length not exceeding zbl_uvdist_max + """ + + clusters = [] + clustered_sites = [] + for i1 in range(len(obs.tarr)): + t1 = obs.tarr[i1] + + if t1['site'] in clustered_sites: + continue + + csites = [t1['site']] + clustered_sites.append(t1['site']) + for i2 in range(len(obs.tarr))[i1:]: + t2 = obs.tarr[i2] + if t2['site'] in clustered_sites: + continue + + site1coord = np.array([t1['x'], t1['y'], t1['z']]) + site2coord = np.array([t2['x'], t2['y'], t2['z']]) + uvdist = np.sqrt(np.sum((site1coord - site2coord)**2)) / (ehc.C / obs.rf) + + if uvdist < zbl_uvdist_max: + csites.append(t2['site']) + clustered_sites.append(t2['site']) + clusters.append(csites) + + clusterdict = {} + for site in obs.tarr['site']: + for k in range(len(clusters)): + if site in clusters[k]: + clusterdict[site] = k + + clusterbls = [set(comb) for comb in it.combinations(range(len(clusterdict)), 2)] + + cluster_data = (clusters, clusterdict, clusterbls) + + return cluster_data diff --git a/calibrating/network_cal.py b/calibrating/network_cal.py new file mode 100644 index 00000000..8fa1cedc --- /dev/null +++ b/calibrating/network_cal.py @@ -0,0 +1,463 @@ +# network_cal.py +# functions for network-calibration +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import scipy.optimize as opt +import time +import copy +from multiprocessing import cpu_count, Pool + +import ehtim.obsdata +import ehtim.parloop as parloop +from . import cal_helpers as calh +import ehtim.observing.obs_helpers as obsh +import ehtim.const_def as ehc + +ZBLCUTOFF = 1.e7 +MAXIT = 5000 + +################################################################################################### +# Network-Calibration +################################################################################################### + + +def network_cal(obs, zbl, sites=[], zbl_uvdist_max=ZBLCUTOFF, method="amp", minimizer_method='BFGS', + pol='I', pad_amp=0., gain_tol=.2, solution_interval=0.0, scan_solutions=False, + caltable=False, processes=-1, show_solution=False, debias=True, msgtype='bar'): + """Network-calibrate a dataset with zero baseline constraints. + + Args: + obs (Obsdata): The observation to be calibrated + zbl (float or function): constant zero baseline flux in Jy, or a function of UT hour. + sites (list): list of sites to include in the network calibration. + empty list calibrates all sites + zbl_uvdist_max (float): maximum uv-distance considered a zero baseline + method (str): chooses what to calibrate, 'amp', 'phase', or 'both'. + minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS') + pol (str): which visibility to compute gains for + + pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature + gain_tol (float): gains that exceed this value will be disfavored by the prior + solution_interval (float): solution interval in seconds; + one gain is derived for each interval. + If 0.0, a solution is determined for each unique time + scan_solutions (bool): If True, determine one gain per site per scan. + Supersedes solution_interval + + debias (bool): If True, debias the amplitudes + caltable (bool): if True, returns a Caltable instead of an Obsdata + processes (int): number of cores to use in multiprocessing + show_solution (bool): if True, display the solution as it is calculated + msgtype (str): type of progress message to be printed, default is 'bar' + + Returns: + (Obsdata): the calibrated observation, if caltable==False + (Caltable): the derived calibration table, if caltable==True + """ + + # Here, RRLL means to use both RR and LL (both as proxies for Stokes I) + # to derive a network calibration solution + if pol not in ['I', 'Q', 'U', 'V', 'RR', 'LL', 'RRLL']: + raise Exception("Can only network-calibrate to I, Q, U, V, RR, LL, or RRLL!") + if pol in ['I', 'Q', 'U', 'V']: + if obs.polrep != 'stokes': + raise Exception("netcal pol is a stokes parameter, but obs.polrep!='stokes'") + # obs = obs.switch_polrep('stokes',pol) + elif pol in ['RR', 'LL', 'RRLL']: + if obs.polrep != 'circ': + raise Exception("netcal pol is RR or LL or RRLL, but obs.polrep!='circ'") + # obs = obs.switch_polrep('circ',pol) + + # V = model visibility, V' = measured visibility, G_i = site gain + # G_i * conj(G_j) * V_ij = V'_ij + if len(sites) == 0: + print("No stations specified in network cal: defaulting to calibrating all stations!") + sites = obs.tarr['site'] + + # find colocated sites and put into list allclusters + cluster_data = calh.make_cluster_data(obs, zbl_uvdist_max) + + # get scans + scans = obs.tlist(t_gather=solution_interval, scan_gather=scan_solutions) + scans_cal = copy.copy(scans) + + # Make the pool for parallel processing + if processes > 0: + counter = parloop.Counter(initval=0, maxval=len(scans)) + if processes > len(scans): + processes = len(scans) + print("Using Multiprocessing with %d Processes" % processes) + pool = Pool(processes=processes, initializer=init, initargs=(counter,)) + elif processes == 0: + counter = parloop.Counter(initval=0, maxval=len(scans)) + processes = int(cpu_count()) + if processes > len(scans): + processes = len(scans) + print("Using Multiprocessing with %d Processes" % processes) + pool = Pool(processes=processes, initializer=init, initargs=(counter,)) + else: + print("Not Using Multiprocessing") + + # loop over scans and calibrate + tstart = time.time() + if processes > 0: # with multiprocessing + scans_cal = np.array(pool.map(get_network_scan_cal, [[i, len(scans), scans[i], + zbl, sites, cluster_data, obs.polrep, pol, + method, pad_amp, gain_tol, + caltable, show_solution, debias, msgtype + ] for i in range(len(scans))]), + dtype=object) + else: # without multiprocessing + for i in range(len(scans)): + obsh.prog_msg(i, len(scans), msgtype=msgtype, nscan_last=i - 1) + scans_cal[i] = network_cal_scan(scans[i], zbl, sites, cluster_data, + polrep=obs.polrep, pol=pol, + method=method, minimizer_method=minimizer_method, + show_solution=show_solution, caltable=caltable, + pad_amp=pad_amp, gain_tol=gain_tol, debias=debias) + + tstop = time.time() + print("\nnetwork_cal time: %f s" % (tstop - tstart)) + + if caltable: # create and return a caltable + allsites = obs.tarr['site'] + caldict = {k: v.reshape(1) for k, v in scans_cal[0].items()} + for i in range(1, len(scans_cal)): + row = scans_cal[i] + if len(row) == 0: + continue + + for site in allsites: + try: + dat = row[site] + except KeyError: + continue + + try: + caldict[site] = np.append(caldict[site], row[site]) + except KeyError: + caldict[site] = [dat] + + caltable = ehtim.caltable.Caltable(obs.ra, obs.dec, obs.rf, obs.bw, caldict, obs.tarr, + source=obs.source, mjd=obs.mjd, timetype=obs.timetype) + out = caltable + + else: # return the calibrated observation + arglist, argdict = obs.obsdata_args() + arglist[4] = np.concatenate(scans_cal) + out = ehtim.obsdata.Obsdata(*arglist, **argdict) + + # close multiprocessing jobs + if processes != -1: + pool.close() + + return out + + +def network_cal_scan(scan, zbl, sites, clustered_sites, polrep='stokes', pol='I', + zbl_uvidst_max=ZBLCUTOFF, method="both", minimizer_method='BFGS', + show_solution=False, pad_amp=0., gain_tol=.2, caltable=False, debias=True): + """Network-calibrate a scan with zero baseline constraints. + + Args: + obs (Obsdata): The observation to be calibrated + zbl (float or function): constant zero baseline flux in Jy, or a function of UT hour. + sites (list): list of sites to include in the network calibration. + empty list calibrates all sites + clustered_sites (tuple): information on clustered sites, returned by make_cluster_data + + polrep (str): 'stokes' or 'circ' to specify the polarization products in scan + pol (str): which image polarization to self-calibrate visibilities to + zbl_uvdist_max (float): maximum uv-distance considered a zero baseline + method (str): chooses what to calibrate, 'amp', 'phase', or 'both' + pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature + gain_tol (float): gains that exceed this value will be disfavored by the prior + + debias (bool): If True, debias the amplitudes + caltable (bool): if True, returns a Caltable instead of an Obsdata + show_solution (bool): if True, display the solution as it is calculated + + + Returns: + (Obsdata): the calibrated scan, if caltable==False + (Caltable): the derived calibration table, if caltable==True + """ + + # determine the zero-baseline flux of the scan + if callable(zbl): + zbl_scan = np.median(zbl(scan['time'])) + else: + zbl_scan = zbl + + # clustered site information + allclusters = clustered_sites[0] + clusterdict = clustered_sites[1] + clusterbls = clustered_sites[2] + + # all the sites in the scan + allsites = list(set(np.hstack((scan['t1'], scan['t2'])))) + + if len(sites) == 0: + print("No stations specified in network cal: defaulting to calibrating all !") + sites = allsites + + # only include sites that are present + sites = [s for s in sites if s in allsites] + + # create a dictionary to keep track of gains; + # sites that aren't network calibrated (no co-located partners) get a value of -1 + # so that they won't be network calibrated; other sites get a unique number + tkey = {b: a for a, b in enumerate(sites)} + for cluster in allclusters: + if len(cluster) == 1: + tkey[cluster[0]] = -1 + + clusterkey = clusterdict + + # restrict solved cluster visibilities to ones present in the scan + # (this is much faster than allowing many unconstrained variables) + clusterbls_scan = [set([clusterkey[row['t1']], clusterkey[row['t2']]]) + for row in scan + if len(set([clusterkey[row['t1']], clusterkey[row['t2']]])) == 2] + + # now delete duplicates + clusterbls = [cluster for cluster in clusterbls if cluster in clusterbls_scan] + + # make two lists of gain keys that relates scan bl gains to solved site ones + # (-1 means that this station does not have a gain that is being solved for) + # and make one list of scan keys that relates scan bl visibilities to solved cluster ones + # (-1 means it's a zero baseline!) + + g1_keys = [] + g2_keys = [] + scan_keys = [] + for row in scan: + try: + g1_keys.append(tkey[row['t1']]) + except KeyError: + g1_keys.append(-1) + try: + g2_keys.append(tkey[row['t2']]) + except KeyError: + g2_keys.append(-1) + + clusternum1 = clusterkey[row['t1']] + clusternum2 = clusterkey[row['t2']] + + if clusternum1 == clusternum2: # sites are in the same cluster + scan_keys.append(-1) + else: # sites are not in the same cluster + bl_index = clusterbls.index(set((clusternum1, clusternum2))) + scan_keys.append(bl_index) + + # no sites to calibrate on this scan! + # if np.all(g1_keys == -1): + # return scan #Doesn't work with the caldict options + + # Start by restricting to visibilities that include baselines to a site with a zero-baseline + vis_mask = [((row['t1'] in tkey.keys() and tkey[row['t1']] != -1) + or (row['t2'] in tkey.keys() and tkey[row['t2']] != -1)) for row in scan] + + # get scan visibilities of the specified polarization + if pol != 'RRLL': + vis = scan[ehc.vis_poldict[pol]] + sigma = scan[ehc.sig_poldict[pol]] + else: + vis = np.concatenate([scan[ehc.vis_poldict['RR']], scan[ehc.vis_poldict['LL']]]) + sigma = np.concatenate([scan[ehc.sig_poldict['RR']], scan[ehc.sig_poldict['LL']]]) + vis_mask = np.concatenate([vis_mask, vis_mask]) + + if method == 'amp': + if debias: + vis = obsh.amp_debias(np.abs(vis), np.abs(sigma)) + else: + vis = np.abs(vis) + + sigma_inv = 1.0 / np.sqrt(sigma**2 + (pad_amp * np.abs(vis))**2) + + # initial guesses for parameters + n_gains = len(sites) + n_clusterbls = len(clusterbls) + if show_solution: + print('%d Gains; %d Clusters' % (n_gains, n_clusterbls)) + + gpar_guess = np.ones(n_gains, dtype=np.complex128).view(dtype=np.float64) + vpar_guess = np.ones(n_clusterbls, dtype=np.complex128) + for i in range(len(scan_keys)): + if scan_keys[i] < 0: + continue + if np.isnan(vis[i]): + continue + vpar_guess[scan_keys[i]] = vis[i] + + vpar_guess = vpar_guess.view(dtype=np.float64) + gvpar_guess = np.hstack((gpar_guess, vpar_guess)) + + # error function + def errfunc(gvpar): + + # all the forward site gains (complex) + g = gvpar[0:2 * n_gains].astype(np.float64).view(dtype=np.complex128) + + # all the intercluster visibilities (complex) + v = gvpar[2 * n_gains:].astype(np.float64).view(dtype=np.complex128) + + # choose to only scale ampliltudes or phases + if method == "phase": + g = g / np.abs(g) + elif method == "amp": + g = np.abs(np.real(g)) + + # append the default values to g for missing points + # and to v for the zero baseline points + g = np.append(g, 1.) + v = np.append(v, zbl_scan) + + # scan visibilities are either an intercluster visibility or the fixed zbl + v_scan = v[scan_keys] + g1 = g[g1_keys] + g2 = g[g2_keys] + if pol == 'RRLL': + v_scan = np.concatenate([v_scan, v_scan]) + g1 = np.concatenate([g1, g1]) + g2 = np.concatenate([g2, g2]) + + if method == 'amp': + verr = np.abs(vis) - g1 * g2.conj() * np.abs(v_scan) + else: + verr = vis - g1 * g2.conj() * v_scan + + chi = np.abs(verr) * sigma_inv + chisq = np.sum((chi * chi)[np.isfinite(chi) * vis_mask]) + + # prior on the gains + g_fracerr = gain_tol + if method == "phase": + chisq_g = 0 # because |g| == 1 so log(|g|) = 0 + elif method == "amp": + logg = np.log(g) + chisq_g = np.sum(logg * logg) / (g_fracerr * g_fracerr) + else: + logabsg = np.log(np.abs(g)) + chisq_g = np.sum(logabsg * logabsg) / (g_fracerr * g_fracerr) + + absv = np.abs(v) + vv = absv * absv + chisq_v = np.sum(vv * vv) / zbl_scan**4 + return chisq + chisq_g + chisq_v + + if np.max(g1_keys) > -1 or np.max(g2_keys) > -1: + # run the minimizer to get a solution (but only run if there's at least one gain to fit) + optdict = {'maxiter': MAXIT} # minimizer params + res = opt.minimize(errfunc, gvpar_guess, method=minimizer_method, options=optdict) + + # get solution + g_fit = res.x[0:2 * n_gains].view(np.complex128) + v_fit = res.x[2 * n_gains:].view(np.complex128) + + if method == "phase": + g_fit = g_fit / np.abs(g_fit) + if method == "amp": + g_fit = np.abs(np.real(g_fit)) + + if show_solution: + print(np.abs(g_fit)) + print(np.abs(v_fit)) + else: + g_fit = [] + v_fit = [] + + g_fit = np.append(g_fit, 1.) + v_fit = np.append(v_fit, zbl_scan) + + # Derive a calibration table or apply the solution to the scan + if caltable: + allsites = list(set(scan['t1']).union(set(scan['t2']))) + + caldict = {} + for site in allsites: + if site in sites: + site_key = tkey[site] + else: + site_key = -1 + + # We will *always* set the R and L gain corrections to be equal in network calibration, + # to avoid breaking polarization consistency relationships + rscale = g_fit[site_key]**-1 + lscale = g_fit[site_key]**-1 + + # Note: we may want to give two entries for the start/stop times + # when a non-zero solution interval is used + caldict[site] = np.array((scan['time'][0], rscale, lscale), dtype=ehc.DTCAL) + + out = caldict + + else: + g1_fit = g_fit[g1_keys] + g2_fit = g_fit[g2_keys] + + gij_inv = (g1_fit * g2_fit.conj())**(-1) + + if polrep == 'stokes': + # scale visibilities + for vistype in ['vis', 'qvis', 'uvis', 'vvis']: + scan[vistype] *= gij_inv + # scale sigmas + for sigtype in ['sigma', 'qsigma', 'usigma', 'vsigma']: + scan[sigtype] *= np.abs(gij_inv) + elif polrep == 'circ': + # scale visibilities + for vistype in ['rrvis', 'llvis', 'rlvis', 'lrvis']: + scan[vistype] *= gij_inv + # scale sigmas + for sigtype in ['rrsigma', 'llsigma', 'rlsigma', 'lrsigma']: + scan[sigtype] *= np.abs(gij_inv) + + out = scan + + return out + + +def init(x): + global counter + counter = x + + +def get_network_scan_cal(args): + return get_network_scan_cal2(*args) + + +def get_network_scan_cal2(i, n, scan, zbl, sites, cluster_data, polrep, pol, + method, pad_amp, gain_tol, caltable, show_solution, debias, msgtype): + if n > 1: + global counter + counter.increment() + obsh.prog_msg(counter.value(), counter.maxval, msgtype, counter.value() - 1) + + return network_cal_scan(scan, zbl, sites, cluster_data, polrep=polrep, pol=pol, + zbl_uvidst_max=ZBLCUTOFF, method=method, caltable=caltable, + show_solution=show_solution, + pad_amp=pad_amp, gain_tol=gain_tol, debias=debias) diff --git a/calibrating/pol_cal.py b/calibrating/pol_cal.py new file mode 100644 index 00000000..d317ba35 --- /dev/null +++ b/calibrating/pol_cal.py @@ -0,0 +1,400 @@ +# pol_cal.py +# functions for polarimetric-calibration +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import scipy.optimize as opt +import matplotlib.pyplot as plt +import time + +import ehtim.imaging.imager_utils as iu +import ehtim.observing.obs_simulate as simobs +import ehtim.const_def as ehc + +MAXIT = 1000 # maximum number of iterations in pol-cal minimizer + +################################################################################################### +# Polarimetric Calibration +################################################################################################### + + +def leakage_cal(obs, im=None, sites=[], leakage_tol=.1, pol_fit=['RL', 'LR'], dtype='vis', + const_fpol=False, inverse=False, minimizer_method='L-BFGS-B', + ttype='direct', fft_pad_factor=2, show_solution=True, obs_apply=False): + """Polarimetric calibration (detects and removes polarimetric leakage, + based on consistency with a given image) + + Args: + obs (Obsdata): The observation to be calibrated + im (Image): the reference image used for calibration + (not needed if using const_fpol = True) + sites (list): list of sites to include in the polarimetric calibration. + empty list calibrates all sites + + leakage_tol (float): leakage values exceeding this value will be disfavored by the prior + pol_fit (list): list of visibilities to use; e.g., ['RL','LR'] or ['RR','LL','RL','LR'] + dtype (str): Type of data to fit ('vis' for complex visibilities; + 'amp' for just the amplitudes) + const_fpol (bool): If true, solve for a single fractional polarization + across all baselines in addition to leakage. + For this option, the passed image is not used. + minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS') + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + + show_solution (bool): if True, display the solution as it is calculated + obs_apply (Obsdata): apply the solution to another observation + Returns: + (Obsdata): the calibrated observation, with computed leakage values added to the obs.tarr + """ + tstart = time.time() + + mask = [] + + # Do everything in a circular basis + if not const_fpol: + im_circ = im.switch_polrep('circ') + else: + im_circ = None + + if dtype not in ['vis', 'amp']: + raise Exception('dtype must be vis or amp') + + # Create the obsdata object for searching + obs_test = obs.copy() + obs_test = obs_test.switch_polrep('circ') + + # Check to see if the field rotation is corrected + if obs_test.frcal is False: + print("Field rotation angles have not been corrected. Correcting now...") + obs_test.data = simobs.apply_jones_inverse(obs_test, frcal=False, dcal=True, verbose=False) + obs_test.frcal = True + + # List of all sites present in the observation + allsites = list(set(np.hstack((obs.data['t1'], obs.data['t2'])))) + + if len(sites) == 0: + print("No stations specified for leakage calibration: defaulting to calibrating all !") + sites = allsites + + # Set all leakage terms in obs_test to zero + # (we will only correct leakage for those sites with new solutions) + for j in range(len(obs_test.tarr)): + if obs_test.tarr[j]['site'] in sites: + continue + obs_test.tarr[j]['dr'] = obs_test.tarr[j]['dl'] = 0.0j + + # only include sites that are present + sites = [s for s in sites if s in allsites] + site_index = [list(obs.tarr['site']).index(s) for s in sites] + + if not const_fpol: + (dataRR, sigmaRR, ARR) = iu.chisqdata(obs, im_circ, mask=mask, dtype=dtype, pol='RR', + ttype=ttype, fft_pad_factor=fft_pad_factor) + (dataLL, sigmaLL, ALL) = iu.chisqdata(obs, im_circ, mask=mask, dtype=dtype, pol='LL', + ttype=ttype, fft_pad_factor=fft_pad_factor) + (dataRL, sigmaRL, ARL) = iu.chisqdata(obs, im_circ, mask=mask, dtype=dtype, pol='RL', + ttype=ttype, fft_pad_factor=fft_pad_factor) + (dataLR, sigmaLR, ALR) = iu.chisqdata(obs, im_circ, mask=mask, dtype=dtype, pol='LR', + ttype=ttype, fft_pad_factor=fft_pad_factor) + + # If inverse modeling, pre-compute field rotation angles + if inverse: + el1 = obs.unpack(['el1'], ang_unit='rad')['el1'] + el2 = obs.unpack(['el2'], ang_unit='rad')['el2'] + par1 = obs.unpack(['par_ang1'], ang_unit='rad')['par_ang1'] + par2 = obs.unpack(['par_ang2'], ang_unit='rad')['par_ang2'] + + fr_elev1 = np.array([obs.tarr[obs.tkey[o['t1']]]['fr_elev'] for o in obs.data]) + fr_elev2 = np.array([obs.tarr[obs.tkey[o['t2']]]['fr_elev'] for o in obs.data]) + fr_par1 = np.array([obs.tarr[obs.tkey[o['t1']]]['fr_par'] for o in obs.data]) + fr_par2 = np.array([obs.tarr[obs.tkey[o['t2']]]['fr_par'] for o in obs.data]) + fr_off1 = np.array([obs.tarr[obs.tkey[o['t1']]]['fr_off'] for o in obs.data]) + fr_off2 = np.array([obs.tarr[obs.tkey[o['t2']]]['fr_off'] for o in obs.data]) + + fr1 = fr_elev1*el1 + fr_par1*par1 + fr_off1*np.pi/180. + fr2 = fr_elev2*el2 + fr_par2*par2 + fr_off2*np.pi/180. + + def chisq_total(data, im, D, inverse=False): + if const_fpol: + if inverse: + # Note: These are linearized approximations (in leakage and source polarization) + fpol_model = D[-2] + cpol_model = np.real(D[-1]) + + D1R = [D[2*sites.index(o['t1'])] for o in data] + D1L = [D[2*sites.index(o['t1'])+1] for o in data] + D2R = [D[2*sites.index(o['t2'])] for o in data] + D2L = [D[2*sites.index(o['t2'])+1] for o in data] + + lrll = data['lrvis']/data['llvis'] + lrrr = data['lrvis']/data['rrvis'] + rlll = data['rlvis']/data['llvis'] + rlrr = data['rlvis']/data['rrvis'] + + lrll_sigma = data['lrsigma']/np.abs(data['llvis']) + lrrr_sigma = data['lrsigma']/np.abs(data['rrvis']) + rlll_sigma = data['rlsigma']/np.abs(data['llvis']) + rlrr_sigma = data['rlsigma']/np.abs(data['rrvis']) + + lrll_model = (np.conjugate(fpol_model) * (1.0 + cpol_model) + + D1L * np.exp(-2j*fr1) * (1.0 + 2.0*cpol_model) + + np.conjugate(D2R) * np.exp(-2j*fr2)) + lrrr_model = (np.conjugate(fpol_model) * (1.0 - cpol_model) + + D1L * np.exp(-2j*fr1) + + np.conjugate(D2R) * np.exp(-2j*fr2) * (1.0 - 2.0*cpol_model)) + rlll_model = (fpol_model * (1.0 + cpol_model) + + D1R * np.exp(2j*fr1) + + np.conjugate(D2L) * np.exp(2j*fr2) * (1.0 + 2.0*cpol_model)) + rlrr_model = (fpol_model * (1.0 - cpol_model) + + D1R * np.exp(2j*fr1) * (1.0 - 2.0*cpol_model) + + np.conjugate(D2L) * np.exp(2j*fr2)) + + chisq = np.concatenate([np.abs((lrll - lrll_model)/lrll_sigma)**2, + np.abs((lrrr - lrrr_model)/lrrr_sigma)**2, + np.abs((rlll - rlll_model)/rlll_sigma)**2, + np.abs((rlrr - rlrr_model)/rlrr_sigma)**2]) + chisq = chisq[~np.isnan(chisq)] + return np.mean(chisq) + else: + fpol_model = D[-2] + cpol_model = np.real(D[-1]) + + fpol_data_1 = 2.0 * data['rlvis']/(data['rrvis'] + data['llvis']) + fpol_data_2 = 2.0 * np.conj(data['lrvis']/(data['rrvis'] + data['llvis'])) + fpol_sigma_1 = 2.0/np.abs(data['rrvis'] + data['llvis']) * data['rlsigma'] + fpol_sigma_2 = 2.0/np.abs(data['rrvis'] + data['llvis']) * data['lrsigma'] + return 0.5*np.mean(np.abs((fpol_model - fpol_data_1)/fpol_sigma_1)**2 + + np.abs((fpol_model - fpol_data_2)/fpol_sigma_2)**2) + else: + chisq_RR = chisq_LL = chisq_RL = chisq_LR = 0.0 + if 'RR' in pol_fit: + chisq_RR = iu.chisq(im.rrvec, ARR, + obs_test.unpack_dat(data, ['rr' + dtype])['rr' + dtype], + data['rrsigma'], dtype=dtype, ttype=ttype, mask=mask) + if 'LL' in pol_fit: + chisq_LL = iu.chisq(im.llvec, ALL, + obs_test.unpack_dat(data, ['ll' + dtype])['ll' + dtype], + data['llsigma'], dtype=dtype, ttype=ttype, mask=mask) + if 'RL' in pol_fit: + chisq_RL = iu.chisq(im.rlvec, ARL, + obs_test.unpack_dat(data, ['rl' + dtype])['rl' + dtype], + data['rlsigma'], dtype=dtype, ttype=ttype, mask=mask) + if 'LR' in pol_fit: + chisq_LR = iu.chisq(im.lrvec, ALR, + obs_test.unpack_dat(data, ['lr' + dtype])['lr' + dtype], + data['lrsigma'], dtype=dtype, ttype=ttype, mask=mask) + return (chisq_RR + chisq_LL + chisq_RL + chisq_LR)/len(pol_fit) + + print("Finding leakage for sites:", sites) + + def errfunc(Dpar): + # all the D-terms (complex). If const_fpol, fpol is the last parameter. + D = Dpar.astype(np.float64).view(dtype=np.complex128) + + if not inverse: + for isite in range(len(sites)): + obs_test.tarr['dr'][site_index[isite]] = D[2*isite] + obs_test.tarr['dl'][site_index[isite]] = D[2*isite+1] + data = simobs.apply_jones_inverse(obs_test, dcal=False, verbose=False) + else: + data = obs.data + + # goodness-of-fit for the leakage-corrected data + chisq = chisq_total(data, im_circ, D, inverse=inverse) + + # prior on the D terms + chisq_D = np.sum(np.abs(D/leakage_tol)**2) + + return chisq + chisq_D + + # Now, we will minimize the total chi-squared. We need two complex leakage terms for each site + optdict = {'maxiter': MAXIT} # minimizer params + Dpar_guess = np.zeros((len(sites) + const_fpol*2)*2, dtype=np.complex128).view(dtype=np.float64) + print("Minimizing...") + res = opt.minimize(errfunc, Dpar_guess, method=minimizer_method, options=optdict) + print(errfunc(Dpar_guess),errfunc(res.x)) + + # get solution + D_fit = res.x.astype(np.float64).view(dtype=np.complex128) # all the D-terms (complex) + + # Apply the solution + for isite in range(len(sites)): + obs_test.tarr['dr'][site_index[isite]] = D_fit[2*isite] + obs_test.tarr['dl'][site_index[isite]] = D_fit[2*isite+1] + obs_test.data = simobs.apply_jones_inverse(obs_test, dcal=False, verbose=False) + obs_test.dcal = True + + # Re-populate any additional leakage terms that were present + for j in range(len(obs_test.tarr)): + if obs_test.tarr[j]['site'] in sites: + continue + obs_test.tarr[j]['dr'] = obs.tarr[j]['dr'] + obs_test.tarr[j]['dl'] = obs.tarr[j]['dl'] + + if show_solution: + if inverse is False: + chisq_orig = chisq_total(obs.switch_polrep('circ').data, + im_circ, D_fit, inverse=inverse) + chisq_new = chisq_total(obs_test.data, im_circ, D_fit, inverse=inverse) + else: + chisq_orig = chisq_total(obs.switch_polrep('circ').data, + im_circ, D_fit*0, inverse=inverse) + chisq_new = chisq_total(obs.switch_polrep('circ').data, im_circ, D_fit, inverse=inverse) + + print("Original chi-squared: {:.4f}".format(chisq_orig)) + print("New chi-squared: {:.4f}\n".format(chisq_new)) + for isite in range(len(sites)): + print(sites[isite]) + print(' D_R: {:.4f}'.format(D_fit[2*isite])) + print(' D_L: {:.4f}\n'.format(D_fit[2*isite+1])) + if const_fpol: + print('Source Fractional Polarization Magnitude: {:.4f}'.format(np.abs(D_fit[-2]))) + print('Source Fractional Polarization EVPA [deg]: {:.4f}\n'.format( + 90./np.pi*np.angle(D_fit[-2]))) + if inverse: + print('Source Fractional Circular Polarization: {:.4f}'.format(np.real(D_fit[-1]))) + + tstop = time.time() + print("\nleakage_cal time: %f s" % (tstop - tstart)) + + if obs_apply is not False: + # Apply the solution to another observation + obs_test = obs_apply.copy() + obs_test.tarr['dr'] *= 0.0 + obs_test.tarr['dl'] *= 0.0 + + # Copy the solved D-terms + for isite in range(len(sites)): + if sites[isite] in list(obs_test.tarr['site']): + i_site = list(obs_test.tarr['site']).index(sites[isite]) + obs_test.tarr['dr'][i_site] = D_fit[2*isite] + obs_test.tarr['dl'][i_site] = D_fit[2*isite+1] + + obs_test.data = simobs.apply_jones_inverse(obs_test, dcal=False, verbose=False) + obs_test.dcal = True + + # Copy in the remaining D-terms that were there before + for j in range(len(obs_test.tarr)): + if obs_test.tarr[j]['site'] in sites: + continue + obs_test.tarr[j]['dr'] = obs_apply.tarr[j]['dr'] + obs_test.tarr[j]['dl'] = obs_apply.tarr[j]['dl'] + else: + obs_test = obs_test.switch_polrep(obs.polrep) + + if not const_fpol: + return obs_test + else: + if inverse: + return [obs_test, D_fit[-2], D_fit[-1]] + else: + return [obs_test, D_fit[-2]] + + +def plot_leakage(obs, sites=[], axis=False, rangex=False, rangey=False, + markers=['o', 's'], markersize=6, + export_pdf="", axislabels=True, legend=True, sort_tarr=True, show=True): + """Plot polarimetric leakage terms in an observation + + Args: + obs (Obsdata): observation (or Array) containing the tarr + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + markers (str): pair of matplotlib plot markers (for RCP and LCP) + markersize (int): size of plot markers + label (str): plot legend label + + export_pdf (str): path to pdf file to save figure + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + """ + + tarr = obs.tarr.copy() + if sort_tarr: + tarr.sort(axis=0) + + if len(sites): + mask = [t in sites for t in tarr['site']] + tarr = tarr[mask] + + clist = ehc.SCOLORS + + # make plot(s) + if axis: + fig = axis.figure + x = axis + else: + fig = plt.figure() + x = fig.add_subplot(1, 1, 1) + + plt.axhline(0, color='black') + plt.axvline(0, color='black') + + xymax = np.max([np.abs(tarr['dr']), np.abs(tarr['dl'])])*100.0 + + plot_points = [] + for i in range(len(tarr)): + color = clist[i % len(clist)] + label = tarr['site'][i] + dre, = x.plot(np.real(tarr['dr'][i])*100.0, np.imag(tarr['dr'][i])*100.0, markers[0], + markersize=markersize, color=color, label=label) + dim, = x.plot(np.real(tarr['dl'][i])*100.0, np.imag(tarr['dl'][i])*100.0, markers[1], + markersize=markersize, color=color, label=label) + plot_points.append([dre, dim]) + + # Data ranges + if not rangex: + rangex = [-xymax*1.1-0.01, xymax*1.1+0.01] + + if not rangey: + rangey = [-xymax*1.1-0.01, xymax*1.1+0.01] + +# if not rangex and not rangey: +# plt.axes().set_aspect('equal', 'datalim') + + x.set_xlim(rangex) + x.set_ylim(rangey) + + # label and save + if axislabels: + x.set_xlabel('Re[$D$] (\%)') + x.set_ylabel('Im[$D$] (\%)') + if legend: + legend1 = plt.legend([l[0] for l in plot_points], tarr['site'], ncol=1, loc=1) + plt.legend(plot_points[0], ['$D_R$ (\%)', '$D_L$ (\%)'], loc=4) + plt.gca().add_artist(legend1) + if export_pdf != "": # and not axis: + fig.savefig(export_pdf, bbox_inches='tight') + if show: + #plt.show(block=False) + ehc.show_noblock() + + return x diff --git a/calibrating/pol_cal_new.py b/calibrating/pol_cal_new.py new file mode 100644 index 00000000..eac9841d --- /dev/null +++ b/calibrating/pol_cal_new.py @@ -0,0 +1,488 @@ +# pol_cal.py +# functions for D-term calibration +# new version that should be faster (2024) +# +# Copyright (C) 2024 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import scipy.optimize as opt +import matplotlib.pyplot as plt +import time + +import ehtim.imaging.imager_utils as iu +import ehtim.observing.obs_simulate as simobs +import ehtim.const_def as ehc + +MAXIT = 10000 # maximum number of iterations in self-cal minimizer +NHIST = 50 # number of steps to store for hessian approx +MAXLS = 40 # maximum number of line search steps in BFGS-B +STOP = 1e-6 # convergence criterion + +################################################################################################### +# Polarimetric Calibration +################################################################################################### + +# TODO - other chi^2 terms, not just 'vis'? +# TODO - do we want to start with some nonzero D-term initial guess? +# TODO - option to not frcal? +# TODO - pass other kwargs to the chisq? +# TODO - handle gain cal == False, read in gains from a caltable + +def leakage_cal_new(obs, im, sites=[], leakage_tol=.1, rescale_leakage_tol=False, + pol_fit=['RL', 'LR'], dtype='vis', + minimizer_method='L-BFGS-B', + ttype='direct', fft_pad_factor=2, + use_grad=True, + show_solution=True, apply_solution=True): + """Polarimetric calibration (detects and removes polarimetric leakage, + based on consistency with a given image) + + Args: + obs (Obsdata): The observation to be calibrated + im (Image): the reference image used for calibration + + sites (list): list of sites to include in the polarimetric calibration. + empty list calibrates all sites + + leakage_tol (float): leakage values exceeding this value will be disfavored by the prior + rescale_leakage_tol (bool): if True, properly scale leakage tol for number of sites + (not done correctly in old version) + + pol_fit (list): list of visibilities to use; e.g., ['RL','LR'] or ['RR','LL','RL','LR'] + + minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS') + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + + use_grad (bool): if True, use gradients in minimizer + + show_solution (bool): if True, display the solution as it is calculated + apply_solution (bool): if True, apply the solution to the Obsdata, otherwise just return in tarr + + Returns: + (Obsdata): the calibrated observation, with computed leakage values added to the obs.tarr + """ + + if not(obs.ampcal and obs.phasecal): + raise Exception("obs must be amplitude and phase calibrated before leakage_cal! (TODO: generalize)") + + + tstart = time.time() + + mask = [] # TODO: add image masks? + dtype = 'vis' # TODO: add other data terms? + + # Do everything in a circular basis + im_circ = im.switch_polrep('circ') + obs_circ = obs.copy().switch_polrep('circ') + + # Check to see if the field rotation is corrected + if obs_circ.frcal is False: + print("Field rotation angles have not been corrected. Correcting now...") + obs_circ.data = simobs.apply_jones_inverse(obs_circ, frcal=False, dcal=True, opacitycal=True, verbose=False) + obs_circ.frcal = True + + # List of all sites present in the observation. Make sure they are all in the tarr + allsites = list(set(np.hstack((obs_circ.data['t1'], obs_circ.data['t2'])))) + for site in allsites: + if not (site in obs_circ.tarr['site']): + raise Exception("site %s not in obs.tarr!"%site) + + if len(sites) == 0: + print("No stations specified for leakage calibration: defaulting to calibrating all sites !") + sites = allsites + # only include sites that are present in obs.tarr + sites = [s for s in sites if s in allsites] + site_index = [list(obs_circ.tarr['site']).index(s) for s in sites] + + # TODO do we want to start with some nonzero D-terms? + # Set all leakage terms in obs_circ to zero + # (we will only correct leakage for those sites with new solutions) + for j in range(len(obs_circ.tarr)): + if obs_circ.tarr[j]['site'] in sites: + continue + obs_circ.tarr[j]['dr'] = obs_circ.tarr[j]['dl'] = 0.0j + + print("Finding leakage for sites:", sites) + + print("Precomputing visibilities...") + # get stations + t1 = obs_circ.unpack('t1')['t1'] + t2 = obs_circ.unpack('t2')['t2'] + + # index sites in t1, t2 position. If no calibrated site is used in a baseline, -1 + idx1 = np.array([sites.index(t) if (t in sites) else -1 for t in t1]) + idx2 = np.array([sites.index(t) if (t in sites) else -1 for t in t2]) + + # get real data and sigmas + # TODO add other chisqdata parameters? + # TODO modify chisqdata function to have the option to return samples? + + (vis_RR, sigma_RR, _) = iu.chisqdata(obs_circ, im_circ, mask=mask, dtype=dtype, pol='RR', + ttype=ttype, fft_pad_factor=fft_pad_factor) + (vis_LL, sigma_LL, _) = iu.chisqdata(obs_circ, im_circ, mask=mask, dtype=dtype, pol='LL', + ttype=ttype, fft_pad_factor=fft_pad_factor) + (vis_RL, sigma_RL, _) = iu.chisqdata(obs_circ, im_circ, mask=mask, dtype=dtype, pol='RL', + ttype=ttype, fft_pad_factor=fft_pad_factor) + (vis_LR, sigma_LR, _) = iu.chisqdata(obs_circ, im_circ, mask=mask, dtype=dtype, pol='LR', + ttype=ttype, fft_pad_factor=fft_pad_factor) + + # get simulated data (from simple Fourier transform) + obs_sim = im_circ.observe_same_nonoise(obs_circ, + ttype=ttype, fft_pad_factor=fft_pad_factor, + zero_empty_pol=True,verbose=False) + + (ft_RR, _, _) = iu.chisqdata(obs_sim, im_circ, mask=mask, dtype=dtype, pol='RR', + ttype=ttype, fft_pad_factor=fft_pad_factor) + (ft_LL, _, _) = iu.chisqdata(obs_sim, im_circ, mask=mask, dtype=dtype, pol='LL', + ttype=ttype, fft_pad_factor=fft_pad_factor) + (ft_RL, _, _) = iu.chisqdata(obs_sim, im_circ, mask=mask, dtype=dtype, pol='RL', + ttype=ttype, fft_pad_factor=fft_pad_factor) + (ft_LR, _, _) = iu.chisqdata(obs_sim, im_circ, mask=mask, dtype=dtype, pol='LR', + ttype=ttype, fft_pad_factor=fft_pad_factor) + + # field rotation angles + el1 = obs_circ.unpack(['el1'], ang_unit='rad')['el1'] + el2 = obs_circ.unpack(['el2'], ang_unit='rad')['el2'] + par1 = obs_circ.unpack(['par_ang1'], ang_unit='rad')['par_ang1'] + par2 = obs_circ.unpack(['par_ang2'], ang_unit='rad')['par_ang2'] + + fr_elev1 = np.array([obs_circ.tarr[obs_circ.tkey[o['t1']]]['fr_elev'] for o in obs.data]) + fr_elev2 = np.array([obs_circ.tarr[obs_circ.tkey[o['t2']]]['fr_elev'] for o in obs.data]) + fr_par1 = np.array([obs_circ.tarr[obs_circ.tkey[o['t1']]]['fr_par'] for o in obs.data]) + fr_par2 = np.array([obs_circ.tarr[obs_circ.tkey[o['t2']]]['fr_par'] for o in obs.data]) + fr_off1 = np.array([obs_circ.tarr[obs_circ.tkey[o['t1']]]['fr_off'] for o in obs.data]) + fr_off2 = np.array([obs_circ.tarr[obs_circ.tkey[o['t2']]]['fr_off'] for o in obs.data]) + + fr1 = fr_elev1*el1 + fr_par1*par1 + fr_off1*np.pi/180. + fr2 = fr_elev2*el2 + fr_par2*par2 + fr_off2*np.pi/180. + + Delta = fr1 - fr2 + Phi = fr1 + fr2 + + # TODO: read in gains from caltable? + # gains + GR1 = np.ones(fr1.shape) + GL1 = np.ones(fr1.shape) + GR2 = np.ones(fr2.shape) + GL2 = np.ones(fr2.shape) + + if not(len(Delta)==len(vis_RR)==len(sigma_LL)==len(ft_RR)==len(t1)): + raise Exception("not all data columns the right length in pol_cal!") + Nvis = len(vis_RR) + + # define the error function + def chisq_total(Dpar): + # all the D-terms as complex numbers. If const_fpol, fpol is the last parameter. + D = Dpar.astype(np.float64).view(dtype=np.complex128) + + # current D-terms for each baseline, zero for stations not calibrated (TODO faster?) + DR1 = np.asarray([D[2*sites.index(s)] if s in sites else 0. for s in t1]) + DL1 = np.asarray([D[2*sites.index(s)+1] if s in sites else 0. for s in t1]) + + DR2 = np.asarray([D[2*sites.index(s)] if s in sites else 0. for s in t2]) + DL2 = np.asarray([D[2*sites.index(s)+1] if s in sites else 0. for s in t2]) + + # simulated visibilities and chisqs with leakage + chisq_RR = chisq_LL = chisq_RL = chisq_LR = 0.0 + if 'RR' in pol_fit: + vis_RR_leak = ft_RR + DR1*DR2.conj()*np.exp(2j*Delta)*ft_LL + DR1*np.exp(2j*fr1)*ft_LR + DR2.conj()*np.exp(-2j*fr2)*ft_RL + vis_RR_leak *= GR1*GR2.conj() + + chisq_RR = np.sum(np.abs(vis_RR - vis_RR_leak)**2 / (sigma_RR**2)) + chisq_RR = chisq_RR / (2.*Nvis) + if 'LL' in pol_fit: + vis_LL_leak = ft_LL + DL1*DL2.conj()*np.exp(-2j*Delta)*ft_RR + DL1*np.exp(-2j*fr1)*ft_RL + DL2.conj()*np.exp(2j*fr2)*ft_LR + vis_LL_leak *= GL1*GL2.conj() + + chisq_LL = np.sum(np.abs(vis_LL - vis_LL_leak)**2 / (sigma_LL**2)) + chisq_LL = chisq_LL / (2.*Nvis) + if 'RL' in pol_fit: + vis_RL_leak = ft_RL + DR1*DL2.conj()*np.exp(2j*Phi)*ft_LR + DR1*np.exp(2j*fr1)*ft_LL + DL2.conj()*np.exp(2j*fr2)*ft_RR + vis_RL_leak *= GR1*GL2.conj() + + chisq_RL = np.sum(np.abs(vis_RL - vis_RL_leak)**2 / (sigma_RL**2)) + chisq_RL = chisq_RL / (2.*Nvis) + if 'LR' in pol_fit: + vis_LR_leak = ft_LR + DL1*DR2.conj()*np.exp(-2j*Phi)*ft_RL + DL1*np.exp(-2j*fr1)*ft_RR + DR2.conj()*np.exp(-2j*fr2)*ft_LL + vis_LR_leak *= GL1*GR2.conj() + + chisq_LR = np.sum(np.abs(vis_LR - vis_LR_leak)**2 / (sigma_LR**2)) + chisq_LR = chisq_LR / (2.*Nvis) + + chisq_tot = (chisq_RR + chisq_LL + chisq_RL + chisq_LR)/len(pol_fit) + return chisq_tot + + def errfunc(Dpar): + # chi-squared + chisq_tot = chisq_total(Dpar) + + # prior on the D terms + # TODO + prior = np.sum((np.abs(Dpar)**2)/(leakage_tol**2)) + if rescale_leakage_tol: + prior = prior / (len(Dpar)) + + return chisq_tot + prior + + # define the error function gradient + def chisq_total_grad(Dpar): + chisqgrad = np.zeros(len(Dpar)) + + # all the D-terms as complex numbers. + # stored in groups of 4 per site [Re(DR), Im(DR), Re(DL), Im(DL)] + D = Dpar.astype(np.float64).view(dtype=np.complex128) + + # current D-terms for each baseline, zero for stations not calibrated (TODO faster?) + DR1 = np.asarray([D[2*sites.index(s)] if s in sites else 0. for s in t1]) + DL1 = np.asarray([D[2*sites.index(s)+1] if s in sites else 0. for s in t1]) + + DR2 = np.asarray([D[2*sites.index(s)] if s in sites else 0. for s in t2]) + DL2 = np.asarray([D[2*sites.index(s)+1] if s in sites else 0. for s in t2]) + + # residual and dV/dD terms + if 'RR' in pol_fit: + vis_RR_leak = ft_RR + DR1*DR2.conj()*np.exp(2j*Delta)*ft_LL + DR1*np.exp(2j*fr1)*ft_LR + DR2.conj()*np.exp(-2j*fr2)*ft_RL + vis_RR_leak *= GR1*GR2.conj() + + resid_RR = (vis_RR - vis_RR_leak).conj() + + dRR_dReDR1 = DR2.conj()*np.exp(2j*Delta)*ft_LL + np.exp(2j*fr1)*ft_LR + dRR_dReDR1 *= GR1*GR2.conj() + + dRR_dReDR2 = DR1*np.exp(2j*Delta)*ft_LL + np.exp(-2j*fr2)*ft_RL + dRR_dReDR2 *= GR1*GR2.conj() + + if 'LL' in pol_fit: + vis_LL_leak = ft_LL + DL1*DL2.conj()*np.exp(-2j*Delta)*ft_RR + DL1*np.exp(-2j*fr1)*ft_RL + DL2.conj()*np.exp(2j*fr2)*ft_LR + vis_LL_leak *= GL1*GL2.conj() + + resid_LL = (vis_LL - vis_LL_leak).conj() + + dLL_dReDL1 = DL2.conj()*np.exp(-2j*Delta)*ft_RR + np.exp(-2j*fr1)*ft_RL + dLL_dReDL1 *= GL1*GL2.conj() + + dLL_dReDL2 = DL1*np.exp(-2j*Delta)*ft_RR + np.exp(2j*fr2)*ft_LR + dLL_dReDL2 *= GL1*GL2.conj() + + + if 'RL' in pol_fit: + vis_RL_leak = ft_RL + DR1*DL2.conj()*np.exp(2j*Phi)*ft_LR + DR1*np.exp(2j*fr1)*ft_LL + DL2.conj()*np.exp(2j*fr2)*ft_RR + vis_RL_leak *= GR1*GL2.conj() + + resid_RL = (vis_RL - vis_RL_leak).conj() + + dRL_dReDR1 = DL2.conj()*np.exp(2j*Phi)*ft_LR + np.exp(2j*fr1)*ft_LL + dRL_dReDR1 *= GR1*GL2.conj() + + dRL_dReDL2 = DR1*np.exp(2j*Phi)*ft_LR + np.exp(2j*fr2)*ft_RR + dRL_dReDL2 *= GR1*GL2.conj() + + + if 'LR' in pol_fit: + vis_LR_leak = ft_LR + DL1*DR2.conj()*np.exp(-2j*Phi)*ft_RL + DL1*np.exp(-2j*fr1)*ft_RR + DR2.conj()*np.exp(-2j*fr2)*ft_LL + vis_LR_leak *= GL1*GR2.conj() + + resid_LR = (vis_LR - vis_LR_leak).conj() + + dLR_dReDL1 = DR2.conj()*np.exp(-2j*Phi)*ft_RL + np.exp(-2j*fr1)*ft_RR + dLR_dReDL1 *= GL1*GR2.conj() + + dLR_dReDR2 = DL1*np.exp(-2j*Phi)*ft_RL + np.exp(-2j*fr2)*ft_LL + dLR_dReDR2 *= GL1*GR2.conj() + + + # to get gradients, sum over baselines + # TODO remove for loop with some fancy vectorization? + for isite in range(len(sites)): + mask1 = (idx1 == isite) + mask2 = (idx2 == isite) + + # DR + regrad = 0 + imgrad = 0 + if 'RR' in pol_fit: + terms = resid_RR[mask1] * dRR_dReDR1[mask1] / (sigma_RR[mask1]**2) + regrad += -1*np.sum(np.real(terms)) + imgrad += np.sum(np.imag(terms)) + + terms = resid_RR[mask2] * dRR_dReDR2[mask2] / (sigma_RR[mask2]**2) + regrad += -1*np.sum(np.real(terms)) + imgrad += -1*np.sum(np.imag(terms)) + + if 'RL' in pol_fit: + terms = resid_RL[mask1] * dRL_dReDR1[mask1] / (sigma_RL[mask1]**2) + regrad += -1*np.sum(np.real(terms)) + imgrad += np.sum(np.imag(terms)) + + if 'LR' in pol_fit: + terms = resid_LR[mask2] * dLR_dReDR2[mask2] / (sigma_LR[mask2]**2) + regrad += -1*np.sum(np.real(terms)) + imgrad += -1*np.sum(np.imag(terms)) + + chisqgrad[4*isite] += regrad # Re(DR) + chisqgrad[4*isite+1] += imgrad # Im(DR) + + # DL + regrad = 0 + imgrad = 0 + if 'LL' in pol_fit: + terms = resid_LL[mask1] * dLL_dReDL1[mask1] / (sigma_LL[mask1]**2) + regrad += -1*np.sum(np.real(terms)) + imgrad += np.sum(np.imag(terms)) + + terms = resid_LL[mask2] * dLL_dReDL2[mask2] / (sigma_LL[mask2]**2) + regrad += -1*np.sum(np.real(terms)) + imgrad += -1*np.sum(np.imag(terms)) + + if 'RL' in pol_fit: + terms = resid_RL[mask2] * dRL_dReDL2[mask2] / (sigma_RL[mask2]**2) + regrad += -1*np.sum(np.real(terms)) + imgrad += -1*np.sum(np.imag(terms)) + + if 'LR' in pol_fit: + terms = resid_LR[mask1] * dLR_dReDL1[mask1] / (sigma_LR[mask1]**2) + regrad += -1*np.sum(np.real(terms)) + imgrad += np.sum(np.imag(terms)) + + chisqgrad[4*isite+2] += regrad # Re(DL) + chisqgrad[4*isite+3] += imgrad # Im(DL) + + chisqgrad /= (Nvis*len(pol_fit)) + + return chisqgrad + + def errfunc_grad(Dpar): + # gradient of the chi^2 + chisqgrad = chisq_total_grad(Dpar) + + # gradient of the prior + priorgrad = 2*Dpar / (leakage_tol**2) + if rescale_leakage_tol: + priorgrad = priorgrad / (len(Dpar)) + + return chisqgrad + priorgrad + + # Gradient test - remove! +# def test_grad(Dpar): +# grad_ana = errfunc_grad(Dpar) +# grad_num1 = np.zeros(len(Dpar)) +# for i in range(len(Dpar)): +# dd = 1.e-8 +# Dpar_dd = Dpar.copy() +# Dpar_dd[i] += dd +# grad_num1[i] = (errfunc(Dpar_dd) - errfunc(Dpar))/dd +# grad_num2 = np.zeros(len(Dpar)) +# for i in range(len(Dpar)): +# dd = -1.e-8 +# Dpar_dd = Dpar.copy() +# Dpar_dd[i] += dd +# grad_num2[i] = (errfunc(Dpar_dd) - errfunc(Dpar))/dd +# +# plt.close('all') +# plt.ion() +# plt.figure() +# plt.plot(np.arange(len(Dpar)), grad_ana, 'ro') +# plt.plot(np.arange(len(Dpar)), grad_num1, 'b.') +# plt.plot(np.arange(len(Dpar)), grad_num2, 'bx') +# plt.xticks(np.arange(0,len(Dpar),4), sites) +# +# plt.figure() +# zscal = 1.e-32*np.min(np.abs(grad_ana)[grad_ana!=0]) +# plt.plot(np.arange(len(Dpar)), 100-100*(grad_num1+zscal)/(grad_ana+zscal),'b.') +# plt.plot(np.arange(len(Dpar)), 100-100*(grad_num2+zscal)/(grad_ana+zscal),'bx') +# plt.xticks(np.arange(0,len(Dpar),4), sites) +# plt.ylim(-1,1) +# plt.show() +# return + +# Dpar_guess = .1*np.random.randn(len(sites)*4) +# test_grad(Dpar_guess) + + print("Calibrating D-terms...") + # Now, we will finally minimize the total error term. We need two complex leakage terms for each site + if minimizer_method=='L-BFGS-B': + optdict = {'maxiter': MAXIT, + 'ftol': STOP, 'gtol': STOP, + 'maxcor': NHIST, 'maxls': MAXLS} + else: + optdict = {'maxiter': MAXIT} + + Dpar_guess = np.zeros(len(sites)*2, dtype=np.complex128).view(dtype=np.float64) + if use_grad: + res = opt.minimize(errfunc, Dpar_guess, method=minimizer_method, options=optdict, jac=errfunc_grad) + else: + res = opt.minimize(errfunc, Dpar_guess, method=minimizer_method, options=optdict) + + print(errfunc(Dpar_guess),errfunc(res.x)) + + # get solution + Dpar_fit = res.x.astype(np.float64) + D_fit = Dpar_fit.view(dtype=np.complex128) # all the D-terms (complex) + + # fill in the new D-terms to the tarr + obs_out = obs_circ.copy() # TODO or overwrite directly? + for isite in range(len(sites)): + obs_out.tarr['dr'][site_index[isite]] = D_fit[2*isite] + obs_out.tarr['dl'][site_index[isite]] = D_fit[2*isite+1] + + # Apply the solution + if apply_solution: + obs_out.data = simobs.apply_jones_inverse(obs_out, dcal=False, frcal=True, opacitycal=True, verbose=False) + obs_out.dcal = True + else: + obs_out.dcal = False + + # Re-populate any additional leakage terms that were present + # NOTE we don't want to do this above, in case we want to ignore these terms in apply_jones inverse + # TODO can we do this better? + for j in range(len(obs_out.tarr)): + if obs_out.tarr[j]['site'] in sites: + continue + obs_out.tarr[j]['dr'] = obs.tarr[j]['dr'] + obs_out.tarr[j]['dl'] = obs.tarr[j]['dl'] + + # TODO are these diagnostics correct? + if show_solution: + chisq_orig = chisq_total(Dpar_fit*0) + chisq_new = chisq_total(Dpar_fit) + + print("Original chi-squared: {:.4f}".format(chisq_orig)) + print("New chi-squared: {:.4f}\n".format(chisq_new)) + for isite in range(len(sites)): + print(sites[isite]) + print(' D_R: {:.4f}'.format(D_fit[2*isite])) + print(' D_L: {:.4f}\n'.format(D_fit[2*isite+1])) + + tstop = time.time() + print("\nleakage_cal time: %f s" % (tstop - tstart)) + + + obs_out = obs_out.switch_polrep(obs.polrep) + + return obs_out + + + + diff --git a/calibrating/polgains_cal.py b/calibrating/polgains_cal.py new file mode 100644 index 00000000..38c36eaa --- /dev/null +++ b/calibrating/polgains_cal.py @@ -0,0 +1,312 @@ +# polgains_cal.py +# functions for calibrating RCP-LCP phase offsets +# +# Copyright (C) 2019 Maciek Wielgus (maciek.wielgus(at)gmail.com) +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import scipy.optimize as opt +import time +import copy +from multiprocessing import cpu_count, Pool + +import ehtim.obsdata +import ehtim.parloop as parloop +import ehtim.observing.obs_helpers as obsh +import ehtim.const_def as ehc + +MAXIT = 5000 + +################################################################################################### +# Polarimetric-Phase-Gains-Calibration +################################################################################################### + + +def polgains_cal(obs, reference='AA', sites=[], method='phase', minimizer_method='BFGS', pad_amp=0., + solution_interval=0.0, scan_solutions=False, + caltable=False, processes=-1, show_solution=False, msgtype='bar'): + + # TODO: function to globalize the polarimetric solution in time + # given provided absolute calibration of the reference station + # so that we have a meaningful EVPA + """Polarimeteric-phase-gains-calibrate a dataset. + Numerically solves for polarimetric gains to align RCP and LCP feeds. + Uses all baselines to find the solution. Effectively assumes phase of Stokes V to be zero. + Because fits are local, it's not providing absolute phase calibration. + + Args: + obs (Obsdata): The observation to be calibrated + reference (str): station used as reference to break the degeneracy + (LCP on baselines to the reference station remains unchanged) + sites (list): list of sites to include in the polarimetric calibration. + Empty list calibrates all sites + method (str): chooses what to calibrate, 'phase' or 'both' + 'phase' is default, most useful (instrumental offsets), + 'both' will align RCP/LCP amplitudes as well + minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS') + pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature + solution_interval (float): solution interval in seconds; one gain is derived per interval. + If 0.0, a solution is determined for each unique time. + scan_solutions (bool): If True, determine one gain per site per scan + Supersedes solution_interval. + caltable (bool): if True, returns a Caltable instead of an Obsdata + processes (int): number of cores to use in multiprocessing + show_solution (bool): if True, display the solution as it is calculated + msgtype (str): type of progress message to be printed, default is 'bar' + + Returns: + (Obsdata): the calibrated observation, if caltable==False + (Caltable): the derived calibration table, if caltable==True + """ + # circular representation is needed + if obs.polrep != 'circ': + obs = obs.switch_polrep('circ') + + if len(sites) == 0: + print("No stations specified in polgain cal!") + print('Defaulting to calibrating all stations with reference station as: ' + reference) + sites = np.array([x for x in obs.tarr['site'] if x != reference], dtype=' 0: + counter = parloop.Counter(initval=0, maxval=len(scans)) + if processes > len(scans): + processes = len(scans) + print("Using Multiprocessing with %d Processes" % processes) + pool = Pool(processes=processes, initializer=init, initargs=(counter,)) + elif processes == 0: + counter = parloop.Counter(initval=0, maxval=len(scans)) + processes = int(cpu_count()) + if processes > len(scans): + processes = len(scans) + print("Using Multiprocessing with %d Processes" % processes) + pool = Pool(processes=processes, initializer=init, initargs=(counter,)) + else: + print("Not Using Multiprocessing") + + # loop over scans and calibrate + tstart = time.time() + if processes > 0: # with multiprocessing + scans_cal = pool.map(get_polgains_scan_cal, + [[i, len(scans), scans[i], reference, + sites, method, pad_amp, caltable, show_solution, msgtype] + for i in range(len(scans)) + ]) + + else: # without multiprocessing + for i in range(len(scans)): + obsh.prog_msg(i, len(scans), msgtype=msgtype, nscan_last=i - 1) + scans_cal[i] = polgains_cal_scan(scans[i], reference, sites, + method=method, minimizer_method=minimizer_method, + show_solution=show_solution, caltable=caltable, + pad_amp=pad_amp) + + tstop = time.time() + print("\npolgain_cal time: %f s" % (tstop - tstart)) + + if caltable: # create and return a caltable + allsites = obs.tarr['site'] + caldict = {k: v.reshape(1) for k, v in scans_cal[0].items()} + for i in range(1, len(scans_cal)): + row = scans_cal[i] + if len(row) == 0: + continue + + for site in allsites: + try: + dat = row[site] + except KeyError: + continue + + try: + caldict[site] = np.append(caldict[site], row[site]) + except KeyError: + caldict[site] = [dat] + + caltable = ehtim.caltable.Caltable(obs.ra, obs.dec, obs.rf, obs.bw, caldict, obs.tarr, + source=obs.source, mjd=obs.mjd, timetype=obs.timetype) + out = caltable + + else: # return the calibrated observation + arglist, argdict = obs.obsdata_args() + arglist[4] = np.concatenate(scans_cal) + out = ehtim.obsdata.Obsdata(*arglist, **argdict) + + # close multiprocessing jobs + if processes != -1: + pool.close() + + return out + + +def polgains_cal_scan(scan, reference='AA', sites=[], method='phase', minimizer_method='BFGS', + show_solution=False, pad_amp=0., caltable=False, msgtype='bar'): + """Polarimeteric-phase-gains-calibrate a dataset. + Numerically solves for polarimetric gains to align RCP and LCP feeds. + Uses all baselines to find the solution. Effectively assumes phase of Stokes V to be zero. + Because fits are local, it's not providing absolute phase calibration. + + Args: + obs (Obsdata): The observation to be calibrated + reference (str): station used as reference to break the degeneracy + (LCP on baselines to the reference station remains unchanged) + sites (list): list of sites to include in the polarimetric calibration. + Empty list calibrates all sites + method (str): chooses what to calibrate, 'phase' or 'both' + 'phase' is default, most useful (instrumental offsets), + 'both' will align RCP/LCP amplitudes as well + minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS') + pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature + caltable (bool): if True, returns a Caltable instead of an Obsdata + show_solution (bool): if True, display the solution as it is calculated + msgtype (str): type of progress message to be printed, default is 'bar' + + Returns: + (Obsdata): the calibrated observation, if caltable==False + (Caltable): the derived calibration table, if caltable==True + """ + indices_no_vis_nans = (~np.isnan(scan[ehc.vis_poldict['RR']])) & ( + ~np.isnan(scan[ehc.vis_poldict['LL']])) + scan_no_vis_nans = scan[indices_no_vis_nans] + allsites_no_vis_nans = list(set(np.hstack((scan_no_vis_nans['t1'], scan_no_vis_nans['t2'])))) + + # all the sites in the scan + allsites = list(set(np.hstack((scan['t1'], scan['t2'])))) + if len(sites) == 0: + print("No stations specified in polgain cal") + print("defaulting to calibrating all non-reference stations!") + sites = allsites + + # only include sites that are present, not singlepol and not reference + sites = [s for s in sites if (s in allsites_no_vis_nans) & (s != reference)] + + # create a dictionary to keep track of gains; reference site and singlepol sites get key -1 + tkey = {b: a for a, b in enumerate(sites)} + # make two lists of gain keys that relates scan bl gains to solved site ones + # -1 means that this station does not have a gain that is being solved for + g1_keys = [] + g2_keys = [] + for row in scan: + try: + g1_keys.append(tkey[row['t1']]) + except KeyError: + g1_keys.append(-1) + try: + g2_keys.append(tkey[row['t2']]) + except KeyError: + g2_keys.append(-1) + + # get scan visibilities of the specified polarization + visRR = scan_no_vis_nans[ehc.vis_poldict['RR']] + visLL = scan_no_vis_nans[ehc.vis_poldict['LL']] + sigmaRR = scan_no_vis_nans[ehc.sig_poldict['RR']] + sigmaLL = scan_no_vis_nans[ehc.sig_poldict['LL']] + sigma = np.sqrt(sigmaRR**2 + sigmaLL**2) + # sigma_inv = 1.0/np.sqrt(sigma**2 + (pad_amp*0.5*(np.abs(visRR)+np.abs(visLL)))**2) + # initial guesses for parameters + n_gains = len(sites) + gpar_guess = np.ones(n_gains, dtype=np.complex128).view(dtype=np.float64) + + # error function + def errfunc(g): + g = g.view(dtype=np.complex128) + g = np.append(g, 1. + 0.j) + if method == "phase": + g = g / np.abs(g) + g1 = g[g1_keys] + g2 = g[g2_keys] + chisq = np.sum(np.abs((visRR - g1[indices_no_vis_nans] * + g2[indices_no_vis_nans].conj() * visLL) / sigma)**2) + return chisq + + if np.max(g1_keys) > -1 or np.max(g2_keys) > -1: + # run the minimizer to get a solution (but only run if there's at least one gain to fit) + optdict = {'maxiter': MAXIT} # minimizer params + res = opt.minimize(errfunc, gpar_guess, method=minimizer_method, options=optdict) + + # get solution + g_fit = res.x.view(np.complex128) + if method == "phase": + g_fit = g_fit / np.abs(g_fit) + + if show_solution: + print(g_fit) + else: + g_fit = [] + g_fit = np.append(g_fit, 1.) + # Derive a calibration table or apply the solution to the scan + if caltable: + allsites = list(set(scan['t1']).union(set(scan['t2']))) + + caldict = {} + for site in allsites: + if site in sites: + site_key = tkey[site] + else: + site_key = -1 + + # Convention is that we calibrate RCP phase to align with the LCP phase + rscale = g_fit[site_key]**-1 + lscale = 1. + 0.j + + caldict[site] = np.array((scan['time'][0], rscale, lscale), dtype=ehc.DTCAL) + out = caldict + else: + g1_fit = g_fit[g1_keys] + g2_fit = g_fit[g2_keys] + g1_inv = g1_fit**(-1) + g2_inv = g2_fit**(-1) + # scale visibilities + scan['rrvis'] *= g1_inv * g2_inv.conj() + scan['llvis'] *= 1. + 0.j + scan['rlvis'] *= g1_inv + scan['lrvis'] *= g2_inv.conj() + # don't scale sigmas + out = scan + return out + + +def init(x): + global counter + counter = x + + +def get_polgains_scan_cal(args): + return get_polgains_scan_cal2(*args) + + +def get_polgains_scan_cal2(i, n, scan, reference, sites, method, pad_amp, caltable, + show_solution, msgtype): + + if n > 1: + global counter + counter.increment() + obsh.prog_msg(counter.value(), counter.maxval, msgtype, counter.value() - 1) + + return polgains_cal_scan(scan, reference, sites, + method=method, caltable=caltable, show_solution=show_solution, + pad_amp=pad_amp, msgtype=msgtype) diff --git a/calibrating/self_cal.py b/calibrating/self_cal.py new file mode 100644 index 00000000..90d8e9d9 --- /dev/null +++ b/calibrating/self_cal.py @@ -0,0 +1,594 @@ +# selfcal.py +# functions for self-calibration +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import scipy.optimize as opt +import time +import copy +from multiprocessing import cpu_count, Pool + +import ehtim.obsdata +import ehtim.parloop as parloop +from . import cal_helpers as calh +from ehtim.observing.obs_simulate import add_jones_and_noise +import ehtim.observing.obs_helpers as obsh +import ehtim.const_def as ehc + +import warnings +warnings.filterwarnings("ignore", message="divide by zero encountered in log") + +MAXIT = 10000 # maximum number of iterations in self-cal minimizer +NHIST = 50 # number of steps to store for hessian approx +MAXLS = 40 # maximum number of line search steps in BFGS-B +STOP = 1e-6 # convergence criterion + +################################################################################################### +# Self-Calibration +################################################################################################### + + +def self_cal(obs, im, sites=[], pol='I', apply_singlepol=False, method="both", + minimizer_method='BFGS', + pad_amp=0., gain_tol=.2, solution_interval=0.0, scan_solutions=False, + ttype='direct', fft_pad_factor=2, caltable=False, + debias=True, apply_dterms=False, + copy_closure_tables=False, + processes=-1, show_solution=False, msgtype='bar', + use_grad=False): + """Self-calibrate a dataset to an image. + + Args: + obs (Obsdata): The observation to be calibrated + im (Image): the image to be calibrated to + sites (list): list of sites to include in the self calibration. + empty list calibrates all sites + + pol (str): which image polarization to self-calibrate visibilities to + apply_singlepol (str): if calibrating to pol='RR' or pol='LL', + apply solution only to the single polarization + + method (str): chooses what to calibrate, 'amp', 'phase', or 'both' + minimizer_method (str): Method for scipy.optimize.minimize (e.g., 'CG', 'BFGS') + + pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature + gain_tol (float or list): gains that exceed this value will be disfavored by the prior + for asymmetric gain_tol for corrections below/above unity, + pass a 2-element list + solution_interval (float): solution interval in seconds; + If 0., determine solution for each unique time + scan_solutions (bool): If True, determine one gain per site per scan + (supersedes solution_interval) + + caltable (bool): if True, returns a Caltable instead of an Obsdata + processes (int): number of cores to use in multiprocessing + + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + + debias (bool): If True, debias the amplitudes + apply_dterms (bool): if True, apply dterms (in obs.tarr) to clean data before calibrating + show_solution (bool): if True, display the solution as it is calculated + msgtype (str): type of progress message to be printed, default is 'bar' + use_grad (bool): if True, use gradients in minimizer + + Returns: + (Obsdata): the calibrated observation, if caltable==False + (Caltable): the derived calibration table, if caltable==True + """ + if use_grad and (method=='phase' or method=='amp'): + raise Exception("errfunc_grad in self_cal only works with method=='both'!") + + if pol in ['I', 'Q', 'U', 'V']: + if obs.polrep != 'stokes': + raise Exception("selfcal pol is a stokes parameter, but obs.polrep!='stokes'") + im = im.switch_polrep('stokes', pol) + elif pol in ['RR', 'LL']: + if obs.polrep != 'circ': + raise Exception("selfcal pol is RR or LL, but obs.polrep!='circ'") + im = im.switch_polrep('circ', pol) + else: + raise Exception("Can only self-calibrate to I, Q, U, V, RR, or LL images!") + + if apply_singlepol and obs.polrep!='circ': + raise Exception("apply_singlepol must be False unless self-calibrating to 'RR' or 'LL'") + + # V = model visibility, V' = measured visibility, G_i = site gain + # G_i * conj(G_j) * V_ij = V'_ij + if len(sites) == 0: + print("No stations specified in self cal: defaulting to calibrating all stations!") + sites = obs.tarr['site'] + + # First, sample the model visibilities of the specified polarization + print("Computing the Model Visibilities with " + ttype + " Fourier Transform...") + obs_clean = im.observe_same_nonoise(obs, ttype=ttype, fft_pad_factor=fft_pad_factor) + + # apply dterms + # TODO check! + if apply_dterms: + print("Applying dterms in obs.tarr to clean visibilities before selfcal!") + obsdata_dterms = add_jones_and_noise(obs_clean, + add_th_noise=False, ampcal=True, phasecal=True, opacitycal=True, + dcal=False, frcal=True, dterm_offset=0.0) + obs_clean.data = obsdata_dterms + + # Partition the list of observed visibilities into scans + scans = obs.tlist(t_gather=solution_interval, scan_gather=scan_solutions) + scans_cal = copy.copy(scans) + + # Partition the list of model visibilities into scans + V_scans = [o[ehc.vis_poldict[pol]] for o in obs_clean.tlist( + t_gather=solution_interval, scan_gather=scan_solutions)] + + # Make the pool for parallel processing + if processes > 0: + counter = parloop.Counter(initval=0, maxval=len(scans)) + print("Using Multiprocessing with %d Processes" % processes) + pool = Pool(processes=processes, initializer=init, initargs=(counter,)) + elif processes == 0: + counter = parloop.Counter(initval=0, maxval=len(scans)) + processes = int(cpu_count()) + print("Using Multiprocessing with %d Processes" % processes) + pool = Pool(processes=processes, initializer=init, initargs=(counter,)) + else: + print("Not Using Multiprocessing") + + # loop over scans and calibrate + tstart = time.time() + if processes > 0: # run on multiple cores with multiprocessing + scans_cal = np.array(pool.map(get_selfcal_scan_cal, [[i, len(scans), scans[i], + im, V_scans[i], sites, + obs.polrep, pol, apply_singlepol, + method, minimizer_method, + show_solution, pad_amp, gain_tol, + debias, caltable, msgtype, + use_grad + ] for i in range(len(scans))]), + dtype=object) + + else: # run on a single core + for i in range(len(scans)): + obsh.prog_msg(i, len(scans), msgtype=msgtype, nscan_last=i - 1) + scans_cal[i] = self_cal_scan(scans[i], im, V_scan=V_scans[i], sites=sites, + polrep=obs.polrep, pol=pol, apply_singlepol=apply_singlepol, + method=method, minimizer_method=minimizer_method, + show_solution=show_solution, + pad_amp=pad_amp, gain_tol=gain_tol, + debias=debias, caltable=caltable, + use_grad=use_grad) + + tstop = time.time() + print("\nself_cal time: %f s" % (tstop - tstart)) + + if caltable: # assemble the caltable to return + allsites = obs.tarr['site'] + caldict = scans_cal[0] + for i in range(1, len(scans_cal)): + row = scans_cal[i] + for site in allsites: + try: + dat = row[site] + except KeyError: + continue + + try: + caldict[site] = np.append(caldict[site], row[site]) + except KeyError: + caldict[site] = dat + + caltable = ehtim.caltable.Caltable(obs.ra, obs.dec, obs.rf, obs.bw, caldict, obs.tarr, + source=obs.source, mjd=obs.mjd, timetype=obs.timetype) + out = caltable + + else: # return a calibrated observation + arglist, argdict = obs.obsdata_args() + arglist[4] = np.concatenate(scans_cal) + out = ehtim.obsdata.Obsdata(*arglist, **argdict) + if copy_closure_tables: + out.camp = obs.camp + out.logcamp = obs.logcamp + out.cphase = obs.cphase + + # close multiprocessing jobs + if processes >= 0: + pool.close() + + return out + + +def self_cal_scan(scan, im, V_scan=[], sites=[], polrep='stokes', pol='I', apply_singlepol=False, + method="both", + minimizer_method='BFGS', show_solution=False, + pad_amp=0., gain_tol=.2, debias=True, caltable=False, + use_grad=False): + """Self-calibrate a scan to an image. + + Args: + scan (np.recarray): data array of type DTPOL_STOKES or DTPOL_CIRC + im (Image): the image to be calibrated to + sites (list): list of sites to include in the self calibration. + empty list calibrates all sites + V_scan (list) : precomputed scan visibilities + + polrep (str): 'stokes' or 'circ' to specify the polarization products in scan + pol (str): which image polarization to self-calibrate visibilities to + apply_singlepol (str): if calibrating to pol='RR' or pol='LL', + apply solution only to the single polarization + + method (str): chooses what to calibrate, 'amp', 'phase', or 'both' + minimizer_method (str): Method for scipy.optimize.minimize + (e.g., 'CG', 'BFGS', 'Nelder-Mead', etc.) + pad_amp (float): adds fractional uncertainty to amplitude sigmas in quadrature + gain_tol (float or list): gains that exceed this value will be disfavored by the prior + for asymmetric gain_tol for corrections below/above unity, + pass a 2-element list + + debias (bool): If True, debias the amplitudes + caltable (bool): if True, returns a Caltable instead of an Obsdata + show_solution (bool): if True, display the solution as it is calculated + use_grad (bool): if True, use gradients in minimizer + + Returns: + (Obsdata): the calibrated observation, if caltable==False + (Caltable): the derived calibration table, if caltable==True + """ + + if use_grad and (method=='phase' or method=='amp'): + raise Exception("errfunc_grad in self_cal only works with method=='both'!") + + if len(sites) == 0: + print("No stations specified in self cal: defaulting to calibrating all !") + sites = list(set(scan['t1']).union(set(scan['t2']))) + + if len(V_scan) < 1: + # TODO This is not correct. Need to update to use polarization dictionary + uv = np.hstack((scan['u'].reshape(-1, 1), scan['v'].reshape(-1, 1))) + A = obsh.ftmatrix(im.psize, im.xdim, im.ydim, uv, pulse=im.pulse) + V_scan = np.dot(A, im.imvec) + + # convert gain tolerance to lookup table if needed + if type(gain_tol) is not dict: + gain_tol = {'default':gain_tol} + # convert any 1-sided tolerance to 2-sided tolerance parameterization + for (key, val) in gain_tol.items(): + if type(val) == float or type(val) == int: + gain_tol[key] = [val, val] + + # create a dictionary to keep track of gains + tkey = {b: a for a, b in enumerate(sites)} + + # make a list of gain keys that relates scan bl gains to solved site ones + # -1 means that this station does not have a gain that is being solved for + g1_keys = [] + g2_keys = [] + for row in scan: + try: + g1_keys.append(tkey[row['t1']]) + except KeyError: + g1_keys.append(-1) + try: + g2_keys.append(tkey[row['t2']]) + except KeyError: + g2_keys.append(-1) + + # no sites to calibrate on this scan! + if np.all(g1_keys == -1) and np.all(g2_keys == -1): + return scan + + # get scan visibilities of the specified polarization + vis = scan[ehc.vis_poldict[pol]] + sigma = scan[ehc.sig_poldict[pol]] + + if method == 'amp': + if debias: + vis = obsh.amp_debias(np.abs(vis), np.abs(sigma)) + else: + vis = np.abs(vis) + + sigma_inv = 1.0 / np.sqrt(sigma**2 + (pad_amp * np.abs(vis))**2) + + # initial guess for gains + gpar_guess = np.ones(len(sites), dtype=np.complex128).view(dtype=np.float64) + + # error function + def errfunc(gpar): + return errfunc_full(gpar, vis, V_scan, sigma_inv, gain_tol, sites, g1_keys, g2_keys, method) + + def errfunc_grad(gpar): + return errfunc_grad_full(gpar, vis, V_scan, sigma_inv, gain_tol, sites, g1_keys, g2_keys, method) + + # use gradient descent to find the gains + # minimizer params + if minimizer_method=='L-BFGS-B': + optdict = {'maxiter': MAXIT, + 'ftol': STOP, 'gtol': STOP, + 'maxcor': NHIST, 'maxls': MAXLS} + else: + optdict = {'maxiter': MAXIT} + + if use_grad: + res = opt.minimize(errfunc, gpar_guess, method=minimizer_method, options=optdict, jac=errfunc_grad) + else: + res = opt.minimize(errfunc, gpar_guess, method=minimizer_method, options=optdict) + + # save the solution + g_fit = res.x.view(np.complex128) + + if show_solution: + print(np.abs(g_fit)) + + if method == "phase": + g_fit = g_fit / np.abs(g_fit) + if method == "amp": + g_fit = np.abs(np.real(g_fit)) + + g_fit = np.append(g_fit, 1.) + + # Derive a calibration table or apply the solution to the scan + if caltable: + allsites = list(set(scan['t1']).union(set(scan['t2']))) + + caldict = {} + for site in allsites: + if site in sites: + site_key = tkey[site] + else: + site_key = -1 + + # TODO: ANDREW - this has been changed + # We will *always* set the R and L gain corrections to be equal in self calibration, + # to avoid breaking polarization consistency relationships + if apply_singlepol: + if pol=='RR': + rscale = g_fit[site_key]**-1 + lscale = np.ones(g_fit[site_key].shape) + elif pol=='LL': + rscale = np.ones(g_fit[site_key].shape) + lscale = g_fit[site_key]**-1 + else: + rscale = g_fit[site_key]**-1 + lscale = g_fit[site_key]**-1 + + # TODO: we may want to give two entries for the start/stop times + # when a non-zero interval is used + caldict[site] = np.array((scan['time'][0], rscale, lscale), dtype=ehc.DTCAL) + + out = caldict + + else: + g1_fit = g_fit[g1_keys] + g2_fit = g_fit[g2_keys] + gij_inv = (g1_fit * g2_fit.conj())**(-1) + + if polrep == 'stokes': + # gain factors + g1_fit = g_fit[g1_keys] + g2_fit = g_fit[g2_keys] + gij_inv = (g1_fit * g2_fit.conj())**(-1) + + # scale visibilities + for vistype in ['vis', 'qvis', 'uvis', 'vvis']: + scan[vistype] *= gij_inv + # scale sigmas + for sigtype in ['sigma', 'qsigma', 'usigma', 'vsigma']: + scan[sigtype] *= np.abs(gij_inv) + + elif polrep == 'circ': + if apply_singlepol: #scale only solved polarization + if pol=='RR': + grr_inv = (g1_fit * g2_fit.conj())**(-1) + gll_inv = np.ones(g1_fit.shape) + grl_inv = (g1_fit)**(-1) + glr_inv = (g2_fit.conj())**(-1) + + elif pol=='LL': + grr_inv = np.ones(g1_fit.shape) + gll_inv = (g1_fit * g2_fit.conj())**(-1) + grl_inv = (g2_fit.conj())**(-1) + glr_inv = (g1_fit)**(-1) + + # scale visibilities + scan['rrvis'] *= grr_inv + scan['llvis'] *= gll_inv + scan['rlvis'] *= grl_inv + scan['lrvis'] *= glr_inv + + # scale sigmas + scan['rrsigma'] *= np.abs(grr_inv) + scan['llsigma'] *= np.abs(gll_inv) + scan['rlsigma'] *= np.abs(grl_inv) + scan['lrsigma'] *= np.abs(glr_inv) + + else: #scale both polarizations + gij_inv = (g1_fit * g2_fit.conj())**(-1) + + # scale visibilities + for vistype in ['rrvis', 'llvis', 'rlvis', 'lrvis']: + scan[vistype] *= gij_inv + # scale sigmas + for sigtype in ['rrsigma', 'llsigma', 'rlsigma', 'lrsigma']: + scan[sigtype] *= np.abs(gij_inv) + + out = scan + + return out + + +def init(x): + global counter + counter = x + + +def get_selfcal_scan_cal(args): + return get_selfcal_scan_cal2(*args) + + +def get_selfcal_scan_cal2(i, n, scan, im, V_scan, sites, polrep, pol, apply_singlepol, method, minimizer_method, + show_solution, pad_amp, gain_tol, debias, caltable, msgtype, use_grad): + if n > 1: + global counter + counter.increment() + obsh.prog_msg(counter.value(), counter.maxval, msgtype, counter.value() - 1) + + return self_cal_scan(scan, im, V_scan=V_scan, sites=sites, polrep=polrep, pol=pol, apply_singlepol=apply_singlepol, + method=method, minimizer_method=minimizer_method, + show_solution=show_solution, + pad_amp=pad_amp, gain_tol=gain_tol, debias=debias, caltable=caltable, + use_grad=use_grad) + +# error function +def errfunc_full(gpar, vis, v_scan, sigma_inv, gain_tol, sites, g1_keys, g2_keys, method): + # all the forward site gains (complex) + g = gpar.astype(np.float64).view(dtype=np.complex128) + + if method == "phase": + g = g / np.abs(g) + if method == "amp": + g = np.abs(np.real(g)) + + # append the default values to g for missing gains + g = np.append(g, 1.) + g1 = g[g1_keys] + g2 = g[g2_keys] + + # build site specific tolerance parameters + tol0 = np.array([gain_tol.get(s, gain_tol['default'])[0] for s in sites]) + tol1 = np.array([gain_tol.get(s, gain_tol['default'])[1] for s in sites]) + + if method == 'amp': + verr = np.abs(vis) - g1 * g2.conj() * np.abs(v_scan) + else: + verr = vis - g1 * g2.conj() * v_scan + + nan_mask = [not np.isnan(v) for v in verr] + verr = verr[nan_mask] + + # goodness-of-fit for gains + chisq = np.sum((verr.real * sigma_inv[nan_mask])**2) + \ + np.sum((verr.imag * sigma_inv[nan_mask])**2) + + # prior on the gains + # don't count the last (default missing site) gain dummy value + tolsq = ((np.abs(g[:-1]) > 1) * tol0 + (np.abs(g[:-1]) <= 1) * tol1)**2 + chisq_g = np.sum(np.log(np.abs(g[:-1]))**2 / tolsq) + + # total chi^2 + chisqtot = chisq + chisq_g + return chisqtot + +def errfunc_grad_full(gpar, vis, v_scan, sigma_inv, gain_tol, sites, g1_keys, g2_keys, method): + # does not work for method=='phase' or method=='amp' + if method=='phase' or method=='amp': + raise Exception("errfunc_grad in self_cal only works with method=='both'!") + + # all the forward site gains (complex) + g = gpar.astype(np.float64).view(dtype=np.complex128) + gr = np.real(g) + gi = np.imag(g) + + # build site specific tolerance parameters + tol0 = np.array([gain_tol.get(s, gain_tol['default'])[0] for s in sites]) + tol1 = np.array([gain_tol.get(s, gain_tol['default'])[1] for s in sites]) + + # append the default values to g for missing gains + g = np.append(g, 1.) + g1 = g[g1_keys] + g2 = g[g2_keys] + + g1r = np.real(g1) + g1i = np.imag(g1) + g2r = np.real(g2) + g2i = np.imag(g2) + + v_scan_sq = v_scan*v_scan.conj() + g1sq = g1*(g1.conj()) + g2sq = g2*(g2.conj()) + + ################################### + # data term chi^2 derivitive + ################################### + + # chi^2 term gradients + dchisq_dg1r = (-g2.conj()*vis.conj()*v_scan - g2*vis*v_scan.conj() + 2*g1r*g2sq*v_scan_sq) + dchisq_dg1i = (-1j*g2.conj()*vis.conj()*v_scan + 1j*g2*vis*v_scan.conj() + 2*g1i*g2sq*v_scan_sq) + + dchisq_dg2r = (-g1*vis.conj()*v_scan - g1.conj()*vis*v_scan.conj() + 2*g2r*g1sq*v_scan_sq) + dchisq_dg2i = (1j*g1*vis.conj()*v_scan - 1j*g1.conj()*vis*v_scan.conj() + 2*g2i*g1sq*v_scan_sq) + + + dchisq_dg1r *= ((sigma_inv)**2) + dchisq_dg1i *= ((sigma_inv)**2) + dchisq_dg2r *= ((sigma_inv)**2) + dchisq_dg2i *= ((sigma_inv)**2) + + # same masking function as in errfunc + # preserve length of dchisq arrays + verr = vis - g1 * g2.conj() * v_scan + nan_mask = np.isnan(verr) + + dchisq_dg1r[nan_mask] = 0 + dchisq_dg1i[nan_mask] = 0 + dchisq_dg2r[nan_mask] = 0 + dchisq_dg2i[nan_mask] = 0 + + # derivitives of real and imaginary gains + dchisq_dgr = np.zeros(len(gpar)//2) #len(gpar) must be even + dchisq_dgi = np.zeros(len(gpar)//2) + + # TODO faster than a for loop? + for i in range(len(gpar)//2): + g1idx = np.argwhere(np.array(g1_keys)==i) + g2idx = np.argwhere(np.array(g2_keys)==i) + + dchisq_dgr[i] = np.sum(dchisq_dg1r[g1idx]) + np.sum(dchisq_dg2r[g2idx]) + dchisq_dgi[i] = np.sum(dchisq_dg1i[g1idx]) + np.sum(dchisq_dg2i[g2idx]) + + ################################### + # prior term chi^2 derivitive + ################################### + + # NOTE this derivitive doesn't account for possible sharp change in tol at g=1 + gsq = np.abs(g[:-1])**2 # don't count default missing site dummy value + tolsq = ((np.abs(g[:-1]) > 1) * tol0 + (np.abs(g[:-1]) <= 1) * tol1)**2 + + dchisqg_dgr = gr*np.log(gsq)/gsq/tolsq + dchisqg_dgi = gi*np.log(gsq)/gsq/tolsq + + # total derivative + dchisqtot_dgr = dchisq_dgr + dchisqg_dgr + dchisqtot_dgi = dchisq_dgi + dchisqg_dgi + + # interleave final derivs + dchisqtot_dgpar = np.zeros(len(gpar)) + dchisqtot_dgpar[0::2] = dchisqtot_dgr + dchisqtot_dgpar[1::2] = dchisqtot_dgi + + # any imaginary parts??? should all be real + dchisqtot_dgpar = np.real(dchisqtot_dgpar) + + return dchisqtot_dgpar + + + diff --git a/caltable.py b/caltable.py new file mode 100644 index 00000000..ed47e617 --- /dev/null +++ b/caltable.py @@ -0,0 +1,1031 @@ +# caltable.py +# a calibration table class +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import matplotlib.pyplot as plt +import os +import copy +import scipy.interpolate + +import ehtim.io.save +import ehtim.io.load + +import ehtim.const_def as ehc +import ehtim.observing.obs_helpers as obsh + + +################################################################################################## +# Caltable object +################################################################################################## + + +class Caltable(object): + """ + Attributes: + source (str): The source name + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + mjd (int): The integer MJD of the observation + rf (float): The observation frequency in Hz + bw (float): The observation bandwidth in Hz + timetype (str): How to interpret tstart and tstop; either 'GMST' or 'UTC' + + tarr (numpy.recarray): The array of telescope data with datatype DTARR + tkey (dict): A dictionary of rows in the tarr for each site name + + data (dict): keys are sites in tarr, entries are calibration data tables of type DTCAL + + """ + + def __init__(self, ra, dec, rf, bw, datadict, tarr, + source=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, timetype='UTC'): + """A Calibration Table. + + Args: + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + rf (float): The observation frequency in Hz + mjd (int): The integer MJD of the observation + bw (float): The observation bandwidth in Hz + + datadict (dict): keys are sites in tarr, entries are data tables of type DTCAL + tarr (numpy.recarray): The array of telescope data with datatype DTARR + + source (str): The source name + mjd (int): The integer MJD of the observation + timetype (str): How to interpret tstart and tstop; either 'GMST' or 'UTC' + + Returns: + (Caltable): an Caltable object + """ + + # Set the various parameters + self.source = str(source) + self.ra = float(ra) + self.dec = float(dec) + self.rf = float(rf) + self.bw = float(bw) + self.mjd = int(mjd) + + if timetype not in ['GMST', 'UTC']: + raise Exception("timetype must by 'GMST' or 'UTC'") + self.timetype = timetype + + # Dictionary of array indices for site names + self.tarr = tarr + self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))} + + # Save the data + self.data = datadict + + def copy(self): + """Copy the observation object. + + Args: + + Returns: + (Caltable): a copy of the Caltable object. + """ + new_caltable = Caltable(self.ra, self.dec, self.rf, self.bw, self.data, self.tarr, + source=self.source, mjd=self.mjd, timetype=self.timetype) + return new_caltable + + def plot_dterms(self, sites='all', label=None, legend=True, clist=ehc.SCOLORS, + rangex=False, rangey=False, markersize=2 * ehc.MARKERSIZE, + show=True, grid=True, export_pdf=""): + """Make a plot of the D-terms. + + Args: + sites (list) : list of sites to plot + label (str) : title for plot + legend (bool) : add telescope legend or not + clist (list) : list of colors for different stations + rangex (list) : lower and upper x-axis limits + rangey (list) : lower and upper y-axis limits + markersize (float) : marker size + show (bool) : display the plot or not + grid (bool) : add a grid to the plot or not + export_pdf (str) : save a pdf file to this path + + Returns: + matplotlib.axes + """ + # sites + if sites in ['all' or 'All'] or sites == []: + sites = list(self.data.keys()) + + if not isinstance(sites, list): + sites = [sites] + + keys = [self.tkey[site] for site in sites] + + axes = plot_tarr_dterms(self.tarr, keys=keys, label=label, legend=legend, clist=clist, + rangex=rangex, rangey=rangey, markersize=markersize, + show=show, grid=grid, export_pdf=export_pdf) + + return axes + + def plot_gains(self, sites, gain_type='amp', pol='R', label=None, + ang_unit='deg', timetype=False, yscale='log', legend=True, + clist=ehc.SCOLORS, rangex=False, rangey=False, markersize=[ehc.MARKERSIZE], + show=True, grid=False, axislabels=True, axis=False, export_pdf=""): + """Plot gains on multiple sites vs time. + Args: + sites (list): a list of site names for which to plot gains. Empty list is all sites. + gain_type (str): 'amp' or 'phase' + pol str(str): 'R' or 'L' + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + yscale (str): 'log' or 'lin', + clist (list): list of colors for the plot + label (str): base label for legend + + rangex (list): [xmin, xmax] x-axis (time) limits + rangey (list): [ymin, ymax] y-axis (gain) limits + + legend (bool): Plot legend if True + grid (bool): Plot gridlines if True + axislabels (bool): Show axis labels if True + show (bool): Display the plot if true + axis (matplotlib.axes.Axes): add plot to this axis + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + + Returns: + (matplotlib.axes.Axes): Axes object with the plot + """ + + colors = iter(clist) + + if timetype is False: + timetype = self.timetype + if timetype not in ['GMST', 'UTC', 'utc', 'gmst']: + raise Exception("timetype should be 'GMST' or 'UTC'!") + if gain_type not in ['amp', 'phase']: + raise Exception("gain_type must be 'amp' or 'phase' ") + if pol not in ['R', 'L', 'both']: + raise Exception("pol must be 'R' or 'L'") + + if ang_unit == 'deg': + angle = ehc.DEGREE + else: + angle = 1.0 + + # axis + if axis: + x = axis + else: + fig = plt.figure() + x = fig.add_subplot(1, 1, 1) + + # sites + if sites in ['all' or 'All'] or sites == []: + sites = sorted(list(self.data.keys())) + + if not isinstance(sites, list): + sites = [sites] + + if len(markersize) == 1: + markersize = markersize * np.ones(len(sites)) + + # plot gain on each site + tmins = tmaxes = gmins = gmaxes = [] + for s in range(len(sites)): + site = sites[s] + times = self.data[site]['time'] + if timetype in ['UTC', 'utc'] and self.timetype == 'GMST': + times = obsh.gmst_to_utc(times, self.mjd) + elif timetype in ['GMST', 'gmst'] and self.timetype == 'UTC': + times = obsh.utc_to_gmst(times, self.mjd) + if pol == 'R': + gains = self.data[site]['rscale'] + elif pol == 'L': + gains = self.data[site]['lscale'] + + if gain_type == 'amp': + gains = np.abs(gains) + ylabel = r'$|G|$' + + if gain_type == 'phase': + gains = np.angle(gains) / angle + if ang_unit == 'deg': + ylabel = r'arg($|G|$) ($^\circ$)' + else: + ylabel = r'arg($|G|$) (radian)' + + tmins.append(np.min(times)) + tmaxes.append(np.max(times)) + gmins.append(np.min(gains)) + gmaxes.append(np.max(gains)) + + # Plot the data + if label is None: + bllabel = str(site) + else: + bllabel = label + ' ' + str(site) + plt.plot(times, gains, color=next(colors), marker='o', markersize=markersize[s], + label=bllabel, linestyle='none') + + if not rangex: + rangex = [np.min(tmins) - 0.2 * np.abs(np.min(tmins)), + np.max(tmaxes) + 0.2 * np.abs(np.max(tmaxes))] + if np.any(np.isnan(np.array(rangex))): + print("Warning: NaN in data x range: specifying rangex to default") + rangex = [0, 24] + if not rangey: + rangey = [np.min(gmins) - 0.2 * np.abs(np.min(gmins)), + np.max(gmaxes) + 0.2 * np.abs(np.max(gmaxes))] + if np.any(np.isnan(np.array(rangey))): + print("Warning: NaN in data x range: specifying rangey to default") + rangey = [1.e-2, 1.e2] + + plt.plot(np.linspace(rangex[0], rangex[1], 5), np.ones(5), 'k--') + x.set_xlim(rangex) + x.set_ylim(rangey) + + # labels + if axislabels: + x.set_xlabel(self.timetype + ' (hr)') + x.set_ylabel(ylabel) + plt.title('Caltable gains for %s on day %s' % (self.source, self.mjd)) + + if legend: + plt.legend() + + if yscale == 'log': + x.set_yscale('log') + if grid: + x.grid() + if export_pdf != "" and not axis: + fig.savefig(export_pdf, bbox_inches='tight') + if show: + #plt.show(block=False) + ehc.show_noblock() + + return x + + def enforce_positive(self, method='median', min_gain=0.9, sites=[], verbose=True): + """Enforce that caltable gains are not low + (e.g., that sites are not significantly more sensitive than estimated). + By rescaling the entire gain curve to enforce a specified minimum site gain. + + Args: + caltab (Caltable): Input Caltable with station gains + method (str): 'median', 'mean', or 'min' + min_gain (float): Site gains above this value are not modified. + sites (list): List of sites to check and adjust. For sites=[], all sites are fixed. + verbose (bool): If True, print corrections. + + Returns: + (Caltable): Axes object with the plot + """ + + if len(sites) == 0: + sites = self.data.keys() + + caltab_pos = self.copy() + for site in self.data.keys(): + if site not in sites: + continue + if len(self.data[site]['rscale']) == 0: + continue + + if method == 'min': + sitemin = np.min([np.abs(self.data[site]['rscale']), + np.abs(self.data[site]['lscale'])]) + elif method == 'mean': + sitemin = np.mean([np.abs(self.data[site]['rscale']), + np.abs(self.data[site]['lscale'])]) + elif method == 'median': + sitemin = np.median([np.abs(self.data[site]['rscale']), + np.abs(self.data[site]['lscale'])]) + else: + print('Method ' + method + ' not recognized!') + return caltab_pos + + if sitemin < min_gain: + if verbose: + print(method + ' gain for ' + site + ' is ' + str(sitemin) + '. Rescaling.') + caltab_pos.data[site]['rscale'] /= sitemin + caltab_pos.data[site]['lscale'] /= sitemin + else: + if verbose: + print(method + ' gain for ' + site + ' is ' + str(sitemin) + '. Not adjusting.') + + return caltab_pos + + # TODO default extrapolation? + def pad_scans(self, maxdiff=60, padtype='median'): + """Pad data points around scans. + + Args: + maxdiff (float): "scan" separation length (seconds) + padtype (str): padding type, 'endval' or 'median' + + Returns: + (Caltable): a padded caltable object + """ + + outdict = {} + scopes = list(self.data.keys()) + for scope in scopes: + if np.any(self.data[scope] is None) or len(self.data[scope]) == 0: + continue + + caldata = copy.deepcopy(self.data[scope]) + + # Gather data into "scans" + # TODO we could use a scan table for this as well! + gathered_data = [] + scandata = [caldata[0]] + for i in range(1, len(caldata)): + if (caldata[i]['time'] - caldata[i - 1]['time']) * 3600 > maxdiff: + scandata = np.array(scandata, dtype=ehc.DTCAL) + gathered_data.append(scandata) + scandata = [caldata[i]] + else: + scandata.append(caldata[i]) + + # This adds the last scan + scandata = np.array(scandata) + gathered_data.append(scandata) + + # Compute padding values and pad scans + for i in range(len(gathered_data)): + gg = gathered_data[i] + + medR = np.median(gg['rscale']) + medL = np.median(gg['lscale']) + + timepre = gg['time'][0] - maxdiff / 2. / 3600. + timepost = gg['time'][-1] + maxdiff / 2. / 3600. + + if padtype == 'median': # pad with median scan value + medR = np.median(gg['rscale']) + medL = np.median(gg['lscale']) + preR = medR + postR = medR + preL = medL + postL = medL + elif padtype == 'endval': # pad with endpoints + preR = gg['rscale'][0] + postR = gg['rscale'][-1] + preL = gg['lscale'][0] + postL = gg['lscale'][-1] + else: # pad with ones + preR = 1. + postR = 1. + preL = 1. + postL = 1. + + valspre = np.array([(timepre, preR, preL)], dtype=ehc.DTCAL) + valspost = np.array([(timepost, postR, postL)], dtype=ehc.DTCAL) + + gg = np.insert(gg, 0, valspre) + gg = np.append(gg, valspost) + + # output data table + if i == 0: + caldata_out = gg + else: + caldata_out = np.append(caldata_out, gg) + + try: + caldata_out # TODO: refractor to avoid using exception + except NameError: + print("No gathered_data") + else: + outdict[scope] = caldata_out + + return Caltable(self.ra, self.dec, self.rf, self.bw, outdict, self.tarr, + source=self.source, mjd=self.mjd, timetype=self.timetype) + + def applycal(self, obs, interp='linear', extrapolate=None, + force_singlepol=False, copy_closure_tables=True): + """Apply the calibration table to an observation. + + Args: + obs (Obsdata): The observation with data to be calibrated + interp (str): Interpolation method ('linear','nearest','cubic') + extrapolate (bool): If True, points outside interpolation range will be extrapolated. + force_singlepol (str): If 'L' or 'R', will set opposite polarization gains + equal to chosen polarization + + Returns: + (Obsdata): the calibrated Obsdata object + """ + if not (self.tarr == obs.tarr).all(): + raise Exception("The telescope array in the Caltable is not the same as in the Obsdata") + + if extrapolate is True: # extrapolate can be a tuple or numpy array + fill_value = "extrapolate" + else: + fill_value = extrapolate + + obs_orig = obs.copy() # Need to do this before switch_polrep to keep tables + orig_polrep = obs.polrep + obs = obs.switch_polrep('circ') + + rinterp = {} + linterp = {} + skipsites = [] + for s in range(0, len(self.tarr)): + site = self.tarr[s]['site'] + + try: + self.data[site] + except KeyError: + skipsites.append(site) + print("No Calibration Data for %s !" % site) + continue + + time_mjd = self.data[site]['time'] / 24.0 + self.mjd + rinterp[site] = relaxed_interp1d(time_mjd, self.data[site]['rscale'], + kind=interp, fill_value=fill_value, bounds_error=False) + linterp[site] = relaxed_interp1d(time_mjd, self.data[site]['lscale'], + kind=interp, fill_value=fill_value, bounds_error=False) + + bllist = obs.bllist() + datatable = [] + for bl_obs in bllist: + t1 = bl_obs['t1'][0] + t2 = bl_obs['t2'][0] + time_mjd = bl_obs['time'] / 24.0 + obs.mjd + + if t1 in skipsites: + rscale1 = lscale1 = np.array(1.) + else: + rscale1 = rinterp[t1](time_mjd) + lscale1 = linterp[t1](time_mjd) + if t2 in skipsites: + rscale2 = lscale2 = np.array(1.) + else: + rscale2 = rinterp[t2](time_mjd) + lscale2 = linterp[t2](time_mjd) + + if force_singlepol == 'R': + lscale1 = rscale1 + lscale2 = rscale2 + + if force_singlepol == 'L': + rscale1 = lscale1 + rscale2 = lscale2 + + rrscale = rscale1 * rscale2.conj() + llscale = lscale1 * lscale2.conj() + rlscale = rscale1 * lscale2.conj() + lrscale = lscale1 * rscale2.conj() + + bl_obs['rrvis'] = (bl_obs['rrvis']) * rrscale + bl_obs['llvis'] = (bl_obs['llvis']) * llscale + bl_obs['rlvis'] = (bl_obs['rlvis']) * rlscale + bl_obs['lrvis'] = (bl_obs['lrvis']) * lrscale + + bl_obs['rrsigma'] = bl_obs['rrsigma'] * np.abs(rrscale) + bl_obs['llsigma'] = bl_obs['llsigma'] * np.abs(llscale) + bl_obs['rlsigma'] = bl_obs['rlsigma'] * np.abs(rlscale) + bl_obs['lrsigma'] = bl_obs['lrsigma'] * np.abs(lrscale) + + if len(datatable): + datatable = np.hstack((datatable, bl_obs)) + else: + datatable = bl_obs + + calobs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, + np.array(datatable), obs.tarr, + polrep=obs.polrep, scantable=obs.scans, source=obs.source, + mjd=obs.mjd, ampcal=obs.ampcal, phasecal=obs.phasecal, + opacitycal=obs.opacitycal, dcal=obs.dcal, + frcal=obs.frcal, timetype=obs.timetype) + calobs = calobs.switch_polrep(orig_polrep) + + if copy_closure_tables: + calobs.camp = obs_orig.camp + calobs.logcamp = obs_orig.logcamp + calobs.cphase = obs_orig.cphase + + return calobs + + def merge(self, caltablelist, interp='linear', extrapolate=1): + """Merge the calibration table with a list of other calibration tables + + Args: + caltablelist (list): The list of caltables to be merged + interp (str): Interpolation method ('linear','nearest','cubic') + extrapolate (bool): If True, points outside interpolation range will be extrapolated. + + Returns: + (Caltable): the merged Caltable object + """ + + if extrapolate is True: # extrapolate can be a tuple or numpy array + fill_value = "extrapolate" + else: + fill_value = extrapolate + + if not hasattr(caltablelist, '__iter__'): + caltablelist = [caltablelist] + + tarr1 = self.tarr.copy() + tkey1 = self.tkey.copy() + data1 = self.data.copy() + for caltable in caltablelist: + + # TODO check metadata! + + # TODO CHECK ARE THEY ALL REFERENCED TO SAME MJD??? + tarr2 = caltable.tarr.copy() + tkey2 = caltable.tkey.copy() + data2 = caltable.data.copy() + sites2 = list(data2.keys()) + sites1 = list(data1.keys()) + for site in sites2: + if site in sites1: # if site in both tables + + # merge the data by interpolating + time1 = data1[site]['time'] + time2 = data2[site]['time'] + + rinterp1 = relaxed_interp1d(time1, data1[site]['rscale'], + kind=interp, fill_value=fill_value, + bounds_error=False) + linterp1 = relaxed_interp1d(time1, data1[site]['lscale'], + kind=interp, fill_value=fill_value, + bounds_error=False) + rinterp2 = relaxed_interp1d(time2, data2[site]['rscale'], + kind=interp, fill_value=fill_value, + bounds_error=False) + linterp2 = relaxed_interp1d(time2, data2[site]['lscale'], + kind=interp, fill_value=fill_value, + bounds_error=False) + + times_merge = np.unique(np.hstack((time1, time2))) + + rscale_merge = rinterp1(times_merge) * rinterp2(times_merge) + lscale_merge = linterp1(times_merge) * linterp2(times_merge) + + # put the merged data back in data1 + # TODO can we do this faster? + datatable = [] + for i in range(len(times_merge)): + datatable.append( + np.array((times_merge[i], rscale_merge[i], lscale_merge[i]), + dtype=ehc.DTCAL)) + data1[site] = np.array(datatable) + + # sites not in both caltables + else: + if site not in tkey1.keys(): + tarr1 = np.append(tarr1, tarr2[tkey2[site]]) + data1[site] = data2[site] + + # update tkeys every time + tkey1 = {tarr1[i]['site']: i for i in range(len(tarr1))} + + new_caltable = Caltable(self.ra, self.dec, self.rf, self.bw, data1, tarr1, + source=self.source, mjd=self.mjd, timetype=self.timetype) + + return new_caltable + + def save_txt(self, obs, datadir='.', sqrt_gains=False): + """Saves a Caltable object to text files in the given directory + Args: + obs (Obsdata): The observation object associated with the Caltable + datadir (str): directory to save caltable in + sqrt_gains (bool): If True, we square gains before saving. + + Returns: + """ + + return save_caltable(self, obs, datadir=datadir, sqrt_gains=sqrt_gains) + + def scan_avg(self, obs, incoherent=True): + """average the gains across scans. + + Args: + obs (ehtim.Obsdata) : input observation + incoherent (bool) : True to average gain amps, False to average amps+phase + + Returns: + (Caltable): the averaged Caltable object + """ + sites = self.data.keys() + ntele = len(sites) + + datatables = {} + + # iterate over each site + for s in range(0, ntele): + site = sites[s] + + # make a list of times that is the same value for all points in the same scan + times = self.data[site]['time'] + times_stable = times.copy() + obs.add_scans() + scans = obs.scans + for j in range(len(times_stable)): + for scan in scans: + if scan[0] <= times_stable[j] and scan[1] >= times_stable[j]: + times_stable[j] = scan[0] + break + + datatable = [] + for scan in scans: + gains_l = self.data[site]['lscale'] + gains_r = self.data[site]['rscale'] + + # if incoherent average then average the magnitude of gains + if incoherent: + gains_l = np.abs(gains_l) + gains_r = np.abs(gains_r) + + # average the gains + gains_l_avg = np.mean(gains_l[np.array(times_stable == scan[0])]) + gains_r_avg = np.mean(gains_r[np.array(times_stable == scan[0])]) + + # add them to a new datatable + datatable.append(np.array((scan[0], gains_r_avg, gains_l_avg), dtype=ehc.DTCAL)) + + datatables[site] = np.array(datatable) + + if len(datatables) > 0: + caltable = Caltable(obs.ra, obs.dec, obs.rf, + obs.bw, datatables, obs.tarr, source=obs.source, + mjd=obs.mjd, timetype=obs.timetype) + else: + caltable = False + + return caltable + + def invert_gains(self): + + sites = self.data.keys() + + for site in sites: + self.data[site]['rscale'] = 1 / self.data[site]['rscale'] + self.data[site]['lscale'] = 1 / self.data[site]['lscale'] + + return self + + +def load_caltable(obs, datadir, sqrt_gains=False): + """Load apriori Caltable object from text files in the given directory + Args: + obs (Obsdata): The observation object associated with the Caltable + datadir (str): directory to save caltable in + sqrt_gains (bool): If True, we take the sqrt of table gains before loading. + + Returns: + (Caltable): a caltable object + """ + + tarr = obs.tarr + array_filename = datadir + '/array.txt' + if os.path.exists(array_filename): + tarr = ehtim.io.load.load_array_txt(array_filename).tarr + + datatables = {} + for s in range(0, len(tarr)): + + site = tarr[s]['site'] + filename = os.path.join(datadir, obs.source + '_' + site + '.txt') + try: + data = np.loadtxt(filename, dtype=bytes).astype(str) + except IOError: + try: + filename = datadir + site + '.txt' + data = np.loadtxt(filename, dtype=bytes).astype(str) + except IOError: + continue + + + datatable = [] + + for row in data: + + time = (float(row[0]) - obs.mjd) * 24.0 # time is given in mjd + + if len(row) == 3: + rscale = float(row[1]) + lscale = float(row[2]) + elif len(row) == 5: + rscale = float(row[1]) + 1j * float(row[2]) + lscale = float(row[3]) + 1j * float(row[4]) + else: + raise Exception("cannot load caltable -- format unknown!") + if sqrt_gains: + rscale = rscale**.5 + lscale = lscale**.5 + datatable.append(np.array((time, rscale, lscale), dtype=ehc.DTCAL)) + + datatables[site] = np.array(datatable) + if len(datatables) > 0: + caltable = Caltable(obs.ra, obs.dec, obs.rf, obs.bw, datatables, tarr, + source=obs.source, mjd=obs.mjd, timetype=obs.timetype) + else: + print("COULD NOT FIND CALTABLE IN DIRECTORY %s" % datadir) + caltable = False + return caltable + + +def save_caltable(caltable, obs, datadir='.', sqrt_gains=False): + """Saves a Caltable object to text files in the given directory + Args: + obs (Obsdata): The observation object associated with the Caltable + datadir (str): directory to save caltable in + sqrt_gains (bool): If True, we square gains before saving. + + Returns: + """ + + if not os.path.exists(datadir): + os.makedirs(datadir) + + ehtim.io.save.save_array_txt(obs.tarr, datadir + '/array.txt') + + datatables = caltable.data + src = caltable.source + for site_info in caltable.tarr: + site = site_info['site'] + + if len(datatables.get(site, [])) == 0: + continue + + filename = datadir + '/' + src + '_' + site + '.txt' + outfile = open(filename, 'w') + site_data = datatables[site] + for entry in site_data: + time = entry['time'] / 24.0 + obs.mjd + + if sqrt_gains: + rscale = np.square(entry['rscale']) + lscale = np.square(entry['lscale']) + else: + rscale = entry['rscale'] + lscale = entry['lscale'] + + rreal = float(np.real(rscale)) + rimag = float(np.imag(rscale)) + lreal = float(np.real(lscale)) + limag = float(np.imag(lscale)) + outline = (str(float(time)) + ' ' + + str(float(rreal)) + ' ' + str(float(rimag)) + ' ' + + str(float(lreal)) + ' ' + str(float(limag)) + '\n') + outfile.write(outline) + outfile.close() + + return + + +def make_caltable(obs, gains, sites, times): + """Create a Caltable object for an observation + Args: + obs (Obsdata): The observation object associated with the Caltable + gains (list): list of gains (?? format ??) + sites (list): list of sites + times (list): list of times + + Returns: + (Caltable): a caltable object + """ + ntele = len(sites) + ntimes = len(times) + + datatables = {} + for s in range(0, ntele): + datatable = [] + for t in range(0, ntimes): + gain = gains[s * ntele + t] + datatable.append(np.array((times[t], gain, gain), dtype=ehc.DTCAL)) + datatables[sites[s]] = np.array(datatable) + if len(datatables) > 0: + caltable = Caltable(obs.ra, obs.dec, obs.rf, + obs.bw, datatables, obs.tarr, source=obs.source, + mjd=obs.mjd, timetype=obs.timetype) + else: + caltable = False + + return caltable + + +def relaxed_interp1d(x, y, **kwargs): + try: + len(x) + except TypeError: + x = np.asarray([x]) + y = np.asarray([y]) # allows to run on a single float number + if len(x) == 1: + x = np.array([-0.5, 0.5]) + x[0] + y = np.array([1.0, 1.0]) * y[0] + return scipy.interpolate.interp1d(x, y, **kwargs) + + +def plot_tarr_dterms(tarr, keys=None, label=None, legend=True, clist=ehc.SCOLORS, + rangex=False, rangey=False, markersize=2 * ehc.MARKERSIZE, + show=True, grid=True, export_pdf="", auto_order=True): + + if auto_order: + # Ensure that the plot will put the stations in alphabetical order + keys = np.argsort(tarr['site']) # range(len(tarr)) + else: + keys = range(len(tarr)) + + colors = iter(clist) + + if export_pdf != "": + fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True, sharex=True, figsize=(16, 8)) + else: + fig, axes = plt.subplots(nrows=1, ncols=2, sharey=True, sharex=True) + + for key in keys: + + # get the label + site = str(tarr[key]['site']) + if label is None: + bllabel = str(site) + else: + bllabel = label + ' ' + str(site) + color = next(colors) + + axes[0].plot(np.real(tarr[key]['dr']), np.imag(tarr[key]['dr']), + color=color, marker='o', markersize=markersize, + label=bllabel, linestyle='none') + axes[0].set_title("Right D-terms") + axes[0].set_xlabel("Real") + axes[0].set_ylabel("Imaginary") + + axes[1].plot(np.real(tarr[key]['dl']), np.imag(tarr[key]['dl']), + color=color, marker='o', markersize=markersize, + label=bllabel, linestyle='none') + axes[1].set_title("Left D-terms") + axes[1].set_xlabel("Real") + axes[1].set_ylabel("Imaginary") + + axes[0].axhline(y=0, color='k') + axes[0].axvline(x=0, color='k') + axes[1].axhline(y=0, color='k') + axes[1].axvline(x=0, color='k') + + if grid: + axes[0].grid() + axes[1].grid() + + if rangex: + axes[0].set_xlim(rangex) + axes[1].set_xlim(rangex) + + if rangey: + axes[0].set_ylim(rangey) + axes[1].set_ylim(rangey) + + if legend: + axes[1].legend(loc='center left', bbox_to_anchor=(1, 0.5)) + + if export_pdf != "": + fig.savefig(export_pdf, bbox_inches='tight') + + return axes + + +def plot_compare_gains(caltab1, caltab2, obs, sites='all', pol='R', gain_type='amp', ang_unit='deg', + scan_avg=True, site_name_dict=None, fontsize=13, legend_fontsize=13, + yscale='log', legend=True, clist=ehc.SCOLORS, + rangex=False, rangey=False, scalefac=[0.9, 1.1], + markersize=[2 * ehc.MARKERSIZE], show=True, grid=False, + axislabels=True, remove_ticks=False, axis=False, + export_pdf=""): + + colors = iter(clist) + + if ang_unit == 'deg': + angle = ehc.DEGREE + else: + angle = 1.0 + + # axis + if axis: + x = axis + else: + fig = plt.figure() + x = fig.add_subplot(1, 1, 1) + + if scan_avg: + caltab1 = caltab1.scan_avg(obs, incoherent=True) + caltab2 = caltab2.scan_avg(obs, incoherent=True) + + # sites + if sites in ['all' or 'All'] or sites == []: + sites = list(set(caltab1.data.keys()).intersection(caltab2.data.keys())) + + if not isinstance(sites, list): + sites = [sites] + + if site_name_dict is None: + print('hi') + site_name_dict = {} + for site in sites: + site_name_dict[site] = site + + if len(markersize) == 1: + markersize = markersize * np.ones(len(sites)) + + maxgain = 0.0 + mingain = 10000 + + for s in range(len(sites)): + + site = sites[s] + if pol == 'R': + gains1 = caltab1.data[site]['rscale'] + gains2 = caltab2.data[site]['rscale'] + elif pol == 'L': + gains1 = caltab1.data[site]['lscale'] + gains2 = caltab2.data[site]['lscale'] + + if gain_type == 'amp': + gains1 = np.abs(gains1) + gains2 = np.abs(gains2) + ylabel = 'Amplitudes' # r'$|G|$' + + if gain_type == 'phase': + gains1 = np.angle(gains1) / angle + gains2 = np.angle(gains2) / angle + if ang_unit == 'deg': + ylabel = r'arg($|G|$) ($^\circ$)' + else: + ylabel = 'Phases (radian)' # r'arg($|G|$) (radian)' + + # print a line + maxgain = np.nanmax([maxgain, np.nanmax(gains1), np.nanmax(gains2)]) + mingain = np.nanmin([mingain, np.nanmin(gains1), np.nanmin(gains2)]) + + # mark the gains on the plot + plt.plot(gains1, gains2, marker='.', linestyle='None', color=next( + colors), markersize=markersize[s], label=site_name_dict[site]) + + plt.xticks(fontsize=fontsize) + plt.yticks(fontsize=fontsize) + + plt.axes().set_aspect('equal') + + if rangex: + x.set_xlim(rangex) + else: + x.set_xlim([mingain * scalefac[0], maxgain * scalefac[1]]) + if rangey: + x.set_ylim(rangey) + else: + x.set_ylim([mingain * scalefac[0], maxgain * scalefac[1]]) + + plt.plot([mingain * scalefac[0], maxgain * scalefac[1]], + [mingain * scalefac[0], maxgain * scalefac[1]], 'grey', linewidth=1) + + # labels + if axislabels: + x.set_xlabel('Ground Truth Gain ' + ylabel, fontsize=fontsize) + x.set_ylabel('Recovered Gain ' + ylabel, fontsize=fontsize) + else: + x.tick_params(axis="y", direction="in", pad=-30) + x.tick_params(axis="x", direction="in", pad=-18) + + if remove_ticks: + plt.setp(x.get_xticklabels(), visible=False) + plt.setp(x.get_yticklabels(), visible=False) + + if legend: + plt.legend(frameon=False, fontsize=legend_fontsize) + + if yscale == 'log': + x.set_yscale('log') + x.set_xscale('log') + if grid: + x.grid() + if export_pdf != "" and not axis: + fig.savefig(export_pdf, bbox_inches='tight') + if show: + #plt.show(block=False) + ehc.show_noblock() + + return x diff --git a/const_def.py b/const_def.py new file mode 100644 index 00000000..67ed1f28 --- /dev/null +++ b/const_def.py @@ -0,0 +1,377 @@ +# const_def.py +# useful constants and definitions +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import matplotlib as mpl +import matplotlib.pyplot as plt +from packaging import version + +from ehtim.observing.pulses import trianglePulse2D + +mpl.rc('font', **{'family': 'serif', 'size': 12}) + +EP = 1.0e-10 +C = 299792458.0 +DEGREE = 3.141592653589/180.0 +HOUR = 15.0*DEGREE +RADPERAS = DEGREE/3600.0 +RADPERUAS = RADPERAS*1.e-6 + +# Default Parameters +SOURCE_DEFAULT = "SgrA" +RA_DEFAULT = 17.761122472222223 +DEC_DEFAULT = -28.992189444444445 +RF_DEFAULT = 230e9 +MJD_DEFAULT = 51544 +PULSE_DEFAULT = trianglePulse2D + +# movie parameters +INTERP_DEFAULT = 'linear' +BOUNDS_ERROR = True # When False, movie will return NEAREST NEIGHBOR frames +# for times beyond [movie.start_hr, movie.stop_hr] + +# Telescope elevation cuts (degrees) +ELEV_LOW = 10.0 +ELEV_HIGH = 85.0 + +TAUDEF = 0.1 # Default Optical Depth +GAINPDEF = 0.1 # Default rms of gain errors +DTERMPDEF = 0.05 # Default rms of D-term errors + +# Sgr A* Kernel Values (Bower et al., in uas/cm^2) +FWHM_MAJ = 1.309 * 1000 # in uas +FWHM_MIN = 0.64 * 1000 +POS_ANG = 78 # in degree, E of N + +# FFT & NFFT options +NFFT_KERSIZE_DEFAULT = 20 +GRIDDER_P_RAD_DEFAULT = 2 +GRIDDER_CONV_FUNC_DEFAULT = 'gaussian' +FFT_PAD_DEFAULT = 2 +FFT_INTERP_DEFAULT = 3 + +# Observation recarray datatypes +DTARR = [('site', 'U32'), ('x', 'f8'), ('y', 'f8'), ('z', 'f8'), + ('sefdr', 'f8'), ('sefdl', 'f8'), ('dr', 'c16'), ('dl', 'c16'), + ('fr_par', 'f8'), ('fr_elev', 'f8'), ('fr_off', 'f8')] + +DTPOL_STOKES = [('time', 'f8'), ('tint', 'f8'), + ('t1', 'U32'), ('t2', 'U32'), + ('tau1', 'f8'), ('tau2', 'f8'), + ('u', 'f8'), ('v', 'f8'), + ('vis', 'c16'), ('qvis', 'c16'), ('uvis', 'c16'), ('vvis', 'c16'), + ('sigma', 'f8'), ('qsigma', 'f8'), ('usigma', 'f8'), ('vsigma', 'f8')] + +DTPOL_CIRC = [('time', 'f8'), ('tint', 'f8'), + ('t1', 'U32'), ('t2', 'U32'), + ('tau1', 'f8'), ('tau2', 'f8'), + ('u', 'f8'), ('v', 'f8'), + ('rrvis', 'c16'), ('llvis', 'c16'), ('rlvis', 'c16'), ('lrvis', 'c16'), + ('rrsigma', 'f8'), ('llsigma', 'f8'), ('rlsigma', 'f8'), ('lrsigma', 'f8')] + +DTAMP = [('time', 'f8'), ('tint', 'f8'), + ('t1', 'U32'), ('t2', 'U32'), + ('u', 'f8'), ('v', 'f8'), + ('amp', 'f8'), ('sigma', 'f8')] + +DTBIS = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'), + ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), ('u3', 'f8'), ('v3', 'f8'), + ('bispec', 'c16'), ('sigmab', 'f8')] + +DTCPHASE = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'), + ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), ('u3', 'f8'), ('v3', 'f8'), + ('cphase', 'f8'), ('sigmacp', 'f8')] + +DTCAMP = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'), ('t4', 'U32'), + ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), + ('u3', 'f8'), ('v3', 'f8'), ('u4', 'f8'), ('v4', 'f8'), + ('camp', 'f8'), ('sigmaca', 'f8')] + +DTCPHASEDIAG = [('time', 'f8'), ('cphase', 'f8'), ('sigmacp', 'f8'), + ('triangles', 'O'), ('u', 'O'), ('v', 'O'), ('tform_matrix', 'O')] + +DTLOGCAMPDIAG = [('time', 'f8'), ('camp', 'f8'), ('sigmaca', 'f8'), + ('quadrangles', 'O'), ('u', 'O'), ('v', 'O'), ('tform_matrix', 'O')] + +DTCAL = [('time', 'f8'), ('rscale', 'c16'), ('lscale', 'c16')] + +DTSCANS = [('time', 'f8'), ('interval', 'f8'), ('startvis', 'f8'), ('endvis', 'f8')] + +# Dictionaries for keeping track of polarization fields +POLDICT_STOKES = {'vis1': 'vis', 'vis2': 'qvis', 'vis3': 'uvis', 'vis4': 'vvis', + 'sigma1': 'sigma', 'sigma2': 'qsigma', 'sigma3': 'usigma', 'sigma4': 'vsigma'} +POLDICT_CIRC = {'vis1': 'rrvis', 'vis2': 'llvis', 'vis3': 'rlvis', 'vis4': 'lrvis', + 'sigma1': 'rrsigma', 'sigma2': 'llsigma', 'sigma3': 'rlsigma', 'sigma4': 'lrsigma'} +vis_poldict = {'I': 'vis', 'Q': 'qvis', 'U': 'uvis', 'V': 'vvis', + 'RR': 'rrvis', 'LL': 'llvis', 'RL': 'rlvis', 'LR': 'lrvis'} +amp_poldict = {'I': 'amp', 'Q': 'qamp', 'U': 'uamp', 'V': 'vamp', + 'RR': 'rramp', 'LL': 'llamp', 'RL': 'rlamp', 'LR': 'lramp'} +sig_poldict = {'I': 'sigma', 'Q': 'qsigma', 'U': 'usigma', 'V': 'vsigma', + 'RR': 'rrsigma', 'LL': 'llsigma', 'RL': 'rlsigma', 'LR': 'lrsigma'} + +# Observation fields for plotting and retrieving data +FIELDS = ['time', 'time_utc', 'time_gmst', + 'tint', 'u', 'v', 'uvdist', + 't1', 't2', 'tau1', 'tau2', + 'el1', 'el2', 'hr_ang1', 'hr_ang2', 'par_ang1', 'par_ang2', + 'vis', 'amp', 'phase', 'snr', + 'qvis', 'qamp', 'qphase', 'qsnr', + 'uvis', 'uamp', 'uphase', 'usnr', + 'vvis', 'vamp', 'vphase', 'vsnr', + 'sigma', 'qsigma', 'usigma', 'vsigma', + 'sigma_phase', 'qsigma_phase', 'usigma_phase', 'vsigma_phase', + 'psigma_phase', 'msigma_phase', + 'pvis', 'pamp', 'pphase', 'psnr', + 'evis', 'eamp', 'ephase', 'esnr', + 'bvis', 'bamp', 'bphase', 'bsnr', + 'm', 'mamp', 'mphase', 'msnr', + 'rrvis', 'rramp', 'rrphase', 'rrsnr', 'rrsigma', 'rrsigma_phase', + 'llvis', 'llamp', 'llphase', 'llsnr', 'llsigma', 'llsigma_phase', + 'rlvis', 'rlamp', 'rlphase', 'rlsnr', 'rlsigma', 'rlsigma_phase', + 'lrvis', 'lramp', 'lrphase', 'lrsnr', 'lrsigma', 'lrsigma_phase', + 'rrllvis', 'rrllamp', 'rrllphase', 'rrllsnr', 'rrllsigma', 'rrllsigma_phase'] + +FIELDS_AMPS = ["amp", "qamp", "uamp", "vamp", + "pamp", "mamp", "rramp", "llamp", "rlamp", "lramp", "rrllamp"] +FIELDS_SIGS = ["sigma", "qsigma", "usigma", "vsigma", + "psigma", "msigma", "rrsigma", "llsigma", "rlsigma", "lrsigma", "rrllsigma"] +FIELDS_PHASE = ["phase", "qphase", "uphase", "vphase", "pphase", "mphase", + "rrphase", "llphase", "rlphase", "lrphase", "rrllphase"] +FIELDS_SIGPHASE = ["sigma_phase", "qsigma_phase", "usigma_phase", "vsigma_phase", + "psigma_phase", "msigma_phase", "rrsigma_phase", "llsigma_phase", + "rlsigma_phase", "lrsigma_phase", "rrllsigma_phase"] +FIELDS_SNRS = ["snr", "qsnr", "usnr", "vsnr", "psnr", "msnr", + "rrsnr", "llsnr", "rlsnr", "lrsnr", "rrllsnr"] + +# Plotting +MARKERSIZE = 3 + +def show_noblock(pause=0.001): + """helper function for image display with different matplotlib versions""" + +# this seems to be required for matplotlib version 3.5 + if version.parse(mpl.__version__) <= version.parse('3.2.2') or version.parse(mpl.__version__) > version.parse('3.6'): + plt.show(block=False) + else: + plt.ion() + plt.show() + plt.pause(pause) + plt.draw() + plt.pause(pause) + +FIELD_LABELS = {'time': 'Time', + 'time_utc': 'Time (UTC)', + 'time_gmst': 'Time (GMST)', + 'tint': 'Integration Time', + 'u': r'$u$', + 'v': r'$v$', + 'uvdist': r'$u-v$ Distance', + 't1': 'Site 1', + 't2': 'Site 2', + 'tau1': r'$\tau_1$', + 'tau2': r'$\tau_2$', + 'el1': r'Elevation Angle$_1$', + 'el2': r'Elevation Angle$_2$', + 'hr_ang1': r'Hour Angle$_1$', + 'hr_ang2': r'Hour Angle$_2$', + 'par_ang1': r'Parallactic Angle$_1$', + 'par_ang2': r'Parallactic Angle$_2$', + 'vis': 'Visibility', + 'amp': 'Amplitude', + 'phase': 'Phase', + 'snr': 'SNR', + 'qvis': 'Q-Visibility', + 'qamp': 'Q-Amplitude', + 'qphase': 'Q-Phase', + 'qsnr': 'Q-SNR', + 'uvis': 'U-Visibility', + 'uamp': 'U-Amplitude', + 'uphase': 'U-Phase', + 'usnr': 'U-SNR', + 'vvis': 'V-Visibility', + 'vamp': 'V-Amplitude', + 'vphase': 'V-Phase', + 'vsnr': 'V-SNR', + 'sigma': r'$\sigma$', + 'qsigma': r'$\sigma_{Q}$', + 'usigma': r'$\sigma_{U}$', + 'vsigma': r'$\sigma_{V}$', + 'sigma_phase': r'$\sigma_{phase}$', + 'qsigma_phase': r'$\sigma_{Q phase}$', + 'usigma_phase': r'$\sigma_{U phase}$', + 'vsigma_phase': r'$\sigma_{V phase}$', + 'psigma_phase': r'$\sigma_{P phase}$', + 'msigma_phase': r'$\sigma_{m phase}$', + 'pvis': r'P-Visibility', + 'pamp': r'P-Amplitude', + 'pphase': 'P-Phase', + 'psnr': 'P-SNR', + 'mvis': r'm-Visibility', + 'mamp': r'm-Amplitude', + 'mphase': 'm-Phase', + 'msnr': 'm-SNR', + 'evis': r'E-Visibility', + 'eamp': r'E-Amplitude', + 'ephase': 'E-Phase', + 'esnr': 'E-SNR', + 'bvis': r'B-Visibility', + 'bamp': r'B-Amplitude', + 'bphase': 'B-Phase', + 'bsnr': 'B-SNR', + 'rrvis': r'RR-Visibility', + 'rramp': r'RR-Amplitude', + 'rrphase': 'RR-Phase', + 'rrsnr': 'RR-SNR', + 'rrsigma': r'$\sigma_{RR}$', + 'rrsigma_phase': r'$\sigma_{RR phase}$', + 'llvis': r'LL-Visibility', + 'llamp': r'LL-Amplitude', + 'llphase': 'LL-Phase', + 'llsnr': 'LL-SNR', + 'llsigma': r'$\sigma_{LL}$', + 'llsigma_phase': r'$\sigma_{LL phase}$', + 'rlvis': r'RL-Visibility', + 'rlamp': r'RL-Amplitude', + 'rlphase': 'RL-Phase', + 'rlsnr': 'RL-SNR', + 'rlsigma': r'$\sigma_{RL}$', + 'rlsigma_phase': r'$\sigma_{RL phase}$', + 'lrvis': r'LR-Visibility', + 'lramp': r'LR-Amplitude', + 'lrphase': 'LR-Phase', + 'lrsnr': 'LR-SNR', + 'lrsigma': r'$\sigma_{LR}$', + 'lrsigma_phase': r'$\sigma_{LR phase}$', + 'rrllvis': r'RR/LL-Visibility', + 'rrllamp': r'RR/LL-Amplitude', + 'rrllphase': 'RR/LL-Phase', + 'rrllsnr': 'RR/LL-SNR', + 'rrllsigma': r'$\sigma_{RR/LL}$', + 'rrllsigma_phase': r'$\sigma_{RR/LL phase}$'} + +# Seaborn Colors from Maciek +SCOLORS = [(0.11764705882352941, 0.5647058823529412, 1.0), + (1.0, 0.38823529411764707, 0.2784313725490196), + (0.5411764705882353, 0.16862745098039217, 0.8862745098039215), + (0.4196078431372549, 0.5568627450980392, 0.13725490196078433), + (1.0, 0.6470588235294118, 0.0), + (0.5450980392156862, 0.27058823529411763, 0.07450980392156863), + (0.0, 0.0, 0.803921568627451), + (1.0, 0.0, 0.0), + (0.0, 1.0, 1.0), + (1.0, 0.0, 1.0), + (0.0, 0.39215686274509803, 0.0), + (0.8235294117647058, 0.7058823529411765, 0.5490196078431373), + (0.0, 0.0, 0.0)] + + +BHIMAGE = [ + ' .. ', + ' ..... . ', + ' ........... .... ', + ' ............ ........................ ', + ' .........................,,******,,..... ', + ' .......,,,,..........,,**/(((/*,,... ', + ' .....,,,,,**,.......,**/(#%%#/,... ', + ' .....,,,,,,**,.....,*/(#%&&%/,. ', + ' ......,**,***/*,,,,*/(#%@@*. ', + ' ..,....,****////***/((%&@@@#, ', + ' ..,,,...****///////((#%&@@@%*. ', + ' .,**,,.,*////((((((##%&@@@&/. ', + ' .,//*,,,*/(((((((##%%&@@@@%, ,,,,,. .,.. ', + ' ,/(//***/(###((##%%%&@@@@#, . ,*/,. ... ', + ' .*(((////(#%%###%%&&@@@@@#, ..*(*. .. ', + ' ./((((//(#%&%%%%&&@@@@@@%*. .*#/. . ', + ' ,(#((#(((#&&&&&&@@@@@@@&(.. .(%* .. ', + ' ... ,(#####((#&&@@@@@@@@@@@%/..,##, .. ', + ' .... ,(###%%#(#%&@@@@@@@@@@@%/,#&* .. ', + ' ..... .(###%%%##%&@@@@@@@@@@@&%, . ', + ' ...... ./#%%%%%%##&@@@@@@@@@@@@@%, . ', + ' ........ .*%%%%%&%##%@@@@@@@@@@@@@(,. . ', + ' ..,,,... .. .,#%%%%%&%%%&@@@@@@@@@@@&%/. . ', + ' ..,,,,... ... ../#%%%%&%%%&@@@@@@@@@@@@*. @@ @ @ # ', + ' ..,,,,... ......*(##%%&&%%&&@@@@@@@@@@@&%(,. @ @ @ @ # ### ## . ', + ' ..,**,,.......,.,/###%%&%#%&@@@@@@@@@@@&%#/,.. @@@@ @@@ @@@ %%% # # # # .. ', + ' ..,**,,.......,,*(((#%&&%#%&@@@@@@@@&&&&%(*... @ @ @ @ # # # # .. ', + ' .,,*,,,,.......,/(((#%&&%##&@@&@@@&&%%%##(*.. @@ @ @ @@ ## # # # . ', + ' ..,,,,,,,.......,/((((%&&%#(%&&&@@%%%###((/*.. . ', + ' ...,,.,,,....,,,*/((/(#%&(#%&@@%%%#(///**,.. . ', + ' ........,.....,,*////((%&&%###&@&%###(/*,,,,... .. ', + ' .............,,**///((#%&&%%#@&%%#((((/*....... ,. ', + ' ............,,,**////((#%%&%@@#####((/***,.... .,. ', + ' ....,,*****////((#%%&@(((((((/**,,,..... ,, ', + ' . ...,**/****/////(#%&@&%##((((((//**,....... ..... .*, ', + ' ... ..,*/////***////(#%&%%%###((((///****,,,,.......,,....... .,,. ', + ' .... .,*/((//****///((#%#######((((///*****,,,,,,........... ,*. ', + ' ..... ..,/(((//****///((#(///((((((///******,,,,,............... .*, ', + ' ... ...,*/(#((///////((#(/*****///(((((////***,,,......... .*, ', + ' .....,*(#####(((((##%%##((((((////****,,,......... ,*, ', + ' .....,,*/(##%%#########(///*******,,,,........... .,*. ', + ' .......,,*/(((#########(//**,,,,,,,**,,,... .,*,. ', + ' ..........,,,*****/********,,,.... .,*,. ', + ' . .... ..**. ', + ' . .. ..,,,. ', + ' . ........ ', + ' ................... ', + ' '] + +EHTIMAGE = [ + ' `..----..` ', + ' `-/oyhmNNNNMMMMNNNNmhs+:-` ', + ' `.+ymNMMMMMMMMMMMMMMMMMMMMMMMMNds/. ', + ' `:ymNMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMNdo-` ', + ' `-yNMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMNmo. ', + ' `/dNMMMMMMMMMMMMMMmdysoo+++dMMMMMMMMMMMMMMMMMMMMNh: :////// ', + ' `+mMMMMMMMMMMMNmy+-``` sMMMMMNshNMMMMMMMMMMMMMNh- mNyoooo``` `` ``` ` ``` `ds`` ', + ' :dMMMMMMMMMMMho-` .:- :sd- sMMMMNs.-yMMMMMMMMMMMMMMMNy. mN: sd: /ds -ydyhh+` .dyyhhdh: :dMmyy. ', + ' `yNMMMMMMMMMmo-` `/ymh:`-hNMM- sMMMMMMmNMMMMMMNmMMMMMMMMMMm/ mMmdddh -mm. .Nm.-Nm-..oMs -Mm:``:Mh``oMy`` ', + ' .hMMMMMMMMMNo. :ymMNo` oNMMMM- sMMMMMMMMMMMMMm/`oNMMMMMMMMMNs` mM/...` /My`yM/ +Mmssssh+ -Md .Md` +Ms ', + ' -mMMMMMMMMMs. /hNMMN: sMMMMMM- sMMMMMMMMMNMMMNh+mMMMMMMMMMMMMy` mM/.... sNsMo -mm/```-` -Md .Md` +Mh`` ', + ' -mMMMMMMMMN/ -hNMMMN: oMMMMMMM: sMMMMMMMMy-sNMMMMMMMMMMMMMMMMMMy ydhyyyy` `yhy .oyhhyy- .ds .hy` `sdhy- ', + ' `mMMMMMMMMm- -:::::- -:::::::` sMMMMMMMMmodMMMMMMMMMdMMMMMMMMMMo `````` ` ``` ` ` ``` ', + ' oMMMMMMMMm- -::::::. `-::::::::` sMNhNMMMMMMMMMMMMMMmo`/mMMMMMMMMM: ', + ' -NMMMMMMMM: +NMMMMMM: /NMMMMMMMM- sMd -+hNMMMMMMMMMMMMm+dNMMMMMMMMMh` ', + ' yMMMMMMMMy .NMMMMMMd` hMMMMMMMMM- sMNo-``.+hNMMMMMMMMMMMMMMMMMMMMMMN: .. `.. `:: ', + ' `mMMMMMMMN- yMMMMMMMs NMMMMMMMMM- sMMMNho-` .+hNMMMMMMMMMMMMMMMMMMMMy dm: :mh `sy` ', + ' -NMMMMMMMm .mMMMMMMM+ `NMMMMMMMMM: sMMMMMMNds:` .+hNMMMNhsshMMMMMMMMMm mN: /Mm -osss+- .o::os:`os``ooooss/ `:osss/` :s:+sss/` ', + ' /MMMMMMMMh `////////. //////////` sMm////////. .+Mm: `+MMMMMMMMN` mMhsssshMm +Nh/-/dm: -MNy+/..mM` --:oMd- hNs:-+Nd. oMm+-:hMo ', + ' /MMMMMMMMh --------` ----------` sMm--------. `:Mm. /NMMMMMMMN` mMo////sMm `dM. /Md -Mm` .mM` `oms. .Md yM/ oMs /My ', + ' -NMMMMMMMm .mNNNNNNN+ `mNNNNNNNNN- sMMNNNNNdy/. `:sdMMMms++sNMMMMMMMMN mN: /Mm yM/` `oMy -Mm` .mM` -hm/ `Nm- .hM: oMo /My ', + ' `mMMMMMMMN- `hMMMMMMMs NMMMMMMMMM- sMMMMmy/. `:sdMMMMMMMMMMMMMMMMMMMMh dm: /md .sdyyyds` -md` .dm`.mNhsss+ :hdyyhd+ +N+ /Ns ', + ' hMMMMMMMMs :MMMMMMMh` dMMMMMMMMM- sMNy/. `:smMMMMMMMMMMMMMMMMMMMMMMM/ ..` `.. `---` .. .. ....... .--.` `.` `.` ', + ' :NMMMMMMMM- oMMMMMMM: /MMMMMMMMM: sMd `:smMMMMMMMMMMMMNymMMMMMMMMMMd` ', + ' `yMMMMMMMMd. `///////- `/////////` sMmsmMMMMMMMMMMMMMMm+ /mMMMMMMMMM/ ', + ' .NMMMMMMMMd. ....... ........` sMMMMMMMMNyNMMMMMMMMNyNMMMMMMMMMy` `-- ', + ' :NMMMMMMMMm- :dNNNNm- sNNNNNNN- sMMMMMMMMs`+NMMMMMMMMMMMMMMMMMMd` `yyyhhhyys` .Nd ', + ' /NMMMMMMMMNo` `omMMMm- `hMMMMMM- sMMMMMMMMMmMMMMmsNMMMMMMMMMMMMd. ...hMo... .-/:-` .Nm `-:/-. .-//:- `-:/:. `-:/:.` -..:/:. .-//-. ', + ' :mMMMMMMMMMm/` `+dNMm/``yMMMMM- sMMMMMMMMMMMMMd:`+NMMMMMMMMMMy` yM+ `omy++dm: .Nm :dd++yNo -mN++oy` /mds+s+ :mms+ymh. .MmyoodN+ /mh hm/ ', + ' .hMMMMMMMMMMd+. `-odms.`/mMMM- sMMMMMMNNMMMMMMNhMMMMMMMMMMNo` yM+ /MN+++sNh .Nm `mMo+++Nm..dmy+-` -mM` `mM- +Ms .Md` .mM`.NNo++oNN ', + ' `+mMMMMMMMMMMNy:. `:+: `+hm- sMMMMMd-/mMMMMMMMMMMMMMMMMd: yM+ /Mm:::::- .Nm `mN/:::::` ./sNm--NM` `mM- /Ms .Md` .dM..NN/ ', + ' .sNMMMMMMMMMMMNho:``` ` sMMMMMm/omMMMMMMMMMMMMMMN/` yM+ `sNy+//o: .Nd /md+//++ -o///dN: oNd+/++ :mdo/omd. .MNy/+hNo +Nh+//o: ', + ' .smMMMMMMMMMMMMMMmhyo+/:---hMMMMMMMMMMMMMMMMMMMMMd+` -/. .:+++:` `/: `:+++:. `:+++:. `:++/. `-++/- .Mh-/+/. `:+++:` ', + ' `+dNMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMNh:` .Mh ', + ' -odNMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMMNh/. `+: ', + ' ./sdNMMMMMMMMMMMMMMMMMMMMMMMMMMNho:` ', + ' `-/symNNMMMMMMMMMMMNNNdyo/.` ', + ' .----::::---.` '] diff --git a/diagnostics.py b/diagnostics.py new file mode 100644 index 00000000..f627ab9c --- /dev/null +++ b/diagnostics.py @@ -0,0 +1,89 @@ +# diagnostics.py +# useful diagnostic tests on images +# +# Copyright (C) 2018 Katie Bouman +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np + + +def sumdown_lin(y, n=16): + """Sum segments of a line together to reduce its size + """ + nold = y.shape[0] + nnew = n + + xold = np.linspace(0, 1, num=nold, endpoint=False) + xnew = np.linspace(0, 1, num=nnew, endpoint=False) + + csum = np.zeros(nnew+1) + for i, x in enumerate(xnew): + u = np.argmax(xold > x) + k = u - 1 + csum[i] = (np.sum(y[0:k]) + + y[k] * (x - xold[k]) / (xold[u] - xold[k])) + csum[nnew] = sum(y) + + return np.diff(csum) + + +def sumdown_img(img, n=16): + """Summing patches of an image together to reduce its size + + For simplicity, we just pad each side of an image to an + integer multiple of n and then down sample but summing up the + intensity of the subcells. This is not perfect but is good enough + for a proof of concept. + """ + nx = img.shape[0] + ny = img.shape[1] + mx = (nx - 1)//n + 1 + my = (ny - 1)//n + 1 + Nx = mx * n + Ny = my * n + px = (Nx - nx)//2 + py = (Ny - ny)//2 + + img = np.pad(img, ((px, Nx-nx-px), (py, Ny-ny-py)), 'constant') + return img.reshape(n, mx, n, my).sum(axis=(1, 3)) + + +def onedimize(imgs, n=16, gt=None): + """One-dimensionalize an image by sorting in terms of pixel intensity + + Args: + imgs: a python array of two-dimensional numpy arrays + n: the number of pixel in both dimensions of the output images + + Return: + oneds: a python array of one-dimensional numpy arrays + mean: the one-dimensionalized mean image + """ + imgs = [sumdown_img(img, n=n) for img in imgs] + + if gt is None: + gt = np.dstack(imgs).mean(axis=2) + else: + gt = sumdown_img(gt, n=n) + + idxs = np.argsort(-gt.reshape(n*n)) + return [img.reshape(n*n)[idxs] for img in imgs], gt.reshape(n*n)[idxs] diff --git a/ehtim/array.py b/ehtim/array.py index f64d4316..8415f22d 100644 --- a/ehtim/array.py +++ b/ehtim/array.py @@ -113,7 +113,7 @@ def obsdata(self, ra, dec, rf, bw, tint, tadv, tstart, tstop, mjd=ehc.MJD_DEFAULT, timetype='UTC', polrep='stokes', elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH, no_elevcut_space=False, - tau=ehc.TAUDEF, fix_theta_GMST=False): + tau=ehc.TAUDEF, fix_theta_GMST=False, reorder=True): """Generate u,v points and baseline uncertainties. Args: @@ -144,7 +144,7 @@ def obsdata(self, ra, dec, rf, bw, tint, tadv, tstart, tstop, elevmin=elevmin, elevmax=elevmax, no_elevcut_space=no_elevcut_space, timetype=timetype, fix_theta_GMST=fix_theta_GMST) - + uniquetimes = np.sort(np.unique(obsarr['time'])) scans = np.array([[time - 0.5 * tadv, time + 0.5 * tadv] for time in uniquetimes]) source = str(ra) + ":" + str(dec) @@ -152,7 +152,8 @@ def obsdata(self, ra, dec, rf, bw, tint, tadv, tstart, tstop, source=source, mjd=mjd, timetype=timetype, polrep=polrep, ampcal=True, phasecal=True, opacitycal=True, dcal=True, frcal=True, - scantable=scans) + scantable=scans, reorder=reorder) + return obs def make_subarray(self, sites): diff --git a/ehtim/const_def.py b/ehtim/const_def.py index 67ed1f28..4984d700 100644 --- a/ehtim/const_def.py +++ b/ehtim/const_def.py @@ -79,20 +79,20 @@ DTPOL_STOKES = [('time', 'f8'), ('tint', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('tau1', 'f8'), ('tau2', 'f8'), - ('u', 'f8'), ('v', 'f8'), + ('u', 'f8'), ('v', 'f8'), ('w', 'f8'), ('vis', 'c16'), ('qvis', 'c16'), ('uvis', 'c16'), ('vvis', 'c16'), ('sigma', 'f8'), ('qsigma', 'f8'), ('usigma', 'f8'), ('vsigma', 'f8')] DTPOL_CIRC = [('time', 'f8'), ('tint', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('tau1', 'f8'), ('tau2', 'f8'), - ('u', 'f8'), ('v', 'f8'), + ('u', 'f8'), ('v', 'f8'), ('w', 'f8'), ('rrvis', 'c16'), ('llvis', 'c16'), ('rlvis', 'c16'), ('lrvis', 'c16'), ('rrsigma', 'f8'), ('llsigma', 'f8'), ('rlsigma', 'f8'), ('lrsigma', 'f8')] DTAMP = [('time', 'f8'), ('tint', 'f8'), ('t1', 'U32'), ('t2', 'U32'), - ('u', 'f8'), ('v', 'f8'), + ('u', 'f8'), ('v', 'f8'), ('amp', 'f8'), ('sigma', 'f8')] DTBIS = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'), @@ -100,11 +100,11 @@ ('bispec', 'c16'), ('sigmab', 'f8')] DTCPHASE = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'), - ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), ('u3', 'f8'), ('v3', 'f8'), + ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), ('u3', 'f8'), ('v3', 'f8'), ('cphase', 'f8'), ('sigmacp', 'f8')] DTCAMP = [('time', 'f8'), ('t1', 'U32'), ('t2', 'U32'), ('t3', 'U32'), ('t4', 'U32'), - ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), + ('u1', 'f8'), ('v1', 'f8'), ('u2', 'f8'), ('v2', 'f8'), ('u3', 'f8'), ('v3', 'f8'), ('u4', 'f8'), ('v4', 'f8'), ('camp', 'f8'), ('sigmaca', 'f8')] @@ -132,7 +132,7 @@ # Observation fields for plotting and retrieving data FIELDS = ['time', 'time_utc', 'time_gmst', - 'tint', 'u', 'v', 'uvdist', + 'tint', 'u', 'v', 'w', 'uvdist', 't1', 't2', 'tau1', 'tau2', 'el1', 'el2', 'hr_ang1', 'hr_ang2', 'par_ang1', 'par_ang2', 'vis', 'amp', 'phase', 'snr', @@ -186,6 +186,7 @@ def show_noblock(pause=0.001): 'tint': 'Integration Time', 'u': r'$u$', 'v': r'$v$', + 'w': r'$w$', 'uvdist': r'$u-v$ Distance', 't1': 'Site 1', 't2': 'Site 2', diff --git a/ehtim/image.py b/ehtim/image.py index 3dbf1933..73ccafaa 100644 --- a/ehtim/image.py +++ b/ehtim/image.py @@ -2303,7 +2303,7 @@ def sample_uv(self, uv, polrep_obs='stokes', def observe_same_nonoise(self, obs, sgrscat=False, ttype="nfft", cache=False, fft_pad_factor=2, - zero_empty_pol=True, verbose=True): + zero_empty_pol=True, verbose=True, reorder=True): """Observe the image on the same baselines as an existing observation without noise. Args: @@ -2327,10 +2327,10 @@ def observe_same_nonoise(self, obs, sgrscat=False, ttype="nfft", if (np.abs(self.rf - obs.rf) / obs.rf > tolerance): raise Exception("Image frequency is not the same as observation frequency!") - if (ttype == 'direct' or ttype == 'fast' or ttype == 'nfft'): + if (ttype == 'direct' or ttype == 'fast' or ttype == 'nfft' or ttype == 'DFT' or ttype == 'DFT_i'): if verbose: print("Producing clean visibilities from image with " + ttype + " FT . . . ") else: - raise Exception("ttype=%s, options for ttype are 'direct', 'fast', 'nfft'" % ttype) + raise Exception("ttype=%s, options for ttype are 'direct', 'fast', 'nfft', 'DFT', 'DFT_i'" % ttype) # Copy data to be safe obsdata = copy.deepcopy(obs.data) @@ -2361,13 +2361,13 @@ def observe_same_nonoise(self, obs, sgrscat=False, ttype="nfft", source=self.source, mjd=self.mjd, polrep=obs.polrep, ampcal=True, phasecal=True, opacitycal=True, dcal=True, frcal=True, - timetype=obs.timetype, scantable=obs.scans) + timetype=obs.timetype, scantable=obs.scans, reorder=reorder) return obs_no_noise def observe_same(self, obs_in, ttype='nfft', fft_pad_factor=2, - sgrscat=False, add_th_noise=True, + sgrscat=False, add_th_noise=True, th_noise_factor=1, jones=False, inv_jones=False, opacitycal=True, ampcal=True, phasecal=True, frcal=True, dcal=True, rlgaincal=True, @@ -2379,7 +2379,7 @@ def observe_same(self, obs_in, dterm_offset=ehc.DTERMPDEF, rlratio_std=0., rlphase_std=0., sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None, - caltable_path=None, seed=False, verbose=True): + caltable_path=None, seed=False, reorder=True, verbose=True): """Observe the image on the same baselines as an existing observation object and add noise. Args: @@ -2443,7 +2443,7 @@ def observe_same(self, obs_in, obs = self.observe_same_nonoise(obs_in, sgrscat=sgrscat,ttype=ttype, cache=False, fft_pad_factor=fft_pad_factor, - zero_empty_pol=True, verbose=verbose) + zero_empty_pol=True, reorder=reorder, verbose=verbose) # Jones Matrix Corruption & Calibration if jones: @@ -2467,7 +2467,7 @@ def observe_same(self, obs_in, source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal, dcal=dcal, frcal=frcal, - timetype=obs.timetype, scantable=obs.scans) + timetype=obs.timetype, scantable=obs.scans, reorder=reorder) if inv_jones: obsdata = simobs.apply_jones_inverse(obs, @@ -2478,7 +2478,7 @@ def observe_same(self, obs_in, source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, ampcal=ampcal, phasecal=phasecal, opacitycal=True, dcal=True, frcal=True, - timetype=obs.timetype, scantable=obs.scans) + timetype=obs.timetype, scantable=obs.scans, reorder=reorder) # No Jones Matrices, Add noise the old way # NOTE There is an asymmetry here - in the old way, we don't offer the ability to @@ -2489,7 +2489,7 @@ def observe_same(self, obs_in, print('WARNING: the caltable is only saved if you apply noise with a Jones Matrix') # TODO -- clean up arguments - obsdata = simobs.add_noise(obs, add_th_noise=add_th_noise, + obsdata = simobs.add_noise(obs, add_th_noise=add_th_noise, th_noise_factor=th_noise_factor, opacitycal=opacitycal, ampcal=ampcal, phasecal=phasecal, stabilize_scan_phase=stabilize_scan_phase, stabilize_scan_amp=stabilize_scan_amp, @@ -2504,7 +2504,7 @@ def observe_same(self, obs_in, source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, ampcal=ampcal, phasecal=phasecal, opacitycal=True, dcal=True, frcal=True, - timetype=obs.timetype, scantable=obs.scans) + timetype=obs.timetype, scantable=obs.scans, reorder=reorder) return obs @@ -2513,8 +2513,8 @@ def observe(self, array, tint, tadv, tstart, tstop, bw, elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH, no_elevcut_space=False, ttype='nfft', fft_pad_factor=2, fix_theta_GMST=False, - sgrscat=False, add_th_noise=True, - jones=False, inv_jones=False, + sgrscat=False, add_th_noise=True, th_noise_factor=1, + jones=False, inv_jones=False, noise=True, opacitycal=True, ampcal=True, phasecal=True, frcal=True, dcal=True, rlgaincal=True, stabilize_scan_phase=False, stabilize_scan_amp=False, @@ -2525,7 +2525,7 @@ def observe(self, array, tint, tadv, tstart, tstop, bw, dterm_offset=ehc.DTERMPDEF, rlratio_std=0.,rlphase_std=0., sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None, - caltable_path=None, seed=False, verbose=True): + caltable_path=None, seed=False, reorder=True, verbose=True): """Generate baselines from an array object and observe the image. Args: @@ -2614,26 +2614,33 @@ def observe(self, array, tint, tadv, tstart, tstop, bw, polrep=polrep_obs, tau=tau, elevmin=elevmin, elevmax=elevmax, no_elevcut_space=no_elevcut_space, - timetype=timetype, fix_theta_GMST=fix_theta_GMST) + timetype=timetype, fix_theta_GMST=fix_theta_GMST, reorder=reorder) + # Observe on the same baselines as the empty observation and add noise - obs = self.observe_same(obs, ttype=ttype, fft_pad_factor=fft_pad_factor, - sgrscat=sgrscat, add_th_noise=add_th_noise, - jones=jones, inv_jones=inv_jones, - opacitycal=opacitycal, ampcal=ampcal, - phasecal=phasecal, dcal=dcal, - frcal=frcal, rlgaincal=rlgaincal, - stabilize_scan_phase=stabilize_scan_phase, - stabilize_scan_amp=stabilize_scan_amp, - neggains=neggains, - taup=taup, - gain_offset=gain_offset, gainp=gainp, - phase_std=phase_std, - dterm_offset=dterm_offset, - rlratio_std=rlratio_std,rlphase_std=rlphase_std, - sigmat=sigmat,phasesigmat=phasesigmat, - rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat, - caltable_path=caltable_path, seed=seed, verbose=verbose) + if noise: + obs = self.observe_same(obs, ttype=ttype, fft_pad_factor=fft_pad_factor, + sgrscat=sgrscat, add_th_noise=add_th_noise, th_noise_factor=th_noise_factor, + jones=jones, inv_jones=inv_jones, + opacitycal=opacitycal, ampcal=ampcal, + phasecal=phasecal, dcal=dcal, + frcal=frcal, rlgaincal=rlgaincal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + neggains=neggains, + taup=taup, + gain_offset=gain_offset, gainp=gainp, + phase_std=phase_std, + dterm_offset=dterm_offset, + rlratio_std=rlratio_std,rlphase_std=rlphase_std, + sigmat=sigmat,phasesigmat=phasesigmat, + rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat, + caltable_path=caltable_path, seed=seed, reorder=reorder, verbose=verbose) + else: + obs = self.observe_same_nonoise(obs, sgrscat=sgrscat, ttype=ttype, + cache=False, fft_pad_factor=fft_pad_factor, + zero_empty_pol=True, reorder=reorder, verbose=verbose) + obs.mjd = mjd diff --git a/ehtim/io/load.py b/ehtim/io/load.py index 8e38a4f5..7ed4d8b5 100644 --- a/ehtim/io/load.py +++ b/ehtim/io/load.py @@ -1436,7 +1436,7 @@ def load_obs_uvfits(filename, polrep='stokes', flipbl=False, (( times[i], tints[i], t1[i], t2[i], tau1[i], tau2[i], - u[i], v[i], + u[i], v[i], 0, rr[i], ll[i], rl[i], lr[i], rrsig[i], llsig[i], rlsig[i], lrsig[i] ), dtype=dtpol_out diff --git a/ehtim/obsdata.py b/ehtim/obsdata.py index 65ca8318..951f7e07 100644 --- a/ehtim/obsdata.py +++ b/ehtim/obsdata.py @@ -103,7 +103,7 @@ class Obsdata(object): def __init__(self, ra, dec, rf, bw, datatable, tarr, scantable=None, polrep='stokes', source=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, timetype='UTC', ampcal=True, phasecal=True, opacitycal=True, dcal=True, frcal=True, - trial_speedups=False): + trial_speedups=False, reorder=True): """A polarimetric VLBI observation of visibility amplitudes and phases (in Jy). Args: @@ -175,7 +175,11 @@ def __init__(self, ra, dec, rf, bw, datatable, tarr, scantable=None, self.reorder_tarr_sefd(reorder_baselines=False) # reorder baselines to uvfits convention - self.reorder_baselines(trial_speedups=trial_speedups) + if reorder: + self.reorder_baselines(trial_speedups=trial_speedups) # comment out for Closure Invariants + else: + self.data = np.array(sorted(self.data, key=lambda x: x['time'])) + # Get tstart, mjd and tstop times = self.unpack(['time'])['time'] @@ -615,7 +619,8 @@ def tlist(self, conj=False, t_gather=0., scan_gather=False): data, lambda x: np.searchsorted(self.scans[:, 0], x['time'])): datalist.append(np.array([obs for obs in group])) - return np.array(datalist, dtype=object) + # return np.array(datalist, dtype=object) + return datalist def split_obs(self, t_gather=0., scan_gather=False): @@ -1224,7 +1229,7 @@ def polchisq(self, im, dtype='pvis', ttype='nfft', pol_trans=True, mask=[], **kw return chisq - def recompute_uv(self): + def recompute_uv(self, with_w=False): """Recompute u,v points using observation times and metadata Args: @@ -1240,10 +1245,17 @@ def recompute_uv(self): print("Recomputing U,V Points using MJD %d \n RA %e \n DEC %e \n RF %e GHz" % (self.mjd, self.ra, self.dec, self.rf / 1.e9)) - (timesout, uout, vout) = obsh.compute_uv_coordinates(arr, site1, site2, times, - self.mjd, self.ra, self.dec, self.rf, - timetype=self.timetype, - elevmin=0, elevmax=90, no_elevcut_space=False) + if with_w: + (timesout, uout, vout, wout) = obsh.compute_uv_coordinates(arr, site1, site2, times, + self.mjd, self.ra, self.dec, self.rf, + timetype=self.timetype, + elevmin=0, elevmax=90, no_elevcut_space=False, + w_term=with_w) + else: + (timesout, uout, vout) = obsh.compute_uv_coordinates(arr, site1, site2, times, + self.mjd, self.ra, self.dec, self.rf, + timetype=self.timetype, + elevmin=0, elevmax=90, no_elevcut_space=False) if len(timesout) != len(times): raise Exception( @@ -1253,10 +1265,14 @@ def recompute_uv(self): datatable['u'] = uout datatable['v'] = vout + if with_w: + datatable['w'] = wout + arglist, argdict = self.obsdata_args() arglist[DATPOS] = np.array(datatable) out = Obsdata(*arglist, **argdict) + return out def avg_coherent(self, inttime, scan_avg=False, moving=False): @@ -2749,6 +2765,97 @@ def errfunc(p): return gparams + def ClosureInvariants(self): + """ + Calculates copolar closure invariants for visibilities assuming an n element + interferometer array using method 1. + + Nithyanandan, T., Rajaram, N., Joseph, S. 2022 “Invariants in copolar + interferometry: An Abelian gauge theory”, PHYS. REV. D 105, 043019. + https://doi.org/10.1103/PhysRevD.105.043019 + + Args: + vis (np.ndarray): visibility data sampled by the interferometer array + n (int): number of antenna as part of the interferometer array + + Returns: + ci (np.ndarray): closure invariants + """ + tlist = self.tlist() + out_ci = np.array([]) + for tdata in tlist: + num_antenna = len(np.unique(tdata['t1'])) + 1 + if num_antenna < 3: + continue + vis = tdata['vis'].reshape(1,1,-1) + ant_pairs = np.array([tdata['t1'], tdata['t2']]).T + unique_ant = pd.unique(ant_pairs.flatten()) + ant_pairs = np.array([np.where(i == unique_ant)[0][0] for i in ant_pairs.flatten()]).reshape(-1, 2) + ant_pairs = [tuple(i) for i in ant_pairs] + reverse_idx = [i for i in range(len(ant_pairs)) if ant_pairs[i][0] > ant_pairs[i][1]] + ant_pairs = [ant_pairs[i] if i not in reverse_idx else (ant_pairs[i][1], ant_pairs[i][0]) for i in range(len(ant_pairs))] + vis[:,:,reverse_idx] = np.conjugate(vis[:,:,reverse_idx]) + _, btriads = self.Triads(num_antenna, pairs=ant_pairs) + C_oa = vis[:, :, btriads[:, 0]] + C_ab = vis[:, :, btriads[:, 1]] + C_bo = np.conjugate(vis[:, :, btriads[:, 2]]) + A_oab = C_oa / np.conjugate(C_ab) * C_bo + A_oab = np.dstack((A_oab.real, A_oab.imag)) + A_max = np.nanmax(np.abs(A_oab), axis=-1, keepdims=True) + ci = A_oab / A_max + ci = ci.reshape(-1) + out_ci = np.concatenate([out_ci, ci], axis=0) + + return out_ci + + def Triads(self, n:int, pairs=None): + """ + Generates arrays of antenna and baseline indicies that form triangular + loops pivoted around the 0th antenna. Used to calculate closure invariants + whereby specific baseline correlations need to be indexed according + to those triangular loops. + Baseline array format [ant1, ant2]: + [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6] ... + [1, 2], [1, 3], [1, 4], [1, 5], [1, 6] ... + [2, 3], [2, 4], [2, 5], [2, 6] ... + [3, 4], [3, 5], [3, 6] ... + [4, 5], [4, 6] ... + [5, 6] ... + + Args: + n (int): number of antenna in the array + + Returns: + atriads (np.ndarray): antenna triangular loop indicies + btriads (np.ndarray): baseline triangular loop indicies + """ + ntriads = (n-1)*(n-2)//2 + ant1 = np.zeros(ntriads, dtype=np.uint8) + ant2 = np.arange(1, n, dtype=np.uint8).reshape(n-1, 1) + np.zeros(n-2, dtype=np.uint8).reshape(1, n-2) + ant3 = np.arange(2, n, dtype=np.uint8).reshape(1, n-2) + np.zeros(n-1, dtype=np.uint8).reshape(n-1, 1) + anti = np.where(ant3 > ant2) + ant2, ant3 = ant2[anti], ant3[anti] + atriads = np.concatenate([ant1.reshape(-1, 1), ant2.reshape(-1, 1), ant3.reshape(-1, 1)], axis=-1) + + ant_pairs_01 = list(zip(ant1, ant2)) + ant_pairs_12 = list(zip(ant2, ant3)) + ant_pairs_20 = list(zip(ant3, ant1)) + + t1 = np.arange(n, dtype=int).reshape(n, 1) + np.zeros(n, dtype=int).reshape(1, n) + t2 = np.arange(n, dtype=int).reshape(1, n) + np.zeros(n, dtype=int).reshape(n, 1) + bli = np.where(t2 > t1) + t1, t2 = t1[bli], t2[bli] + if pairs == None: + bl_pairs = list(zip(t1, t2)) + else: + bl_pairs = pairs + + bl_01 = np.asarray([bl_pairs.index(apair) for apair in ant_pairs_01]) + bl_12 = np.asarray([bl_pairs.index(apair) for apair in ant_pairs_12]) + bl_20 = np.asarray([bl_pairs.index(tuple(reversed(apair))) for apair in ant_pairs_20]) + btriads = np.concatenate([bl_01.reshape(-1, 1), bl_12.reshape(-1, 1), bl_20.reshape(-1, 1)], axis=-1) + return atriads, btriads + def bispectra(self, vtype='vis', mode='all', count='min', timetype=False, uv_min=False, snrcut=0.): """Return a recarray of the equal time bispectra. diff --git a/ehtim/observing/obs_helpers.py b/ehtim/observing/obs_helpers.py index 6dc1dbaa..893b73fa 100644 --- a/ehtim/observing/obs_helpers.py +++ b/ehtim/observing/obs_helpers.py @@ -55,7 +55,7 @@ def compute_uv_coordinates(array, site1, site2, time, mjd, ra, dec, rf, timetype='UTC', elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH, no_elevcut_space=False, - fix_theta_GMST=False): + fix_theta_GMST=False, w_term=False): """Compute u,v coordinates for an array at a given time for a source at a given ra,dec,rf """ @@ -189,6 +189,7 @@ def compute_uv_coordinates(array, site1, site2, time, mjd, ra, dec, rf, timetype # u,v coordinates u = np.dot((coord1 - coord2)/wvl, projU) # u (lambda) v = np.dot((coord1 - coord2)/wvl, projV) # v (lambda) + w = np.dot((coord1 - coord2)/wvl, sourcevec) # w (lambda) # mask out below elevation cut mask_elev_1 = elevcut(coord1, sourcevec, elevmin=elevmin, elevmax=elevmax) @@ -205,8 +206,12 @@ def compute_uv_coordinates(array, site1, site2, time, mjd, ra, dec, rf, timetype time = time[mask] u = u[mask] v = v[mask] + w = w[mask] - # return times and uv points where we have data + if w_term: + return (time, u, v, w) + + # return times and uv points where we have data return (time, u, v) diff --git a/ehtim/observing/obs_simulate.py b/ehtim/observing/obs_simulate.py index 1d3d2e6a..99028e08 100644 --- a/ehtim/observing/obs_simulate.py +++ b/ehtim/observing/obs_simulate.py @@ -148,9 +148,10 @@ def make_uvpoints(array, ra, dec, rf, bw, tint, tadv, tstart, tstop, ra, dec, rf, timetype=timetype, elevmin=elevmin, elevmax=elevmax, no_elevcut_space=no_elevcut_space, - fix_theta_GMST=fix_theta_GMST) + fix_theta_GMST=fix_theta_GMST, + w_term=True) - (timesout, uout, vout) = uvdat + (timesout, uout, vout, wout) = uvdat for k in range(len(timesout)): outlist.append(np.array(( timesout[k], @@ -161,6 +162,7 @@ def make_uvpoints(array, ra, dec, rf, bw, tint, tadv, tstart, tstop, tau2, # Station 1 zenith optical depth uout[k], # u (lambda) vout[k], # v (lambda) + wout[k], # w (lambda) 0.0, # 1st Visibility (Jy) 0.0, # 2nd Visibility 0.0, # 3rd Visibility @@ -345,6 +347,38 @@ def sample_vis(im_org, uv, sgrscat=False, polrep_obs='stokes', obsdata.append(vis) + elif ttype == "DFT": + xfov, yfov = im.xdim*im.psize/4.84813681109536e-12, im.ydim*im.psize/4.84813681109536e-12 + for i in range(4): + pol = pollist[i] + imvec = im._imdict[pol] + if imvec is None or len(imvec) == 0: + if zero_empty_pol: + obsdata.append(np.zeros(len(uv))) + else: + obsdata.append(None) + else: + imarr = imvec.reshape(im.ydim, im.xdim) + vis = DFT(imarr, uv, xfov=xfov, yfov=yfov) + obsdata.append(vis) + + elif ttype == "DFT_i": + xfov, yfov = im.xdim*im.psize/4.84813681109536e-12, im.ydim*im.psize/4.84813681109536e-12 + for i in range(4): + pol = pollist[i] + imvec = im._imdict[pol] + if imvec is None or len(imvec) == 0: + if zero_empty_pol: + obsdata.append(np.zeros(len(uv))) + else: + obsdata.append(None) + else: + uv = np.array([uv[:,1], uv[:,0]]).T # uv swap hack + imarr = imvec.reshape(im.ydim, im.xdim) + vis = DFT(imarr, uv, xfov=xfov, yfov=yfov) + obsdata.append(vis) + + # Get visibilities from DTFT else: # Construct Fourier matrix @@ -375,6 +409,35 @@ def sample_vis(im_org, uv, sgrscat=False, polrep_obs='stokes', return obsdata + +def DFT(data, uv, xfov=225, yfov=225): + if data.ndim == 2: + data = data[np.newaxis,...] + out_shape = (uv.shape[0],) + elif data.ndim > 2: + data = data.reshape((-1,) + data.shape[-2:]) + out_shape = data.shape[:-2] + (uv.shape[0],) + ny, nx = data.shape[-2:] + dx = xfov*4.84813681109536e-12 / nx + dy = yfov*4.84813681109536e-12 / ny + angx = (np.arange(nx) - nx//2) * dx + angy = (np.arange(ny) - ny//2) * dy + lvect = np.sin(angx) + mvect = np.sin(angy) + l, m = np.meshgrid(lvect, mvect) + lm = np.concatenate([l.reshape(1,-1), m.reshape(1,-1)], axis=0) + imgvect = data.reshape((data.shape[0],-1)) + x = -2*np.pi*np.dot(uv,lm)[np.newaxis, ...] + visr = np.sum(imgvect[:, np.newaxis, :] * np.cos(x, dtype=np.float32), axis=-1) + visi = np.sum(imgvect[:, np.newaxis, :] * np.sin(x, dtype=np.float32), axis=-1) + if data.ndim == 2: + vis = visr.ravel() + 1j*visi.ravel() + else: + vis = visr.ravel() + 1j*visi.ravel() + vis = vis.reshape(out_shape) + return vis + + ################################################################################################## # Noise + miscalibration funcitons ################################################################################################## @@ -1199,7 +1262,7 @@ def apply_jones_inverse(obs, opacitycal=True, dcal=True, frcal=True, verbose=Tru # The old noise generating function. -def add_noise(obs, add_th_noise=True, opacitycal=True, ampcal=True, phasecal=True, +def add_noise(obs, add_th_noise=True, th_noise_factor=1, opacitycal=True, ampcal=True, phasecal=True, stabilize_scan_amp=False, stabilize_scan_phase=False, neggains=False, taup=ehc.GAINPDEF, gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF, @@ -1297,7 +1360,7 @@ def add_noise(obs, add_th_noise=True, opacitycal=True, ampcal=True, phasecal=Tru sig_rr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdr'], obs.tarr[obs.tkey[sites[i][1]]]['sefdr'], tint[i], bw) for i in range(len(tint))), float) - sig_ll = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdr'], + sig_ll = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdl'], obs.tarr[obs.tkey[sites[i][1]]]['sefdl'], tint[i], bw) for i in range(len(tint))), float) sig_rl = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdr'], @@ -1401,10 +1464,10 @@ def add_noise(obs, add_th_noise=True, opacitycal=True, ampcal=True, phasecal=Tru sigma_est4 = sigma_perf4 * gain_true * tau_est if add_th_noise: - vis1 = (vis1 + obsh.cerror(sigma_true1)) - vis2 = (vis2 + obsh.cerror(sigma_true2)) - vis3 = (vis3 + obsh.cerror(sigma_true3)) - vis4 = (vis4 + obsh.cerror(sigma_true4)) + vis1 = (vis1 + th_noise_factor*obsh.cerror(sigma_true1)) + vis2 = (vis2 + th_noise_factor*obsh.cerror(sigma_true2)) + vis3 = (vis3 + th_noise_factor*obsh.cerror(sigma_true3)) + vis4 = (vis4 + th_noise_factor*obsh.cerror(sigma_true4)) # Add the gain error to the true visibilities vis1 = vis1 * gain_true * tau_est / tau_true diff --git a/features/__init__.py b/features/__init__.py new file mode 100644 index 00000000..d7cc51cb --- /dev/null +++ b/features/__init__.py @@ -0,0 +1,11 @@ +""" +.. module:: ehtim.features + :platform: Unix + :synopsis: EHT Imaging Utilities: feature extraction functions + +.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu) + +""" +from . import rex + +from ..const_def import * diff --git a/features/rex.py b/features/rex.py new file mode 100644 index 00000000..62fafd31 --- /dev/null +++ b/features/rex.py @@ -0,0 +1,1057 @@ +# rex.py +# ring fitting code for ehtim +# +# Copyright (C) 2019 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import os +import glob +import matplotlib.pyplot as plt +import numpy as np +import astropy.io.fits as fits +import subprocess + +import scipy.interpolate +import scipy.optimize +import scipy.stats +from astropy.stats import median_absolute_deviation + +from ehtim.image import load_image +import ehtim.imaging.dynamical_imaging as di +import ehtim.parloop as ploop +import ehtim.const_def as ehc + +################################################################################################### +# Parameters +################################################################################################### + +EP = 1.e-16 +BIG = 1./EP + +IMSIZE = 160*ehc.RADPERUAS # 250*ehc.RADPERUAS # FOV of resampled image (muas) +NPIX = 160 # 128 # pixels in resampled image + +NRAYS = 360 # number of angular rays in final profile +NRS = 100 # number of radial points in final profile + +RMAX = 50 # maximum radius in every profile slice (muas) +RMIN = 5 # radius threshold for averaging inside ring (muas) + +RPRIOR_MIN = 15. # 5. # minimum radius for search (muas) +RPRIOR_MAX = 50. # 60. # maximum radius for search (muas) +NRAYS_SEARCH = 25 # number of angular rays in search profiles +NRS_SEARCH = 50 # number of radial points in search profiles +THRESH = 0.05 # thresholding level for the images in the search +# BLUR_VALUE_MIN=2 # blur to this value for initial centroid search (uas) +FOVP_SEARCH = 0.1 # fractional FOV around image center for brute force search +NSEARCH = 10 # number of points in each dimension for brute force search +NORMFLUX = 1 # normalized image flux for outputted profiles (Jy) + +POSTPROCDIR = '.' # default postprocessing directory + +################################################################################################### +# Profiles class +################################################################################################### + + +class Profiles(object): + + def __init__(self, im, x0, y0, profs, thetas, rmin=RMIN, rmax=RMAX, + interptype='cubic',flux_norm=NORMFLUX, + profsQ=[], profsU=[]): + + self.x0 = x0 + self.y0 = y0 + self.im = im + self.rmin = rmin + self.rmax = rmax + + # store the center image + deltay = -(im.fovy()/2. - y0*ehc.RADPERUAS)/im.psize + deltax = (im.fovx()/2. - x0*ehc.RADPERUAS)/im.psize + self.im_center = im.shift([int(np.round(deltay)), int(np.round(deltax))]) + + # total flux and normalization + self.flux = im.total_flux() + self.parea = (im.psize/ehc.RADPERUAS)**2 + + # factor to convert to normalized brightness temperature (total flux of 1 Jy) + self.flux_norm = flux_norm + self.normfactor = self.flux_norm / im.total_flux() + + # image array and profiles + factor = 3.254e13/(im.rf**2 * im.psize**2) # factor to convert to brightness temperature + self.imarr = im.imvec.reshape(im.ydim, im.xdim)[::-1] * factor # in Tb + + self.xs = np.arange(im.xdim)*im.psize/ehc.RADPERUAS + self.ys = np.arange(im.ydim)*im.psize/ehc.RADPERUAS + self.interp = scipy.interpolate.interp2d(self.ys, self.xs, self.imarr, kind=interptype) + + self.profiles = np.array(profs) + self.profilesQ = np.array(profsQ) + self.profilesU = np.array(profsU) + self.profilesP = np.sqrt(self.profilesQ**2 + self.profilesU**2) + self.thetas = np.array(thetas) + self.nang = len(thetas) + self.nrs = len(self.profiles[0]) + self.nthetas = len(self.thetas) + self.rs = np.linspace(0, self.rmax, self.nrs) + self.dr = self.rs[-1] - self.rs[-2] + self.pks = [] + self.pk_vals = [] + self.diameters = [] + + for prof in self.profiles: + pk, vpk = self.calc_pkrad_from_prof(prof) + + self.pks.append(pk) + self.pk_vals.append(vpk) + self.diameters.append(2*np.abs(pk)) + + self.pks = np.array(self.pks) + self.pk_vals = np.array(self.pk_vals) + self.diameters = np.array(self.diameters) + + # ring size + self.RingSize1 = (np.mean(self.diameters), np.std(self.diameters)) + self.RingSize1_med = (np.median(self.diameters), median_absolute_deviation(self.diameters)) + + def calc_pkrad_from_prof(self, prof): + """calculate peak radius and value with linear interpolation""" + args = np.argsort(prof) + pkpos = args[-1] + pk = self.rs[pkpos] + vpk = prof[pkpos] + if pkpos > 0 and pkpos < self.nrs-1: + vals = [prof[pkpos-1], prof[pkpos], prof[pkpos+1]] + pk, vpk = quad_interp_radius(pk, self.dr, vals) + return (pk, vpk) + + def calc_meanprof_and_stats(self): + + # calculate mean profile + self.meanprof = np.mean(self.profiles, axis=0) + args = np.argsort(self.meanprof) + + self.pkloc = args[-1] + self.pkrad = self.rs[self.pkloc] + self.meanpk = self.meanprof[self.pkloc] + + # absolute peak in angle and radius + profile_peak_loc = np.unravel_index(np.argmax(self.profiles), self.profiles.shape) + self.abspk_loc_rad = profile_peak_loc[1] + self.abspk_rad = self.rs[self.abspk_loc_rad] + self.abspk_loc_ang = profile_peak_loc[0] + self.abspk_ang = self.thetas[self.abspk_loc_ang] + + # find inside mean flux + inner_loc = np.argmin((self.rs-self.rmin)**2) + self.in_level = np.mean(self.meanprof[0:inner_loc]) # profile avg inside ring + + # find outside mean flux + outer_loc = np.argmin((self.rs-(self.rmax-self.rmin))**2) + self.out_level = np.mean(self.meanprof[outer_loc:]) # profile avg outside ring + + # find mean profile FWHM with spline interpolation + meanprof_zeroed = self.meanprof - self.out_level + (lh_meanprof, rh_meanprof) = self.calc_width(meanprof_zeroed) + lhloc_meanprof = np.argmin((self.rs-lh_meanprof)**2) + rhloc_meanprof = np.argmin((self.rs-rh_meanprof)**2) + + self.lh = lh_meanprof + self.rh = rh_meanprof + self.lhloc = lhloc_meanprof + self.rhloc = rhloc_meanprof + + # ring diameter and width from the mean profile + meanprof_diameter = 2*self.pkrad + meanprof_width = np.abs(rh_meanprof - lh_meanprof) + self.RingSize2 = (meanprof_diameter, meanprof_width) + + # find ring width with all angular profiles + ringwidths = [] + for i in range(self.nang): + rprof = self.profiles[i] + # TODO zero min profile before taking width??? + # rprof_zeroed = rprof - np.max((np.min(rprof), 0)) + (lh, rh) = self.calc_width(rprof) + width = rh-lh + if width <= 0 or width >= 2*meanprof_width: + continue # AC ?? ok to exclude huge widths??? + ringwidths.append(width) + + self.RingWidth = (np.mean(ringwidths), np.std(ringwidths)) + + # ring angle 1: mean and std deviation of individual profiles + ringangles = [] + ringasyms = [] + for i in range(self.lhloc, self.rhloc+1): + angprof = self.profiles.T[i] + if i == self.lhloc: + prof_mean_r = angprof.reshape(1, len(self.profiles.T[i])) + else: + prof_mean_r = np.vstack((prof_mean_r, angprof)) + + angle_asym = self.calc_ringangle_asymmetry(angprof) + ringangles.append(angle_asym[0]) + ringasyms.append(angle_asym[1]) + + self.RingAngle1 = (scipy.stats.circmean(ringangles), scipy.stats.circstd(ringangles)) + + # ring angle 2: ring angle function on avg profile + prof_mean_r = np.mean(np.array(prof_mean_r), axis=0) + self.meanprof_theta = prof_mean_r + ringangle2 = self.calc_ringangle_asymmetry(prof_mean_r) + self.RingAngle2 = (ringangle2[0], ringangle2[-1]) + + # contrast 1: maximum profile value / mean of inner region + # self.RingContrast1 = np.max(self.profiles.T[self.pkloc]) / self.in_level + # self.RingContrast1 = np.max(self.profiles) / self.in_level + self.RingContrast1 = np.max(self.profiles[:, self.lhloc:self.rhloc+1]) / self.in_level + + # contrast 1: mean profile max value / mean of inner region + self.RingContrast2 = self.meanpk / self.in_level + + # asymmetry 1: m1 mode of angular profile + self.RingAsym1 = (np.mean(ringasyms), np.std(ringasyms)) + + # asymmetry 2: integrated flux in bottom half of ring vs top half of ring + mask_inner = self.im.copy() + mask_outer = self.im.copy() + immask = self.im.copy() + + x0_c = self.im.fovx()/2. - self.x0*ehc.RADPERUAS + y0_c = self.y0*ehc.RADPERUAS - self.im.fovy()/2. + + # mask annulus + rad_inner = (self.RingSize1[0]/2. - self.RingWidth[0]/2.)*ehc.RADPERUAS + rad_outer = (self.RingSize1[0]/2. + self.RingWidth[0]/2.)*ehc.RADPERUAS + + mask_inner.imvec *= 0 + mask_outer.imvec *= 0 + mask_inner = mask_inner.add_gauss(1, [2*rad_inner, 2*rad_inner, 0, x0_c, y0_c]) + mask_inner = mask_inner.mask(cutoff=.5) + mask_outer = mask_outer.add_gauss(1, [2*rad_outer, 2*rad_outer, 0, x0_c, y0_c]) + mask_outer = mask_outer.mask(cutoff=.5) + + maskvec_annulus = np.logical_xor( + mask_inner.imvec.astype(bool), mask_outer.imvec.astype(bool)) + + # mask angle + xlist = np.arange(0, -self.im.xdim, -1)*self.im.psize + \ + (self.im.psize*self.im.xdim)/2.0 - self.im.psize/2.0 + ylist = np.arange(0, -self.im.ydim, -1)*self.im.psize + \ + (self.im.psize*self.im.ydim)/2.0 - self.im.psize/2.0 + + cangle = self.RingAngle1[0] + + def anglemask(x, y): + ang = np.mod(-np.arctan2(y-y0_c, x-x0_c)+np.pi/2., 2*np.pi) + # return ang + if np.mod(np.abs(ang-cangle), 2*np.pi) > 0.5*np.pi: + return False + else: + return True + + maskvec_ang = np.array([[anglemask(i, j) for i in xlist] + for j in ylist]).flatten().astype(bool) + + # combine masks and get the bright and dim flux + maskvec_brighthalf = maskvec_annulus * maskvec_ang + brightflux = np.sum(immask.imvec[(maskvec_brighthalf)]) + + maskvec_dimhalf = maskvec_annulus * ~maskvec_ang + dimflux = np.sum(immask.imvec[(maskvec_dimhalf)]) + self.RingFlux = brightflux + dimflux + self.RingAsym2 = ((brightflux-dimflux)/(brightflux+dimflux), brightflux/dimflux) + + # Polarization brightness ratio + # AC TODO FOR PAPER VIII ANALYSIS + self.RingAsymPol = (0., 0.) + if len(self.profilesP)>0 and len(self.im.qvec) > 0 and len(self.im.uvec) > 0: + pvec = np.sqrt(self.im.qvec**2 + self.im.uvec**2) + pvec_C = (self.im.qvec + 1j*self.im.uvec) + + ringanglesPol = [] + # ringasymsPol = [] + for i in range(self.lhloc, self.rhloc+1): + angprof = self.profilesP.T[i] + # simple maximum AC TODO + ringanglesPol.append(self.thetas[np.argmax(angprof)]) + + # weighted avg + # angle_asym = self.calc_ringangle_asymmetry(angprof) + # ringanglesPol.append(angle_asym[0]) + # ringasymsPol.append(angle_asym[1]) + + self.RingAnglePol = (scipy.stats.circmean(ringanglesPol), + scipy.stats.circstd(ringanglesPol)) + + cangle = self.RingAnglePol[0] + + def anglemask_pol(x, y): + ang = np.mod(-np.arctan2(y-y0_c, x-x0_c)+np.pi/2., 2*np.pi) + # return ang + if np.mod(np.abs(ang-cangle), 2*np.pi) > 0.5*np.pi: + return False + else: + return True + + maskvec_ang = np.array([[anglemask_pol(i, j) for i in xlist] + for j in ylist]).flatten().astype(bool) + + # combine masks and get the bright and dim pol flux + maskvec_brighthalf = maskvec_annulus * maskvec_ang + maskvec_dimhalf = maskvec_annulus * ~maskvec_ang + # maskvec_brighthalf = maskvec_ang + # maskvec_dimhalf = ~maskvec_ang + + # calculate polarized asymmetry / birghtness ratio + brightflux_pol_C = np.abs(np.sum(pvec_C[(maskvec_brighthalf)])) + dimflux_pol_C = np.abs(np.sum(pvec_C[(maskvec_dimhalf)])) + brightflux_pol = np.sum(pvec[(maskvec_brighthalf)]) + dimflux_pol = np.sum(pvec[(maskvec_dimhalf)]) + self.RingAsymPol = ((brightflux_pol_C/dimflux_pol_C), + brightflux_pol/dimflux_pol) + + # calculate dynamic range + mask = self.im.copy() + immask = self.im.copy() + + x0_c = mask.fovx()/2. - self.x0*ehc.RADPERUAS + y0_c = self.y0*ehc.RADPERUAS - mask.fovy()/2. + rad = self.RingSize1[0]*ehc.RADPERUAS + + mask.imvec *= 0 + mask = mask.add_gauss(1, [2*rad, 2*rad, 0, x0_c, y0_c]) + mask = mask.mask(cutoff=.5) + maskvec = mask.imvec.astype(bool) + (immask.imvec < EP*self.flux) + offsource_vec = immask.imvec[~(maskvec)] + + self.impeak = np.max(self.im.imvec) + self.std_offsource = np.std(offsource_vec) + EP + self.mean_offsource = np.mean(offsource_vec) + EP + self.dynamic_range = self.impeak / self.std_offsource + + def calc_width(self, prof): + pkrad, maxval = self.calc_pkrad_from_prof(prof) + spline = scipy.interpolate.UnivariateSpline(self.rs, prof-0.5*maxval, s=0) + roots = spline.roots() # find the roots + + if len(roots) == 0: + return(self.rs[0], self.rs[-1]) + + lh = self.rs[0] + rh = self.rs[-1] + for root in np.sort(roots): + if root < pkrad: + lh = root + else: + rh = root + break + + return (lh, rh) + + def calc_ringangle_asymmetry(self, prof): + dtheta = self.thetas[-1]-self.thetas[-2] + prof = prof / np.sum(prof*dtheta) # normalize + x = np.sum(prof * np.exp(1j*self.thetas) * dtheta) + ang = np.mod(np.angle(x), 2*np.pi) + asym = np.abs(x) + std = np.sqrt(-2*np.log(np.abs(x))) + return (ang, asym, std) + + def plot_img(self, postprocdir=POSTPROCDIR, save_png=False): + plt.figure() + + fovx = self.im.fovx()/ehc.RADPERUAS + fovy = self.im.fovy()/ehc.RADPERUAS + xsplot = self.xs - fovx/2. + ysplot = self.ys - fovy/2. + x0plot = self.x0 - fovx/2. + y0plot = self.y0 - fovy/2. + + plt.pcolormesh(xsplot,ysplot,self.imarr,cmap='afmhot') + plt.contour(xsplot,ysplot, self.imarr, colors='k',levels=5) + + plt.xlabel(r"-RA ($\mu$as)") + plt.ylabel(r"Dec ($\mu$as)") + + plt.plot(x0plot,y0plot, 'r*', markersize=20) + + thetas = np.linspace(0, 2*np.pi, 100) + plt.plot(x0plot+ np.cos(thetas) * self.RingSize1[0]/2, + y0plot + np.sin(thetas) * self.RingSize1[0]/2, + 'r-', markersize=1) + + #plt.axes().set_aspect('equal', 'datalim') + + if save_png: + dirname = os.path.basename(os.path.dirname(self.imname)) + if not os.path.exists(postprocdir + '/' + dirname): + subprocess.call(['mkdir', postprocdir + '/' + dirname]) + + basename = os.path.basename(self.imname) + fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_contour.png' + + plt.savefig(fname) + plt.close() + else: + #plt.show(block=False) + ehc.show_noblock() + + def plot_unwrapped(self, postprocdir=POSTPROCDIR, save_png=False, + xlabel=True, ylabel=True, xticklabel=True, yticklabel=True, + ax=False, imrange=[], show=True, cfun='jet', linecolor='r', labelsize=14): + + # line colors + angcolor = np.array([100, 149, 237])/256. + pkcolor = np.array([219, 0., 219])/256. + # pkcolor = np.array([0,255,0])/256. + + imarr = np.array(self.profiles).T/1.e9 + if ax is False: + plt.figure() + ax = plt.gca() + + if imrange: + plt.imshow(imarr, cmap=plt.get_cmap(cfun), origin='lower', + vmin=imrange[0], vmax=imrange[1], interpolation='gaussian') + else: + plt.imshow(imarr, cmap=plt.get_cmap(cfun), origin='lower', interpolation='gaussian') + + uas_to_pix = self.nrs/np.max(self.rs) # convert radius to pixels + rad_to_pix = self.nang/(2*np.pi) # convert az. angle to pixels + + # horizontal lines -- radius + pkloc = self.RingSize1[0]/2. * uas_to_pix + lhloc = (self.RingSize1[0] - self.RingSize1[1])/2. * uas_to_pix + rhloc = (self.RingSize1[0] + self.RingSize1[1]) / 2. * uas_to_pix + + plt.axhline(y=pkloc, color=linecolor, linewidth=1) + plt.axhline(y=lhloc, color=linecolor, linewidth=1, linestyle=':') + plt.axhline(y=rhloc, color=linecolor, linewidth=1, linestyle=':') + + # horizontal lines -- width + # add radius and half width sigma in quadrature + bandloc_sigma = np.sqrt((self.RingWidth[1]/2)**2 + (self.RingSize1[1]/2)**2) + + rhloc = (self.RingSize1[0]/2. + self.RingWidth[0]/2.) * uas_to_pix + rhloc2 = (self.RingSize1[0]/2. + self.RingWidth[0]/2. + bandloc_sigma) * uas_to_pix + rhloc3 = (self.RingSize1[0]/2. + self.RingWidth[0]/2. - bandloc_sigma) * uas_to_pix + + lhloc = (self.RingSize1[0]/2. - self.RingWidth[0]/2.) * uas_to_pix + lhloc2 = (self.RingSize1[0]/2. - self.RingWidth[0]/2. + bandloc_sigma) * uas_to_pix + lhloc3 = (self.RingSize1[0]/2. - self.RingWidth[0]/2. - bandloc_sigma) * uas_to_pix + + plt.axhline(y=lhloc, color=linecolor, linewidth=1, linestyle='--') + plt.axhline(y=lhloc2, color=linecolor, linewidth=1, linestyle=':') + plt.axhline(y=lhloc3, color=linecolor, linewidth=1, linestyle=':') + plt.axhline(y=rhloc, color=linecolor, linewidth=1, linestyle='--') + plt.axhline(y=rhloc2, color=linecolor, linewidth=1, linestyle=':') + plt.axhline(y=rhloc3, color=linecolor, linewidth=1, linestyle=':') + + # position angle line + pkloc = self.RingAngle1[0] * rad_to_pix + lhloc = (self.RingAngle1[0] + self.RingAngle1[1]) * rad_to_pix + rhloc = (self.RingAngle1[0] - self.RingAngle1[1]) * rad_to_pix + + plt.axvline(x=pkloc, color=angcolor, linewidth=1) + plt.axvline(x=lhloc, color=angcolor, linewidth=1, linestyle=':') + plt.axvline(x=rhloc, color=angcolor, linewidth=1, linestyle=':') + + # bright peak point + plt.plot([self.abspk_loc_ang], [self.abspk_loc_rad], marker='x', mew=2, ms=6, color=pkcolor) + + # labels + if xlabel: + plt.xlabel(r"$\theta $ ($^\circ$)", size=labelsize) + if ylabel: + plt.ylabel(r"$r$ ($\mu$as)", size=labelsize) + + xticklabels = np.arange(0, 360, 60) + xticks = (360/imarr.shape[1])*xticklabels + + yticks = np.floor(np.arange(0, imarr.shape[0], imarr.shape[0]/5)).astype(int) + yticklabels = ["%0.0f" % r for r in self.rs[yticks]] + + if not xticklabel: + xticklabels = [] + if not yticklabel: + yticklabels = [] + + plt.xticks(xticks, xticklabels) + plt.yticks(yticks, yticklabels) + plt.tick_params(axis='both', which='major', length=6) + + if save_png: + dirname = os.path.basename(os.path.dirname(self.imname)) + if not os.path.exists(postprocdir + '/' + dirname): + subprocess.call(['mkdir', postprocdir + '/' + dirname]) + + basename = os.path.basename(self.imname) + fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_unwrapped.png' + plt.savefig(fname) + plt.close() + elif show: + #plt.show(block=False) + ehc.show_noblock() + + def save_unwrapped(self, fname): + + imarr = np.array(self.profiles).T + + header = fits.Header() + header['CTYPE1'] = 'RA---SIN' + header['CTYPE2'] = 'DEC--SIN' + header['CDELT1'] = 2*np.pi/float(len(self.profiles)) + header['CDELT2'] = np.max(self.rs)/float(len(self.rs)) + header['BUNIT'] = 'K' + hdu = fits.PrimaryHDU(imarr, header=header) + hdulist = [hdu] + hdulist = fits.HDUList(hdulist) + hdulist.writeto(fname, overwrite=True) + + def plot_profs(self, postprocdir=POSTPROCDIR, save_png=False, colors=ehc.SCOLORS): + plt.figure() + plt.xlabel(r"distance from center ($\mu$as)") + plt.ylabel(r"$T_{\rm b}$") + #plt.ylim([0, 1]) + plt.xlim([-10, 60]) + plt.title('All Profiles') + for j in range(len(self.profiles)): + plt.plot(self.rs, self.profiles[j], color=colors[j % len(colors)], + linestyle='-', linewidth=1) + if save_png: + dirname = os.path.basename(os.path.dirname(self.imname)) + if not os.path.exists(postprocdir + '/' + dirname): + subprocess.call(['mkdir', postprocdir + '/' + dirname]) + + basename = os.path.basename(self.imname) + fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_profiles.png' + + plt.savefig(fname) + plt.close() + else: + #plt.show(block=False) + ehc.show_noblock() + + def plot_prof_band(self, postprocdir=POSTPROCDIR, save_png=False, + color='b', fontsize=14, show=True, axis=None, xlabel=True, ylabel=False): + """2-sided plot of radial profiles, cut across orthogonal to position angle""" + if axis is None: + plt.figure() + ax = plt.gca() + else: + ax = axis + + if xlabel: + plt.xlabel(r"$r$ ($\mu$as)", size=fontsize) + + yticks = [0, 2, 4, 6, 8, 10] + yticklabels = [] + if ylabel: + plt.ylabel(r'Brightness Temperature ($10^9$ K)', size=fontsize) + yticklabels = yticks + + plt.yticks(yticks, yticklabels) + + #plt.ylim([0, 11]) + plt.xlim([-55, 55]) + + # cut the ring in half orthagonal to the position angle + cutloc1 = np.argmin(np.abs(self.thetas-np.mod(self.RingAngle1[0] - np.pi/2., 2*np.pi))) + cutloc2 = np.argmin(np.abs(self.thetas-np.mod(self.RingAngle1[0] + np.pi/2., 2*np.pi))) + + if cutloc1 < cutloc2: + prof_half_1 = self.profiles[cutloc1:cutloc2+1] + prof_half_2 = np.vstack((self.profiles[cutloc2+1:], self.profiles[0:cutloc1])) + else: + prof_half_1 = np.vstack((self.profiles[cutloc1:], self.profiles[0:cutloc2+1])) + prof_half_2 = self.profiles[cutloc2+1:cutloc1] + + # plot left half + radii = -np.flip(self.rs) + tho_m = np.flip(np.median(np.array(prof_half_1), axis=0)) + tho_l = np.flip(np.percentile(np.array(prof_half_1), 0, axis=0)) + tho_u = np.flip(np.percentile(np.array(prof_half_1), 100, axis=0)) + tho_l1 = np.flip(np.percentile(np.array(prof_half_1), 25, axis=0)) + tho_u1 = np.flip(np.percentile(np.array(prof_half_1), 75, axis=0)) + + ax.plot(radii, tho_m/1.e9, ls='-', linewidth=2, color=color) + ax.fill_between(radii, tho_l/1.e9, tho_u/1.e9, alpha=.2, edgecolor=None, facecolor=color) + ax.fill_between(radii, tho_l1/1.e9, tho_u1/1.e9, alpha=.4, edgecolor=None, facecolor=color) + + # plot rights half + radii = self.rs + tho_m = np.median(np.array(prof_half_2), axis=0) + tho_l = np.percentile(np.array(prof_half_2), 0, axis=0) + tho_u = np.percentile(np.array(prof_half_2), 100, axis=0) + tho_l1 = np.percentile(np.array(prof_half_2), 25, axis=0) + tho_u1 = np.percentile(np.array(prof_half_2), 75, axis=0) + + ax.plot(radii, tho_m/1.e9, ls='-', linewidth=2, color=color) + ax.fill_between(radii, tho_l/1.e9, tho_u/1.e9, alpha=.2, edgecolor=None, facecolor=color) + ax.fill_between(radii, tho_l1/1.e9, tho_u1/1.e9, alpha=.4, edgecolor=None, facecolor=color) + + if save_png: + dirname = os.path.basename(os.path.dirname(self.imname)) + if not os.path.exists(postprocdir + '/' + dirname): + subprocess.call(['mkdir', postprocdir + '/' + dirname]) + + basename = os.path.basename(self.imname) + fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_band_profile.png' + + plt.savefig(fname) + plt.close() + if show: + #plt.show(block=False) + ehc.show_noblock() + + def plot_meanprof(self, postprocdir=POSTPROCDIR, save_png=False, color='k'): + fig = plt.figure() + plt.plot(self.rs, self.meanprof, + color=color, linestyle='-', linewidth=1) + plt.plot((self.lh, self.rh), (0.5*self.meanpk, 0.5*self.meanpk), + color=color, linestyle='--', linewidth=1) + plt.xlabel(r"distance from center ($\mu$as)") + plt.ylabel(r"Flux (mJy/$\mu$as$^2$)") + #plt.ylim([0, 1]) + plt.xlim([-10, 60]) + plt.title('Mean Profile') + if save_png: + dirname = os.path.basename(os.path.dirname(self.imname)) + if not os.path.exists(postprocdir + '/' + dirname): + subprocess.call(['mkdir', postprocdir + '/' + dirname]) + + basename = os.path.basename(self.imname) + fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_meanprofile.png' + + plt.savefig(fname) + plt.close() + else: + #plt.show(block=False) + ehc.show_noblock() + + def plot_meanprof_theta(self, postprocdir=POSTPROCDIR, save_png=False, color='k'): + fig = plt.figure() + plt.plot(self.thetas/ehc.DEGREE, self.meanprof_theta, + color=color, linestyle='-', linewidth=1) + + ang1 = self.RingAngle1[0]/ehc.DEGREE + std1 = self.RingAngle1[1]/ehc.DEGREE + up = np.mod(ang1+std1, 360) + down = np.mod(ang1-std1, 360) + plt.axvline(x=ang1, color='b', linewidth=1) + plt.axvline(x=up, color='b', linewidth=1, linestyle='--') + plt.axvline(x=down, color='b', linewidth=1, linestyle='--') + + ang2 = self.RingAngle2[0]/ehc.DEGREE + std2 = self.RingAngle2[1]/ehc.DEGREE + up = np.mod(ang2+std2, 360) + down = np.mod(ang2-std2, 360) + plt.axvline(x=ang2, color='r', linewidth=1) + plt.axvline(x=up, color='r', linewidth=1, linestyle='--') + plt.axvline(x=down, color='r', linewidth=1, linestyle='--') + + plt.xlabel(r"Angle E of N ($^{\circ}$)") + plt.ylabel("Normalized Flux") + plt.title('Mean Angular Profile') + if save_png: + dirname = os.path.basename(os.path.dirname(self.imname)) + if not os.path.exists(postprocdir + '/' + dirname): + subprocess.call(['mkdir', postprocdir + '/' + dirname]) + + basename = os.path.basename(self.imname) + fname = postprocdir + '/' + dirname + '/' + basename[:-5] + '_meanangprofile.png' + + plt.savefig(fname) + plt.close() + else: + #plt.show(block=False) + ehc.show_noblock() + +################################################################################################### +# Other functions +################################################################################################### + + +def quad_interp_radius(r_max, dr, val_list): + v_L = val_list[0] + v_max = val_list[1] + v_R = val_list[2] + + rpk = r_max + dr*(v_L - v_R) / (2 * (v_L + v_R - 2*v_max)) + + vpk = 8*v_max*(v_L + v_R) - (v_L - v_R)**2 - 16*v_max**2 + vpk /= (8*(v_L + v_R - 2*v_max)) + + return (rpk, vpk) + + +def compute_ring_profile(im, x0, y0, title="", + nrays=NRAYS, nrs=NRS, rmin=RMIN, rmax=RMAX, + flux_norm=NORMFLUX, pol_profs=False, + interptype='cubic'): + """compute a ring profile given a center location + """ + + rs = np.linspace(0, rmax, nrs) + thetas = np.linspace(0, 2*np.pi, nrays) + + factor = 3.254e13/(im.rf**2 * im.psize**2) # convert to brightness temperature + imarr = im.imvec.reshape(im.ydim, im.xdim)[::-1] * factor # in brightness temperature K + xs = np.arange(im.xdim)*im.psize/ehc.RADPERUAS + ys = np.arange(im.ydim)*im.psize/ehc.RADPERUAS + + # TODO: test fiducial images with linear? + interp = scipy.interpolate.interp2d(ys, xs, imarr, kind=interptype) + + def ringVals(theta): + xxs = x0 - rs*np.sin(theta) + yys = y0 + rs*np.cos(theta) + + vals = [interp(xxs[i], yys[i])[0] for i in np.arange(len(rs))] + return vals + + profs = [] + for j in range(nrays): + vals = ringVals(thetas[j]) + profs.append(vals) + + # polarization profiles + profsQ = [] + profsU = [] + if len(im.qvec) > 0 and len(im.uvec > 0) and pol_profs: + qarr = im.qvec.reshape(im.ydim, im.xdim)[::-1] * factor # in brightness temperature K + uarr = im.uvec.reshape(im.ydim, im.xdim)[::-1] * factor # in brightness temperature K + interpQ = scipy.interpolate.interp2d(ys, xs, qarr, kind=interptype) + interpU = scipy.interpolate.interp2d(ys, xs, uarr, kind=interptype) + + def ringValsQ(theta): + xxs = x0 - rs*np.sin(theta) + yys = y0 + rs*np.cos(theta) + + vals = [interpQ(xxs[i], yys[i])[0] for i in np.arange(len(rs))] + return vals + + def ringValsU(theta): + xxs = x0 - rs*np.sin(theta) + yys = y0 + rs*np.cos(theta) + + vals = [interpU(xxs[i], yys[i])[0] for i in np.arange(len(rs))] + return vals + + for j in range(nrays): + valsQ = ringValsQ(thetas[j]) + profsQ.append(valsQ) + valsU = ringValsU(thetas[j]) + profsU.append(valsU) + + + profiles = Profiles(im, x0, y0, profs, thetas, rmin=rmin, rmax=rmax, interptype=interptype, + flux_norm=flux_norm,profsQ=profsQ, profsU=profsU) + + return profiles + + +def findCenter(im, + rmin=RMIN, rmax=RMAX, + rmin_search=RPRIOR_MIN, rmax_search=RPRIOR_MAX, + nrays_search=NRAYS_SEARCH, nrs_search=NRS_SEARCH, + fov_search=FOVP_SEARCH, n_search=NSEARCH, flux_norm=NORMFLUX): + """Find the ring center by looking at profiles over a given range + """ + + print("nrays", nrays_search, "nrs", nrs_search, "fov", fov_search, "n", n_search) + + rs = np.linspace(0, rmax_search, nrs_search) + dr = rs[-1] - rs[-2] + thetas = np.linspace(0, 2*np.pi, nrays_search) + factor = 3.254e13/(im.rf**2 * im.psize**2) # convert to brightness temperature + imarr = im.imvec.reshape(im.ydim, im.xdim)[::-1] * factor # in brightness temperature K + xs = np.arange(im.xdim)*im.psize/ehc.RADPERUAS + ys = np.arange(im.ydim)*im.psize/ehc.RADPERUAS + + # TODO: test fiducial images with linear? + #iminterp = scipy.interpolate.interp2d(ys, xs, imarr, kind='cubic') + iminterp = scipy.interpolate.interp2d(ys, xs, imarr, kind='linear') + #iminterp = scipy.interpolate.RegularGridInterpolator((ys,xs),imarr) + def objFunc(pos): + (x0, y0) = pos + diameters = [] + for j in range(nrays_search): + xxs = x0 - rs*np.sin(thetas[j]) + yys = y0 + rs*np.cos(thetas[j]) + prof = np.array([iminterp(xxs[i], yys[i])[0] for i in np.arange(len(rs))]) + + args = np.argsort(prof) + pkpos = args[-1] + pk = rs[pkpos] + vpk = prof[pkpos] + if pkpos > 0 and pkpos < nrs_search-1: + vals = [prof[pkpos-1], prof[pkpos], prof[pkpos+1]] + pk, vpk = quad_interp_radius(pk, dr, vals) + + diameters.append(2*np.abs(pk)) + + # ring size + mean,std = (np.mean(diameters), np.std(diameters)) + + if mean < rmin_search or mean > rmax_search: + return np.inf + else: + J = np.abs(std/mean) + return J + + fovx = im.fovx()/ehc.RADPERUAS + fovy = im.fovy()/ehc.RADPERUAS + + # brute force search + fmin finisher to find + fovmin_x = (.5-fov_search) * fovx + fovmax_x = (.5+fov_search) * fovx + fovmin_y = (.5-fov_search) * fovy + fovmax_y = (.5+fov_search) * fovy + res = scipy.optimize.brute(objFunc, ranges=((fovmin_x, fovmax_x), (fovmin_y, fovmax_y)), + Ns=n_search) + + return res + +def FindProfile(im, blur=0, pol_profs=False, + imsize=IMSIZE, npix=NPIX, rmin=RMIN, rmax=RMAX, nrays=NRAYS, nrs=NRS, + rmin_search=RPRIOR_MIN, rmax_search=RPRIOR_MAX, + nrays_search=NRAYS_SEARCH, nrs_search=NRS_SEARCH, + thresh_search=THRESH, fov_search=FOVP_SEARCH, n_search=NSEARCH, + flux_norm=NORMFLUX,center=False): + """find the best ring profile for an image object and save results + """ + + im_raw = im.copy() + # blur image if requested + if blur > 0: + im_raw = im_raw.blur_circ(blur*ehc.RADPERUAS, blur*ehc.RADPERUAS) + + # center image and regrid to uniform pixel size and fox + if center: + im = di.center_core(im_raw) # TODO -- why isn't this working? + else: + im = im_raw + + im_search = im.regrid_image(imsize, npix) + im = im.regrid_image(imsize, npix) + + # blur image if requested + # if blur > 0: + # im_search = im_search.blur_circ(blur*ehc.RADPERUAS) + # im = im.blur_circ(blur*ehc.RADPERUAS) + + # blur and threshold image FOR SEARCH ONLY + # if blur==0: + # im_search = im.blur_circ(BLUR_VALUE_MIN*ehc.RADPERUAS) + # else: + # im_search = im.copy() + + # threshold the search image to 5% of the maximum + im_search.imvec[im_search.imvec < thresh_search*np.max(im_search.imvec)] = 0 + + # find center + res = findCenter(im_search, rmin=rmin, rmax=rmax, + rmin_search=rmin_search, rmax_search=rmax_search, + nrays_search=nrays_search, nrs_search=nrs_search, + fov_search=fov_search, n_search=n_search, flux_norm=flux_norm) + + # compute profiles using the original (regridded, flux centroid centered) image + print("compute profile") + pp = compute_ring_profile(im, res[0], res[1], nrs=nrs, nrays=nrays, + rmin=rmin, rmax=rmax, flux_norm=flux_norm, + pol_profs=pol_profs) + pp.calc_meanprof_and_stats() + + return pp + +def FindProfileSingle(imname, postprocdir, + save_files=False, blur=0, aipscc=False, tag='', + rerun=True, return_pp=True, + imsize=IMSIZE, npix=NPIX, rmin=RMIN, rmax=RMAX, nrays=NRAYS, nrs=NRS, + rmin_search=RPRIOR_MIN, rmax_search=RPRIOR_MAX, + nrays_search=NRAYS_SEARCH, nrs_search=NRS_SEARCH, + thresh_search=THRESH, fov_search=FOVP_SEARCH, n_search=NSEARCH, + flux_norm=NORMFLUX,center=False): + """find the best ring profile for an image fits file and save results + """ + + dirname = os.path.basename(os.path.dirname(imname)) + basename = os.path.basename(imname) + txtname = postprocdir + '/' + dirname + '/' + basename[:-5] + tag + '.txt' + if rerun is False and os.path.exists(txtname): + return -1 + + # print("nrays",nrays_search,"nrs",nrs_search,"fov",fov_search) + with ploop.HiddenPrints(): + + # load image + im_raw = load_image(imname, aipscc=aipscc) + + # calculate profile + pp = FindProfile(im, blur=blur, + imsize=imsize, npix=npix, rmin=rmin, rmax=rmax, nrays=nrays, nrs=nrs, + rmin_search=rmin_search, rmax_search=rmax_search, + nrays_search=nrays_search, nrs_search=nrs_search, + thresh_search=thresh_search, fov_search=fov_search, n_search=n_search, + flux_norm=flux_norm,center=center) + + + # save files + if save_files: + print("save files") + dirname = os.path.basename(os.path.dirname(imname)) + if not os.path.exists(postprocdir + '/' + dirname): + subprocess.call(['mkdir', postprocdir + '/' + dirname]) + + basename = os.path.splitext(os.path.basename(imname))[0] + txtname = postprocdir + '/' + dirname + '/' + basename + tag + '.txt' + + if os.path.exists(txtname): + os.remove(txtname) + + f = open(txtname, 'a') + f.write('ring_x0 ' + str(res[0]) + '\n') + f.write('ring_y0 ' + str(res[1]) + '\n') + + f.write('ring_diameter ' + str(pp.RingSize1[0]) + '\n') + f.write('ring_diameter_sigma ' + str(pp.RingSize1[1]) + '\n') + + f.write('meanprof_ring_diameter ' + str(pp.RingSize2[0]) + '\n') + f.write('meanprof_ring_diameter_sigma ' + str(pp.RingSize2[1]) + '\n') + + f.write('ring_orientation: ' + str(pp.RingAngle1[0]) + '\n') + f.write('ring_orientation_sigma: ' + str(pp.RingAngle1[1]) + '\n') + + f.write('meanprof_ring_orientation: ' + str(pp.RingAngle2[0]) + '\n') + f.write('meanprof_ring_orientation_sigma: ' + str(pp.RingAngle2[1]) + '\n') + + f.write('ring_width: ' + str(pp.RingWidth[0]) + '\n') + f.write('ring_width_sigma: ' + str(pp.RingWidth[1]) + '\n') + + f.write('total_flux ' + str(pp.flux) + '\n') + f.write('total_ring_flux ' + str(pp.RingFlux) + '\n') + + f.write('ring_asym_1 ' + str(pp.RingAsym1[0]) + '\n') + f.write('ring_asym_1_sigma ' + str(pp.RingAsym1[1]) + '\n') + f.write('ring_asym_2 ' + str(pp.RingAsym2[0]) + '\n') + f.write('ring_brighthalf_over_dimhalf ' + str(pp.RingAsym2[1]) + '\n') + + f.write('in_flux_mean_ring ' + str(pp.in_level) + '\n') + f.write('out_flux_mean_ring ' + str(pp.out_level) + '\n') + f.write('max_flux_mean_ring ' + str(pp.meanpk) + '\n') + + f.write('max_ring_contrast: ' + str(pp.RingContrast1) + '\n') + f.write('mean_ring_contrast: ' + str(pp.RingContrast2) + '\n') + f.write('dynamic_range ' + str(pp.dynamic_range) + '\n') + + f.write('norm_factor ' + str(pp.normfactor) + '\n') + + f.write('ring_diameter_med ' + str(pp.RingSize1_med[0]) + '\n') + f.write('ring_diameter_medabsdev ' + str(pp.RingSize1_med[1]) + '\n') + + f.write('ring_angle_pol ' + str(pp.RingAnglePol[0]) + '\n') + f.write('ring_angle_pol_sigma ' + str(pp.RingAnglePol[1]) + '\n') + + f.write('ring_pol_ratio_p ' + str(pp.RingAsymPol[0]) + '\n') + f.write('ring_pol_ratio_m ' + str(pp.RingAsymPol[1]) + '\n') + + f.close() + + # save unwrapped and centered fits image + # fitsname = postprocdir + '/' + dirname + '/' + basename[:-5] + tag + '.fits' + # fitsname_centered = postprocdir + '/' + dirname + \ + # '/' + basename[:-5] + tag + '_cent.fits' + + # pp.save_unwrapped(fitsname) + # pp.im_center.save_fits(fitsname_centered) + + # save radial profile + # radprof_name = postprocdir + '/' + dirname + '/' + basename[:-5] + tag + '_radprof.txt + # data=np.hstack((pp.rs.reshape(pp.nrs,1), + # pp.meanprof.reshape(pp.nrs,1), + # pp.normfactor * pp.meanprof.reshape(pp.nrs,1))) + # np.savetxt(radprof_name, data) + + # save angular profile + # angprof_name = postprocdir + '/' + dirname + '/' + basename[:-5] + tag+'_angprof.txt' + # data=np.hstack((pp.thetas.reshape(pp.nthetas,1), + # pp.meanprof_theta.reshape(pp.nthetas,1), + # pp.normfactor * pp.meanprof_theta.reshape(pp.nthetas,1))) + # np.savetxt(angprof_name, data) + + # pp.plot_unwrapped(save_png=True) + # pp.plot_img(save_png=True) + # pp.plot_meanprof(save_png=True) + # pp.plot_meanprof_theta(save_png=True) + # plt.close('all') + + if return_pp: + return pp + else: + del pp + return + + +def FindProfiles(foldername, postprocdir, processes=-1, + save_files=False, blur=0, + aipscc=False, tag='', rerun=True, return_pp=True, + imsize=IMSIZE, npix=NPIX, rmin=RMIN, rmax=RMAX, nrays=NRAYS, nrs=NRS, + rmin_search=RPRIOR_MIN, rmax_search=RPRIOR_MAX, + nrays_search=NRAYS_SEARCH, nrs_search=NRS_SEARCH, + thresh_search=THRESH, fov_search=FOVP_SEARCH, n_search=NSEARCH, + flux_norm=NORMFLUX + ): + """find profiles for all image fits files in a directory + """ + + foldername = os.path.abspath(foldername) + imlist = np.array(glob.glob(foldername + '/*.fits')) + ext = '.fits' + + # Look for hdf5 files for EHT library runs + if len(imlist) == 0: + imlist = np.array(glob.glob(foldername + '/*.h5')) + ext = '.h5' + + if len(imlist) == 0: + print("\nfound no image files in ", foldername) + return [] + + print("\nfound ", len(imlist), " ", ext, " files in ", foldername) + + imlist = np.sort(imlist) + arglist = [[imlist[i], postprocdir, + save_files, blur, + aipscc, tag, rerun, return_pp, + imsize, npix, rmin, rmax, nrays, nrs, + rmin_search, rmax_search, nrays_search, nrs_search, + thresh_search, fov_search, n_search, flux_norm] + for i in range(len(imlist))] + + parloop = ploop.Parloop(FindProfileSingle) + pplist = parloop.run_loop(arglist, processes) + return pplist diff --git a/image.py b/image.py new file mode 100644 index 00000000..73ccafaa --- /dev/null +++ b/image.py @@ -0,0 +1,4320 @@ +# image.py +# an interferometric image class +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import sys +import copy +import math +import numpy as np +import numpy.matlib as matlib +import matplotlib as mpl +import matplotlib.pyplot as plt +import scipy.optimize as opt +import scipy.signal +import scipy.ndimage.filters as filt +import scipy.interpolate +from scipy import ndimage as ndi + + +try: + from skimage.feature import canny + from skimage.transform import hough_circle, hough_circle_peaks +except ImportError: + print("Warning: scikit-image not installed! Cannot use hough transform") + +import ehtim.observing.obs_simulate as simobs +import ehtim.observing.pulses as pulses +import ehtim.io.save +import ehtim.io.load +import ehtim.const_def as ehc +import ehtim.observing.obs_helpers as obsh + +# TODO : add time to all images +# TODO : add arbitrary center location + +################################################################################################### +# Image object +################################################################################################### + + +class Image(object): + + """A polarimetric image (in units of Jy/pixel). + + Attributes: + pulse (function): The function convolved with the pixel values for continuous image. + psize (float): The pixel dimension in radians + xdim (int): The number of pixels along the x dimension + ydim (int): The number of pixels along the y dimension + mjd (int): The integer MJD of the image + time (float): The observing time of the image (UTC hours) + source (str): The astrophysical source name + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + rf (float): The image frequency in Hz + + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, or RR,LL,LR,RL for Circular + _imdict (dict): The dictionary with the polarimetric images + _mflist (list): List of spectral index images (and higher order terms) + """ + + def __init__(self, image, psize, ra, dec, pa=0.0, + polrep='stokes', pol_prim=None, + rf=ehc.RF_DEFAULT, pulse=ehc.PULSE_DEFAULT, source=ehc.SOURCE_DEFAULT, + mjd=ehc.MJD_DEFAULT, time=0.): + """A polarimetric image (in units of Jy/pixel). + + Args: + image (numpy.array): The 2D intensity values in a Jy/pixel array + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + + psize (float): The pixel dimension in radians + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + pa (float): logical positional angle of the image + rf (float): The image frequency in Hz + pulse (function): The function convolved with the pixel values for continuous image. + source (str): The source name + mjd (int): The integer MJD of the image + time (float): The observing time of the image (UTC hours) + + Returns: + (Image): the Image object + """ + + if len(image.shape) != 2: + raise Exception("image must be a 2D numpy array") + if polrep not in ['stokes', 'circ']: + raise Exception("only 'stokes' and 'circ' are supported polreps!") + + # Save the image vector + imvec = image.flatten() + + if polrep == 'stokes': + if pol_prim is None: + pol_prim = 'I' + if pol_prim == 'I': + self._imdict = {'I': imvec, 'Q': np.array([]), 'U': np.array([]), 'V': np.array([])} + elif pol_prim == 'V': + self._imdict = {'I': np.array([]), 'Q': np.array([]), 'U': np.array([]), 'V': imvec} + elif pol_prim == 'Q': + self._imdict = {'I': np.array([]), 'Q': imvec, 'U': np.array([]), 'V': np.array([])} + elif pol_prim == 'U': + self._imdict = {'I': np.array([]), 'Q': np.array([]), 'U': imvec, 'V': np.array([])} + else: + raise Exception("for polrep=='stokes', pol_prim must be 'I','Q','U', or 'V'!") + + elif polrep == 'circ': + if pol_prim is None: + print("polrep is 'circ' and no pol_prim specified! Setting pol_prim='RR'") + pol_prim = 'RR' + if pol_prim == 'RR': + self._imdict = {'RR': imvec, 'LL': np.array([]), 'RL': np.array([]), 'LR': np.array([])} + elif pol_prim == 'LL': + self._imdict = {'RR': np.array([]), 'LL': imvec, 'RL': np.array([]), 'LR': np.array([])} + else: + raise Exception("for polrep=='circ', pol_prim must be 'RR' or 'LL'!") + else: + raise Exception("polrep must be 'circ' or 'stokes'!") + + # multifrequency spectral index, curvature arrays + # TODO -- higher orders? + # TODO -- don't initialize to zero? + avec = np.array([]) # np.zeros(imvec.shape) + bvec = np.array([]) # np.zeros(imvec.shape) + self._mflist = [avec, bvec] + + # Save the image dimension data + self.pol_prim = pol_prim + self.polrep = polrep + self.pulse = pulse + self.psize = float(psize) + self.xdim = image.shape[1] + self.ydim = image.shape[0] + + # Save the image metadata + self.ra = float(ra) + self.dec = float(dec) + self.pa = float(pa) + self.rf = float(rf) + self.source = str(source) + self.mjd = int(mjd) + + # Cached FFT of the image + self.cached_fft = {} + + if time > 24: + self.mjd += int((time - time % 24) / 24) + self.time = float(time % 24) + else: + self.time = time + + @property + def imvec(self): + imvec = self._imdict[self.pol_prim] + return imvec + + @imvec.setter + def imvec(self, vec): + if len(vec) != self.xdim * self.ydim: + raise Exception("imvec size is not consistent with xdim*ydim!") + + self._imdict[self.pol_prim] = vec + + @property + def specvec(self): + specvec = self._mflist[0] + return specvec + + @specvec.setter + def specvec(self, vec): + if len(vec) != self.xdim * self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + self._mflist[0] = vec + + @property + def curvvec(self): + curvvec = self._mflist[1] + return curvvec + + @curvvec.setter + def curvvec(self, vec): + if len(vec) != self.xdim * self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + self._mflist[1] = vec + + @property + def ivec(self): +# if self.polrep != 'stokes': +# raise Exception("ivec is not defined unless self.polrep=='stokes'") + + ivec = np.array([]) + if self.polrep == 'stokes': + ivec = self._imdict['I'] + elif self.polrep == 'circ': + if len(self.rrvec) != 0 and len(self.llvec) != 0: + ivec = 0.5 * (self.rrvec + self.llvec) + + return ivec + + @ivec.setter + def ivec(self, vec): + if len(vec) != self.xdim * self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + if self.polrep != 'stokes': + raise Exception("ivec cannot be set unless self.polrep=='stokes'") + + self._imdict['I'] = vec + + @property + def qvec(self): +# if self.polrep != 'stokes': +# raise Exception("qvec is not defined unless self.polrep=='stokes'") + + qvec = np.array([]) + if self.polrep == 'stokes': + qvec = self._imdict['Q'] + elif self.polrep == 'circ': + if len(self.rlvec) != 0 and len(self.lrvec) != 0: + qvec = np.real(0.5 * (self.lrvec + self.rlvec)) + + return qvec + + @qvec.setter + def qvec(self, vec): + if len(vec) != self.xdim * self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + if self.polrep != 'stokes': + raise Exception("ivec cannot be set unless self.polrep=='stokes'") + + self._imdict['Q'] = vec + + @property + def uvec(self): +# if self.polrep != 'stokes': +# raise Exception("qvec is not defined unless self.polrep=='stokes'") + + uvec = np.array([]) + if self.polrep == 'stokes': + uvec = self._imdict['U'] + elif self.polrep == 'circ': + if len(self.rlvec) != 0 and len(self.lrvec) != 0: + uvec = np.real(0.5j * (self.lrvec - self.rlvec)) + + return uvec + + @uvec.setter + def uvec(self, vec): + if len(vec) != self.xdim * self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + if self.polrep != 'stokes': + raise Exception("uvec cannot be set unless self.polrep=='stokes'") + + self._imdict['U'] = vec + + @property + def vvec(self): +# if self.polrep != 'stokes': +# raise Exception("vvec is not defined unless self.polrep=='stokes'") + + vvec = np.array([]) + if self.polrep == 'stokes': + vvec = self._imdict['V'] + elif self.polrep == 'circ': + if len(self.rrvec) != 0 and len(self.llvec) != 0: + vvec = 0.5 * (self.rrvec - self.llvec) + + return vvec + + @vvec.setter + def vvec(self, vec): + if len(vec) != self.xdim * self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + if self.polrep != 'stokes': + raise Exception("vvec cannot be set unless self.polrep=='stokes'") + + self._imdict['V'] = vec + + @property + def rrvec(self): +# if self.polrep != 'circ': +# raise Exception("rrvec is not defined unless self.polrep=='circ'") + + rrvec = np.array([]) + if self.polrep == 'circ': + rrvec = self._imdict['RR'] + elif self.polrep == 'stokes': + if len(self.ivec) != 0 and len(self.vvec) != 0: + rrvec = (self.ivec + self.vvec) + + return rrvec + + @rrvec.setter + def rrvec(self, vec): + if len(vec) != self.xdim * self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + if self.polrep != 'circ': + raise Exception("rrvec cannot be set unless self.polrep=='circ'") + + self._imdict['RR'] = vec + + @property + def llvec(self): +# if self.polrep != 'circ': +# raise Exception("llvec is not defined unless self.polrep=='circ'") + + llvec = np.array([]) + if self.polrep == 'circ': + llvec = self._imdict['LL'] + elif self.polrep == 'stokes': + if len(self.ivec) != 0 and len(self.vvec) != 0: + llvec = (self.ivec - self.vvec) + + return llvec + + @llvec.setter + def llvec(self, vec): + if len(vec) != self.xdim * self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + if self.polrep != 'circ': + raise Exception("llvec cannot be set unless self.polrep=='circ'") + + self._imdict['LL'] = vec + + @property + def rlvec(self): +# if self.polrep != 'circ': +# raise Exception("rlvec is not defined unless self.polrep=='circ'") + + rlvec = np.array([]) + if self.polrep == 'circ': + rlvec = self._imdict['RL'] + elif self.polrep == 'stokes': + if len(self.qvec) != 0 and len(self.uvec) != 0: + rlvec = (self.qvec + 1j * self.uvec) + + return rlvec + + @rlvec.setter + def rlvec(self, vec): + if len(vec) != self.xdim * self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + if self.polrep != 'circ': + raise Exception("rlvec cannot be set unless self.polrep=='circ'") + + self._imdict['RL'] = vec + + @property + def lrvec(self): + """Return the imvec of LR""" +# if self.polrep != 'circ': +# raise Exception("lrvec is not defined unless self.polrep=='circ'") + + lrvec = np.array([]) + if self.polrep == 'circ': + lrvec = self._imdict['LR'] + elif self.polrep == 'stokes': + if len(self.qvec) != 0 and len(self.uvec) != 0: + lrvec = (self.qvec - 1j * self.uvec) + + + return lrvec + + @lrvec.setter + def lrvec(self, vec): + """Set the imvec""" + + if len(vec) != self.xdim * self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + if self.polrep != 'circ': + raise Exception("lrvec cannot be set unless self.polrep=='circ'") + + self._imdict['LR'] = vec + + @property + def pvec(self): + """Return the polarization magnitude for each pixel""" + if self.polrep == 'circ': + pvec = np.abs(self.rlvec) + elif self.polrep == 'stokes': + pvec = np.abs(self.qvec + 1j * self.uvec) + + return pvec + + @property + def mvec(self): + """Return the fractional polarization for each pixel""" + if self.polrep == 'circ': + mvec = 2 * np.abs(self.rlvec) / (self.rrvec + self.llvec) + elif self.polrep == 'stokes': + mvec = np.abs(self.qvec + 1j * self.uvec) / self.ivec + + return mvec + + @property + def chivec(self): + """Return the fractional polarization angle for each pixel""" + if self.polrep == 'circ': + chivec = 0.5 * np.angle(self.rlvec / (self.rrvec + self.llvec)) + elif self.polrep == 'stokes': + chivec = 0.5 * np.angle((self.qvec + 1j * self.uvec) / self.ivec) + + return chivec + + @property + def evpavec(self): + """Return the fractional polarization angle for each pixel""" + + return self.chivec + + @property + def evec(self): + """Return the E mode image vector""" + if self.polrep == 'circ': + qvec = np.real(0.5 * (self.lrvec + self.rlvec)) + uvec = np.real(0.5j * (self.lrvec - self.rlvec)) + elif self.polrep == 'stokes': + qvec = self.qvec + uvec = self.uvec + + qarr = qvec.reshape((self.ydim, self.xdim)) + uarr = uvec.reshape((self.ydim, self.xdim)) + qarr_fft = np.fft.fftshift(np.fft.fft2(qarr)) + uarr_fft = np.fft.fftshift(np.fft.fft2(uarr)) + + # TODO -- check conventions for u,v angle + s, t = np.meshgrid(np.flip(np.fft.fftshift(np.fft.fftfreq(self.xdim, d=1.0 / self.xdim))), + np.flip(np.fft.fftshift(np.fft.fftfreq(self.ydim, d=1.0 / self.ydim)))) + s = s + .5 # .5 offset to reference to pixel center + t = t + .5 # .5 offset to reference to pixel center + uvangle = np.arctan2(s, t) + + # TODO -- these conventions for e,b are from kaminokowski aara 54:227-69 sec 4.1 + # TODO -- check! + cos2arr = np.round(np.cos(2 * uvangle), decimals=10) + sin2arr = np.round(np.sin(2 * uvangle), decimals=10) + earr_fft = (cos2arr * qarr_fft + sin2arr * uarr_fft) + + earr = np.fft.ifft2(np.fft.ifftshift(earr_fft)) + return np.real(earr.flatten()) + + @property + def bvec(self): + """Return the B mode image vector""" + + if self.polrep == 'circ': + qvec = np.real(0.5 * (self.lrvec + self.rlvec)) + uvec = np.real(0.5j * (self.lrvec - self.rlvec)) + elif self.polrep == 'stokes': + qvec = self.qvec + uvec = self.uvec + + # TODO -- check conventions for u,v angle + qarr = qvec.reshape((self.ydim, self.xdim)) + uarr = uvec.reshape((self.ydim, self.xdim)) + qarr_fft = np.fft.fftshift(np.fft.fft2(qarr)) + uarr_fft = np.fft.fftshift(np.fft.fft2(uarr)) + + # TODO -- are these conventions for u,v right? + s, t = np.meshgrid(np.flip(np.fft.fftshift(np.fft.fftfreq(self.xdim, d=1.0 / self.xdim))), + np.flip(np.fft.fftshift(np.fft.fftfreq(self.ydim, d=1.0 / self.ydim)))) + s = s + .5 # .5 offset to reference to pixel center + t = t + .5 # .5 offset to reference to pixel center + uvangle = np.arctan2(s, t) + + # TODO -- check! + cos2arr = np.round(np.cos(2 * uvangle), decimals=10) + sin2arr = np.round(np.sin(2 * uvangle), decimals=10) + barr_fft = (-sin2arr * qarr_fft + cos2arr * uarr_fft) + + barr = np.fft.ifft2(np.fft.ifftshift(barr_fft)) + return np.real(barr.flatten()) + + def get_polvec(self, pol): + """Get the imvec corresponding to the chosen polarization + """ + if self.polrep == 'stokes' and pol is None: + pol = 'I' + elif self.polrep == 'circ' and pol is None: + pol = 'RR' + + if pol.lower() == 'i': + outvec = self.ivec + elif pol.lower() == 'q': + outvec = self.qvec + elif pol.lower() == 'u': + outvec = self.uvec + elif pol.lower() == 'v': + outvec = self.vvec + elif pol.lower() == 'rr': + outvec = self.rrvec + elif pol.lower() == 'll': + outvec = self.llvec + elif pol.lower() == 'lr': + outvec = self.lrvec + elif pol.lower() == 'rl': + outvec = self.rlvec + elif pol.lower() == 'p': + outvec = self.pvec + elif pol.lower() == 'm': + outvec = self.mvec + elif pol.lower() == 'chi' or pol.lower() =='evpa': + outvec = self.chivec + elif pol.lower() == 'e': + outvec = self.evec + elif pol.lower() == 'b': + outvec = self.bvec + else: + raise Exception("Requested polvec type not recognized!") + return outvec + + def image_args(self): + """Copy arguments for making a new Image into a list and dictonary + """ + + arglist = [self.imarr(), self.psize, self.ra, self.dec] + argdict = {'rf': self.rf, 'pa': self.pa, + 'polrep': self.polrep, 'pol_prim': self.pol_prim, + 'pulse': self.pulse, 'source': self.source, + 'mjd': self.mjd, 'time': self.time} + + return (arglist, argdict) + + def copy(self): + """Return a copy of the Image object. + + Args: + + Returns: + (Image): copy of the Image. + """ + + # Make new image with primary polarization + arglist, argdict = self.image_args() + newim = Image(*arglist, **argdict) + + # Copy over all polarization images + newim.copy_pol_images(self) + + # Copy over spectral index information + newim._mflist = copy.deepcopy(self._mflist) + + return newim + + def copy_pol_images(self, old_image): + """Copy polarization images from old_image over to self. + + Args: + old_image (Image): image object to copy from + + """ + + for pol in list(self._imdict.keys()): + + if (pol == self.pol_prim): + continue + + polvec = old_image._imdict[pol] + if len(polvec): + self.add_pol_image(polvec.reshape(self.ydim, self.xdim), pol) + + def add_pol_image(self, image, pol): + """Add another image polarization. + + Args: + image (list): 2D image frame (possibly complex) in a Jy/pixel array + pol (str): The image type: 'I','Q','U','V' for stokes, 'RR','LL','RL','LR' for circ + """ + + if pol == self.pol_prim: + raise Exception("new pol in add_pol_image is the same as pol_prim!") + if image.shape != (self.ydim, self.xdim): + raise Exception("add_pol_image image shapes incompatible with primary image!") + if not (pol in list(self._imdict.keys())): + raise Exception("for polrep==%s, pol in add_pol_image in " % + self.polrep + ",".join(list(self._imdict.keys()))) + + if self.polrep == 'stokes': + if pol == 'I': + self.ivec = image.flatten() + elif pol == 'Q': + self.qvec = image.flatten() + elif pol == 'U': + self.uvec = image.flatten() + elif pol == 'V': + self.vvec = image.flatten() + + elif self.polrep == 'circ': + if pol == 'RR': + self.rrvec = image.flatten() + elif pol == 'LL': + self.llvec = image.flatten() + elif pol == 'RL': + self.rlvec = image.flatten() + elif pol == 'LR': + self.lrvec = image.flatten() + + return + + # TODO deprecated -- replace with generic add_pol_image + def add_qu(self, qimage, uimage): + """Add Stokes Q and U images. self.polrep must be 'stokes' + + Args: + qimage (numpy.array): The 2D Stokes Q values in Jy/pixel array + uimage (numpy.array): The 2D Stokes U values in Jy/pixel array + + Returns: + """ + + if self.polrep != 'stokes': + raise Exception("polrep must be 'stokes' for add_qu() !") + self.add_pol_image(qimage, 'Q') + self.add_pol_image(uimage, 'U') + + return + + # TODO deprecated -- replace with generic add_pol_image + def add_v(self, vimage): + """Add Stokes V image. self.polrep must be 'stokes' + + Args: + vimage (numpy.array): The 2D Stokes Q values in Jy/pixel array + """ + + if self.polrep != 'stokes': + raise Exception("polrep must be 'stokes' for add_v() !") + self.add_pol_image(vimage, 'V') + + return + + def switch_polrep(self, polrep_out='stokes', pol_prim_out=None): + """Return a new image with the polarization representation changed + Args: + polrep_out (str): the polrep of the output data + pol_prim_out (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for circ + + Returns: + (Image): new Image object with potentially different polrep + """ + + if polrep_out not in ['stokes', 'circ']: + raise Exception("polrep_out must be either 'stokes' or 'circ'") + if pol_prim_out is None: + if polrep_out == 'stokes': + pol_prim_out = 'I' + elif polrep_out == 'circ': + pol_prim_out = 'RR' + + # Simply copy if the polrep is unchanged + if polrep_out == self.polrep and pol_prim_out == self.pol_prim: + return self.copy() + + # Assemble a dictionary of new polarization vectors + if polrep_out == 'stokes': + if self.polrep == 'stokes': + imdict = {'I': self.ivec, 'Q': self.qvec, 'U': self.uvec, 'V': self.vvec} + else: + if len(self.rrvec) == 0 or len(self.llvec) == 0: + ivec = np.array([]) + vvec = np.array([]) + else: + ivec = 0.5 * (self.rrvec + self.llvec) + vvec = 0.5 * (self.rrvec - self.llvec) + + if len(self.rlvec) == 0 or len(self.lrvec) == 0: + qvec = np.array([]) + uvec = np.array([]) + else: + qvec = np.real(0.5 * (self.lrvec + self.rlvec)) + uvec = np.real(0.5j * (self.lrvec - self.rlvec)) + + imdict = {'I': ivec, 'Q': qvec, 'U': uvec, 'V': vvec} + + elif polrep_out == 'circ': + if self.polrep == 'circ': + imdict = {'RR': self.rrvec, 'LL': self.llvec, 'RL': self.rlvec, 'LR': self.lrvec} + else: + if len(self.ivec) == 0 or len(self.vvec) == 0: + rrvec = np.array([]) + llvec = np.array([]) + else: + rrvec = (self.ivec + self.vvec) + llvec = (self.ivec - self.vvec) + + if len(self.qvec) == 0 or len(self.uvec) == 0: + rlvec = np.array([]) + lrvec = np.array([]) + else: + rlvec = (self.qvec + 1j * self.uvec) + lrvec = (self.qvec - 1j * self.uvec) + + imdict = {'RR': rrvec, 'LL': llvec, 'RL': rlvec, 'LR': lrvec} + + # Assemble the new image + imvec = imdict[pol_prim_out] + if len(imvec) == 0: + raise Exception("for switch_polrep to %s with pol_prim_out=%s, \n" % + (polrep_out, pol_prim_out) + "output image is not defined") + + arglist, argdict = self.image_args() + arglist[0] = imvec.reshape(self.ydim, self.xdim) + argdict['polrep'] = polrep_out + argdict['pol_prim'] = pol_prim_out + newim = Image(*arglist, **argdict) + + # Add in any other polarizations + for pol in list(imdict.keys()): + if pol == newim.pol_prim: + continue + polvec = imdict[pol] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim) + newim.add_pol_image(polarr, pol) + + # Add in spectral index + newim._mflist = copy.deepcopy(self._mflist) + + return newim + + def orth_chi(self): + """Rotate the EVPA 90 degrees + + Args: + + Returns: + (Image): image with rotated EVPA + """ + im = self.copy() + if im.polrep == 'stokes': + im.qvec *= -1 + im.uvec *= -1 + elif im.polrep == 'circ': + im.lrvec *= -1# np.conjugate(im.rlvec) + im.rlvec *= -1#np.conjugate(im.rlvec) + #im.lrvec = np.conjugate(im.rlvec) + #im.rlvec = np.conjugate(im.rlvec) + + return im + + def get_image_mf(self, nu): + """Get image at a given frequency given the spectral information in self._mflist + + Args: + nu (float): frequency in Hz + + Returns: + (Image): image at the desired frequency + """ + # TODO -- what to do about polarization? Faraday rotation? + + nuref = self.rf + log_nufrac = np.log(nu / nuref) + log_imvec = np.log(self.imvec) + + for n, mfvec in enumerate(self._mflist): + if len(mfvec): + log_imvec += mfvec * (log_nufrac**(n + 1)) + imvec = np.exp(log_imvec) + + arglist, argdict = self.image_args() + arglist[0] = imvec.reshape(self.ydim, self.xdim) + argdict['rf'] = nu + outim = Image(*arglist, **argdict) + + # Copy over all polarization images -- unchanged for now + outim.copy_pol_images(self) + + # DON'T copy over spectral index information for now + # outim._mflist = copy.deepcopy(self._mflist) + + return outim + + def imarr(self, pol=None): + """Return the 2D image array of a given pol parameter. + + Args: + pol (str): I,Q,U or V for Stokes, or RR,LL,LR,RL for Circ + + Returns: + (numpy.array): 2D image array of dimension (ydim, xdim) + """ + + if pol is None: + pol = self.pol_prim + + imvec = self.get_polvec(pol) + if len(imvec): + imarr = imvec.reshape(self.ydim, self.xdim) + else: + imarr = np.array([]) + return imarr + +# imarr = np.array([]) +# if self.polrep == 'stokes': +# if pol == "I" and len(self.ivec): +# imarr = self.ivec.reshape(self.ydim, self.xdim) +# elif pol == "Q" and len(self.qvec): +# imarr = self.qvec.reshape(self.ydim, self.xdim) +# elif pol == "U" and len(self.uvec): +# imarr = self.uvec.reshape(self.ydim, self.xdim) +# elif pol == "V" and len(self.vvec): +# imarr = self.vvec.reshape(self.ydim, self.xdim) +# elif self.polrep == 'circ': +# if pol == "RR" and len(self.rrvec): +# imarr = self.rrvec.reshape(self.ydim, self.xdim) +# elif pol == "LL" and len(self.llvec): +# imarr = self.llvec.reshape(self.ydim, self.xdim) +# elif pol == "RL" and len(self.rlvec): +# imarr = self.rlvec.reshape(self.ydim, self.xdim) +# elif pol == "LR" and len(self.lrvec): +# imarr = self.lrvec.reshape(self.ydim, self.xdim) + + return imarr + + def sourcevec(self): + """Return the source position vector in geocentric coordinates at 0h GMST. + + Args: + + Returns: + (numpy.array): normal vector pointing to source in geocentric coordinates (m) + """ + + sourcevec = np.array([np.cos(self.dec * ehc.DEGREE), 0, np.sin(self.dec * ehc.DEGREE)]) + return sourcevec + + def fovx(self): + """Return the image fov in x direction in radians. + + Args: + + Returns: + (float) : image fov in x direction (radian) + """ + + return self.psize * self.xdim + + def fovy(self): + """Returns the image fov in y direction in radians. + + Args: + + Returns: + (float) : image fov in y direction (radian) + """ + + return self.psize * self.ydim + + def total_flux(self): + """Return the total flux of the image in Jy. + + Args: + + Returns: + (float) : image total flux (Jy) + """ + if self.polrep == 'stokes': + flux = np.sum(self.ivec) + elif self.polrep == 'circ': + flux = 0.5 * (np.sum(self.rrvec) + np.sum(self.llvec)) + + return flux + + def lin_polfrac(self): + """Return the total fractional linear polarized flux + + Args: + + Returns: + (float) : image fractional linear polarized flux + """ + if self.polrep == 'stokes': + frac = np.abs(np.sum(self.qvec + 1j * self.uvec)) / np.abs(np.sum(self.ivec)) + elif self.polrep == 'circ': + frac = 2 * np.abs(np.sum(self.rlvec)) / np.abs(np.sum(self.rrvec + self.llvec)) + + return frac + + def evpa(self): + """Return the total evpa + + Args: + + Returns: + (float) : image average evpa (E of N) in radian + """ + if self.polrep == 'stokes': + frac = 0.5 * np.angle(np.sum(self.qvec + 1j * self.uvec)) + elif self.polrep == 'circ': + frac = np.angle(np.sum(self.rlvec)) + + return frac + + def circ_polfrac(self): + """Return the total fractional circular polarized flux + + Args: + + Returns: + (float) : image fractional circular polarized flux + """ + if self.polrep == 'stokes': + frac = np.sum(self.vvec) / np.abs(np.sum(self.ivec)) + elif self.polrep == 'circ': + frac = np.sum(self.rrvec - self.llvec) / np.abs(np.sum(self.rrvec + self.llvec)) + + return frac + + def center(self, pol=None): + """Center the image based on the coordinates of the centroid(). + A non-integer shift is used, which wraps the image when rotating. + + Args: + pol (str): The polarization for which to find the image centroid + + Returns: + (np.array): centroid positions (x0,y0) in radians + """ + + return self.shift_fft(-self.centroid(pol=pol)) + + def centroid(self, pol=None): + """Compute the location of the image centroid (corresponding to the polarization pol) + + Args: + pol (str): The polarization for which to find the image centroid + + Returns: + (np.array): centroid positions (x0,y0) in radians + """ + + if pol is None: + pol = self.pol_prim + imvec = self.get_polvec(pol) + pdim = self.psize + +# if not (pol in list(self._imdict.keys())): +# raise Exception("for polrep==%s, pol must be in " % +# self.polrep + ",".join(list(self._imdict.keys()))) +# imvec = self._imdict[pol] + + if len(imvec): + xlist = np.arange(0, -self.xdim, -1) * pdim + (pdim * self.xdim) / 2.0 - pdim / 2.0 + ylist = np.arange(0, -self.ydim, -1) * pdim + (pdim * self.ydim) / 2.0 - pdim / 2.0 + x0 = np.sum(np.outer(0.0 * ylist + 1.0, xlist).ravel() * imvec) / np.sum(imvec) + y0 = np.sum(np.outer(ylist, 0.0 * xlist + 1.0).ravel() * imvec) / np.sum(imvec) + centroid = np.array([x0, y0]) + else: + raise Exception("No %s image found!" % pol) + + return centroid + + def pad(self, fovx, fovy): + """Pad an image to new fov_x by fov_y in radian. + Args: + fovx (float): new fov in x dimension (rad) + fovy (float): new fov in y dimension (rad) + + Returns: + im_pad (Image): padded image + """ + + # Find pad widths + fovoldx = self.fovx() + fovoldy = self.fovy() + padx = int(0.5 * (fovx - fovoldx) / self.psize) + pady = int(0.5 * (fovy - fovoldy) / self.psize) + + # Pad main image vector + imarr = self.imvec.reshape(self.ydim, self.xdim) + imarr = np.pad(imarr, ((pady, pady), (padx, padx)), 'constant') + + # Make new image + arglist, argdict = self.image_args() + arglist[0] = imarr + outim = Image(*arglist, **argdict) + + # Pad all polarizations and copy over + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + polvec = self._imdict[pol] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim) + polarr = np.pad(polarr, ((pady, pady), (padx, padx)), 'constant') + outim.add_pol_image(polarr, pol) + + # Add in spectral index + mflist_out = [] + for mfvec in self._mflist: + if len(mfvec): + mfarr = mfvec.reshape(self.ydim, self.xdim) + mfarr = np.pad(mfarr, ((pady, pady), (padx, padx)), 'constant') + mfvec_out = mfarr.flatten() + else: + mfvec_out = np.array([]) + mflist_out.append(mfvec_out) + outim._mflist = mflist_out + + return outim + + def resample_square(self, xdim_new, ker_size=5): + """Exactly resample a square image to new dimensions using the pulse function. + + Args: + xdim_new (int): new pixel dimension + ker_size (int): kernel size for resampling + + Returns: + im_resampled (Image): resampled image + """ + + if self.xdim != self.ydim: + raise Exception("Image must be square to use Image.resample_square!") + if self.pulse == pulses.deltaPulse2D: + raise Exception("Image.resample_squre does not work with delta pulses!") + + ydim_new = xdim_new + fov = self.xdim * self.psize + psize_new = float(fov) / float(xdim_new) + + # Define an interpolation function using the pulse + ij = np.array([[[i * self.psize + (self.psize * self.xdim) / 2.0 - self.psize / 2.0, + j * self.psize + (self.psize * self.ydim) / 2.0 - self.psize / 2.0] + for i in np.arange(0, -self.xdim, -1)] + for j in np.arange(0, -self.ydim, -1)]).reshape((self.xdim * self.ydim, 2)) + + def im_new_val(imvec, x_idx, y_idx): + x = x_idx * psize_new + (psize_new * xdim_new) / 2.0 - psize_new / 2.0 + y = y_idx * psize_new + (psize_new * ydim_new) / 2.0 - psize_new / 2.0 + mask = (((x - ker_size * self.psize / 2.0) < ij[:, 0]) * + (ij[:, 0] < (x + ker_size * self.psize / 2.0)) * + ((y - ker_size * self.psize / 2.0) < ij[:, 1]) * + (ij[:, 1] < (y + ker_size * self.psize / 2.0)) + ).flatten() + interp = np.sum([imvec[n] * self.pulse(x - ij[n, 0], y - ij[n, 1], self.psize, dom="I") + for n in np.arange(len(imvec))[mask]]) + return interp + + def im_new(imvec): + imarr_new = np.array([[im_new_val(imvec, x_idx, y_idx) + for x_idx in np.arange(0, -xdim_new, -1)] + for y_idx in np.arange(0, -ydim_new, -1)]) + return imarr_new + + # Compute new primary image vector + imarr_new = im_new(self.imvec) + + # Normalize + scaling = np.sum(self.imvec) / np.sum(imarr_new) + imarr_new *= scaling + + # Make new image + arglist, argdict = self.image_args() + arglist[0] = imarr_new + arglist[1] = psize_new + outim = Image(*arglist, **argdict) + + # Interpolate all polarizations and copy over + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + polvec = self._imdict[pol] + if len(polvec): + polarr_new = im_new(polvec) + polarr_new *= scaling + outim.add_pol_image(polarr_new, pol) + + # Interpolate spectral index and copy over + mflist_out = [] + for mfvec in self._mflist: + print("WARNING: resample_squre not debugged for spectral index resampling!") + if len(mfvec): + mfarr = im_new(mfvec) + mfvec_out = mfarr.flatten() + else: + mfvec_out = np.array([]) + mflist_out.append(mfvec_out) + outim._mflist = mflist_out + + return outim + + def regrid_image(self, targetfov, npix, interp='linear'): + """Resample the image to new (square) dimensions. + + Args: + targetfov (float): new field of view (radian) + npix (int): new pixel dimension + interp ('linear', 'cubic', 'quintic'): type of interpolation. default is linear + + Returns: + (Image): resampled image + """ + + psize_new = float(targetfov) / float(npix) + fov_x = self.fovx() + fov_y = self.fovy() + + # define an interpolation function + x = np.linspace(-fov_x / 2, fov_x / 2, self.xdim) + y = np.linspace(-fov_y / 2, fov_y / 2, self.ydim) + + xtarget = np.linspace(-targetfov / 2, targetfov / 2, npix) + ytarget = np.linspace(-targetfov / 2, targetfov / 2, npix) + + def interp_imvec(imvec, specind=False): + if np.any(np.imag(imvec) != 0): + return interp_imvec(np.real(imvec)) + 1j * interp_imvec(np.imag(imvec)) + + interpfunc = scipy.interpolate.interp2d(y, x, np.reshape(imvec, (self.ydim, self.xdim)), + kind=interp) + tmpimg = interpfunc(ytarget, xtarget) + tmpimg[np.abs(xtarget) > fov_x / 2., :] = 0.0 + tmpimg[:, np.abs(ytarget) > fov_y / 2.] = 0.0 + + if not specind: # adjust pixel size if not a spectral index map + tmpimg = tmpimg * (psize_new)**2 / self.psize**2 + return tmpimg + + # Make new image + imarr_new = interp_imvec(self.imvec) + arglist, argdict = self.image_args() + arglist[0] = imarr_new + arglist[1] = psize_new + + outim = Image(*arglist, **argdict) + + # Interpolate all polarizations and copy over + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + polvec = self._imdict[pol] + if len(polvec): + polarr_new = interp_imvec(polvec) + outim.add_pol_image(polarr_new, pol) + + # Interpolate spectral index and copy over + mflist_out = [] + for mfvec in self._mflist: + if len(mfvec): + mfarr = interp_imvec(mfvec, specind=True) + mfvec_out = mfarr.flatten() + else: + mfvec_out = np.array([]) + mflist_out.append(mfvec_out) + outim._mflist = mflist_out + + return outim + + def rotate(self, angle, interp='cubic'): + """Rotate the image counterclockwise by the specified angle. + + Args: + angle (float): CCW angle to rotate the image (radian) + interp ('linear', 'cubic', 'quintic'): type of interpolation. default is cubic + Returns: + (Image): resampled image + """ + + order = 3 + if interp == 'linear': + order = 1 + elif interp == 'cubic': + order = 3 + elif interp == 'quintic': + order = 5 + + # Define an interpolation function + def rot_imvec(imvec): + if np.any(np.imag(imvec) != 0): + return rot_imvec(np.real(imvec)) + 1j * rot_imvec(np.imag(imvec)) + imarr_rot = scipy.ndimage.interpolation.rotate(imvec.reshape((self.ydim, self.xdim)), + angle * 180.0 / np.pi, reshape=False, + order=order, mode='constant', + cval=0.0, prefilter=True) + + return imarr_rot + + # pol_prim needs to be RR,LL,I,or V for a simple rotation to work! + if(not (self.pol_prim in ['RR', 'LL', 'I', 'V'])): + raise Exception("im.pol_prim must be a scalar ('I','V','RR','LL') for simple rotation!") + + # Make new image + imarr_rot = rot_imvec(self.imvec) + + arglist, argdict = self.image_args() + arglist[0] = imarr_rot + outim = Image(*arglist, **argdict) + + # Rotate all polarizations and copy over + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + polvec = self._imdict[pol] + if len(polvec): + polarr_rot = rot_imvec(polvec) + if pol == 'RL': + polarr_rot *= np.exp(1j * 2 * angle) + elif pol == 'LR': + polarr_rot *= np.exp(-1j * 2 * angle) + elif pol == 'Q': + polarr_rot = polarr_rot + 1j * rot_imvec(self._imdict['U']) + polarr_rot = np.real(np.exp(1j * 2 * angle) * polarr_rot) + elif pol == 'U': + polarr_rot = rot_imvec(self._imdict['Q']) + 1j * polarr_rot + polarr_rot = np.imag(np.exp(1j * 2 * angle) * polarr_rot) + + outim.add_pol_image(polarr_rot, pol) + + # Rotate spectral index and copy over + mflist_out = [] + for mfvec in self._mflist: + if len(mfvec): + mfarr = rot_imvec(mfvec) + mfvec_out = mfarr.flatten() + else: + mfvec_out = np.array([]) + mflist_out.append(mfvec_out) + outim._mflist = mflist_out + + return outim + + def shift(self, shiftidx): + """Shift the image by a given number of pixels. + + Args: + shiftidx (list): pixel offsets [x_offset, y_offset] for the image shift + + Returns: + (Image): shifted images + """ + + # Define shifting function + def shift_imvec(imvec): + im_shift = np.roll(imvec.reshape(self.ydim, self.xdim), shiftidx[0], axis=0) + im_shift = np.roll(im_shift, shiftidx[1], axis=1) + return im_shift + + # Make new image + imarr_shift = shift_imvec(self.imvec) + + arglist, argdict = self.image_args() + arglist[0] = imarr_shift + outim = Image(*arglist, **argdict) + + # Shift all polarizations and copy over + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + polvec = self._imdict[pol] + if len(polvec): + polarr_shift = shift_imvec(polvec) + outim.add_pol_image(polarr_shift, pol) + + # Shift spectral index and copy over + mflist_out = [] + for mfvec in self._mflist: + if len(mfvec): + mfarr = shift_imvec(mfvec) + mfvec_out = mfarr.flatten() + else: + mfvec_out = np.array([]) + mflist_out.append(mfvec_out) + outim._mflist = mflist_out + + return outim + + def shift_fft(self, shift): + """Shift the image by a given vector in radians. + This allows non-integer pixel shifts, via FFT. + + Args: + shift (list): offsets [x_offset, y_offset] for the image shift in radians + + Returns: + (Image): shifted image + """ + + Nx = self.xdim + Ny = self.ydim + + [dx_pixels, dy_pixels] = np.array(shift) / self.psize + + s, t = np.meshgrid(np.fft.fftfreq(Nx, d=1.0 / Nx), np.fft.fftfreq(Ny, d=1.0 / Ny)) + rotate = np.exp(2.0 * np.pi * 1j * (s * dx_pixels + t * dy_pixels) / float(Nx)) + + imarr = self.imvec.reshape((Ny, Nx)) + imarr_rotate = np.real(np.fft.ifft2(np.fft.fft2(imarr) * rotate)) + + # make new Image + arglist, argdict = self.image_args() + arglist[0] = imarr_rotate + outim = Image(*arglist, **argdict) + + # Shift all polarizations and copy over + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + polvec = self._imdict[pol] + if len(polvec): + imarr = polvec.reshape((Ny, Nx)) + imarr_rotate = np.real(np.fft.ifft2(np.fft.fft2(imarr) * rotate)) + outim.add_pol_image(imarr_rotate, pol) + + # Shift spectral index and copy over + mflist_out = [] + for mfvec in self._mflist: + if len(mfvec): + mfarr = mfvec.reshape((Ny, Nx)) + mfarr = np.real(np.fft.ifft2(np.fft.fft2(mfarr) * rotate)) + mfvec_out = mfarr.flatten() + else: + mfvec_out = np.array([]) + mflist_out.append(mfvec_out) + outim._mflist = mflist_out + + return outim + + def blur_gauss(self, beamparams, frac=1., frac_pol=0): + """Blur image with a Gaussian beam w/ beamparams [fwhm_max, fwhm_min, theta] in radians. + + Args: + beamparams (list): [fwhm_maj, fwhm_min, theta, x, y] in radians + frac (float): fractional beam size for blurring the main image + frac_pol (float): fractional beam size for blurring the other polarizations + + Returns: + (Image): output image + """ + + if frac <= 0.0 or beamparams[0] <= 0: + return self.copy() + + # Make a Gaussian image + xlist = np.arange(0, -self.xdim, -1) * self.psize + \ + (self.psize * self.xdim) / 2.0 - self.psize / 2.0 + ylist = np.arange(0, -self.ydim, -1) * self.psize + \ + (self.psize * self.ydim) / 2.0 - self.psize / 2.0 + sigma_maj = beamparams[0] / (2. * np.sqrt(2. * np.log(2.))) + sigma_min = beamparams[1] / (2. * np.sqrt(2. * np.log(2.))) + cth = np.cos(beamparams[2]) + sth = np.sin(beamparams[2]) + + def gaussim(blurfrac): + gauss = np.array([[np.exp(-(j * cth + i * sth)**2 / (2 * (blurfrac * sigma_maj)**2) - + (i * cth - j * sth)**2 / (2 * (blurfrac * sigma_min)**2)) + for i in xlist] + for j in ylist]) + gauss = gauss[0:self.ydim, 0:self.xdim] + gauss = gauss / np.sum(gauss) # normalize to 1 + return gauss + + gauss = gaussim(frac) + if frac_pol: + gausspol = gaussim(frac_pol) + + # Define a convolution function + def blur(imarr, gauss): + imarr_blur = scipy.signal.fftconvolve(gauss, imarr, mode='same') + return imarr_blur + + # Convolve the primary image + imarr = (self.imvec).reshape(self.ydim, self.xdim).astype('float64') + imarr_blur = blur(imarr, gauss) + + # Make new image object + arglist, argdict = self.image_args() + arglist[0] = imarr_blur + outim = Image(*arglist, **argdict) + + # Blur all polarizations and copy over + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + polvec = self._imdict[pol] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim).astype('float64') + if frac_pol: + polarr = blur(polarr, gausspol) + outim.add_pol_image(polarr, pol) + + # Blur spectral index and copy over + mflist_out = [] + for mfvec in self._mflist: + if len(mfvec): + mfarr = mfvec.reshape(self.ydim, self.xdim).astype('float64') + mfarr = blur(mfarr, gauss) + mfvec_out = mfarr.flatten() + else: + mfvec_out = np.array([]) + mflist_out.append(mfvec_out) + outim._mflist = mflist_out + + return outim + + def blur_circ(self, fwhm_i, fwhm_pol=0, filttype='gauss'): + """Apply a circular gaussian filter to the image, with FWHM in radians. + + Args: + fwhm_i (float): circular beam size for Stokes I blurring in radian + fwhm_pol (float): circular beam size for Stokes Q,U,V blurring in radian + filttype (str): "gauss" or "butter" + + Returns: + (Image): output image + """ + + sigma = fwhm_i / (2. * np.sqrt(2. * np.log(2.))) + sigmap = sigma / self.psize + fwhmp = fwhm_i / self.psize + fwhmp_pol = fwhm_pol / self.psize + + # Define a convolution function + def blur_gauss(imarr, fwhm): + sigma = fwhmp / (2. * np.sqrt(2. * np.log(2.))) + if np.any(np.imag(imarr) != 0): + return blur(np.real(imarr), sigma) + 1j * blur(np.imag(imarr), sigma) + imarr_blur = filt.gaussian_filter(imarr, (sigma, sigma)) + return imarr_blur + + def blur_butter(imarr, size): + + #bfilt = scipy.signal.butter(2,freq,btype='low',output='sos') + #if np.any(np.imag(imarr) != 0): + # return blur(np.real(imarr), sigma) + 1j * blur(np.imag(imarr), sigma) + + #imarr_blur = scipy.signal.sosfilt(bfilt, imarr, axis=0) + #imarr_blur = scipy.signal.sosfilt(bfilt, imarr_blur, axis=1) + + if size==0: + return imarr + + cutoff = 1/size + Nx = self.xdim + Ny = self.ydim + + s, t = np.meshgrid(np.fft.fftfreq(Nx, d=1.0 ), np.fft.fftfreq(Ny, d=1.0 )) + #s, t = np.meshgrid(np.fft.fftfreq(Nx, d=1.0 / Nx), np.fft.fftfreq(Ny, d=1.0 / Ny)) + r = np.sqrt(s**2 + t**2) + + bfilt = 1./np.sqrt(1 + (r/cutoff)**4) + + imarr = self.imvec.reshape((Ny, Nx)) + imarr_filt = np.real(np.fft.ifft2(np.fft.fft2(imarr) * bfilt)) + return imarr_filt + + + if filttype=='gauss': + blur = blur_gauss + elif filttype=='butter': + blur = blur_butter + else: + raise Exception("filttype not recognized in blur_circ!") + + # Blur the primary image + imarr = self.imvec.reshape(self.ydim, self.xdim) + imarr_blur = blur(imarr, fwhmp) + + arglist, argdict = self.image_args() + arglist[0] = imarr_blur + outim = Image(*arglist, **argdict) + + # Blur spectral index and copy over + mflist_out = [] + for mfvec in self._mflist: + if len(mfvec): + mfarr = mfvec.reshape(self.ydim, self.xdim) + mfarr = blur(mfarr, fwhmp) + mfvec_out = mfarr.flatten() + else: + mfvec_out = np.array([]) + mflist_out.append(mfvec_out) + outim._mflist = mflist_out + + # Blur all polarizations and copy overi + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + polvec = self._imdict[pol] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim) + if fwhm_pol: + #print("Blurring polarization") + polarr = blur(polarr, fwhmp_pol) + outim.add_pol_image(polarr, pol) + + + return outim + + def blur_mf(self, freqs, fwhm, fit_order=1, filttype='gauss'): + """Blur image correctly across multiple frequencies + WARNING: does not currently do polarization correctly! + + Args: + freqs (float): Frequencies to include in the blurring & spectral index fit + fwhm (float): circular beam size + fit_order (int): how many orders to fit spectrum: 1 or 2 + filttype (str): "gauss" or "butter" + + Returns: + (Image): output image + + """ + if fit_order not in [1,2]: + raise Exception("fit_order must be 1 or 2 in blur_mf!") + + reffreq = self.rf + + # remove any zeros in the images + imlist = [self.get_image_mf(rf).blur_circ(kernel, filttype=filttype) for rf in freqs] + for image in imlist: + image.imvec[image.imvec<=0] = np.min(image.imvec[image.imvec!=0]) + + xfit = np.log(np.array(freqs)/reffreq) + yfit = np.log(np.array([im.imvec for im in imlist])) + + if fit_order == 2: + coeffs = np.polyfit(xfit,yfit,2) + beta = coeffs[0] + alpha = coeffs[1] + elif fit_order == 1: + coeffs = np.polyfit(xfit,yfit,1) + alpha = coeffs[0] + beta = 0*alpha + else: + alpha = 0*yfit + beta = 0*yfit + + outim = self.blur_circ(kernel, filttype=filttype) + outim.specvec = alpha + outim.curvvec = beta + return outim + + def grad(self, gradtype='abs'): + """Return the gradient image + + Args: + gradtype (str): 'x','y',or 'abs' for the image gradient dimension + + Returns: + Image : an image object containing the gradient image + """ + + # Define the desired gradient function + def gradim(imvec): + if np.any(np.imag(imvec) != 0): + return gradim(np.real(imvec)) + 1j * gradim(np.imag(imvec)) + + imarr = imvec.reshape(self.ydim, self.xdim) + + #sx = ndi.sobel(imarr, axis=0, mode='constant') + #sy = ndi.sobel(imarr, axis=1, mode='constant') + sx = ndi.sobel(imarr, axis=0, mode='nearest') + sy = ndi.sobel(imarr, axis=1, mode='nearest') + + # TODO: are these in the right order?? + if gradtype == 'x': + gradarr = sx + if gradtype == 'y': + gradarr = sy + else: + gradarr = np.hypot(sx, sy) + return gradarr + + # Find the gradient for the primary image + gradarr = gradim(self.imvec) + + arglist, argdict = self.image_args() + arglist[0] = gradarr + outim = Image(*arglist, **argdict) + + # Find the gradient for all polarizations and copy over + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + polvec = self._imdict[pol] + if len(polvec): + gradarr = gradim(polvec) + outim.add_pol_image(gradarr, pol) + + # Find the spectral index gradients and copy over + mflist_out = [] + for mfvec in self._mflist: + if len(mfvec): + mfarr = gradim(mfvec) + mfvec_out = mfarr.flatten() + else: + mfvec_out = np.array([]) + mflist_out.append(mfvec_out) + outim._mflist = mflist_out + + return outim + + def mask(self, cutoff=0.05, beamparams=None, frac=0.0): + """Produce an image mask that shows all pixels above the specified cutoff frac of the max + Works off the primary image + + Args: + cutoff (float): mask pixels with intensities greater than cuttoff * max + beamparams (list): either [fwhm_maj, fwhm_min, pos_ang] or a single fwhm + frac (float): the fraction of nominal beam to blur with + + Returns: + (Image): output mask image + + """ + + # Blur the image + if beamparams is not None: + try: + len(beamparams) + except TypeError: + beamparams = [beamparams, beamparams, 0] + if len(beamparams) == 3: + mask = self.blur_gauss(beamparams, frac) + else: + raise Exception("beamparams should be a length 3 array [maj, min, posang]!") + else: + mask = self.copy() + + # Mask pixels outside the desired intensity range + maxval = np.max(mask.imvec) + minval = np.min(mask.imvec) + intensityrange = maxval - minval + thresh = intensityrange * cutoff + minval + maskvec = (mask.imvec > thresh).astype(int) + + # make the primary image + maskarr = maskvec.reshape(mask.ydim, mask.xdim) + + arglist, argdict = self.image_args() + arglist[0] = maskarr + mask = Image(*arglist, **argdict) + + # Replace all polarization imvecs with mask + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + mask.add_pol_image(maskarr, pol) + + # No spectral index information in mask + + return mask + + # TODO make this work with a mask image of different dimensions & fov + def apply_mask(self, mask_im, fill_val=0.): + """Apply a mask to the image + + Args: + mask_im (Image): a mask image with the same dimensions as the Image + fill_val (float): masked pixels of all polarizations are set to this value + + Returns: + (Image): the masked image + + """ + if ((self.psize != mask_im.psize) or + (self.xdim != mask_im.xdim) or (self.ydim != mask_im.ydim)): + raise Exception("mask image does not match dimensions of the current image!") + + # Get the mask vector + maskvec = mask_im.imvec.astype(bool) + maskvec[maskvec <= 0] = 0 + maskvec[maskvec > 0] = 1 + + # Mask the primary image + imvec = self.imvec + imvec[~maskvec] = fill_val + imarr = imvec.reshape(self.ydim, self.xdim) + + arglist, argdict = self.image_args() + arglist[0] = imarr + outim = Image(*arglist, **argdict) + + # Apply mask to all polarizations and copy over + for pol in list(self._imdict.keys()): + if pol == self.pol_prim: + continue + polvec = self._imdict[pol] + if len(polvec): + polvec[~maskvec] = fill_val + polarr = polvec.reshape(self.ydim, self.xdim) + outim.add_pol_image(polarr, pol) + + # Apply mask to spectral index and copy over + mflist_out = [] + for mfvec in self._mflist: + if len(mfvec): + mfvec_out = copy.deepcopy(mfvec) + mfvec_out[~maskvec] = 0. + else: + mfvec_out = np.array([]) + mflist_out.append(mfvec_out) + outim._mflist = mflist_out + + return outim + + def threshold(self, cutoff=0.05, beamparams=None, frac=0.0, fill_val=None): + """Apply a hard threshold to the primary polarization image. + Leave other polarizations untouched. + + Args: + cutoff (float): Mask pixels with intensities greater than cuttoff * max + beamparams (list): either [fwhm_maj, fwhm_min, pos_ang] or a single fwhm + frac (float): the fraction of nominal beam to blur with + fill_val (float): masked pixels are set to this value. + If fill_val==None, they are set to the min unmasked value + + Returns: + (Image): output mask image + """ + + if fill_val is None or fill_val is False: + maxval = np.max(self.imvec) + minval = np.min(self.imvec) + intensityrange = maxval - minval + fill_val = (intensityrange * cutoff + minval) + + mask = self.mask(cutoff=cutoff, beamparams=beamparams, frac=frac) + out = self.apply_mask(mask, fill_val=fill_val) + return out + + def add_flat(self, flux, pol=None): + """Add a flat background flux to the main polarization image. + + Args: + flux (float): total flux to add to image + pol (str): the polarization to add the flux to. None defaults to pol_prim. + Returns: + (Image): output image + """ + + if pol is None: + pol = self.pol_prim + if not (pol in list(self._imdict.keys())): + raise Exception("for polrep==%s, pol must be in " % + self.polrep + ",".join(list(self._imdict.keys()))) + if not len(self._imdict[pol]): + raise Exception("no image for pol %s" % pol) + + # Make a flat image array + flatarr = ((flux / float(len(self.imvec))) * np.ones(len(self.imvec))) + flatarr = flatarr.reshape(self.ydim, self.xdim) + + # Add to the main image and create the new image object + imarr = self.imvec.reshape(self.ydim, self.xdim).copy() + if pol == self.pol_prim: + imarr += flatarr + + arglist, argdict = self.image_args() + arglist[0] = imarr + outim = Image(*arglist, **argdict) + + # Copy over the rest of the polarizations + for pol2 in list(self._imdict.keys()): + if pol2 == self.pol_prim: + continue + polvec = self._imdict[pol2] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim).copy() + if pol2 == pol: + polarr += flatarr + outim.add_pol_image(polarr, pol2) + + # Copy the spectral index (unchanged) + outim._mflist = copy.deepcopy(self._mflist) + + return outim + + def add_tophat(self, flux, radius, pol=None): + """Add centered tophat flux to the Stokes I image inside a given radius. + + Args: + flux (float): total flux to add to image + radius (float): radius of top hat flux in radians + pol (str): the polarization to add the flux to. None defaults to pol_prim + + Returns: + (Image): output image + """ + + if pol is None: + pol = self.pol_prim + if not (pol in list(self._imdict.keys())): + raise Exception("for polrep==%s, pol must be in " % + self.polrep + ",".join(list(self._imdict.keys()))) + if not len(self._imdict[pol]): + raise Exception("no image for pol %s" % pol) + + # Make a tophat image array + xlist = np.arange(0, -self.xdim, -1) * self.psize + \ + (self.psize * self.xdim) / 2.0 - self.psize / 2.0 + ylist = np.arange(0, -self.ydim, -1) * self.psize + \ + (self.psize * self.ydim) / 2.0 - self.psize / 2.0 + + hatarr = np.array([[1.0 if np.sqrt(i**2 + j**2) <= radius else 0. + for i in xlist] + for j in ylist]) + + hatarr = hatarr[0:self.ydim, 0:self.xdim] + hatarr *= flux / np.sum(hatarr) + + # Add to the main image and create the new image object + imarr = self.imvec.reshape(self.ydim, self.xdim).copy() + if pol == self.pol_prim: + imarr += hatarr + + arglist, argdict = self.image_args() + arglist[0] = imarr + outim = Image(*arglist, **argdict) + + # Copy over the rest of the polarizations + for pol2 in list(self._imdict.keys()): + if pol2 == self.pol_prim: + continue + polvec = self._imdict[pol2] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim).copy() + if pol2 == pol: + polarr += hatarr + outim.add_pol_image(polarr, pol2) + + # Copy the spectral index (unchanged) + outim._mflist = copy.deepcopy(self._mflist) + + return outim + + def add_gauss(self, flux, beamparams, pol=None): + """Add a gaussian to an image. + + Args: + flux (float): the total flux contained in the Gaussian in Jy + beamparams (list): [fwhm_maj, fwhm_min, theta, x, y], all in radians + pol (str): the polarization to add the flux to. None defaults to pol_prim. + + Returns: + (Image): output image + """ + + if pol is None: + pol = self.pol_prim + if not (pol in list(self._imdict.keys())): + raise Exception("for polrep==%s, pol must be in " % + self.polrep + ",".join(list(self._imdict.keys()))) + if not len(self._imdict[pol]): + raise Exception("no image for pol %s" % pol) + + # Make a Gaussian image + try: + x = beamparams[3] + y = beamparams[4] + except IndexError: + x = y = 0.0 + + sigma_maj = beamparams[0] / (2. * np.sqrt(2. * np.log(2.))) + sigma_min = beamparams[1] / (2. * np.sqrt(2. * np.log(2.))) + cth = np.cos(beamparams[2]) + sth = np.sin(beamparams[2]) + xlist = np.arange(0, -self.xdim, -1) * self.psize + \ + (self.psize * self.xdim) / 2.0 - self.psize / 2.0 + ylist = np.arange(0, -self.ydim, -1) * self.psize + \ + (self.psize * self.ydim) / 2.0 - self.psize / 2.0 + + def gaussian(x2, y2): + gauss = np.exp(-((y2) * cth + (x2) * sth)**2 / (2 * sigma_maj**2) + + -((x2) * cth - (y2) * sth)**2 / (2 * sigma_min**2)) + return gauss + + gaussarr = np.array([[gaussian(i - x, j - y) for i in xlist] for j in ylist]) + gaussarr = gaussarr[0:self.ydim, 0:self.xdim] + gaussarr *= flux / np.sum(gaussarr) + + # TODO: if we want to add a gaussian to V, we might also want to make sure we add it to I + # Add to the main image and create the new image object + imarr = self.imvec.reshape(self.ydim, self.xdim).copy() + if pol == self.pol_prim: + imarr += gaussarr + + arglist, argdict = self.image_args() + arglist[0] = imarr + outim = Image(*arglist, **argdict) + + # Copy over the rest of the polarizations + for pol2 in list(self._imdict.keys()): + if pol2 == self.pol_prim: + continue + polvec = self._imdict[pol2] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim).copy() + if pol2 == pol: + polarr += gaussarr + outim.add_pol_image(polarr, pol2) + + # Copy the spectral index (unchanged) + outim._mflist = copy.deepcopy(self._mflist) + + return outim + + def add_crescent(self, flux, Rp, Rn, a, b, x=0, y=0, pol=None): + """Add a crescent to an image; see Kamruddin & Dexter (2013). + + Args: + flux (float): the total flux contained in the crescent in Jy + Rp (float): the larger radius in radians + Rn (float): the smaller radius in radians + a (float): the relative x offset of smaller disk in radians + b (float): the relative y offset of smaller disk in radians + x (float): the center x coordinate of the larger disk in radians + y (float): the center y coordinate of the larger disk in radians + pol (str): the polarization to add the flux to. None defaults to pol_prim. + + Returns: + (Image): output image add_gaus + """ + + if pol is None: + pol = self.pol_prim + if not (pol in list(self._imdict.keys())): + raise Exception("for polrep==%s, pol must be in " % + self.polrep + ",".join(list(self._imdict.keys()))) + if not len(self._imdict[pol]): + raise Exception("no image for pol %s" % pol) + + # Make a crescent image + xlist = np.arange(0, -self.xdim, -1) * self.psize + \ + (self.psize * self.xdim) / 2.0 - self.psize / 2.0 + ylist = np.arange(0, -self.ydim, -1) * self.psize + \ + (self.psize * self.ydim) / 2.0 - self.psize / 2.0 + + def crescent(x2, y2): + if (x2 - a)**2 + (y2 - b)**2 > Rn**2 and x2**2 + y2**2 < Rp**2: + return 1.0 + else: + return 0.0 + + crescarr = np.array([[crescent(i - x, j - y) for i in xlist] for j in ylist]) + crescarr = crescarr[0:self.ydim, 0:self.xdim] + crescarr *= flux / np.sum(crescarr) + + # Add to the main image and create the new image object + imarr = self.imvec.reshape(self.ydim, self.xdim).copy() + if pol == self.pol_prim: + imarr += crescarr + + arglist, argdict = self.image_args() + arglist[0] = imarr + outim = Image(*arglist, **argdict) + + # Copy over the rest of the polarizations + for pol2 in list(self._imdict.keys()): + if pol2 == self.pol_prim: + continue + polvec = self._imdict[pol2] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim).copy() + if pol2 == pol: + polarr += crescarr + outim.add_pol_image(polarr, pol2) + + # Copy the spectral index (unchanged) + outim._mflist = copy.deepcopy(self._mflist) + + return outim + + def add_ring_m1(self, I0, I1, r0, phi, sigma, x=0, y=0, pol=None): + """Add a ring to an image with an m=1 mode + + Args: + I0 (float): + I1 (float): + r0 (float): the radius + phi (float): angle of m1 mode + sigma (float): the blurring size + x (float): the center x coordinate of the larger disk in radians + y (float): the center y coordinate of the larger disk in radians + pol (str): the polarization to add the flux to. None defaults to pol_prim. + Returns: + (Image): output image add_gaus + """ + + if pol is None: + pol = self.pol_prim + if not (pol in list(self._imdict.keys())): + raise Exception("for polrep==%s, pol must be in " % + self.polrep + ",".join(list(self._imdict.keys()))) + if not len(self._imdict[pol]): + raise Exception("no image for pol %s" % pol) + + # Make a ring image + flux = I0 - 0.5 * I1 + phi = phi + np.pi + psize = self.psize + xlist = np.arange(0, -self.xdim, -1) * self.psize + \ + (self.psize * self.xdim) / 2.0 - self.psize / 2.0 + ylist = np.arange(0, -self.ydim, -1) * self.psize + \ + (self.psize * self.ydim) / 2.0 - self.psize / 2.0 + + def ringm1(x2, y2): + if (x2**2 + y2**2) > (r0 - psize)**2 and (x2**2 + y2**2) < (r0 + psize)**2: + theta = np.arctan2(y2, x2) + flux = (I0 - 0.5 * I1 * (1 + np.cos(theta - phi))) / (2 * np.pi * r0) + return flux + else: + return 0.0 + + ringarr = np.array([[ringm1(i - x, j - y) + for i in xlist] + for j in ylist]) + ringarr = ringarr[0:self.ydim, 0:self.xdim] + + arglist, argdict = self.image_args() + arglist[0] = ringarr + outim = Image(*arglist, **argdict) + + outim = outim.blur_circ(sigma) + outim.imvec *= flux / (outim.total_flux()) + ringarr = outim.imvec.reshape(self.ydim, self.xdim) + + # Add to the main image and create the new image object + imarr = self.imvec.reshape(self.ydim, self.xdim).copy() + if pol == self.pol_prim: + imarr += ringarr + + arglist[0] = imarr + outim = Image(*arglist, **argdict) + + # Copy over the rest of the polarizations + for pol2 in list(self._imdict.keys()): + if pol2 == self.pol_prim: + continue + polvec = self._imdict[pol2] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim).copy() + if pol2 == pol: + polarr += ringarr + outim.add_pol_image(polarr, pol2) + + # Copy the spectral index (unchanged) + outim._mflist = copy.deepcopy(self._mflist) + + return outim + + def add_const_pol(self, mag, angle, cmag=0, csign=1): + """Return an with constant fractional linear and circular polarization + + Args: + mag (float): constant polarization fraction to add to the image + angle (float): constant EVPA + cmag (float): constant circular polarization fraction to add to the image + cmag (int): constant circular polarization sign +/- 1 + + Returns: + (Image): output image + """ + + if not (0 <= mag < 1): + raise Exception("fractional polarization magnitude must be between 0 and 1!") + + if not (0 <= cmag < 1): + raise Exception("circular polarization magnitude must be between 0 and 1!") + + if self.polrep == 'stokes': + im_stokes = self + elif self.polrep == 'circ': + im_stokes = self.switch_polrep(polrep_out='stokes') + ivec = im_stokes.ivec.copy() + qvec = obsh.qimage(ivec, mag * np.ones(len(ivec)), angle * np.ones(len(ivec))) + uvec = obsh.uimage(ivec, mag * np.ones(len(ivec)), angle * np.ones(len(ivec))) + vvec = cmag * np.sign(csign) * ivec + + # create the new stokes image object + iarr = ivec.reshape(self.ydim, self.xdim).copy() + + arglist, argdict = self.image_args() + arglist[0] = iarr + argdict['polrep'] = 'stokes' + argdict['pol_prim'] = 'I' + outim = Image(*arglist, **argdict) + + # Copy over the rest of the polarizations + imdict = {'I': ivec, 'Q': qvec, 'U': uvec, 'V': vvec} + for pol in list(imdict.keys()): + if pol == 'I': + continue + polvec = imdict[pol] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim).copy() + outim.add_pol_image(polarr, pol) + + # Copy the spectral index (unchanged) + outim._mflist = copy.deepcopy(self._mflist) + + return outim + + def add_random_pol(self, mag, corr, cmag=0., ccorr=0., seed=0): + """Return an image random linear and circular polarizations with certain correlation lengths + + Args: + mag (float): linear polarization fraction + corr (float): EVPA correlation length (radians) + cmag (float): circular polarization fraction + ccorr (float): CP correlation length (radians) + seed (int): Seed for random number generation + + Returns: + (Image): output image + """ + + import ehtim.scattering.stochastic_optics as so + + if not (0 <= mag < 1): + raise Exception("fractional polarization magnitude must be between 0 and 1!") + + if not (0 <= cmag < 1): + raise Exception("circular polarization magnitude must be between 0 and 1!") + + if self.polrep == 'stokes': + im_stokes = self + elif self.polrep == 'circ': + im_stokes = self.switch_polrep(polrep_out='stokes') + ivec = im_stokes.ivec.copy() + + # create the new stokes image object + iarr = ivec.reshape(self.ydim, self.xdim).copy() + + arglist, argdict = self.image_args() + arglist[0] = iarr + argdict['polrep'] = 'stokes' + argdict['pol_prim'] = 'I' + outim = Image(*arglist, **argdict) + + # Make a random phase screen using the scattering tools + # Use this screen to define the EVPA + dist = 1.0 * 3.086e21 + rdiff = np.abs(corr) * dist / 1e3 + theta_mas = 0.37 * 1.0 / rdiff * 1000. * 3600. * 180. / np.pi + sm = so.ScatteringModel(scatt_alpha=1.67, observer_screen_distance=dist, + source_screen_distance=1.e5 * dist, + theta_maj_mas_ref=theta_mas, theta_min_mas_ref=theta_mas, + r_in=rdiff * 2, r_out=1e30) + ep = so.MakeEpsilonScreen(self.xdim, self.ydim, rngseed=seed) + ps = np.array(sm.MakePhaseScreen(ep, outim, obs_frequency_Hz=29.979e9).imvec) + ps = ps / 1000**(1.66 / 2) + qvec = ivec * mag * np.sin(ps) + uvec = ivec * mag * np.cos(ps) + + # Make a random phase screen using the scattering tools + # Use this screen to define the CP magnitude + if cmag != 0.0 and ccorr > 0.0: + dist = 1.0 * 3.086e21 + rdiff = np.abs(ccorr) * dist / 1e3 + theta_mas = 0.37 * 1.0 / rdiff * 1000. * 3600. * 180. / np.pi + sm = so.ScatteringModel(scatt_alpha=1.67, observer_screen_distance=dist, + source_screen_distance=1.e5 * dist, + theta_maj_mas_ref=theta_mas, theta_min_mas_ref=theta_mas, + r_in=rdiff * 2, r_out=1e30) + ep = so.MakeEpsilonScreen(self.xdim, self.ydim, rngseed=seed * 2) + ps = np.array(sm.MakePhaseScreen(ep, outim, obs_frequency_Hz=29.979e9).imvec) + ps = ps / 1000**(1.66 / 2) + vvec = ivec * cmag * np.sin(ps) + else: + vvec = ivec * cmag + + # Copy over the rest of the polarizations + imdict = {'I': ivec, 'Q': qvec, 'U': uvec, 'V': vvec} + for pol in list(imdict.keys()): + if pol == 'I': + continue + polvec = imdict[pol] + if len(polvec): + polarr = polvec.reshape(self.ydim, self.xdim).copy() + outim.add_pol_image(polarr, pol) + + # Copy the spectral index (unchanged) + outim._mflist = copy.deepcopy(self._mflist) + + return outim + + def add_const_mf(self, alpha, beta=0.): + """Add a constant spectral index and curvature term + + Args: + alpha (float): spectral index (with no - sign) + beta (float): curvature + + Returns: + (Image): output image with constant mf information added + """ + + avec = alpha * np.ones(len(self.imvec)) + bvec = beta * np.ones(len(self.imvec)) + + # create the new image object + outim = self.copy() + outim._mflist = [avec, bvec] + + return outim + + def add_zblterm(self, obs, uv_min, zblval=None, new_fov=False, + gauss_sz=False, gauss_sz_factor=0.75, debias=True): + """Add a large Gaussian term to account for missing flux in the zero baseline. + + Args: + obs : an Obsdata object to determine min non-zero baseline and 0-bl flux + uv_min (float): The cutoff in Glambada used to determine what is a 0-bl + new_fov (rad): The size of the padded image once the Gaussian is added + (if False it will be set to 3 x the gaussian fwhm) + gauss_sz (rad): The size of the Gaussian added to add flux to the 0-bl. + (if False it is computed from the min non-zero baseline) + gauss_sz_factor (float): The fraction of the min non-zero baseline + used to caluclate the Gaussian FWHM. + debias (bool): True if you use debiased amplitudes to caluclate the 0-bl flux in Jy + + Returns: + (Image): a padded image with a large Gaussian component + """ + + if gauss_sz is False: + obs_flag = obs.flag_uvdist(uv_min=uv_min) + minuvdist = np.min(np.sqrt(obs_flag.data['u']**2 + obs_flag.data['v']**2)) + gauss_sz_sigma = (1 / (gauss_sz_factor * minuvdist)) + gauss_sz = gauss_sz_sigma * 2.355 # convert from stdev to fwhm + + factor = 5.0 + if new_fov is False: + im_fov = np.max((self.xdim * self.psize, self.ydim * self.psize)) + new_fov = np.max((factor * (gauss_sz / 2.355), im_fov)) + + if new_fov < factor * (gauss_sz / 2.355): + print('WARNING! The specified new fov may not be large enough') + + # calculate the amount of flux to include in the Gaussian + obs_zerobl = obs.flag_uvdist(uv_max=uv_min) + obs_zerobl.add_amp(debias=debias) + orig_totflux = np.sum(obs_zerobl.amp['amp'] * (1 / obs_zerobl.amp['sigma']**2)) + orig_totflux /= np.sum(1 / obs_zerobl.amp['sigma']**2) + + if zblval is None: + addedflux = orig_totflux - np.sum(self.imvec) + else: + addedflux = orig_totflux - zblval + + print('Adding a ' + str(addedflux) + ' Jy circular Gaussian of FWHM size ' + + str(gauss_sz / ehc.RADPERUAS) + ' uas') + + im_new = self.copy() + im_new = im_new.pad(new_fov, new_fov) + im_new = im_new.add_gauss(addedflux, (gauss_sz, gauss_sz, 0, 0, 0)) + return im_new + + def sample_uv(self, uv, polrep_obs='stokes', + sgrscat=False, ttype='nfft', + cache=False, fft_pad_factor=2, + zero_empty_pol=True, verbose=True): + """Sample the image on the selected uv points without creating an Obsdata object. + + Args: + uv (ndarray): an array of uv points + polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel + + ttype (str): "fast" or "nfft" or "direct" + cache (bool): Use cached fft for 'fast' mode -- deprecated, use nfft instead! + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + zero_empty_pol (bool): if True, returns zero vec if the polarization doesn't exist. + Otherwise return None + verbose (bool): Boolean value controls output prints. + + Returns: + (list): a list of [I,Q,U,V] visibilities + """ + + if polrep_obs not in ['stokes', 'circ']: + raise Exception("polrep_obs must be either 'stokes' or 'circ'") + + data = simobs.sample_vis(self, uv, polrep_obs=polrep_obs, sgrscat=sgrscat, + ttype=ttype, cache=cache, fft_pad_factor=fft_pad_factor, + zero_empty_pol=zero_empty_pol, verbose=verbose) + return data + + def observe_same_nonoise(self, obs, sgrscat=False, ttype="nfft", + cache=False, fft_pad_factor=2, + zero_empty_pol=True, verbose=True, reorder=True): + """Observe the image on the same baselines as an existing observation without noise. + + Args: + obs (Obsdata): the existing observation + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel + ttype (str): "fast" or "nfft" or "direct" + cache (bool): Use cached fft for 'fast' mode -- deprecated, use nfft instead! + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + zero_empty_pol (bool): if True, returns zero vec if the polarization doesn't exist. + Otherwise return None + verbose (bool): Boolean value controls output prints. + + Returns: + (Obsdata): an observation object with no noise + """ + + # Check for agreement in coordinates and frequency + tolerance = 1e-8 + if (np.abs(self.ra - obs.ra) > tolerance) or (np.abs(self.dec - obs.dec) > tolerance): + raise Exception("Image coordinates are not the same as observtion coordinates!") + if (np.abs(self.rf - obs.rf) / obs.rf > tolerance): + raise Exception("Image frequency is not the same as observation frequency!") + + if (ttype == 'direct' or ttype == 'fast' or ttype == 'nfft' or ttype == 'DFT' or ttype == 'DFT_i'): + if verbose: print("Producing clean visibilities from image with " + ttype + " FT . . . ") + else: + raise Exception("ttype=%s, options for ttype are 'direct', 'fast', 'nfft', 'DFT', 'DFT_i'" % ttype) + + # Copy data to be safe + obsdata = copy.deepcopy(obs.data) + + # Extract uv datasample + uv = obsh.recarr_to_ndarr(obsdata[['u', 'v']], 'f8') + data = simobs.sample_vis(self, uv, sgrscat=sgrscat, polrep_obs=obs.polrep, + ttype=ttype, cache=cache, fft_pad_factor=fft_pad_factor, + zero_empty_pol=zero_empty_pol, verbose=verbose) + + # put visibilities into the obsdata + if obs.polrep == 'stokes': + obsdata['vis'] = data[0] + if not(data[1] is None): + obsdata['qvis'] = data[1] + obsdata['uvis'] = data[2] + obsdata['vvis'] = data[3] + + elif obs.polrep == 'circ': + obsdata['rrvis'] = data[0] + if not(data[1] is None): + obsdata['llvis'] = data[1] + if not(data[2] is None): + obsdata['rlvis'] = data[2] + obsdata['lrvis'] = data[3] + + obs_no_noise = ehtim.obsdata.Obsdata(self.ra, self.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=self.source, mjd=self.mjd, polrep=obs.polrep, + ampcal=True, phasecal=True, opacitycal=True, + dcal=True, frcal=True, + timetype=obs.timetype, scantable=obs.scans, reorder=reorder) + + return obs_no_noise + + def observe_same(self, obs_in, + ttype='nfft', fft_pad_factor=2, + sgrscat=False, add_th_noise=True, th_noise_factor=1, + jones=False, inv_jones=False, + opacitycal=True, ampcal=True, phasecal=True, + frcal=True, dcal=True, rlgaincal=True, + stabilize_scan_phase=False, stabilize_scan_amp=False, + neggains=False, + taup=ehc.GAINPDEF, + gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF, + phase_std=-1, + dterm_offset=ehc.DTERMPDEF, + rlratio_std=0., rlphase_std=0., + sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None, + caltable_path=None, seed=False, reorder=True, verbose=True): + """Observe the image on the same baselines as an existing observation object and add noise. + + Args: + obs_in (Obsdata): the existing observation + ttype (str): "fast" or "nfft" or "direct" + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel + add_th_noise (bool): if True, baseline-dependent thermal noise is added + + jones (bool): if True, uses Jones matrix to apply mis-calibration effects + inv_jones (bool): if True, applies estimated inverse Jones matrix + (not including random terms) to a priori calibrate data + + opacitycal (bool): if False, time-dependent gaussian errors are added to opacities + ampcal (bool): if False, time-dependent gaussian errors are added to station gains + phasecal (bool): if False, time-dependent station-based random phases are added + frcal (bool): if False, feed rotation angle terms are added to Jones matrices. + dcal (bool): if False, time-dependent gaussian errors added to D-terms. + rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + neggains (bool): if True, force the applied gains to be <1 + + taup (float): the fractional std. dev. of the random error on the opacities + gainp (float): the fractional std. dev. of the random error on the gains + or a dict giving one std. dev. per site + + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + phase_std (float): std. dev. of LCP phase, + or a dict giving one std. dev. per site + a negative value samples from uniform + dterm_offset (float): the base std. dev. of random additive error at all sites, + or a dict giving one std. dev. per site + + rlratio_std (float): the fractional std. dev. of the R/L gain offset + or a dict giving one std. dev. per site + rlphase_std (float): std. dev. of R/L phase offset, + or a dict giving one std. dev. per site + a negative value samples from uniform + + sigmat (float): temporal std for a Gaussian Process used to generate gains. + If sigmat=None then an iid gain noise is applied. + phasesigmat (float): temporal std for a Gaussian Process used to generate phases. + If phasesigmat=None then an iid gain noise is applied. + rlgsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L gain ratios. + If rlgsigmat=None then an iid gain noise is applied. + rlpsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L phase diff. + If rlpsigmat=None then an iid gain noise is applied. + + caltable_path (string): If not None, path and prefix for saving the applied caltable + seed (int): seeds the random component of the noise terms. DO NOT set to 0! + verbose (bool): print updates and warnings + Returns: + (Obsdata): an observation object + """ + + if seed: + np.random.seed(seed=seed) + + obs = self.observe_same_nonoise(obs_in, sgrscat=sgrscat,ttype=ttype, + cache=False, fft_pad_factor=fft_pad_factor, + zero_empty_pol=True, reorder=reorder, verbose=verbose) + + # Jones Matrix Corruption & Calibration + if jones: + obsdata = simobs.add_jones_and_noise(obs, add_th_noise=add_th_noise, + opacitycal=opacitycal, ampcal=ampcal, + phasecal=phasecal, frcal=frcal, dcal=dcal, + rlgaincal=rlgaincal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + neggains=neggains, + taup=taup, + gain_offset=gain_offset, gainp=gainp, + phase_std=phase_std, + dterm_offset=dterm_offset, + rlratio_std=rlratio_std, rlphase_std=rlphase_std, + sigmat=sigmat, phasesigmat=phasesigmat, + rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat, + caltable_path=caltable_path, seed=seed,verbose=verbose) + + obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, + ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal, + dcal=dcal, frcal=frcal, + timetype=obs.timetype, scantable=obs.scans, reorder=reorder) + + if inv_jones: + obsdata = simobs.apply_jones_inverse(obs, + opacitycal=opacitycal, dcal=dcal, frcal=frcal, + verbose=verbose) + + obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, + ampcal=ampcal, phasecal=phasecal, + opacitycal=True, dcal=True, frcal=True, + timetype=obs.timetype, scantable=obs.scans, reorder=reorder) + + # No Jones Matrices, Add noise the old way + # NOTE There is an asymmetry here - in the old way, we don't offer the ability to + # *not* unscale estimated noise. + else: + + if caltable_path: + print('WARNING: the caltable is only saved if you apply noise with a Jones Matrix') + + # TODO -- clean up arguments + obsdata = simobs.add_noise(obs, add_th_noise=add_th_noise, th_noise_factor=th_noise_factor, + opacitycal=opacitycal, ampcal=ampcal, phasecal=phasecal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + neggains=neggains, + taup=taup, + gain_offset=gain_offset, gainp=gainp, + sigmat=sigmat, + caltable_path=caltable_path, seed=seed, + verbose=verbose) + + obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, + ampcal=ampcal, phasecal=phasecal, + opacitycal=True, dcal=True, frcal=True, + timetype=obs.timetype, scantable=obs.scans, reorder=reorder) + + return obs + + def observe(self, array, tint, tadv, tstart, tstop, bw, + mjd=None, timetype='UTC', polrep_obs=None, + elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH, + no_elevcut_space=False, + ttype='nfft', fft_pad_factor=2, fix_theta_GMST=False, + sgrscat=False, add_th_noise=True, th_noise_factor=1, + jones=False, inv_jones=False, noise=True, + opacitycal=True, ampcal=True, phasecal=True, + frcal=True, dcal=True, rlgaincal=True, + stabilize_scan_phase=False, stabilize_scan_amp=False, + neggains=False, + tau=ehc.TAUDEF, taup=ehc.GAINPDEF, + gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF, + phase_std=-1, + dterm_offset=ehc.DTERMPDEF, + rlratio_std=0.,rlphase_std=0., + sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None, + caltable_path=None, seed=False, reorder=True, verbose=True): + """Generate baselines from an array object and observe the image. + + Args: + array (Array): an array object containing sites with which to generate baselines + tint (float): the scan integration time in seconds + tadv (float): the uniform cadence between scans in seconds + tstart (float): the start time of the observation in hours + tstop (float): the end time of the observation in hours + bw (float): the observing bandwidth in Hz + + mjd (int): the mjd of the observation, if set as different from the image mjd + timetype (str): how to interpret tstart and tstop; either 'GMST' or 'UTC' + polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation + elevmin (float): station minimum elevation in degrees + elevmax (float): station maximum elevation in degrees + no_elevcut_space (bool): if True, do not apply elevation cut to orbiters + + ttype (str): "fast", "nfft" or "dtft" + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in the FFT + fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v + + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel + add_th_noise (bool): if True, baseline-dependent thermal noise is added + + jones (bool): if True, uses Jones matrix to apply mis-calibration effects + otherwise uses old formalism without D-terms + inv_jones (bool): if True, applies estimated inverse Jones matrix + (not including random terms) to calibrate data + opacitycal (bool): if False, time-dependent gaussian errors are added to opacities + ampcal (bool): if False, time-dependent gaussian errors are added to station gains + phasecal (bool): if False, time-dependent station-based random phases are added + frcal (bool): if False, feed rotation angle terms are added to Jones matrix. + dcal (bool): if False, time-dependent gaussian errors added to Jones matrix D-terms. + rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + neggains (bool): if True, force the applied gains to be <1 + + tau (float): the base opacity at all sites, or a dict giving one opacity per site + taup (float): the fractional std. dev. of the random error on the opacities + gainp (float): the fractional std. dev. of the random error on the gains + or a dict giving one std. dev. per site + + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + phase_std (float): std. dev. of LCP phase, + or a dict giving one std. dev. per site + a negative value samples from uniform + dterm_offset (float): the base std. dev. of random additive error at all sites, + or a dict giving one std. dev. per site + + rlratio_std (float): the fractional std. dev. of the R/L gain offset + or a dict giving one std. dev. per site + rlphase_std (float): std. dev. of R/L phase offset, + or a dict giving one std. dev. per site + a negative value samples from uniform + + sigmat (float): temporal std for a Gaussian Process used to generate gains. + If sigmat=None then an iid gain noise is applied. + phasesigmat (float): temporal std for a Gaussian Process used to generate phases. + If phasesigmat=None then an iid gain noise is applied. + rlgsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L gain ratios. + If rlgsigmat=None then an iid gain noise is applied. + rlpsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L phase diff. + If rlpsigmat=None then an iid gain noise is applied. + + + caltable_path (string): If not None, path and prefix for saving the applied caltable + seed (int): seeds the random component of the noise terms. DO NOT set to 0! + + verbose (bool): print updates and warnings + + Returns: + (Obsdata): an observation object + """ + + # Generate empty observation + if verbose: print("Generating empty observation file . . . ") + + if mjd is None: + mjd = self.mjd + if polrep_obs is None: + polrep_obs = self.polrep + + obs = array.obsdata(self.ra, self.dec, self.rf, bw, tint, tadv, tstart, tstop, mjd=mjd, + polrep=polrep_obs, tau=tau, + elevmin=elevmin, elevmax=elevmax, + no_elevcut_space=no_elevcut_space, + timetype=timetype, fix_theta_GMST=fix_theta_GMST, reorder=reorder) + + + # Observe on the same baselines as the empty observation and add noise + if noise: + obs = self.observe_same(obs, ttype=ttype, fft_pad_factor=fft_pad_factor, + sgrscat=sgrscat, add_th_noise=add_th_noise, th_noise_factor=th_noise_factor, + jones=jones, inv_jones=inv_jones, + opacitycal=opacitycal, ampcal=ampcal, + phasecal=phasecal, dcal=dcal, + frcal=frcal, rlgaincal=rlgaincal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + neggains=neggains, + taup=taup, + gain_offset=gain_offset, gainp=gainp, + phase_std=phase_std, + dterm_offset=dterm_offset, + rlratio_std=rlratio_std,rlphase_std=rlphase_std, + sigmat=sigmat,phasesigmat=phasesigmat, + rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat, + caltable_path=caltable_path, seed=seed, reorder=reorder, verbose=verbose) + else: + obs = self.observe_same_nonoise(obs, sgrscat=sgrscat, ttype=ttype, + cache=False, fft_pad_factor=fft_pad_factor, + zero_empty_pol=True, reorder=reorder, verbose=verbose) + + + obs.mjd = mjd + + return obs + + def observe_vex(self, vex, source, t_int=0.0, tight_tadv=False, + polrep_obs=None, ttype='nfft', fft_pad_factor=2, + fix_theta_GMST=False, + sgrscat=False, add_th_noise=True, + jones=False, inv_jones=False, + opacitycal=True, ampcal=True, phasecal=True, + frcal=True, dcal=True, rlgaincal=True, + stabilize_scan_phase=False, stabilize_scan_amp=False, + neggains=False, + tau=ehc.TAUDEF, taup=ehc.GAINPDEF, + gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF, + phase_std=-1, + dterm_offset=ehc.DTERMPDEF, + rlratio_std=0.,rlphase_std=0., + sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None, + caltable_path=None, seed=False, verbose=True): + """Generate baselines from a vex file and observes the image. + + Args: + vex (Vex): an vex object containing sites and scan information + source (str): the source to observe + + t_int (float): if not zero, overrides the vex scan lengths + tight_tadv (float): if True, advance right after each integration, + otherwise advance after 2x the scan length + + polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation + ttype (str): "fast" or "nfft" or "dtft" + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v + + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel + add_th_noise (bool): if True, baseline-dependent thermal noise is added + jones (bool): if True, uses Jones matrix to apply mis-calibration effects + otherwise uses old formalism without D-terms + inv_jones (bool): if True, applies estimated inverse Jones matrix + (not including random terms) to calibrate data + opacitycal (bool): if False, time-dependent gaussian errors are added to opacities + ampcal (bool): if False, time-dependent gaussian errors are added to station gains + phasecal (bool): if False, time-dependent station-based random phases are added + frcal (bool): if False, feed rotation angle terms are added to Jones matrix. + dcal (bool): if False, time-dependent gaussian errors added to Jones matrix D-terms. + rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + neggains (bool): if True, force the applied gains to be <1 + + tau (float): the base opacity at all sites, + or a dict giving one opacity per site + taup (float): the fractional std. dev. of the random error on the opacities + gainp (float): the fractional std. dev. of the random error on the gains + or a dict giving one std. dev. per site + + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + phase_std (float): std. dev. of LCP phase, + or a dict giving one std. dev. per site + a negative value samples from uniform + dterm_offset (float): the base std. dev. of random additive error at all sites, + or a dict giving one std. dev. per site + + rlratio_std (float): the fractional std. dev. of the R/L gain offset + or a dict giving one std. dev. per site + rlphase_std (float): std. dev. of R/L phase offset, + or a dict giving one std. dev. per site + a negative value samples from uniform + + sigmat (float): temporal std for a Gaussian Process used to generate gains. + If sigmat=None then an iid gain noise is applied. + phasesigmat (float): temporal std for a Gaussian Process used to generate phases. + If phasesigmat=None then an iid gain noise is applied. + rlgsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L gain ratios. + If rlgsigmat=None then an iid gain noise is applied. + rlpsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L phase diff. + If rlpsigmat=None then an iid gain noise is applied. + + caltable_path (string): If not None, path and prefix for saving the applied caltable + seed (int): seeds the random component of the noise terms. DO NOT set to 0! + verbose (bool): print updates and warnings + + Returns: + (Obsdata): an observation object + + """ + + if polrep_obs is None: + polrep_obs = self.polrep + + t_int_flag = False + if t_int == 0.0: + t_int_flag = True + + # Loop over all scans and assemble a list of scan observations + obs_List = [] + for i_scan in range(len(vex.sched)): + + if t_int_flag: + t_int = vex.sched[i_scan]['scan'][0]['scan_sec'] + if tight_tadv: + t_adv = t_int + else: + t_adv = 2.0 * vex.sched[i_scan]['scan'][0]['scan_sec'] + + # If this scan doesn't observe the source, advance + if vex.sched[i_scan]['source'] != source: + continue + + # What subarray is observing now? + scankeys = list(vex.sched[i_scan]['scan'].keys()) + subarray = vex.array.make_subarray([vex.sched[i_scan]['scan'][key]['site'] + for key in scankeys]) + + # Observe with the subarray over the scan interval + t_start = vex.sched[i_scan]['start_hr'] + t_stop = t_start + vex.sched[i_scan]['scan'][0]['scan_sec']/3600.0 - ehc.EP + + obs = self.observe(subarray, t_int, t_adv, t_start, t_stop, vex.bw_hz, + mjd=vex.sched[i_scan]['mjd_floor'], timetype='UTC', + polrep_obs=polrep_obs, + elevmin=.01, elevmax=89.99, + ttype=ttype, fft_pad_factor=fft_pad_factor, + fix_theta_GMST=fix_theta_GMST, + sgrscat=sgrscat, + add_th_noise=add_th_noise, + jones=jones, inv_jones=inv_jones, + opacitycal=opacitycal, ampcal=ampcal, phasecal=phasecal, + frcal=frcal, dcal=dcal, rlgaincal=rlgaincal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + neggains=neggains, + tau=tau, taup=taup, + gain_offset=gain_offset, gainp=gainp, + phase_std=phase_std, + dterm_offset=dterm_offset, + rlratio_std=rlratio_std,rlphase_std=rlphase_std, + sigmat=sigmat,phasesigmat=phasesigmat, + rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat, + caltable_path=caltable_path, seed=seed, verbose=verbose) + + obs_List.append(obs) + + # Merge the scans together + obs = ehtim.obsdata.merge_obs(obs_List) + + return obs + + def compare_images(self, im_compare, pol=None, psize=None,target_fov=None, blur_frac=0.0, + beamparams=[1., 1., 1.], metric=['nxcorr', 'nrmse', 'rssd'], + blursmall=False, shift=True): + """Compare to another image by computing normalized cross correlation, + normalized root mean squared error, or square root of the sum of squared differences. + Returns metrics only for the primary polarization imvec! + + Args: + im_compare (Image): the image to compare to + pol (str): which polarization image to compare. Default is self.pol_prim + psize (float): pixel size of comparison image (rad). + If None it is the smallest of the input image pizel sizes + target_fov (float): fov of the comparison image (rad). + If None it is twice the largest fov of the input images + + beamparams (list): the nominal Gaussian beam parameters [fovx, fovy, position angle] + blur_frac (float): fractional beam to blur each image to before comparison + + metric (list) : a list of fidelity metrics from ['nxcorr','nrmse','rssd'] + blursmall (bool) : True to blur the unpadded image rather than the large image. + shift (int): manual image shift, otherwise use shift from maximum cross-correlation + + Returns: + (tuple): [errormetric, im1_pad, im2_shift] + """ + + im1 = self.copy() + im2 = im_compare.switch_polrep(polrep_out=im1.polrep, pol_prim_out=im1.pol_prim) + + if im1.polrep != im2.polrep: + raise Exception("In find_shift, im1 and im2 must have the same polrep!") + if im1.pol_prim != im2.pol_prim: + raise Exception("In find_shift, im1 and im2 must have the same pol_prim!") + + # Shift the comparison image to maximize normalized cross-corr. + [idx, xcorr, im1_pad, im2_pad] = im1.find_shift(im2, psize=psize, target_fov=target_fov, + beamparams=beamparams, pol=pol, + blur_frac=blur_frac, blursmall=blursmall) + + if not isinstance(shift, bool): + idx = shift + + im2_shift = im2_pad.shift(idx) + + # Compute error metrics + error = [] + imvec1 = im1_pad.get_polvec(pol) + imvec2 = im2_shift.get_polvec(pol) + if 'nxcorr' in metric: + error.append(xcorr[idx[0], idx[1]] / (im1_pad.xdim * im1_pad.ydim)) + if 'nrmse' in metric: + error.append(np.sqrt(np.sum((np.abs(imvec1 - imvec2)**2 * im1_pad.psize**2)) / + np.sum((imvec1)**2 * im1_pad.psize**2))) + if 'rssd' in metric: + error.append(np.sqrt(np.sum(np.abs(imvec1 - imvec2)**2) * im1_pad.psize**2)) + + return (error, im1_pad, im2_shift) + + def align_images(self, im_list, pol=None, shift=True, final_fov=False, scale='lin', + gamma=0.5, dynamic_range=[1.e3]): + """Align all the images in im_list to the current image (self) + Aligns all images by comparison of the primary pol image. + + Args: + im_list (list): list of images to align to the current image + shift (list): list of manual image shifts, + otherwise use the shift from maximum cross-correlation + pol (str): which polarization image to compare. Default is self.pol_prim + final_fov (float): fov of the comparison image (rad). + If False it is the largestinput image fov + + scale (str) : compare images in 'log','lin',or 'gamma' scale + gamma (float): exponent for gamma scale comparison + dynamic_range (float): dynamic range for log and gamma scale comparisons + + Returns: + (tuple): (im_list_shift, shifts, im0_pad) + """ + + im0 = self.copy() + if not np.all(im0.polrep == np.array([im.polrep for im in im_list])): + raise Exception("In align_images, all images must have the same polrep!") + if not np.all(im0.pol_prim == np.array([im.pol_prim for im in im_list])): + raise Exception("In find_shift, all images must have the same pol_prim!") + + if len(dynamic_range) == 1: + dynamic_range = dynamic_range * np.ones(len(im_list) + 1) + + useshift = True + if isinstance(shift, bool): + useshift = False + + # Find the minimum psize and the maximum field of view + psize = im0.psize + max_fov = np.max([im0.xdim * im0.psize, im0.ydim * im0.psize]) + for i in range(0, len(im_list)): + psize = np.min([psize, im_list[i].psize]) + max_fov = np.max([max_fov, + im_list[i].xdim * im_list[i].psize, + im_list[i].ydim * im_list[i].psize]) + + if not final_fov: + final_fov = max_fov + + # Shift all images in the list + im_list_shift = [] + shifts = [] + for i in range(0, len(im_list)): + (idx, _, im0_pad_orig, im_pad) = im0.find_shift(im_list[i], target_fov=2 * max_fov, + psize=psize, pol=pol, + scale=scale, gamma=gamma, + dynamic_range=dynamic_range[i + 1]) + + if i == 0: + npix = int(im0_pad_orig.xdim / 2) + im0_pad = im0_pad_orig.regrid_image(final_fov, npix) + if useshift: + idx = shift[i] + + tmp = im_pad.shift(idx) + shifts.append(idx) + im_list_shift.append(tmp.regrid_image(final_fov, npix)) + + return (im_list_shift, shifts, im0_pad) + + def find_shift(self, im_compare, pol=None, psize=None, target_fov=None, + beamparams=[1., 1., 1.], blur_frac=0.0, blursmall=False, + scale='lin', gamma=0.5, dynamic_range=1.e3): + """Find image shift that maximizes normalized cross correlation with a second image im2. + Finds shift only by comparison of the primary pol image. + + Args: + im_compare (Image): image with respect with to switch + pol (str): which polarization image to compare. Default is self.pol_prim + psize (float): pixel size of comparison image (rad). + If None it is the smallest of the input image pizel sizes + target_fov (float): fov of the comparison image (rad). + If None it is twice the largest fov of the input images + + beamparams (list): the nominal Gaussian beam parameters [fovx, fovy, position angle] + blur_frac (float): fractional beam to blur each image to before comparison + blursmall (bool) : True to blur the unpadded image rather than the large image. + + scale (str) : compare images in 'log','lin',or 'gamma' scale + gamma (float): exponent for gamma scale comparison + dynamic_range (float): dynamic range for log and gamma scale comparisons + + Returns: + (tuple): (errormetric, im1_pad, im2_shift) + """ + + im1 = self.copy() + im2 = im_compare.switch_polrep(polrep_out=im1.polrep, pol_prim_out=im1.pol_prim) + if pol=='RL' or pol=='LR': + raise Exception("Find_shift currently doesn't work with complex RL or LR imvecs!") + if im1.polrep != im2.polrep: + raise Exception("In find_shift, im1 and im2 must have the same polrep!") + if im1.pol_prim != im2.pol_prim: + raise Exception("In find_shift, im1 and im2 must have the same pol_prim!") + + # Find maximum FOV and minimum pixel size for comparison + if target_fov is None: + max_fov = np.max([im1.fovx(), im1.fovy(), im2.fovx(), im2.fovy()]) + target_fov = 2 * max_fov + if psize is None: + psize = np.min([im1.psize, im2.psize]) + + npix = int(target_fov / psize) + + # Blur images, then pad + if ((blur_frac > 0.0) and (blursmall is True)): + im1 = im1.blur_gauss(beamparams, blur_frac, blur_frac) + im2 = im2.blur_gauss(beamparams, blur_frac, blur_frac) + + im1_pad = im1.regrid_image(target_fov, npix) + im2_pad = im2.regrid_image(target_fov, npix) + + # or, pad images, then blur + if ((blur_frac > 0.0) and (blursmall is False)): + im1_pad = im1_pad.blur_gauss(beamparams, blur_frac, blur_frac) + im2_pad = im2_pad.blur_gauss(beamparams, blur_frac, blur_frac) + + # Rescale the image vectors into log or gamma scale + # TODO -- what about negative values? complex values? + im1_pad_vec = im1_pad.get_polvec(pol) + im2_pad_vec = im2_pad.get_polvec(pol) + if scale == 'log': + im1_pad_vec[im1_pad_vec < 0.0] = 0.0 + im1_pad_vec = np.log(im1_pad_vec + np.max(im1_pad_vec) / dynamic_range) + im2_pad_vec[im2_pad_vec < 0.0] = 0.0 + im2_pad_vec = np.log(im2_pad_vec + np.max(im2_pad_vec) / dynamic_range) + if scale == 'gamma': + im1_pad_vec[im1_pad_vec < 0.0] = 0.0 + im1_pad_vec = (im1_pad_vec + np.max(im1_pad_vec) / dynamic_range)**(gamma) + im2_pad_vec[im2_pad_vec < 0.0] = 0.0 + im2_pad_vec = (im2_pad_vec + np.max(im2_pad_vec) / dynamic_range)**(gamma) + + # Normalize images and compute cross correlation with FFT + im1_norm = (im1_pad_vec.reshape(im1_pad.ydim, im1_pad.xdim) - np.mean(im1_pad_vec)) + im1_norm /= np.std(im1_pad_vec) + im2_norm = (im2_pad_vec.reshape(im2_pad.ydim, im2_pad.xdim) - np.mean(im2_pad_vec)) + im2_norm /= np.std(im2_pad_vec) + + fft_im1 = np.fft.fft2(im1_norm) + fft_im2 = np.fft.fft2(im2_norm) + + xcorr = np.real(np.fft.ifft2(fft_im1 * np.conj(fft_im2))) + + # Find idx of shift that maximized cross-correlation + idx = np.unravel_index(xcorr.argmax(), xcorr.shape) + + return [idx, xcorr, im1_pad, im2_pad] + + def hough_ring(self, edgetype='canny', thresh=0.2, num_circles=3, radius_range=None, + return_type='rad', display_results=True): + """Use a circular hough transform to find a circle in the image + Returns metrics only for the primary polarization imvec! + + Args: + num_circles (int) : number of circles to return + radius_range (tuple): range of radii to search in Hough transform, in radian + edgetype (str): edge detection type, 'gradient' or 'canny' + thresh(float): fractional threshold for the gradient image + display_results (bool): True to display results of the fit + return_type (str): 'rad' to return in radian, 'pixel' to return in pixel units + + Returns: + list : a list of fitted circles (xpos, ypos, radius, objFunc), in radian + """ + + if 'skimage' not in sys.modules: + raise Exception("scikit-image not installed: cannot use hough_ring!") + + # coordinate values + pdim = self.psize + xlist = np.arange(0, -self.xdim, -1) * pdim + (pdim * self.xdim) / 2.0 - pdim / 2.0 + ylist = np.arange(0, -self.ydim, -1) * pdim + (pdim * self.ydim) / 2.0 - pdim / 2.0 + + # normalize to range 0, 1 + im = self.copy() + maxval = np.max(im.imvec) + meanval = np.mean(im.imvec) + + im_norm = im.imvec / (maxval + .01 * meanval) + im_norm = im_norm.astype('float') # is it a problem if it's double?? + im_norm[np.isnan(im.imvec)] = 0 # mask nans to 0 + im.imvec = im_norm + + # detect edges + if edgetype == 'canny': + imarr = im.imvec.reshape(self.ydim, self.xdim) + edges = canny(imarr, sigma=0, high_threshold=thresh, low_threshold=0.01) + im_edges = self.copy() + im_edges.imvec = edges.flatten() + + elif edgetype == 'grad': + im_edges = self.grad() + if not (thresh is None): + thresh_val = thresh * np.max(im_edges.imvec) + mask = im_edges.imvec > thresh_val + # im_edges.imvec[mask] = 1 + im_edges.imvec[~mask] = 0 + edges = im_edges.imvec.reshape(self.ydim, self.xdim) + else: + im_edges = im.copy() + if not (thresh is None): + thresh_val = thresh * np.max(im_edges.imvec) + mask = im_edges.imvec > thresh_val + # im_edges.imvec[mask] = 1f + im_edges.imvec[~mask] = 0 + edges = im_edges.imvec.reshape(self.ydim, self.xdim) + + # define radius range for Hough transform search + if radius_range is None: + hough_radii = np.arange(int(10 * ehc.RADPERUAS / self.psize), + int(50 * ehc.RADPERUAS / self.psize)) + else: + hough_radii = np.linspace( + radius_range[0] / + self.psize, + radius_range[0] / + self.psize, + 25) + + # perform the hough transform and select the most prominent circles + hough_res = hough_circle(edges, hough_radii) + accums, cy, cx, radii = hough_circle_peaks(hough_res, hough_radii, + total_num_peaks=num_circles) + accum_tot = np.sum(accums) + + # print results, plot circles, and return + outlist = [] + if display_results: + plt.ion() + fig = self.display() + ax = fig.gca() + + i = 0 + colors = ['b', 'r', 'w', 'lime', 'magenta', 'aqua'] + for accum, center_y, center_x, radius in zip(accums, cy, cx, radii): + accum_frac = accum / accum_tot + if return_type == 'rad': + x_rad = xlist[int(np.round(center_x))] + y_rad = ylist[int(np.round(center_y))] + r_rad = radius * self.psize + outlist.append([x_rad, y_rad, r_rad, accum_frac]) + else: + outlist.append([center_x, center_y, radius, accum_frac]) + print(accum_frac) + print("%i ring diameter: %0.1f microarcsec" % (i, 2 * radius * pdim / ehc.RADPERUAS)) + if display_results: + if i > len(colors): + color = colors[-1] + else: + color = colors[i] + circ = mpl.patches.Circle((center_y, center_x), radius, fill=False, color=color) + ax.add_patch(circ) + i += 1 + + return outlist + + def fit_gauss(self, units='rad'): + """Determine the Gaussian parameters that short baselines would measure for the source + by diagonalizing the image covariance matrix. + Returns parameters only for the primary polarization! + + Args: + units (string): 'rad' returns values in radians, + 'natural' returns FWHM in uas and PA in degrees + + Returns: + (tuple) : a tuple (fwhm_maj, fwhm_min, theta) of the fit Gaussian parameters + """ + + (x1, y1) = self.centroid() + pdim = self.psize + im = self.imvec + + xlist = np.arange(0, -self.xdim, -1) * pdim + (pdim * self.xdim) / 2.0 - pdim / 2.0 + ylist = np.arange(0, -self.ydim, -1) * pdim + (pdim * self.ydim) / 2.0 - pdim / 2.0 + + x2 = (np.sum(np.outer(0.0 * ylist + 1.0, (xlist - x1)**2).ravel() * im) / np.sum(im)) + y2 = (np.sum(np.outer((ylist - y1)**2, 0.0 * xlist + 1.0).ravel() * im) / np.sum(im)) + xy = (np.sum(np.outer(ylist - y1, xlist - x1).ravel() * im) / np.sum(im)) + + eig = np.linalg.eigh(np.array(((x2, xy), (xy, y2)))) + gauss_params = np.array((eig[0][1]**0.5 * (8. * np.log(2.))**0.5, + eig[0][0]**0.5 * (8. * np.log(2.))**0.5, + np.mod(np.arctan2(eig[1][1][0], eig[1][1][1]) + np.pi, np.pi))) + if units == 'natural': + gauss_params[0] /= ehc.RADPERUAS + gauss_params[1] /= ehc.RADPERUAS + gauss_params[2] *= 180. / np.pi + + return gauss_params + + def fit_gauss_empirical(self, paramguess=None): + """Determine the Gaussian parameters that short baselines would measure + Returns parameters only for the primary polarization! + + Args: + paramguess (tuple): Initial guess (fwhm_maj, fwhm_min, theta) of fit parameters + + Returns: + (tuple) : a tuple (fwhm_maj, fwhm_min, theta) of the fit Gaussian parameters. + """ + + # This could be done using moments of the intensity distribution (self.fit_gauss) + # but we'll use the visibility approach + u_max = 1.0 / (self.psize * self.xdim) / 5.0 + uv = np.array([[u, v] + for u in np.arange(-u_max, u_max * 1.001, u_max / 4.0) + for v in np.arange(-u_max, u_max * 1.001, u_max / 4.0)]) + u = uv[:, 0] + v = uv[:, 1] + vis = np.dot(obsh.ftmatrix(self.psize, self.xdim, self.ydim, uv, pulse=self.pulse), + self.imvec) + + if paramguess is None: + paramguess = (self.psize * self.xdim / 4.0, self.psize * self.xdim / 4.0, 0.) + + def errfunc(p): + vismodel = obsh.gauss_uv(u, v, self.total_flux(), p, x=0., y=0.) + err = np.sum((np.abs(vis) - np.abs(vismodel))**2) + return err + + # minimizer params + optdict = {'maxiter': 5000, 'maxfev': 5000, 'xtol': paramguess[0] / 1e9, 'ftol': 1e-10} + res = opt.minimize(errfunc, paramguess, method='Nelder-Mead', options=optdict) + + # Return in the form [maj, min, PA] + x = res.x + x[0] = np.abs(x[0]) + x[1] = np.abs(x[1]) + x[2] = np.mod(x[2], np.pi) + if x[0] < x[1]: + maj = x[1] + x[1] = x[0] + x[0] = maj + x[2] = np.mod(x[2] + np.pi / 2.0, np.pi) + + return x + + def contour(self, contour_levels=[0.1, 0.25, 0.5, 0.75], + contour_cfun=None, color='w', legend=True, show_im=True, + cfun='afmhot', scale='lin', interp='gaussian', gamma=0.5, dynamic_range=1.e3, + plotp=False, nvec=20, pcut=0.01, mcut=0.1, label_type='ticks', has_title=True, + has_cbar=True, cbar_lims=(), cbar_unit=('Jy', 'pixel'), + contour_im=False, power=0, beamcolor='w', + export_pdf="", show=True, beamparams=None, cbar_orientation="vertical", + scale_lw=1, beam_lw=1, cbar_fontsize=12, axis=None, scale_fontsize=12): + """Display the image in a contour plot. + + Args: + contour_levels (arr): the fractional contour levels relative to the max flux plotted + contour_cfun (pyplot colormap function): the function used to get the RGB colors + legend (bool): True to show a legend that says what each contour line corresponds to + cfun (str): matplotlib.pyplot color function + scale (str): image scaling in ['log','gamma','lin'] + interp (str): image interpolation 'gauss' or 'lin' + + gamma (float): index for gamma scaling + dynamic_range (float): dynamic range for log and gamma scaling + + plotp (bool): True to plot linear polarimetic image + nvec (int): number of polarimetric vectors to plot + pcut (float): minimum stokes P value for displaying polarimetric vectors + as fraction of maximum Stokes I pixel + mcut (float): minimum fractional polarization for plotting vectors + label_type (string): specifies the type of axes labeling: 'ticks', 'scale', 'none' + has_title (bool): True if you want a title on the plot + has_cbar (bool): True if you want a colorbar on the plot + cbar_lims (tuple): specify the lower and upper limit of the colorbar + cbar_unit (tuple of strings): the unit of each pixel for the colorbar: + 'Jy', 'm-Jy', '$\mu$Jy' + + export_pdf (str): path to exported PDF with plot + show (bool): Display the plot if true + show_im (bool): Display the image with the contour plot if True + + Returns: + (matplotlib.figure.Figure): figure object with image + + """ + + image = self.copy() + + # or some generalized version for image sizes + y = np.linspace(0, image.ydim, image.ydim) + x = np.linspace(0, image.xdim, image.xdim) + + # make the image grid + z = image.imvec.reshape((image.ydim, image.xdim)) + maxz = max(image.imvec) + if axis is None: + ax = plt.gca() + + elif axis is not None: + ax = axis + plt.sca(axis) + + if show_im: + if axis is not None: + axis = image.display(cfun=cfun, scale=scale, interp=interp, gamma=gamma, + dynamic_range=dynamic_range, + plotp=plotp, nvec=nvec, pcut=pcut, mcut=mcut, + label_type=label_type, has_title=has_title, + has_cbar=has_cbar, cbar_lims=cbar_lims, + cbar_unit=cbar_unit, + beamparams=beamparams, + cbar_orientation=cbar_orientation, scale_lw=1, beam_lw=1, + cbar_fontsize=cbar_fontsize, axis=axis, + scale_fontsize=scale_fontsize, power=power, + beamcolor=beamcolor) + else: + image.display(cfun=cfun, scale=scale, interp=interp, gamma=gamma, + dynamic_range=dynamic_range, + plotp=plotp, nvec=nvec, pcut=pcut, mcut=mcut, label_type=label_type, + has_title=has_title, has_cbar=has_cbar, + cbar_lims=cbar_lims, cbar_unit=cbar_unit, beamparams=beamparams, + cbar_orientation=cbar_orientation, scale_lw=1, beam_lw=1, + cbar_fontsize=cbar_fontsize, + axis=None, scale_fontsize=scale_fontsize, + power=power, beamcolor=beamcolor) + else: + if contour_im is False: + image.imvec = 0.0 * image.imvec + else: + image = contour_im.copy() + + if axis is not None: + axis = image.display(cfun=cfun, scale=scale, interp=interp, gamma=gamma, + dynamic_range=dynamic_range, + plotp=plotp, nvec=nvec, pcut=pcut, mcut=mcut, + label_type=label_type, has_title=has_title, + has_cbar=has_cbar, cbar_lims=cbar_lims, cbar_unit=cbar_unit, + beamparams=beamparams, + cbar_orientation=cbar_orientation, scale_lw=1, beam_lw=1, + cbar_fontsize=cbar_fontsize, + axis=axis, + scale_fontsize=scale_fontsize, power=power, + beamcolor=beamcolor) + else: + image.display(cfun=cfun, scale=scale, interp=interp, gamma=gamma, + dynamic_range=dynamic_range, + plotp=plotp, nvec=nvec, pcut=pcut, mcut=mcut, label_type=label_type, + has_title=has_title, + has_cbar=has_cbar, cbar_lims=cbar_lims, cbar_unit=cbar_unit, + beamparams=beamparams, + cbar_orientation=cbar_orientation, scale_lw=1, beam_lw=1, + cbar_fontsize=cbar_fontsize, axis=None, + scale_fontsize=scale_fontsize, power=power, beamcolor=beamcolor) + + if axis is None: + ax = plt.gcf() + if axis is not None: + ax = axis + + if axis is not None: + ax = axis + plt.sca(axis) + + count = 0. + + for level in contour_levels: + if not(contour_cfun is None): + rgbval = contour_cfun(count / len(contour_levels)) + rgbstring = '#%02x%02x%02x' % (rgbval[0] * 256, rgbval[1] * 256, rgbval[2] * 256) + else: + rgbstring = color + cs = plt.contour(x, y, z, levels=[level * maxz], colors=rgbstring, cmap=None) + count += 1 + cs.collections[0].set_label(str(int(level * 100)) + '%') + if legend: + plt.legend() + + if show: + #plt.show(block=False) + ehc.show_noblock() + + if export_pdf != "": + ax.savefig(export_pdf, bbox_inches='tight', pad_inches=0) + + elif axis is not None: + return axis + return ax + + def display(self, pol=None, cfun=False, interp='gaussian', + scale='lin', gamma=0.5, dynamic_range=1.e3, + plotp=False, plot_stokes=False, nvec=20, + vec_cfun=None, + scut=0, pcut=0.1, mcut=0.01, scale_ticks=False, + log_offset=False, + label_type='ticks', has_title=True, alpha=1, + has_cbar=True, only_cbar=False, cbar_lims=(), cbar_unit=('Jy', 'pixel'), + export_pdf="", pdf_pad_inches=0.0, show=True, beamparams=None, + cbar_orientation="vertical", scinot=False, + scale_lw=1, beam_lw=1, cbar_fontsize=12, axis=None, + scale_fontsize=12, + power=0, + beamcolor='w', beampos='right', scalecolor='w',dpi=500): + """Display the image. + + Args: + pol (str): which polarization image to plot. Default is self.pol_prim + pol='spec' will plot spectral index + pol='curv' will plot spectral curvature + cfun (str): matplotlib.pyplot color function. + False changes with 'pol', but is 'afmhot' for most + interp (str): image interpolation 'gauss' or 'lin' + + scale (str): image scaling in ['log','gamma','lin'] + gamma (float): index for gamma scaling + dynamic_range (float): dynamic range for log and gamma scaling + + plotp (bool): True to plot linear polarimetic image + plot_stokes (bool): True to plot stokes subplots along with plotp + nvec (int): number of polarimetric vectors to plot + vec_cfun (str): color function for vectors colored by lin pol frac + + scut (float): minimum stokes I value for displaying spectral index + pcut (float): minimum stokes I value for displaying polarimetric vectors + (fraction of maximum Stokes I) + mcut (float): minimum fractional polarization value for displaying vectors + label_type (string): specifies the type of axes labeling: 'ticks', 'scale', 'none' + has_title (bool): True if you want a title on the plot + has_cbar (bool): True if you want a colorbar on the plot + cbar_lims (tuple): specify the lower and upper limit of the colorbar + cbar_unit (tuple): specifies the unit of the colorbar: e.g., + ('Jy','pixel'),('m-Jy','$\mu$as$^2$'),['Tb'] + beamparams (list): [fwhm_maj, fwhm_min, theta], set to plot beam contour + + export_pdf (str): path to exported PDF with plot + show (bool): Display the plot if true + scinot (bool): Display numbers/units in scientific notation + scale_lw (float): Linewidth of the scale overlay + beam_lw (float): Linewidth of the beam overlay + cbar_fontsize (float): Fontsize of the text elements of the colorbar + axis (matplotlib.axes.Axes): An axis object + scale_fontsize (float): Fontsize of the scale label + + power (float): Passed to colorbar for division of ticks by 1e(power) + beamcolor (str): color of the beam overlay + scalecolor (str): color of the scale label overlay + Returns: + (matplotlib.figure.Figure): figure object with image + + """ + + if (interp in ['gauss', 'gaussian', 'Gaussian', 'Gauss']): + interp = 'gaussian' + elif (interp in ['linear','bilinear']): + interp = 'bilinear' + else: + interp = 'none' + + if not(beamparams is None or beamparams is False): + if beamparams[0] > self.fovx() or beamparams[1] > self.fovx(): + raise Exception("beam FWHM must be smaller than fov!") + + if self.polrep == 'stokes' and pol is None: + pol = 'I' + elif self.polrep == 'circ' and pol is None: + pol = 'RR' + + if only_cbar: + has_cbar = True + label_type = 'none' + has_title = False + + if axis is None: + f = plt.figure() + plt.clf() + + if axis is not None: + plt.sca(axis) + f = plt.gcf() + + # Get unit scale factor + factor = 1. + fluxunit = 'Jy' + areaunit = 'pixel' + + if cbar_unit[0] in ['m-Jy', 'mJy']: + fluxunit = 'mJy' + factor *= 1.e3 + elif cbar_unit[0] in ['muJy', r'$\mu$-Jy', r'$\mu$Jy']: + fluxunit = r'$\mu$Jy' + factor *= 1.e6 + elif cbar_unit[0] == 'Tb': + factor = 3.254e13 / (self.rf**2 * self.psize**2) + fluxunit = 'Brightness Temperature (K)' + areaunit = '' + if power != 0: + fluxunit = (r'Brightness Temperature ($10^{{' + str(power) + '}}$ K)') + else: + fluxunit = 'Brightness Temperature (K)' + elif cbar_unit[0] in ['Jy']: + fluxunit = 'Jy' + factor *= 1. + else: + factor = 1 + fluxunit = cbar_unit[0] + areaunit = '' + + if len(cbar_unit) == 1 or cbar_unit[0] == 'Tb': + factor *= 1. + + elif cbar_unit[1] == 'pixel': + factor *= 1. + if power != 0: + areaunit = areaunit + (r' ($10^{{' + str(power) + '}}$ K)') + + elif cbar_unit[1] in ['$arcseconds$^2$', 'as$^2$', 'as2']: + areaunit = 'as$^2$' + fovfactor = self.xdim * self.psize * (1 / ehc.RADPERAS) + factor *= (1. / fovfactor)**2 / (1. / self.xdim)**2 + if power != 0: + areaunit = areaunit + (r' ($10^{{' + str(power) + '}}$ K)') + + elif cbar_unit[1] in [r'$\m-arcseconds$^2$', 'mas$^2$', 'mas2']: + areaunit = 'mas$^2$' + fovfactor = self.xdim * self.psize * (1 / ehc.RADPERUAS) / 1000. + factor *= (1. / fovfactor)**2 / (1. / self.xdim)**2 + if power != 0: + areaunit = areaunit + (r' ($10^{{' + str(power) + '}}$ K)') + + elif cbar_unit[1] in [r'$\mu$-arcseconds$^2$', r'$\mu$as$^2$', 'muas2']: + areaunit = r'$\mu$as$^2$' + fovfactor = self.xdim * self.psize * (1 / ehc.RADPERUAS) + factor *= (1. / fovfactor)**2 / (1. / self.xdim)**2 + if power != 0: + areaunit = areaunit + (r' ($10^{{' + str(power) + '}}$ K)') + + elif cbar_unit[1] == 'beam': + if (beamparams is None or beamparams is False): + print("Cannot convert to Jy/beam without beamparams!") + else: + areaunit = 'beam' + beamarea = (2.0 * np.pi * beamparams[0] * beamparams[1] / (8.0 * np.log(2))) + factor *= beamarea / (self.psize**2) + if power != 0: + areaunit = areaunit + (r' ($10^{{' + str(power) + '}}$ K)') + + else: + raise ValueError('cbar_unit ' + cbar_unit[1] + ' is not a possible option') + + if not plotp: # Plot a single polarization image + cbar_lims_p = () + + if pol.lower() == 'spec': + imvec = self.specvec.copy() + + # mask out low total intensity values + mask = self.imvec < (scut * np.max(self.imvec)) + imvec[mask] = np.nan + + unit = r'$\alpha$' + factor = 1 + cbar_lims_p = [-5, 5] + cfun_p = 'seismic' + elif pol.lower() == 'curv': + imvec = self.curvvec.copy() + + # mask out low total intensity values + mask = self.imvec < (scut * np.max(self.imvec)) + imvec[mask] = np.nan + + unit = r'$\beta$' + factor = 1 + cbar_lims_p = [-5, 5] + cfun_p = 'seismic' + elif pol.lower() == 'm': + imvec = self.mvec.copy() + unit = r'$\|\breve{m}|$' + factor = 1 + cbar_lims_p = [0, 1] + cfun_p = 'cool' + elif pol.lower() == 'p': + imvec = self.mvec * self.ivec + unit = r'$\|P|$' + cfun_p = 'afmhot' + elif pol.lower() == 'chi' or pol.lower() == 'evpa': + imvec = self.chivec.copy() / ehc.DEGREE + unit = r'$\chi (^\circ)$' + factor = 1 + cbar_lims_p = [0, 180] + cfun_p = 'hsv' + elif pol.lower() == 'e': + imvec = self.evec.copy() + unit = r'$E$-mode' + cfun_p = 'Spectral' + elif pol.lower() == 'b': + imvec = self.bvec.copy() + unit = r'$B$-mode' + cfun_p = 'Spectral' + else: + pol = pol.upper() + if pol == 'V': + cfun_p = 'bwr' + else: + cfun_p = 'afmhot' + try: + imvec = np.array(self._imdict[pol]).reshape(-1) / (10.**power) + except KeyError: + try: + if self.polrep == 'stokes': + im2 = self.switch_polrep('circ') + elif self.polrep == 'circ': + im2 = self.switch_polrep('stokes') + imvec = np.array(im2._imdict[pol]).reshape(-1) / (10.**power) + except KeyError: + raise Exception("Cannot make pol %s image in display()!" % pol) + + unit = fluxunit + if areaunit != '': + unit += ' / ' + areaunit + + if np.any(np.imag(imvec)): + print('casting complex image to abs value') + imvec = np.real(imvec) + + imvec = imvec * factor + imarr = imvec.reshape(self.ydim, self.xdim) + + if scale == 'log': + if (imarr < 0.0).any(): + print('clipping values less than 0 in display') + imarr[imarr < 0.0] = 0.0 + if log_offset: + imarr = np.log10(imarr + log_offset / dynamic_range) + else: + imarr = np.log10(imarr + np.max(imarr) / dynamic_range) + unit = r'$\log_{10}$(' + unit + ')' + + if scale == 'gamma': + if (imarr < 0.0).any(): + print('clipping values less than 0 in display') + imarr[imarr < 0.0] = 0.0 + imarr = (imarr + np.max(imarr) / dynamic_range)**(gamma) + unit = '(' + unit + ')^' + str(gamma) + + if not cbar_lims and cbar_lims_p: + cbar_lims = cbar_lims_p + + if cbar_lims: + cbar_lims[0] = cbar_lims[0] / (10.**power) + cbar_lims[1] = cbar_lims[1] / (10.**power) + imarr[imarr > cbar_lims[1]] = cbar_lims[1] + imarr[imarr < cbar_lims[0]] = cbar_lims[0] + + if has_title: + plt.title("%s %.2f GHz %s" % (self.source, self.rf / 1e9, pol), fontsize=16) + + if not cfun: + cfun = cfun_p + cmap = plt.get_cmap(cfun).copy() + cmap.set_bad(color='whitesmoke') + + if cbar_lims: + im = plt.imshow(imarr, alpha=alpha, cmap=cmap, interpolation=interp, + vmin=cbar_lims[0], vmax=cbar_lims[1]) + else: + im = plt.imshow(imarr, alpha=alpha, cmap=cmap, interpolation=interp) + + if not(beamparams is None or beamparams is False): + if beampos=='left': + beamparams = [beamparams[0], beamparams[1], beamparams[2], + +.4 * self.fovx(), -.4 * self.fovy()] + else: + beamparams = [beamparams[0], beamparams[1], beamparams[2], + -.35 * self.fovx(), -.35 * self.fovy()] + beamimage = self.copy() + beamimage.imvec *= 0 + beamimage = beamimage.add_gauss(1, beamparams) + halflevel = 0.5 * np.max(beamimage.imvec) + beamimarr = (beamimage.imvec).reshape(beamimage.ydim, beamimage.xdim) + plt.contour(beamimarr, levels=[halflevel], colors=beamcolor, linewidths=beam_lw) + + if has_cbar: + if only_cbar: + im.set_visible(False) + cb = plt.colorbar(im, fraction=0.046, pad=0.04, orientation=cbar_orientation) + cb.set_label(unit, fontsize=float(cbar_fontsize)) + + if cbar_fontsize != 12: + cb.set_label(unit, fontsize=float(cbar_fontsize) / 1.5) + cb.ax.tick_params(labelsize=cbar_fontsize) + + if cbar_lims: + plt.clim(cbar_lims[0], cbar_lims[1]) + if scinot: + cb.formatter.set_powerlimits((0, 0)) + cb.update_ticks() + + else: # plot polarization with ticks! + + im_stokes = self.switch_polrep(polrep_out='stokes') + imvec = np.array(im_stokes.imvec).reshape(-1) / (10**power) + qvec = np.array(im_stokes.qvec).reshape(-1) / (10**power) + uvec = np.array(im_stokes.uvec).reshape(-1) / (10**power) + vvec = np.array(im_stokes.vvec).reshape(-1) / (10**power) + + if len(imvec) == 0: + imvec = np.zeros(im_stokes.ydim * im_stokes.xdim) + if len(qvec) == 0: + qvec = np.zeros(im_stokes.ydim * im_stokes.xdim) + if len(uvec) == 0: + uvec = np.zeros(im_stokes.ydim * im_stokes.xdim) + if len(vvec) == 0: + vvec = np.zeros(im_stokes.ydim * im_stokes.xdim) + + imvec *= factor + qvec *= factor + uvec *= factor + vvec *= factor + + imarr = (imvec).reshape(im_stokes.ydim, im_stokes.xdim) + qarr = (qvec).reshape(im_stokes.ydim, im_stokes.xdim) + uarr = (uvec).reshape(im_stokes.ydim, im_stokes.xdim) + varr = (vvec).reshape(im_stokes.ydim, im_stokes.xdim) + + unit = fluxunit + if areaunit != '': + unit = fluxunit + ' / ' + areaunit + + # only the stokes I image gets transformed! TODO + imarr2 = imarr.copy() + if scale == 'log': + if (imarr2 < 0.0).any(): + print('clipping values less than 0 in display') + imarr2[imarr2 < 0.0] = 0.0 + imarr2 = np.log10(imarr2 + np.max(imarr2) / dynamic_range) + unit = r'$\log_{10}$(' + unit + ')' + + if scale == 'gamma': + if (imarr2 < 0.0).any(): + print('clipping values less than 0 in display') + imarr2[imarr2 < 0.0] = 0.0 + imarr2 = (imarr2 + np.max(imarr2) / dynamic_range)**(gamma) + unit = '(' + unit + ')^gamma' + + if cbar_lims: + cbar_lims[0] = cbar_lims[0] / (10.**power) + cbar_lims[1] = cbar_lims[1] / (10.**power) + imarr2[imarr2 > cbar_lims[1]] = cbar_lims[1] + imarr2[imarr2 < cbar_lims[0]] = cbar_lims[0] + + # polarization ticks + m = (np.abs(qvec + 1j * uvec) / imvec).reshape(self.ydim, self.xdim) + + thin = self.xdim // nvec + maska = (imvec).reshape(self.ydim, self.xdim) > pcut * np.max(imvec) + maskb = (np.abs(qvec + 1j * uvec) / imvec).reshape(self.ydim, self.xdim) > mcut + mask = maska * maskb + mask2 = mask[::thin, ::thin] + x = (np.array([[i for i in range(self.xdim)] + for j in range(self.ydim)])[::thin, ::thin]) + x = x[mask2] + y = (np.array([[j for i in range(self.xdim)] + for j in range(self.ydim)])[::thin, ::thin]) + y = y[mask2] + a = (-np.sin(np.angle(qvec + 1j * uvec) / + 2).reshape(self.ydim, self.xdim)[::thin, ::thin]) + a = a[mask2] + b = (np.cos(np.angle(qvec + 1j * uvec) / + 2).reshape(self.ydim, self.xdim)[::thin, ::thin]) + b = b[mask2] + + m = (np.abs(qvec + 1j * uvec) / imvec).reshape(self.ydim, self.xdim) + p = (np.abs(qvec + 1j * uvec)).reshape(self.ydim, self.xdim) + m[np.logical_not(mask)] = np.nan + p[np.logical_not(mask)] = np.nan + qarr[np.logical_not(mask)] = np.nan + uarr[np.logical_not(mask)] = np.nan + + voi = (vvec / imvec).reshape(self.ydim, self.xdim) + voi[np.logical_not(mask)] = np.nan + + if scale_ticks: + pticks = ((np.abs(qvec + 1j * uvec)).reshape(self.ydim, self.xdim))[::thin, ::thin][mask2] + pscale = (pticks - np.min(pticks))/(np.max(pticks) - np.min(pticks)) + a *= pscale + b *= pscale + + # Little pol plots + if plot_stokes: + + maxval = 1.1 * np.max((np.max(np.abs(uarr)), + np.max(np.abs(qarr)), np.max(np.abs(varr)))) + + # P Plot + ax = plt.subplot2grid((2, 5), (0, 0)) + im = plt.imshow(p, cmap=plt.get_cmap('bwr'), interpolation=interp, + vmin=-maxval, vmax=maxval) + plt.contour(imarr, colors='k', linewidths=.25) + ax.set_xticks([]) + ax.set_yticks([]) + if has_title: + plt.title('P') + if has_cbar: + cbaxes = plt.gcf().add_axes([0.1, 0.2, 0.01, 0.6]) + cbar = plt.colorbar(im, fraction=0.046, pad=0.04, cax=cbaxes, + label=unit, orientation='vertical') + cbar.ax.tick_params(labelsize=cbar_fontsize) + cbaxes.yaxis.set_ticks_position('left') + cbaxes.yaxis.set_label_position('left') + if cbar_lims: + plt.clim(-maxval, maxval) + + cmap = plt.get_cmap('bwr') + cmap.set_bad('whitesmoke') + # V Plot + ax = plt.subplot2grid((2, 5), (0, 1)) + plt.imshow(varr, cmap=cmap, interpolation=interp, + vmin=-maxval, vmax=maxval) + ax.set_xticks([]) + ax.set_yticks([]) + if has_title: + plt.title('V') + + # Q Plot + ax = plt.subplot2grid((2, 5), (1, 0)) + plt.imshow(qarr, cmap=cmap, interpolation=interp, + vmin=-maxval, vmax=maxval) + plt.contour(imarr, colors='k', linewidths=.25) + ax.set_xticks([]) + ax.set_yticks([]) + if has_title: + plt.title('Q') + + # U Plot + ax = plt.subplot2grid((2, 5), (1, 1)) + plt.imshow(uarr, cmap=cmap, interpolation=interp, + vmin=-maxval, vmax=maxval) + plt.contour(imarr, colors='k', linewidths=.25) + ax.set_xticks([]) + ax.set_yticks([]) + if has_title: + plt.title('U') + + # V/I plot + ax = plt.subplot2grid((2, 5), (0, 2)) + cmap = plt.get_cmap('seismic') + cmap.set_bad('whitesmoke') + + im = plt.imshow(voi, cmap=cmap, interpolation=interp, + vmin=-1, vmax=1) + if has_title: + plt.title('V/I') + plt.contour(imarr, colors='k', linewidths=.25) + ax.set_xticks([]) + ax.set_yticks([]) + if has_cbar: + cbaxes = plt.gcf().add_axes([0.125, 0.1, 0.425, 0.01]) + cbar = plt.colorbar(im, fraction=0.046, pad=0.04, cax=cbaxes, + label='|m|', orientation='horizontal') + cbar.ax.tick_params(labelsize=cbar_fontsize) + cbaxes.yaxis.set_ticks_position('right') + cbaxes.yaxis.set_label_position('right') + + if cbar_lims: + plt.clim(-1, 1) + + # m plot + ax = plt.subplot2grid((2, 5), (1, 2)) + plt.imshow(m, cmap=plt.get_cmap('seismic'), interpolation=interp, vmin=-1, vmax=1) + ax.set_xticks([]) + ax.set_yticks([]) + if has_title: + plt.title('m') + plt.contour(imarr, colors='k', linewidths=.25) + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.01 * self.xdim, units='x', pivot='mid', color='k', angles='uv', + scale=1.0 / thin) + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.005 * self.xdim, units='x', pivot='mid', color='w', angles='uv', + scale=1.1 / thin) + + # Big Stokes I plot --axis + ax = plt.subplot2grid((2, 5), (0, 3), rowspan=2, colspan=2) + else: + ax = plt.gca() + + if not cfun: + cfun = 'afmhot' + cmap = plt.get_cmap(cfun) + cmap.set_bad(color='whitesmoke') + + # Big Stokes I plot + if cbar_lims: + im = plt.imshow(imarr2, cmap=cmap, interpolation=interp, + vmin=cbar_lims[0], vmax=cbar_lims[1]) + else: + im = plt.imshow(imarr2, cmap, interpolation=interp) + + if vec_cfun is None: + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.01 * self.xdim, units='x', pivot='mid', color='k', angles='uv', + scale=1.0 / thin) + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.005 * self.xdim, units='x', pivot='mid', color='w', angles='uv', + scale=1.1 / thin) + else: + mthin = ( + np.abs( + qvec + + 1j * + uvec) / + imvec).reshape( + self.ydim, + self.xdim)[ + ::thin, + ::thin] + mthin = mthin[mask2] + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.01 * self.xdim, units='x', pivot='mid', color='w', angles='uv', + scale=1.0 / thin) + plt.quiver(x, y, a, b, mthin, + norm=mpl.colors.Normalize(vmin=0, vmax=1.), cmap=vec_cfun, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.007 * self.xdim, units='x', pivot='mid', angles='uv', + scale=1.1 / thin) + + if not(beamparams is None or beamparams is False): + beamparams = [beamparams[0], beamparams[1], beamparams[2], + -.35 * self.fovx(), -.35 * self.fovy()] + beamimage = self.copy() + beamimage.imvec *= 0 + beamimage = beamimage.add_gauss(1, beamparams) + halflevel = 0.5 * np.max(beamimage.imvec) + beamimarr = (beamimage.imvec).reshape(beamimage.ydim, beamimage.xdim) + plt.contour(beamimarr, levels=[halflevel], colors=beamcolor, linewidths=beam_lw) + + if has_cbar: + + cbar = plt.colorbar(im, fraction=0.046, pad=0.04, + label=unit, orientation=cbar_orientation) + cbar.ax.tick_params(labelsize=cbar_fontsize) + if cbar_lims: + plt.clim(cbar_lims[0], cbar_lims[1]) + if has_title: + plt.title("%s %.1f GHz : m=%.1f%% , v=%.1f%%" % (self.source, self.rf / 1e9, + self.lin_polfrac() * 100, + self.circ_polfrac() * 100), + fontsize=12) + f.subplots_adjust(hspace=.1, wspace=0.3) + + # Label the plot + ax = plt.gca() + if label_type == 'ticks': + xticks = obsh.ticks(self.xdim, self.psize / ehc.RADPERAS / 1e-6) + yticks = obsh.ticks(self.ydim, self.psize / ehc.RADPERAS / 1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel(r'Relative RA ($\mu$as)') + plt.ylabel(r'Relative Dec ($\mu$as)') + + elif label_type == 'scale': + plt.axis('off') + fov_uas = self.xdim * self.psize / ehc.RADPERUAS # get the fov in uas + roughfactor = 1. / 3. # make the bar about 1/3 the fov + fov_scale = int(math.ceil(fov_uas * roughfactor / 10.0)) * 10 + start = self.xdim * roughfactor / 3.0 # select the start location + end = start + fov_scale / fov_uas * self.xdim # determine the end location + plt.plot([start, end], [self.ydim - start - 5, self.ydim - start - 5], + color=scalecolor, lw=scale_lw) # plot a line + plt.text(x=(start + end) / 2.0, y=self.ydim - start + self.ydim / 30, + s=str(fov_scale) + r" $\mu$as", color=scalecolor, + ha="center", va="center", fontsize=scale_fontsize) + ax = plt.gca() + if axis is None: + ax.axes.get_xaxis().set_visible(False) + ax.axes.get_yaxis().set_visible(False) + + elif label_type == 'none' or label_type is None: + plt.axis('off') + ax = plt.gca() + if axis is None: + ax.axes.get_xaxis().set_visible(False) + ax.axes.get_yaxis().set_visible(False) + + # Show or save to file + if axis is not None: + return axis + if show: + #plt.show(block=False) + ehc.show_noblock() + + if export_pdf != "": + f.savefig(export_pdf, bbox_inches='tight', pad_inches=pdf_pad_inches, dpi=dpi) + + return f + + def overlay_display(self, im_list, color_coding=np.array([[1, 0, 1], [0, 1, 0]]), + export_pdf="", show=True, f=False, + shift=[0, 0], final_fov=False, interp='gaussian', + scale='lin', gamma=0.5, dynamic_range=[1.e3], rescale=True): + """Overlay primary polarization images of a list of images to compare structures. + + Args: + im_list (list): list of images to align to the current image + color_coding (numpy.array): Color coding of each image in the composite + + f (matplotlib.pyplot.figure): Figure to overlay on top of + export_pdf (str): path to exported PDF with plot + show (bool): Display the plot if true + + shift (list): list of manual image shifts, + otherwise use the shift from maximum cross-correlation + final_fov (float): fov of the comparison image (rad). + If False it is the largestinput image fov + + scale (str) : compare images in 'log','lin',or 'gamma' scale + gamma (float): exponent for gamma scale comparison + dynamic_range (float): dynamic range for log and gamma scale comparisons + + Returns: + (matplotlib.figure.Figure): figure object with image + + """ + + if not f: + f = plt.figure() + plt.clf() + + if len(dynamic_range) == 1: + dynamic_range = dynamic_range * np.ones(len(im_list) + 1) + + if not isinstance(shift, np.ndarray) and not isinstance(shift, bool): + shift = matlib.repmat(shift, len(im_list), 1) + + psize = self.psize + max_fov = np.max([self.xdim * self.psize, self.ydim * self.psize]) + for i in range(0, len(im_list)): + psize = np.min([psize, im_list[i].psize]) + max_fov = np.max([max_fov, im_list[i].xdim * im_list[i].psize, + im_list[i].ydim * im_list[i].psize]) + + if not final_fov: + final_fov = max_fov + + (im_list_shift, shifts, im0_pad) = self.align_images(im_list, shift=shift, + final_fov=final_fov, + scale=scale, gamma=gamma, + dynamic_range=dynamic_range) + + # unit = 'Jy/pixel' + if scale == 'log': + # unit = 'log(Jy/pixel)' + log_offset = np.max(im0_pad.imvec) / dynamic_range[0] + im0_pad.imvec = np.log10(im0_pad.imvec + log_offset) + for i in range(0, len(im_list)): + log_offset = np.max(im_list_shift[i].imvec) / dynamic_range[i + 1] + im_list_shift[i].imvec = np.log10(im_list_shift[i].imvec + log_offset) + + if scale == 'gamma': + # unit = '(Jy/pixel)^gamma' + log_offset = np.max(im0_pad.imvec) / dynamic_range[0] + im0_pad.imvec = (im0_pad.imvec + log_offset)**(gamma) + for i in range(0, len(im_list)): + log_offset = np.max(im_list_shift[i].imvec) / dynamic_range[i + 1] + im_list_shift[i].imvec = (im_list_shift[i].imvec + log_offset)**(gamma) + + composite_img = np.zeros((im0_pad.ydim, im0_pad.xdim, 3)) + for i in range(-1, len(im_list)): + + if i == -1: + immtx = im0_pad.imvec.reshape(im0_pad.ydim, im0_pad.xdim) + else: + immtx = im_list_shift[i].imvec.reshape(im0_pad.ydim, im0_pad.xdim) + + if rescale: + immtx = immtx - np.min(np.min(immtx)) + immtx = immtx / np.max(np.max(immtx)) + + for c in range(0, 3): + composite_img[:, :, c] = composite_img[:, :, c] + (color_coding[i + 1, c] * immtx) + + if rescale is False: + composite_img = composite_img - np.min(np.min(np.min(composite_img))) + composite_img = composite_img / np.max(np.max(np.max(composite_img))) + + plt.subplot(111) + plt.title('%s MJD %i %.2f GHz' % (self.source, self.mjd, self.rf / 1e9), fontsize=20) + plt.imshow(composite_img, interpolation=interp) + xticks = obsh.ticks(im0_pad.xdim, im0_pad.psize / ehc.RADPERAS / 1e-6) + yticks = obsh.ticks(im0_pad.ydim, im0_pad.psize / ehc.RADPERAS / 1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel(r'Relative RA ($\mu$as)') + plt.ylabel(r'Relative Dec ($\mu$as)') + + if show: + #plt.show(block=False) + ehc.show_noblock() + if export_pdf != "": + f.savefig(export_pdf, bbox_inches='tight') + + return (f, shift) + + def save_txt(self, fname): + """Save image data to text file. + + Args: + fname (str): path to output text file + + Returns: + """ + + ehtim.io.save.save_im_txt(self, fname) + return + + def save_fits(self, fname): + """Save image data to a fits file. + + Args: + fname (str): path to output fits file + + Returns: + """ + ehtim.io.save.save_im_fits(self, fname) + return + + +################################################################################################### +# Image creation functions +################################################################################################### + +def make_square(obs, npix, fov, pulse=ehc.PULSE_DEFAULT, polrep='stokes', pol_prim=None): + """Make an empty square image. + + Args: + obs (Obsdata): an obsdata object with the image metadata + npix (int): the pixel size of each axis + fov (float): the field of view of each axis in radians + pulse (function): the function convolved with the pixel values for continuous image + + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, or RR,LL,LR,RL for Circular + Returns: + (Image): an image object + """ + + outim = make_empty(npix, fov, obs.ra, obs.dec, rf=obs.rf, source=obs.source, + polrep=polrep, pol_prim=pol_prim, pulse=pulse, + mjd=obs.mjd, time=obs.tstart) + + return outim + + +def make_empty(npix, fov, ra, dec, rf=ehc.RF_DEFAULT, source=ehc.SOURCE_DEFAULT, + polrep='stokes', pol_prim=None, pulse=ehc.PULSE_DEFAULT, + mjd=ehc.MJD_DEFAULT, time=0.): + """Make an empty square image. + + Args: + npix (int): the pixel size of each axis + fov (float): the field of view of each axis in radians + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + rf (float): The image frequency in Hz + + source (str): The source name + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + pulse (function): The function convolved with the pixel values for continuous image. + + mjd (int): The integer MJD of the image + time (float): The observing time of the image (UTC hours) + + Returns: + (Image): an image object + """ + + pdim = fov / float(npix) + npix = int(npix) + imarr = np.zeros((npix, npix)) + outim = Image(imarr, pdim, ra, dec, + polrep=polrep, pol_prim=pol_prim, + rf=rf, source=source, mjd=mjd, time=time, pulse=pulse) + return outim + + +def load_image(image, display=False, aipscc=False): + """Read in an image from a text, .fits, .h5, or ehtim.image.Image object + + Args: + image (str/Image): path to input file + display (boolean): determine whether to display the image default + aipscc (boolean): if True, then AIPS CC table will be loaded instead + of the original brightness distribution. + Returns: + (Image): loaded image object + (boolean): False if the image cannot be read + """ + + is_unicode = False + try: + if isinstance(image, basestring): + is_unicode = True + except NameError: # python 3 + pass + if isinstance(image, str) or is_unicode: + if image.endswith('.fits'): + im = ehtim.io.load.load_im_fits(image, aipscc=aipscc) + elif image.endswith('.txt'): + im = ehtim.io.load.load_im_txt(image) + elif image.endswith('.h5'): + im = ehtim.io.load.load_im_hdf5(image) + else: + print("Image format is not recognized. Was expecting .fits, .txt, or Image.") + print(" Got <.{0}>. Returning False.".format(image.split('.')[-1])) + return False + + elif isinstance(image, ehtim.image.Image): + im = image + + else: + print("Image format is not recognized. Was expecting .fits, .txt, or Image.") + print(" Got {0}. Returning False.".format(type(image))) + return False + + if display: + im.display() + + return im + + +def load_txt(fname, polrep='stokes', pol_prim=None, pulse=ehc.PULSE_DEFAULT, zero_pol=True): + """Read in an image from a text file. + + Args: + fname (str): path to input text file + pulse (function): The function convolved with the pixel values for continuous image. + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + zero_pol (bool): If True, loads any missing polarizations as zeros + + Returns: + (Image): loaded image object + """ + + return ehtim.io.load.load_im_txt(fname, pulse=pulse, polrep=polrep, + pol_prim=pol_prim, zero_pol=True) + + +def load_fits(fname, aipscc=False, pulse=ehc.PULSE_DEFAULT, + polrep='stokes', pol_prim=None, zero_pol=False): + """Read in an image from a FITS file. + + Args: + fname (str): path to input fits file + aipscc (bool): if True, then AIPS CC table will be loaded + pulse (function): The function convolved with the pixel values for continuous image. + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + zero_pol (bool): If True, loads any missing polarizations as zeros + + Returns: + (Image): loaded image object + """ + + return ehtim.io.load.load_im_fits(fname, aipscc=aipscc, pulse=pulse, + polrep=polrep, pol_prim=pol_prim, zero_pol=zero_pol) + +def avg_imlist(imlist): + """Average a list of images. + + Args: + imlist (list): list of image objects + + Returns: + (Image): average image object + """ + + imavg = imlist[0] + + if np.any(np.array([im.polrep for im in imlist]) != imavg.polrep): + raise Exception("im.polrep in all images are not the same in avg_imlist!") + if np.any(np.array([im.source for im in imlist]) != imavg.source): + raise Exception("im.source in all images are not the same in avg_imlist!") + if np.any(np.array([im.rf for im in imlist]) != imavg.rf): + raise Exception("im.rf in all images are not the same in avg_imlist!") + + keys = imavg._imdict.keys() + + for im in imlist[1:]: + for key in keys: + imavg._imdict[key] += im._imdict[key] + + for key in keys: + imavg._imdict[key] /= float(len(imlist)) + + + return imavg + +def get_specim(imlist, reffreq, fit_order=2): + """get the spectral index/curvature from a list of images""" + freqs = [im.rf for im in imlist] + + # remove any zeros in the images + for im in imlist: + im.imvec[im.imvec<=0] = np.min(im.imvec[im.imvec!=0]) + + # fit + xfit = np.log(np.array(freqs)/reffreq) + yfit = np.log(np.array([im.imvec for im in imlist])) + + if fit_order == 2: + coeffs = np.polyfit(xfit,yfit,2) + beta = coeffs[0] + alpha = coeffs[1] + imvec = np.exp(coeffs[2]) + elif fit_order == 1: + coeffs = np.polyfit(xfit,yfit,1) + alpha = coeffs[0] + beta = 0*alpha + imvec = np.exp(coeffs[1]) + else: + raise Exception() + + outim = imlist[0].copy() #TODO no polarization + outim.imvec = imvec + outim.rf = reffreq + outim.specvec = alpha + outim.curvvec = beta + + return outim + + +def blur_mf(im,freqs,kernel,fit_order=2): + """blur multifrequncy images with the same beam""" + reffreq = im.rf + + # remove any zeros in the images + + + imlist = [im.get_image_mf(rf).blur_circ(kernel) for rf in freqs] + for image in imlist: + image.imvec[image.imvec<=0] = np.min(image.imvec[image.imvec!=0]) + + xfit = np.log(np.array(freqs)/reffreq) + yfit = np.log(np.array([im.imvec for im in imlist])) + + if fit_order == 2: + coeffs = np.polyfit(xfit,yfit,2) + beta = coeffs[0] + alpha = coeffs[1] + elif fit_order == 1: + coeffs = np.polyfit(xfit,yfit,1) + alpha = coeffs[0] + beta = 0*alpha + else: + alpha = 0*yfit + beta = 0*yfit + + outim = im.blur_circ(kernel) + outim.specvec = alpha + outim.curvvec = beta + return outim diff --git a/imager.py b/imager.py new file mode 100644 index 00000000..bae00748 --- /dev/null +++ b/imager.py @@ -0,0 +1,2010 @@ +# imager.py +# a general interferometric imager class +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import copy +import time +import numpy as np +import scipy.optimize as opt + +import ehtim.scattering as so +import ehtim.imaging.imager_utils as imutils +import ehtim.imaging.pol_imager_utils as polutils +import ehtim.imaging.multifreq_imager_utils as mfutils +import ehtim.image +import ehtim.const_def as ehc + +MAXIT = 200 # number of iterations +NHIST = 50 # number of steps to store for hessian approx +MAXLS = 40 # maximum number of line search steps in BFGS-B +STOP = 1e-6 # convergence criterion +EPS = 1e-8 + +DATATERMS = ['vis', 'bs', 'amp', 'cphase', 'cphase_diag', 'camp', 'logcamp', 'logcamp_diag'] +REGULARIZERS = ['gs', 'tv', 'tvlog','tv2', 'tv2log', 'l1', 'l1w', 'lA', 'patch', + 'flux', 'cm', 'simple', 'compact', 'compact2', 'rgauss'] +REGULARIZERS_SPECIND = ['l2_alpha', 'tv_alpha'] +REGULARIZERS_CURV = ['l2_beta', 'tv_beta'] + + + +DATATERMS_POL = ['pvis', 'm', 'pbs','vvis'] +REGULARIZERS_POL = ['msimple', 'hw', 'ptv','l1v','l2v','vtv','vtv2','vflux'] + +GRIDDER_P_RAD_DEFAULT = 2 +GRIDDER_CONV_FUNC_DEFAULT = 'gaussian' +FFT_PAD_DEFAULT = 2 +FFT_INTERP_DEFAULT = 3 + +REG_DEFAULT = {'simple': 1} +DAT_DEFAULT = {'vis': 100} + +POL_TRANS = True # this means we solve for polarization in the m, chi basis +#POL_WHICH_SOLVE = (0, 1, 1) # this means that pol imaging solves for m & chi (not I), for now + # not used, now determined by 'pol_next' +MF_WHICH_SOLVE = (1, 1, 0) # this means that mf imaging solves for I0 and alpha (not beta), for now + # DEFAULT ONLY: object now uses self.mf_which_solve + +REGPARAMS_DEFAULT = {'major':50*ehc.RADPERUAS, + 'minor':50*ehc.RADPERUAS, + 'PA':0., + 'alpha_A':1.0, + 'epsilon_tv':0.0} + +POLARIZATION_MODES = ['P','QU','IP','IQU','V','IV','IQUV','IPV'] # TODO: treatment of V may be inconsistent + +################################################################################################### +# Imager object +################################################################################################### + + +class Imager(object): + """A general interferometric imager. + """ + + def __init__(self, obs_in, init_im, + prior_im=None, flux=None, data_term=DAT_DEFAULT, reg_term=REG_DEFAULT, **kwargs): + + self.logstr = "" + self._obs_list = [] + self._init_list = [] + self._prior_list = [] + self._out_list = [] + self._out_list_epsilon = [] + self._out_list_scattered = [] + self._reg_term_list = [] + self._dat_term_list = [] + self._clipfloor_list = [] + self._maxset_list = [] + self._pol_list = [] + self._maxit_list = [] + self._stop_list = [] + self._flux_list = [] + self._pflux_list = [] + self._vflux_list = [] + self._snrcut_list = [] + self._debias_list = [] + self._systematic_noise_list = [] + self._systematic_cphase_noise_list = [] + self._transform_list = [] + self._weighting_list = [] + + # Regularizer/data terms for the next imaging iteration + self.reg_term_next = reg_term # e.g. [('simple',1), ('l1',10), ('flux',500), ('cm',500)] + self.dat_term_next = data_term # e.g. [('amp', 1000), ('cphase',100)] + + # Observations, frequencies + self.reffreq = init_im.rf + if isinstance(obs_in, list): + self._obslist_next = obs_in + self.obslist_next = obs_in + else: + self._obslist_next = [obs_in] + self.obslist_next = [obs_in] + + # Init, prior, flux + self.init_next = init_im + + if prior_im is None: + self.prior_next = self.init_next + else: + self.prior_next = prior_im + + if flux is None: + self.flux_next = self.prior_next.total_flux() + else: + self.flux_next = flux + + # set polarimetric flux values equal to Stokes I flux by default + # used in regularizer normalization + self.pflux_next = kwargs.get('pflux', flux) + self.vflux_next = kwargs.get('vflux', flux) + + # Polarization + self.pol_next = kwargs.get('pol', self.init_next.pol_prim) + + # Weighting/debiasing/snr cut/systematic noise + self.debias_next = kwargs.get('debias', True) + snrcut = kwargs.get('snrcut', 0.) + self.snrcut_next = {key: 0. for key in set(DATATERMS+DATATERMS_POL)} + + if type(snrcut) is dict: + for key in snrcut.keys(): + self.snrcut_next[key] = snrcut[key] + else: + for key in self.snrcut_next.keys(): + self.snrcut_next[key] = snrcut + + self.systematic_noise_next = kwargs.get('systematic_noise', 0.) + self.systematic_cphase_noise_next = kwargs.get('systematic_cphase_noise', 0.) + self.weighting_next = kwargs.get('weighting', 'natural') + + # Maximal/minimal closure set + self.maxset_next = kwargs.get('maxset', False) + + # Clippping + self.clipfloor_next = kwargs.get('clipfloor', 0.) + self.maxit_next = kwargs.get('maxit', MAXIT) + self.stop_next = kwargs.get('stop', STOP) + self.transform_next = kwargs.get('transform', ['log','mcv']) + self.transform_next = np.array([self.transform_next]).flatten() #so we can handle multiple transforms + + # Normalize or not? + self.norm_init = kwargs.get('norm_init', True) + self.norm_reg = kwargs.get('norm_reg', False) + self.beam_size = self.obslist_next[0].res() + self.regparams = {k: kwargs.get(k, REGPARAMS_DEFAULT[k]) for k in REGPARAMS_DEFAULT.keys()} + + self.chisq_transform = False + self.chisq_offset_gradient = 0.0 + + # FFT parameters + self._ttype = kwargs.get('ttype', 'nfft') + self._fft_gridder_prad = kwargs.get('fft_gridder_prad', GRIDDER_P_RAD_DEFAULT) + self._fft_conv_func = kwargs.get('fft_conv_func', GRIDDER_CONV_FUNC_DEFAULT) + self._fft_pad_factor = kwargs.get('fft_pad_factor', FFT_PAD_DEFAULT) + self._fft_interp_order = kwargs.get('fft_interp_order', FFT_INTERP_DEFAULT) + + # UV minimum for closure phases + self.cp_uv_min = kwargs.get('cp_uv_min', False) + + # Parameters related to scattering + self.epsilon_list_next = [] + self.scattering_model = kwargs.get('scattering_model', None) + self._sqrtQ = None + self._ea_ker = None + self._ea_ker_gradient_x = None + self._ea_ker_gradient_y = None + self._alpha_phi_list = [] + self.alpha_phi_next = kwargs.get('alpha_phi', 1e4) + + # Imager history + self._change_imgr_params = True + self.nruns = 0 + + # multifrequency + self.mf_next = False + self.reg_all_freq_mf = kwargs.get('reg_all_freq_mf',False) + self.mf_which_solve = kwargs.get('mf_which_solve',MF_WHICH_SOLVE) + + # Set embedding matrices and prepare imager + self.check_params() + self.check_limits() + self.init_imager() + + @property + def obslist_next(self): + return self._obslist_next + + @obslist_next.setter + def obslist_next(self, obslist): + if not isinstance(obslist, list): + raise Exception("obslist_next must be a list!") + self._obslist_next = obslist + self.freq_list = [obs.rf for obs in self.obslist_next] + #self.reffreq = self.freq_list[0] #Changed so that reffreq is determined by initial image/prior rf + self._logfreqratio_list = [np.log(nu/self.reffreq) for nu in self.freq_list] + + @property + def obs_next(self): + """the next Obsdata to be used in imaging + """ + return self.obslist_next[0] + + @obs_next.setter + def obs_next(self, obs): + """the next Obsdata to be used in imaging + """ + self.obslist_next = [obs] + + def make_image(self, pol=None, grads=True, mf=False, **kwargs): + """Make an image using current imager settings. + + Args: + pol (str): which polarization to image + grads (bool): whether or not to use image gradients + mf (bool): whether or not to do multifrequency (spectral index only for now) + + Returns: + (Image): output image + + """ + + self.mf_next = mf + self.reg_all_freq_mf = kwargs.get('reg_all_freq_mf', self.reg_all_freq_mf) + self.mf_which_solve = kwargs.get('mf_which_solve', self.mf_which_solve) + + if pol is None: + pol_prim = self.pol_next + else: + self.pol_next = pol + pol_prim = pol + + print("==============================") + print("Imager run %i " % (int(self.nruns)+1)) + + # For polarimetric imaging, switch polrep to Stokes + if self.pol_next in POLARIZATION_MODES: + print("Imaging Polarization: switching to Stokes!") + self.prior_next = self.prior_next.switch_polrep(polrep_out='stokes', pol_prim_out='I') + self.init_next = self.init_next.switch_polrep(polrep_out='stokes', pol_prim_out='I') + pol_prim = 'I' + + # Checks and initialize + self.check_params() + self.check_limits() + self.init_imager() + + # Print initial stats + self._nit = 0 + self._show_updates = kwargs.get('show_updates', True) + self._update_interval = kwargs.get('update_interval', 1) + + # Plot initial image + self.plotcur(self._xinit, **kwargs) + + # Minimize + optdict = {'maxiter': self.maxit_next, + 'ftol': self.stop_next, 'gtol': self.stop_next, + 'maxcor': NHIST, 'maxls': MAXLS} + def callback_func(xcur): + self.plotcur(xcur, **kwargs) + + print("Imaging . . .") + tstart = time.time() + if grads: + res = opt.minimize(self.objfunc, self._xinit, method='L-BFGS-B', jac=self.objgrad, + options=optdict, callback=callback_func) + else: + res = opt.minimize(self.objfunc, self._xinit, method='L-BFGS-B', + options=optdict, callback=callback_func) + tstop = time.time() + + # Format output + out = res.x[:] + self.tmpout = res.x + + if self.pol_next in POLARIZATION_MODES: # polarization + if self.pol_next == 'P': + out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (0,1,1)) + if 'mcv' in self.transform_next: + out = polutils.mcv(out) + + elif self.pol_next == 'IP' or self.pol_next == 'IQU': + out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (1,1,1)) + if 'mcv' in self.transform_next: + out = polutils.mcv(out) + if 'log' in self.transform_next: + out[0] = np.exp(out[0]) + + elif self.pol_next == 'V': + out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (0,0,0,1)) + if 'mcv' in self.transform_next: + out = polutils.mcv(out) + + elif self.pol_next == 'IV': + out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (1,0,0,1)) + if 'mcv' in self.transform_next: + out = polutils.mcv(out) + if 'log' in self.transform_next: + out[0] = np.exp(out[0]) + + elif self.pol_next == 'IQUV': + out = polutils.unpack_poltuple(out, self._xtuple, self._nimage, (1,1,1,1)) + if 'mcv' in self.transform_next: + out = polutils.mcv(out) + if 'log' in self.transform_next: + out[0] = np.exp(out[0]) + + elif self.mf_next: # multi-frequency + out = mfutils.unpack_mftuple(out, self._xtuple, self._nimage, self.mf_which_solve) + if 'log' in self.transform_next: + out[0] = np.exp(out[0]) + + elif 'log' in self.transform_next: # simple single-frequency + out = np.exp(out) + + # Print final stats + outstr = "" + chi2_term_dict = self.make_chisq_dict(out) + for dname in sorted(self.dat_term_next.keys()): + for i, obs in enumerate(self.obslist_next): + if len(self.obslist_next)==1: + dname_key = dname + else: + dname_key = dname + ('_%i' % i) + outstr += "chi2_%s : %0.2f " % (dname_key, chi2_term_dict[dname_key]) + + try: + print("time: %f s" % (tstop - tstart)) + print("J: %f" % res.fun) + print(outstr) + if isinstance(res.message,str): print(res.message) + else: print(res.message.decode()) + except: # TODO -- issues for some users with res.message + pass + + print("==============================") + + # Embed image + if self.pol_next in POLARIZATION_MODES: # polarization + if np.any(np.invert(self._embed_mask)): + out = polutils.embed_pol(out, self._embed_mask) + iimage_out = out[0] + qimage_out = polutils.make_q_image(out, POL_TRANS) + uimage_out = polutils.make_u_image(out, POL_TRANS) + vimage_out = polutils.make_v_image(out, POL_TRANS) + + elif self.mf_next: # multi-frequency + if np.any(np.invert(self._embed_mask)): + out = mfutils.embed_mf(out, self._embed_mask) + iimage_out = out[0] + specind_out = out[1] + curv_out = out[2] + + else: # simple single-pol + if np.any(np.invert(self._embed_mask)): + out = imutils.embed(out, self._embed_mask) + iimage_out = out + + # Return image + arglist, argdict = self.prior_next.image_args() + arglist[0] = iimage_out.reshape(self.prior_next.ydim, self.prior_next.xdim) + argdict['pol_prim'] = pol_prim + outim = ehtim.image.Image(*arglist, **argdict) + + # Copy over other polarizations + for pol2 in list(outim._imdict.keys()): + + # Is it the base image? + if pol2 == outim.pol_prim: + continue + + # Did we solve for polarimeric image or are we copying over old pols? + if self.pol_next in POLARIZATION_MODES and pol2 == 'Q': + polvec = qimage_out + elif self.pol_next in POLARIZATION_MODES and pol2 == 'U': + polvec = uimage_out + elif self.pol_next in POLARIZATION_MODES and pol2 == 'V': + polvec = vimage_out + else: + polvec = self.init_next._imdict[pol2] + + if len(polvec): + polarr = polvec.reshape(outim.ydim, outim.xdim) + outim.add_pol_image(polarr, pol2) + + # Copy over spectral index information + outim._mflist = copy.deepcopy(self.init_next._mflist) + if self.mf_next: + outim._mflist[0] = specind_out + outim._mflist[1] = curv_out + + # Append to history + logstr = str(self.nruns) + ": make_image(pol=%s)" % pol + self._append_image_history(outim, logstr) + self.nruns += 1 + + # Return Image object + return outim + + def converge(self, niter, blur_frac, pol, grads=True, **kwargs): + + blur = blur_frac * self.obs_next.res() + for repeat in range(niter-1): + init = self.out_last() + init = init.blur_circ(blur, blur) + self.init_next = init + self.make_image(pol=pol, grads=grads, **kwargs) + + + def make_image_I(self, grads=True, niter=1, blur_frac=1, **kwargs): + """Make Stokes I image using current imager settings. + """ + pol = 'I' + self.make_image(pol=pol, grads=grads, **kwargs) + self.converge(niter, blur_frac, pol, grads, **kwargs) + + return self.out_last() + + + def make_image_P(self, grads=True, niter=1, blur_frac=1, **kwargs): + """Make Stokes P polarimetric image using current imager settings. + """ + pol = 'P' + self.make_image(pol=pol, grads=grads, **kwargs) + self.converge(niter, blur_frac, pol, grads, **kwargs) + + return self.out_last() + + + def make_image_IP(self, grads=True, niter=1, blur_frac=1, **kwargs): + """Make Stokes I and P polarimetric image simultaneously using current imager settings. + """ + pol = 'IP' + self.make_image(pol=pol, grads=grads, **kwargs) + self.converge(niter, blur_frac, pol, grads, **kwargs) + + return self.out_last() + + def make_image_V(self, grads=True, niter=1, blur_frac=1, **kwargs): + """Make Stokes I image using current imager settings. + """ + pol = 'V' + self.make_image(pol=pol, grads=grads, **kwargs) + self.converge(niter, blur_frac, pol, grads, **kwargs) + + return self.out_last() + + def make_image_IV(self, grads=True, niter=1, blur_frac=1, **kwargs): + """Make Stokes I image using current imager settings. + """ + pol = 'IV' + self.make_image(pol=pol, grads=grads, **kwargs) + self.converge(niter, blur_frac, pol, grads, **kwargs) + + return self.out_last() + + def set_embed(self): + """Set embedding matrix. + """ + self._embed_mask = self.prior_next.imvec > self.clipfloor_next + if not np.any(self._embed_mask): + raise Exception("clipfloor_next too large: all prior pixels have been clipped!") + + xmax = self.prior_next.xdim//2 + ymax = self.prior_next.ydim//2 + + if self.prior_next.xdim % 2: xmin=-xmax-1 + else: xmin=-xmax + + if self.prior_next.ydim % 2: ymin=-ymax-1 + else: ymin=-ymax + + coord = np.array([[[x, y] + for x in np.arange(xmax, xmin, -1)] + for y in np.arange(ymax, ymin, -1)]) + + coord = coord.reshape(self.prior_next.ydim * self.prior_next.xdim, 2) + coord = coord * self.prior_next.psize + + self._coord_matrix = coord[self._embed_mask] + + return + + def check_params(self): + """Check parameter consistency. + """ + if ((self.prior_next.psize != self.init_next.psize) or + (self.prior_next.xdim != self.init_next.xdim) or + (self.prior_next.ydim != self.init_next.ydim)): + raise Exception("Initial image does not match dimensions of the prior image!") + + if ((self.prior_next.rf != self.init_next.rf)): + raise Exception("Initial image does not have same frequency as prior image!") + + if (self.prior_next.polrep != self.init_next.polrep): + raise Exception( + "Initial image polrep does not match prior polrep!") + + if (self.prior_next.polrep == 'circ' and not(self.pol_next in ['RR', 'LL'])): + raise Exception("Initial image polrep is 'circ': pol_next must be 'RR' or 'LL'") + + if (self.prior_next.polrep == 'stokes' and not(self.pol_next in ['I', 'Q', 'U', 'V', 'P','IP','IQU','IV','IQUV'])): + raise Exception( + "Initial image polrep is 'stokes': pol_next must be in 'I', 'Q', 'U', 'V', 'P','IP','IQU','IV','IQUV'!") + + # TODO single-polarization imaging. should we still support? + if ('log' in self.transform_next and self.pol_next in ['Q', 'U', 'V']): + raise Exception("Cannot image Stokes Q, U, V with log image transformation!") + + if(self.pol_next in ['Q', 'U', 'V'] and + ('gs' in self.reg_term_next.keys() or 'simple' in self.reg_term_next.keys())): + raise Exception( + "'simple' and 'gs' methods do not work with Stokes Q, U, or V images!") + + if self._ttype not in ['fast', 'direct', 'nfft']: + raise Exception("Possible ttype values are 'fast', 'direct','nfft'!") + + # Catch errors in multifrequency imaging setup + if self.mf_next and len(set(self.freq_list)) < 2: + raise Exception( + "must have observations at at least two frequencies for multifrequency imaging!") + + # Catch errors for polarimetric imaging setup + if self.pol_next in POLARIZATION_MODES: + if 'mcv' not in self.transform_next: + raise Exception("Polarimetric imaging needs 'mcv' transform!") + if (self._ttype not in ["direct", "nfft"]): + raise Exception("FFT not yet implemented in polarimetric imaging -- use NFFT!") + if 'I' in self.pol_next: + rlist = REGULARIZERS + REGULARIZERS_POL + dlist = DATATERMS + DATATERMS_POL + else: + rlist = REGULARIZERS_POL + dlist = DATATERMS_POL + else: + rlist = REGULARIZERS + REGULARIZERS_SPECIND + REGULARIZERS_CURV + dlist = DATATERMS + + # catch errors in general imaging setup + dt_here = False + dt_type = True + for term in sorted(self.dat_term_next.keys()): + if (term is not None) and (term is not False): + dt_here = True + if not ((term in dlist) or (term is False)): + dt_type = False + + st_here = False + st_type = True + for term in sorted(self.reg_term_next.keys()): + if (term is not None) and (term is not False): + st_here = True + if not ((term in rlist) or (term is False)): + st_type = False + + if not dt_here: + raise Exception("Must have at least one data term!") + if not st_here: + raise Exception("Must have at least one regularizer term!") + if not dt_type: + raise Exception("Invalid data term: valid data terms are: " + ','.join(dlist)) + if not st_type: + raise Exception("Invalid regularizer: valid regularizers are: " + ','.join(rlist)) + + + # Determine if we need to recompute the saved imager parameters on the next imager run + if self.nruns == 0: + return + + if self.pol_next != self.pol_last(): + print("changed polarization!") + self._change_imgr_params = True + return + + if self.obslist_next != self.obslist_last(): + print("changed observation!") + self._change_imgr_params = True + return + + if len(self.reg_term_next) != len(self.reg_terms_last()): + print("changed number of regularizer terms!") + self._change_imgr_params = True + return + + if len(self.dat_term_next) != len(self.dat_terms_last()): + print("changed number of data terms!") + self._change_imgr_params = True + return + + for term in sorted(self.dat_term_next.keys()): + if term not in self.dat_terms_last().keys(): + print("added %s to data terms" % term) + self._change_imgr_params = True + return + + for term in sorted(self.reg_term_next.keys()): + if term not in self.reg_terms_last().keys(): + print("added %s to regularizers!" % term) + self._change_imgr_params = True + return + + if ((self.prior_next.psize != self.prior_last().psize) or + (self.prior_next.xdim != self.prior_last().xdim) or + (self.prior_next.ydim != self.prior_last().ydim)): + print("changed prior dimensions!") + self._change_imgr_params = True + + if self.debias_next != self.debias_last(): + print("changed debiasing!") + self._change_imgr_params = True + return + if self.snrcut_next != self.snrcut_last(): + print("changed snrcut!") + self._change_imgr_params = True + return + if self.weighting_next != self.weighting_last(): + print("changed data weighting!") + self._change_imgr_params = True + return + if self.systematic_noise_next != self.systematic_noise_last(): + print("changed systematic noise!") + self._change_imgr_params = True + return + if self.systematic_cphase_noise_next != self.systematic_cphase_noise_last(): + print("changed systematic cphase noise!") + self._change_imgr_params = True + return + + def check_limits(self): + """Check image parameter consistency with observation. + """ + uvmax = 1.0/self.prior_next.psize + uvmin = 1.0/(self.prior_next.psize*np.max((self.prior_next.xdim, self.prior_next.ydim))) + uvdists = self.obs_next.unpack('uvdist')['uvdist'] + maxbl = np.max(uvdists) + minbl = np.max(uvdists[uvdists > 0]) + + if uvmax < maxbl: + print("Warning! Pixel size is larger than smallest spatial wavelength!") + if uvmin > minbl: + print("Warning! Field of View is smaller than largest nonzero spatial wavelength!") + + if self.pol_next in ['I', 'RR', 'LL']: + maxamp = np.max(np.abs(self.obs_next.unpack('amp')['amp'])) + if self.flux_next > 1.2*maxamp: + print("Warning! Specified flux is > 120% of maximum visibility amplitude!") + if self.flux_next < .8*maxamp: + print("Warning! Specified flux is < 80% of maximum visibility amplitude!") + + def reg_terms_last(self): + """Return last used regularizer terms. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._reg_term_list[-1] + + def dat_terms_last(self): + """Return last used data terms. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._dat_term_list[-1] + + def obslist_last(self): + """Return last used observation. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._obs_list[-1] + + def obs_last(self): + """Return last used observation. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._obs_list[-1][0] + + def prior_last(self): + """Return last used prior image. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._prior_list[-1] + + def out_last(self): + """Return last result. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._out_list[-1] + + def out_scattered_last(self): + """Return last result with scattering. + """ + if self.nruns == 0 or len(self._out_list_scattered) == 0: + print("No stochastic optics imager runs yet!") + return + return self._out_list_scattered[-1] + + def out_epsilon_last(self): + """Return last result with scattering. + """ + if self.nruns == 0 or len(self._out_list_epsilon) == 0: + print("No stochastic optics imager runs yet!") + return + return self._out_list_epsilon[-1] + + def init_last(self): + """Return last initial image. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._init_list[-1] + + def flux_last(self): + """Return last total flux constraint. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._flux_list[-1] + + def pflux_last(self): + """Return last total linear polarimetric flux constraint. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._pflux_list[-1] + + def vflux_last(self): + """Return last total circular polarimetric flux constraint. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._vflux_list[-1] + + def clipfloor_last(self): + """Return last clip floor. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._clipfloor_list[-1] + + def pol_last(self): + """Return last polarization imaged. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._pol_list[-1] + + def maxit_last(self): + """Return last max_iterations value. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._maxit_list[-1] + + def debias_last(self): + """Return last debias value. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._debias_list[-1] + + def snrcut_last(self): + """Return last snrcut value. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._snrcut_list[-1] + + def weighting_last(self): + """Return last weighting value. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._weighting_list[-1] + + def systematic_noise_last(self): + """Return last systematic_noise value. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._systematic_noise_list[-1] + + def systematic_cphase_noise_last(self): + """Return last closure phase systematic noise value (in degree). + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._systematic_cphase_noise_list[-1] + + def stop_last(self): + """Return last convergence value. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._stop_list[-1] + + def transform_last(self): + """Return last image transform used. + """ + if self.nruns == 0: + print("No imager runs yet!") + return + return self._transform_list[-1] + + def init_imager(self): + """Set up Stokes I imager. + """ + # Set embedding + self.set_embed() + + # Set prior & initial image vectors for polarimetric imaging + if self.pol_next in POLARIZATION_MODES: + + # initial I image + if self.norm_init and ('I' in self.pol_next): + self._nprior = (self.flux_next * self.prior_next.imvec / + np.sum((self.prior_next.imvec)[self._embed_mask]))[self._embed_mask] + iinit = (self.flux_next * self.init_next.imvec / + np.sum((self.init_next.imvec)[self._embed_mask]))[self._embed_mask] + else: + self._nprior = self.prior_next.imvec[self._embed_mask] + iinit = self.init_next.imvec[self._embed_mask] + + self._nimage = len(iinit) + + # Initialize m & phi & v + if (len(self.init_next.qvec) and + (np.any(self.init_next.qvec != 0) or np.any(self.init_next.uvec != 0))): + + init1 = np.abs(self.init_next.qvec + 1j*self.init_next.uvec) / self.init_next.imvec + init1 = init1[self._embed_mask] + init2 = (np.arctan2(self.init_next.uvec, self.init_next.qvec) / 2.0) + init2 = init2[self._embed_mask] + else: + # !AC TODO get the actual zero baseline polarization fraction from the data? + print("No polarimetric image in init_next!") + print("--initializing with 20% pol and random orientation!") + init1 = 0.2 * (np.ones(self._nimage) + 1e-2 * np.random.rand(self._nimage)) + init2 = np.zeros(self._nimage) + 1e-2 * np.random.rand(self._nimage) + + # Initialize v + if 'V' in self.pol_next: + if len(self.init_next.vvec) and (np.any(self.init_next.vvec != 0)): + init3 = self.init_next.vvec / self.init_next.imvec + init3 = init3[self._embed_mask] + else: + # !AC TODO get the actual zero baseline polarization fraction from the data? + print("No V polarimetric image in init_next!") + print("--initializing with random vector") + #init3 = 0.05 * np.random.randn(self._nimage) + init3 = 0.01 * (np.ones(self._nimage) + 1e-2 * np.random.rand(self._nimage)) + self._inittuple = np.array((iinit, init1, init2, init3)) + else: + self._inittuple = np.array((iinit, init1, init2)) + + # Change of variables + if 'mcv' in self.transform_next: + self._xtuple = polutils.mcv_r(self._inittuple) + else: + raise Exception("Polarimetric imaging only works with mcv transform!") + + # Only apply log transformation to Stokes I if simultaneous imaging + if ('log' in self.transform_next) and ('I' in self.pol_next): + self._xtuple[0] = np.log(self._xtuple[0]) + + # Determine pol_which_solve + if self.pol_next in ['P','QU']: + self._pol_which_solve = (0,1,1) + elif self.pol_next in ['IP','IQU']: + self._pol_which_solve = (1,1,1) + elif self.pol_next in ['V']: + self._pol_which_solve = (0,0,0,1) + elif self.pol_next in ['IV']: + self._pol_which_solve = (1,0,0,1) + elif self.pol_next in ['IQUV']: + self._pol_which_solve = (1,1,1,1) + else: + raise Exception("Do not know correct pol_which_solve for self.pol_next=%s!"%self.pol_next) + + # Pack into single vector + self._xinit = polutils.pack_poltuple(self._xtuple, self._pol_which_solve) + + # Set prior & initial image vectors for multifrequency imaging + elif self.mf_next: + + self.reffreq = self.init_next.rf # set reference frequency to same as prior + # reset logfreqratios in case reference frequency changed + self._logfreqratio_list = [np.log(nu/self.reffreq) for nu in self.freq_list] + + if self.norm_init: + nprior_I = (self.flux_next * self.prior_next.imvec / + np.sum((self.prior_next.imvec)[self._embed_mask]))[self._embed_mask] + ninit_I = (self.flux_next * self.init_next.imvec / + np.sum((self.init_next.imvec)[self._embed_mask]))[self._embed_mask] + else: + nprior_I = self.prior_next.imvec[self._embed_mask] + ninit_I = self.init_next.imvec[self._embed_mask] + + if len(self.init_next.specvec): + ninit_a = self.init_next.specvec[self._embed_mask] + else: + ninit_a = np.zeros(self._nimage)[self._embed_mask] + if len(self.prior_next.specvec): + nprior_a = self.prior_next.specvec[self._embed_mask] + else: + nprior_a = np.zeros(self._nimage)[self._embed_mask] + + if len(self.init_next.curvvec): + ninit_b = self.init_next.curvvec[self._embed_mask] + else: + ninit_b = np.zeros(self._nimage)[self._embed_mask] + if len(self.prior_next.curvvec): + nprior_b = self.init_next.curvvec[self._embed_mask] + else: + nprior_b = np.zeros(self._nimage)[self._embed_mask] + + self._nimage = len(ninit_I) + + self.inittuple = np.array((ninit_I, ninit_a, ninit_b)) + self.priortuple = np.array((nprior_I, nprior_a, nprior_b)) + + # Change of variables + if 'log' in self.transform_next: + self._xtuple = np.array((np.log(ninit_I), ninit_a, ninit_b)) + else: + self._xtuple = self.inittuple + + # Pack into single vector + self._xinit = mfutils.pack_mftuple(self._xtuple, self.mf_which_solve) + + # Set prior & initial image vectors for single stokes or RR, LL imaging + else: + + if self.norm_init: + self._nprior = (self.flux_next * self.prior_next.imvec / + np.sum((self.prior_next.imvec)[self._embed_mask]))[self._embed_mask] + ninit = (self.flux_next * self.init_next.imvec / + np.sum((self.init_next.imvec)[self._embed_mask]))[self._embed_mask] + else: + self._nprior = self.prior_next.imvec[self._embed_mask] + ninit = self.init_next.imvec[self._embed_mask] + + self._nimage = len(ninit) + # Change of variables + if 'log' in self.transform_next: + self._xinit = np.log(ninit) + else: + self._xinit = ninit + + # Make data term tuples + if self._change_imgr_params: + if self.nruns == 0: + print("Initializing imager data products . . .") + if self.nruns > 0: + print("Recomputing imager data products . . .") + + self._data_tuples = {} + + # Loop over all data term types + for dname in sorted(self.dat_term_next.keys()): + + # Loop over all observations in the list + for i, obs in enumerate(self.obslist_next): + # Each entry in the dterm dictionary past the first has an appended number + if len(self.obslist_next)==1: + dname_key = dname + else: + dname_key = dname + ('_%i' % i) + + # Polarimetric data products + if dname in DATATERMS_POL: + tup = polutils.polchisqdata(obs, self.prior_next, self._embed_mask, dname, + ttype=self._ttype, + fft_pad_factor=self._fft_pad_factor, + conv_func=self._fft_conv_func, + p_rad=self._fft_gridder_prad) + + # Single polarization data products + elif dname in DATATERMS: + if self.pol_next in POLARIZATION_MODES: + if not 'I' in self.pol_next: + raise Exception("cannot use dterm %s with pol=%s"%(dname,self.pol_next)) + pol_next = 'I' + else: + pol_next = self.pol_next + + tup = imutils.chisqdata(obs, self.prior_next, self._embed_mask, dname, + pol=pol_next, maxset=self.maxset_next, + debias=self.debias_next, + snrcut=self.snrcut_next[dname], + weighting=self.weighting_next, + systematic_noise=self.systematic_noise_next, + systematic_cphase_noise=self.systematic_cphase_noise_next, + ttype=self._ttype, order=self._fft_interp_order, + fft_pad_factor=self._fft_pad_factor, + conv_func=self._fft_conv_func, + p_rad=self._fft_gridder_prad, + cp_uv_min=self.cp_uv_min) + else: + raise Exception("data term %s not recognized!" % dname) + + self._data_tuples[dname_key] = tup + + self._change_imgr_params = False + + return + + def init_imager_scattering(self): + """Set up scattering imager. + """ + N = self.prior_next.xdim + + if self.scattering_model is None: + self.scattering_model = so.ScatteringModel() + + # First some preliminary definitions + wavelength = ehc.C/self.obs_next.rf*100.0 # Observing wavelength [cm] + N = self.prior_next.xdim + + # Field of view, in cm, at the scattering screen + FOV = self.prior_next.psize * N * self.scattering_model.observer_screen_distance + + # The ensemble-average convolution kernel and its gradients + self._ea_ker = self.scattering_model.Ensemble_Average_Kernel( + self.prior_next, wavelength_cm=wavelength) + ea_ker_gradient = so.Wrapped_Gradient(self._ea_ker/(FOV/N)) + self._ea_ker_gradient_x = -ea_ker_gradient[1] + self._ea_ker_gradient_y = -ea_ker_gradient[0] + + # The power spectrum + # Note: rotation is not currently implemented; + # the gradients would need to be modified slightly + self._sqrtQ = np.real(self.scattering_model.sqrtQ_Matrix(self.prior_next, t_hr=0.0)) + + # Generate the initial image+screen vector. + # By default, the screen is re-initialized to zero each time. + if len(self.epsilon_list_next) == 0: + self._xinit = np.concatenate((self._xinit, np.zeros(N**2-1))) + else: + self._xinit = np.concatenate((self._xinit, self.epsilon_list_next)) + + def make_chisq_dict(self, imcur): + """Make a dictionary of current chi^2 term values + i indexes the observation number in self.obslist_next + """ + + chi2_dict = {} + for dname in sorted(self.dat_term_next.keys()): + # Loop over all observations in the list + for i, obs in enumerate(self.obslist_next): + if len(self.obslist_next)==1: + dname_key = dname + else: + dname_key = dname + ('_%i' % i) + + (data, sigma, A) = self._data_tuples[dname_key] + + if dname in DATATERMS_POL: + chi2 = polutils.polchisq(imcur, A, data, sigma, dname, + ttype=self._ttype, mask=self._embed_mask, + pol_trans=POL_TRANS) + + elif dname in DATATERMS: + if self.mf_next: # multifrequency + logfreqratio = self._logfreqratio_list[i] + imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio) + elif self.pol_next in POLARIZATION_MODES: # polarization + imcur_nu = imcur[0] + else: # normal imaging + imcur_nu = imcur + + chi2 = imutils.chisq(imcur_nu, A, data, sigma, dname, + ttype=self._ttype, mask=self._embed_mask) + + else: + raise Exception("data term %s not recognized!" % dname) + + chi2_dict[dname_key] = chi2 + + return chi2_dict + + def make_chisqgrad_dict(self, imcur, i=0): + """Make a dictionary of current chi^2 term gradient values + i indexes the observation number in self.obslist_next + """ + chi2grad_dict = {} + for dname in sorted(self.dat_term_next.keys()): + # Loop over all observations in the list + for i, obs in enumerate(self.obslist_next): + if len(self.obslist_next)==1: + dname_key = dname + else: + dname_key = dname + ('_%i' % i) + + (data, sigma, A) = self._data_tuples[dname_key] + + # Polarimetric data products + if dname in DATATERMS_POL: + chi2grad = polutils.polchisqgrad(imcur, A, data, sigma, dname, + ttype=self._ttype, mask=self._embed_mask, + pol_solve=self._pol_which_solve, + pol_trans=POL_TRANS) + + # Single polarization data products + elif dname in DATATERMS: + if self.mf_next: # multifrequency + logfreqratio = self._logfreqratio_list[i] + imref = imcur[0] + imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio) + elif self.pol_next in POLARIZATION_MODES: # polarization + imcur_nu = imcur[0] + else: # normal imaging + imcur_nu = imcur + + chi2grad = imutils.chisqgrad(imcur_nu, A, data, sigma, dname, + ttype=self._ttype, mask=self._embed_mask) + + # If multifrequency imaging, + # transform the image gradients for all the solved quantities + if self.mf_next: + logfreqratio = self._logfreqratio_list[i] + chi2grad = mfutils.mf_all_grads_chain(chi2grad, imcur_nu, imref, logfreqratio) + + # If imaging polarization simultaneously, bundle the gradient properly + if self.pol_next in POLARIZATION_MODES: + if 'V' in self.pol_next: + chi2grad = np.array((chi2grad, np.zeros(self._nimage), np.zeros(self._nimage), np.zeros(self._nimage))) + else: + chi2grad = np.array((chi2grad, np.zeros(self._nimage), np.zeros(self._nimage))) + + else: + raise Exception("data term %s not recognized!" % dname) + + chi2grad_dict[dname_key] = np.array(chi2grad) + + return chi2grad_dict + + def make_reg_dict(self, imcur): + """Make a dictionary of current regularizer values + """ + reg_dict = {} + + for regname in sorted(self.reg_term_next.keys()): + + # Polarimetric regularizer + if regname in REGULARIZERS_POL: + reg = polutils.polregularizer(imcur, self._embed_mask, + self.flux_next, self.pflux_next, self.vflux_next, + self.prior_next.xdim, self.prior_next.ydim, + self.prior_next.psize, regname, + norm_reg=self.norm_reg, beam_size=self.beam_size, + pol_trans=POL_TRANS) + + # Multifrequency regularizers + elif self.mf_next: + + # Image regularizer(s) + if regname in REGULARIZERS: + # new option to regularize ALL the images in multifrequency imaging + # TODO total fluxes not right? + if self.reg_all_freq_mf: + for i in range(len(self.obslist_next)): + regname_key = regname + ('_%i' % i) + logfreqratio = self._logfreqratio_list[i] + imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio) + prior_nu = mfutils.imvec_at_freq(self.priortuple, logfreqratio) + imref =imcur[0] + + reg = imutils.regularizer(imcur_nu, prior_nu, self._embed_mask, + self.flux_next, self.prior_next.xdim, + self.prior_next.ydim, self.prior_next.psize, + regname, + norm_reg=self.norm_reg, beam_size=self.beam_size, + **self.regparams) + reg_dict[regname_key] = reg + + # normally we only regularize reference frequency image + else: + reg = imutils.regularizer(imcur[0], self.priortuple[0], self._embed_mask, + self.flux_next, self.prior_next.xdim, + self.prior_next.ydim, self.prior_next.psize, + regname, + norm_reg=self.norm_reg, beam_size=self.beam_size, + **self.regparams) + + # Spectral index regularizer(s) + elif regname in REGULARIZERS_SPECIND: + reg = mfutils.regularizer_mf(imcur[1], self.priortuple[1], self._embed_mask, + self.flux_next, self.prior_next.xdim, + self.prior_next.ydim, self.prior_next.psize, + regname, + norm_reg=self.norm_reg, beam_size=self.beam_size, + **self.regparams) + + # Curvature index regularizer(s) + elif regname in REGULARIZERS_CURV: + reg = mfutils.regularizer_mf(imcur[2], self.priortuple[2], self._embed_mask, + self.flux_next, self.prior_next.xdim, + self.prior_next.ydim, self.prior_next.psize, + regname, + norm_reg=self.norm_reg, beam_size=self.beam_size, + **self.regparams) + + # Normal, single polarization, single-frequency regularizer + elif regname in REGULARIZERS: + if self.pol_next in POLARIZATION_MODES: + imcur0 = imcur[0] + else: + imcur0 = imcur + + reg = imutils.regularizer(imcur0, self._nprior, self._embed_mask, + self.flux_next, self.prior_next.xdim, + self.prior_next.ydim, self.prior_next.psize, + regname, + norm_reg=self.norm_reg, beam_size=self.beam_size, + **self.regparams) + else: + raise Exception("regularizer term %s not recognized!" % regname) + + # multifrequency regularizer terms are already in the dictionary + # if we regularize all images with self.reg_all_freq_mf + if not(self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS)): + reg_dict[regname] = reg + + return reg_dict + + def make_reggrad_dict(self, imcur): + """Make a dictionary of current regularizer gradient values + """ + + reggrad_dict = {} + + + for regname in sorted(self.reg_term_next.keys()): + + # Polarimetric regularizer + if regname in REGULARIZERS_POL: + reg = polutils.polregularizergrad(imcur, self._embed_mask, + self.flux_next, self.pflux_next, self.vflux_next, + self.prior_next.xdim, self.prior_next.ydim, + self.prior_next.psize, regname, + norm_reg=self.norm_reg, beam_size=self.beam_size, + pol_solve=self._pol_which_solve, + pol_trans=POL_TRANS) + + # Multifrequency regularizer + elif self.mf_next: + + # Image regularizer(s) + if regname in REGULARIZERS: + # new option to regularize ALL the images in multifrequency imaging + # TODO total fluxes not right? + if self.reg_all_freq_mf: + for i in range(len(self.obslist_next)): + regname_key = regname + ('_%i' % i) + logfreqratio = self._logfreqratio_list[i] + imcur_nu = mfutils.imvec_at_freq(imcur, logfreqratio) + prior_nu = mfutils.imvec_at_freq(self.priortuple, logfreqratio) + imref =imcur[0] + + reg = imutils.regularizergrad(imcur_nu, prior_nu, + self._embed_mask, self.flux_next, + self.prior_next.xdim, self.prior_next.ydim, + self.prior_next.psize, regname, + norm_reg=self.norm_reg, + beam_size=self.beam_size, + **self.regparams) + + + reg = mfutils.mf_all_grads_chain(reg, imcur_nu, imref, logfreqratio) + reg_dict[regname_key] = reg + + # normally we only regularize the reference frequency image + else: + reg = imutils.regularizergrad(imcur[0], self.priortuple[0], + self._embed_mask, self.flux_next, + self.prior_next.xdim, self.prior_next.ydim, + self.prior_next.psize, regname, + norm_reg=self.norm_reg, + beam_size=self.beam_size, + **self.regparams) + reg = np.array((reg, np.zeros(self._nimage), np.zeros(self._nimage))) + + + # Spectral index regularizer(s) + elif regname in REGULARIZERS_SPECIND: + reg = mfutils.regularizergrad_mf(imcur[1], self.priortuple[1], + self._embed_mask, self.flux_next, + self.prior_next.xdim, self.prior_next.ydim, + self.prior_next.psize, regname, + norm_reg=self.norm_reg, + beam_size=self.beam_size, + **self.regparams) + reg = np.array((np.zeros(self._nimage), reg, np.zeros(self._nimage))) + + # Curvature index regularizer(s) + elif regname in REGULARIZERS_CURV: + reg = mfutils.regularizergrad_mf(imcur[2], self.priortuple[2], + self._embed_mask, self.flux_next, + self.prior_next.xdim, self.prior_next.ydim, + self.prior_next.psize, regname, + norm_reg=self.norm_reg, + beam_size=self.beam_size, + **self.regparams) + reg = np.array((np.zeros(self._nimage), np.zeros(self._nimage), reg)) + + # Normal, single polarization, single-frequency regularizer + elif regname in REGULARIZERS: + if self.pol_next in POLARIZATION_MODES: + imcur0 = imcur[0] + else: + imcur0 = imcur + reg = imutils.regularizergrad(imcur0, self._nprior, self._embed_mask, self.flux_next, + self.prior_next.xdim, self.prior_next.ydim, + self.prior_next.psize, + regname, + norm_reg=self.norm_reg, beam_size=self.beam_size, + **self.regparams) + if self.pol_next in POLARIZATION_MODES: + if 'V' in self.pol_next: + reg = np.array((reg, np.zeros(self._nimage), np.zeros(self._nimage), np.zeros(self._nimage))) + else: + reg = np.array((reg, np.zeros(self._nimage), np.zeros(self._nimage))) + else: + raise Exception("regularizer term %s not recognized!" % regname) + + # multifrequency regularizer gradient terms are already in the dictionary + # if we regularize all images with self.reg_all_freq_mf + if not(self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS)): + reggrad_dict[regname] = reg + + return reggrad_dict + + def objfunc(self, imvec): + """Current objective function. + """ + + # Unpack polarimetric/multifrequency vector into an array + if self.pol_next in POLARIZATION_MODES: + imcur = polutils.unpack_poltuple(imvec, self._xtuple, self._nimage, self._pol_which_solve) + elif self.mf_next: + imcur = mfutils.unpack_mftuple(imvec, self._xtuple, self._nimage, self.mf_which_solve) + else: + imcur = imvec + + # Image change of variables + if self.pol_next in POLARIZATION_MODES and 'mcv' in self.transform_next: + imcur = polutils.mcv(imcur) + + if 'log' in self.transform_next: + if self.pol_next in POLARIZATION_MODES: + imcur[0] = np.exp(imcur[0]) + elif self.mf_next: + imcur[0] = np.exp(imcur[0]) + else: + imcur = np.exp(imcur) + + # Data terms + datterm = 0. + chi2_term_dict = self.make_chisq_dict(imcur) + for dname in sorted(self.dat_term_next.keys()): + hyperparameter = self.dat_term_next[dname] + + for i, obs in enumerate(self.obslist_next): + if len(self.obslist_next)==1: + dname_key = dname + else: + dname_key = dname + ('_%i' % i) + + chi2 = chi2_term_dict[dname_key] + + if self.chisq_transform: + datterm += hyperparameter * (chi2 + 1./chi2 - 1.) + else: + datterm += hyperparameter * (chi2 - 1.) + + # Regularizer terms + regterm = 0 + reg_term_dict = self.make_reg_dict(imcur) + for regname in sorted(self.reg_term_next.keys()): + hyperparameter = self.reg_term_next[regname] + # multifrequency imaging, regularize every frequency + if self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS): + for i in range(len(self.obslist_next)): + regname_key = regname + ('_%i' % i) + regularizer = reg_term_dict[regname_key] + regterm += hyperparameter * regularizer + + # but normally just one regularizer + else: + regularizer = reg_term_dict[regname] + regterm += hyperparameter * regularizer + + # Total cost + cost = datterm + regterm + + return cost + + def objgrad(self, imvec): + """Current objective function gradient. + """ + + # Unpack polarimetric/multifrequency vector into an array + if self.pol_next in POLARIZATION_MODES: + imcur = polutils.unpack_poltuple(imvec, self._xtuple, self._nimage, self._pol_which_solve) + elif self.mf_next: + imcur = mfutils.unpack_mftuple(imvec, self._xtuple, self._nimage, self.mf_which_solve) + else: + imcur = imvec + + # Image change of variables + if 'mcv' in self.transform_next: + if self.pol_next in POLARIZATION_MODES: + cvcur = imcur.copy() + imcur = polutils.mcv(imcur) + + if 'log' in self.transform_next: + if self.pol_next in POLARIZATION_MODES: + imcur[0] = np.exp(imcur[0]) + elif self.mf_next: + imcur[0] = np.exp(imcur[0]) + else: + imcur = np.exp(imcur) + + # Data terms + datterm = 0. + chi2_term_dict = self.make_chisqgrad_dict(imcur) + if self.chisq_transform: + chi2_value_dict = self.make_chisq_dict(imcur) + for dname in sorted(self.dat_term_next.keys()): + hyperparameter = self.dat_term_next[dname] + + for i, obs in enumerate(self.obslist_next): + if len(self.obslist_next)==1: + dname_key = dname + else: + dname_key = dname + ('_%i' % i) + + chi2_grad = chi2_term_dict[dname_key] + + if self.chisq_transform: + chi2_val = chi2_value_dict[dname] + datterm += hyperparameter * chi2_grad * (1. - 1./(chi2_val**2)) + else: + datterm += hyperparameter * (chi2_grad + self.chisq_offset_gradient) + + # Regularizer terms + regterm = 0 + reg_term_dict = self.make_reggrad_dict(imcur) + for regname in sorted(self.reg_term_next.keys()): + hyperparameter = self.reg_term_next[regname] + + # multifrequency imaging, regularize every frequency + if self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS): + for i in range(len(self.obslist_next)): + regname_key = regname + ('_%i' % i) + regularizer = reg_term_dict[regname_key] + regterm += hyperparameter * regularizer + + # but normally just one regularizer + else: + regularizer_grad = reg_term_dict[regname] + regterm += hyperparameter * regularizer_grad + + # Total gradient + grad = datterm + regterm + + # Chain rule term for change of variables + if 'mcv' in self.transform_next: + if self.pol_next in POLARIZATION_MODES: + grad *= polutils.mchain(cvcur) + + if 'log' in self.transform_next: + if self.pol_next in POLARIZATION_MODES: + grad[0] *= imcur[0] + elif self.mf_next: + grad[0] *= imcur[0] + else: + grad *= imcur + + # Repack gradient for polarimetric imaging + if self.pol_next in POLARIZATION_MODES: + grad = polutils.pack_poltuple(grad, self._pol_which_solve) + + # repack gradient for multifrequency imaging + elif self.mf_next: + grad = mfutils.pack_mftuple(grad, self.mf_which_solve) + + return grad + + def plotcur(self, imvec, **kwargs): + """Plot current image. + """ + + if self._show_updates: + if self._nit % self._update_interval == 0: + if self.pol_next in POLARIZATION_MODES: + + imcur = polutils.unpack_poltuple(imvec, self._xtuple, self._nimage, self._pol_which_solve) + elif self.mf_next: + imcur = mfutils.unpack_mftuple( + imvec, self._xtuple, self._nimage, self.mf_which_solve) + else: + imcur = imvec + + # Image change of variables + + if 'mcv' in self.transform_next: + if self.pol_next in POLARIZATION_MODES: + imcur = polutils.mcv(imcur) + + if 'log' in self.transform_next: + if self.pol_next in POLARIZATION_MODES: + imcur[0] = np.exp(imcur[0]) + elif self.mf_next: + imcur[0] = np.exp(imcur[0]) + else: + imcur = np.exp(imcur) + + # Get chi^2 and regularizer + chi2_term_dict = self.make_chisq_dict(imcur) + reg_term_dict = self.make_reg_dict(imcur) + + # Format print string + outstr = "------------------------------------------------------------------" + outstr += "\n%4d | " % self._nit + for dname in sorted(self.dat_term_next.keys()): + for i, obs in enumerate(self.obslist_next): + if len(self.obslist_next)==1: + dname_key = dname + else: + dname_key = dname + ('_%i' % i) + outstr += "chi2_%s : %0.2f " % (dname_key, chi2_term_dict[dname_key]) + outstr += "\n " + for dname in sorted(self.dat_term_next.keys()): + for i, obs in enumerate(self.obslist_next): + if len(self.obslist_next)==1: + dname_key = dname + else: + dname_key = dname + ('_%i' % i) + dval = chi2_term_dict[dname_key]*self.dat_term_next[dname] + outstr += "%s : %0.1f " % (dname_key, dval) + + outstr += "\n " + for regname in sorted(self.reg_term_next.keys()): + + if self.mf_next and self.reg_all_freq_mf and (regname in REGULARIZERS): + for i in range(len(self.obslist_next)): + regname_key = regname + ('_%i' % i) + rval = reg_term_dict[regname_key]*self.reg_term_next[regname] + outstr += "%s : %0.1f " % (regname_key, rval) + else: + rval = reg_term_dict[regname]*self.reg_term_next[regname] + outstr += "%s : %0.1f " % (regname, rval) + + # Embed and plot the image + if self.pol_next in POLARIZATION_MODES: + if np.any(np.invert(self._embed_mask)): + imcur = polutils.embed_pol(imcur, self._embed_mask) + polutils.plot_m(imcur, self.prior_next, self._nit, chi2_term_dict, **kwargs) + + else: + if self.mf_next: + implot = imcur[0] + else: + implot = imcur + if np.any(np.invert(self._embed_mask)): + implot = imutils.embed(implot, self._embed_mask) + + imutils.plot_i(implot, self.prior_next, self._nit, + chi2_term_dict, pol=self.pol_next, **kwargs) + + if self._nit == 0: + print() + print(outstr) + + self._nit += 1 + + def objfunc_scattering(self, minvec): + """Current stochastic optics objective function. + """ + N = self.prior_next.xdim + + imvec = minvec[:N**2] + EpsilonList = minvec[N**2:] + if 'log' in self.transform_next: + imvec = np.exp(imvec) + + IM = ehtim.image.Image(imvec.reshape(N, N), self.prior_next.psize, + self.prior_next.ra, self.prior_next.dec, + self.prior_next.pa, rf=self.obs_next.rf, + source=self.prior_next.source, mjd=self.prior_next.mjd) + + # The scattered image vector + screen = so.MakeEpsilonScreenFromList(EpsilonList, N) + scatt_im = self.scattering_model.Scatter(IM, Epsilon_Screen=screen, + ea_ker=self._ea_ker, sqrtQ=self._sqrtQ, + Linearized_Approximation=True) + scatt_im = scatt_im.imvec + + # Calculate the chi^2 using the scattered image + datterm = 0. + chi2_term_dict = self.make_chisq_dict(scatt_im) + for dname in sorted(self.dat_term_next.keys()): + datterm += self.dat_term_next[dname] * (chi2_term_dict[dname] - 1.) + + # Calculate the entropy using the unscattered image + regterm = 0 + reg_term_dict = self.make_reg_dict(imvec) + + # Make dict also for scattered image + reg_term_dict_scatt = self.make_reg_dict(scatt_im) + + for regname in sorted(self.reg_term_next.keys()): + if regname == 'rgauss': + # Get gradient of the scattered image vector + regterm += self.reg_term_next[regname] * reg_term_dict_scatt[regname] + + else: + regterm += self.reg_term_next[regname] * reg_term_dict[regname] + + # Scattering screen regularization term + chisq_epsilon = sum(EpsilonList*EpsilonList)/((N*N-1.0)/2.0) + regterm_scattering = self.alpha_phi_next * (chisq_epsilon - 1.0) + + return datterm + regterm + regterm_scattering + + def objgrad_scattering(self, minvec): + """Current stochastic optics objective function gradient + """ + wavelength = ehc.C/self.obs_next.rf*100.0 # Observing wavelength [cm] + wavelengthbar = wavelength/(2.0*np.pi) # lambda/(2pi) [cm] + N = self.prior_next.xdim + + # Field of view, in cm, at the scattering screen + FOV = self.prior_next.psize * N * self.scattering_model.observer_screen_distance + rF = self.scattering_model.rF(wavelength) + + imvec = minvec[:N**2] + EpsilonList = minvec[N**2:] + if 'log' in self.transform_next: + imvec = np.exp(imvec) + + IM = ehtim.image.Image(imvec.reshape(N, N), self.prior_next.psize, + self.prior_next.ra, self.prior_next.dec, + self.prior_next.pa, rf=self.obs_next.rf, + source=self.prior_next.source, mjd=self.prior_next.mjd) + + # The scattered image vector + screen = so.MakeEpsilonScreenFromList(EpsilonList, N) + scatt_im = self.scattering_model.Scatter(IM, Epsilon_Screen=screen, + ea_ker=self._ea_ker, sqrtQ=self._sqrtQ, + Linearized_Approximation=True) + scatt_im = scatt_im.imvec + + EA_Image = self.scattering_model.Ensemble_Average_Blur(IM, ker=self._ea_ker) + EA_Gradient = so.Wrapped_Gradient((EA_Image.imvec/(FOV/N)).reshape(N, N)) + + # The gradient signs don't actually matter, but let's make them match intuition + # (i.e., right to left, bottom to top) + EA_Gradient_x = -EA_Gradient[1] + EA_Gradient_y = -EA_Gradient[0] + + Epsilon_Screen = so.MakeEpsilonScreenFromList(EpsilonList, N) + phi_scr = self.scattering_model.MakePhaseScreen(Epsilon_Screen, IM, + obs_frequency_Hz=self.obs_next.rf, + sqrtQ_init=self._sqrtQ) + phi = phi_scr.imvec.reshape((N, N)) + phi_Gradient = so.Wrapped_Gradient(phi/(FOV/N)) + phi_Gradient_x = -phi_Gradient[1] + phi_Gradient_y = -phi_Gradient[0] + + # Entropy gradient; wrt unscattered image so unchanged by scattering + regterm = 0 + reg_term_dict = self.make_reggrad_dict(imvec) + + # Make dict also for scattered image + reg_term_dict_scatt = self.make_reggrad_dict(scatt_im) + + for regname in sorted(self.reg_term_next.keys()): + # We need an exception if the regularizer is 'rgauss' + if regname == 'rgauss': + # Get gradient of the scattered image vector + gaussterm = self.reg_term_next[regname] * reg_term_dict_scatt[regname] + dgauss_dIa = gaussterm.reshape((N, N)) + + # Now the chain rule factor to get the gauss gradient wrt the unscattered image + gx = so.Wrapped_Convolve( + self._ea_ker_gradient_x[::-1, ::-1], phi_Gradient_x * (dgauss_dIa)) + gx = (rF**2.0 * gx).flatten() + + gy = so.Wrapped_Convolve( + self._ea_ker_gradient_y[::-1, ::-1], phi_Gradient_y * (dgauss_dIa)) + gy = (rF**2.0 * gy).flatten() + + # Now we add the gradient for the unscattered image + regterm += so.Wrapped_Convolve(self._ea_ker[::-1, ::-1], + (dgauss_dIa)).flatten() + gx + gy + + else: + regterm += self.reg_term_next[regname] * reg_term_dict[regname] + + # Chi^2 gradient wrt the unscattered image + # First, the chi^2 gradient wrt to the scattered image + datterm = 0. + chi2_term_dict = self.make_chisqgrad_dict(scatt_im) + for dname in sorted(self.dat_term_next.keys()): + datterm += self.dat_term_next[dname] * (chi2_term_dict[dname]) + dchisq_dIa = datterm.reshape((N, N)) + + # Now the chain rule factor to get the chi^2 gradient wrt the unscattered image + gx = so.Wrapped_Convolve(self._ea_ker_gradient_x[::-1, ::-1], phi_Gradient_x * (dchisq_dIa)) + gx = (rF**2.0 * gx).flatten() + + gy = so.Wrapped_Convolve(self._ea_ker_gradient_y[::-1, ::-1], phi_Gradient_y * (dchisq_dIa)) + gy = (rF**2.0 * gy).flatten() + + chisq_grad_im = so.Wrapped_Convolve( + self._ea_ker[::-1, ::-1], (dchisq_dIa)).flatten() + gx + gy + + # Gradient of the data chi^2 wrt to the epsilon screen + # Preliminary Definitions + chisq_grad_epsilon = np.zeros(N**2-1) + i_grad = 0 + ell_mat = np.zeros((N, N)) + m_mat = np.zeros((N, N)) + for ell in range(0, N): + for m in range(0, N): + ell_mat[ell, m] = ell + m_mat[ell, m] = m + + # Real part; top row + for t in range(1, (N+1)//2): + s = 0 + grad_term = (wavelengthbar/FOV*self._sqrtQ[s][t] * + 2.0*np.cos(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N)) + grad_term = so.Wrapped_Gradient(grad_term) + grad_term_x = -grad_term[1] + grad_term_y = -grad_term[0] + + cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y) + chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term) + + i_grad = i_grad + 1 + + # Real part; remainder + for s in range(1, (N+1)//2): + for t in range(N): + grad_term = (wavelengthbar/FOV*self._sqrtQ[s][t] * + 2.0*np.cos(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N)) + grad_term = so.Wrapped_Gradient(grad_term) + grad_term_x = -grad_term[1] + grad_term_y = -grad_term[0] + + cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y) + chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term) + + i_grad = i_grad + 1 + + # Imaginary part; top row + for t in range(1, (N+1)//2): + s = 0 + grad_term = (-wavelengthbar/FOV*self._sqrtQ[s][t] * + 2.0*np.sin(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N)) + grad_term = so.Wrapped_Gradient(grad_term) + grad_term_x = -grad_term[1] + grad_term_y = -grad_term[0] + + cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y) + chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term) + + i_grad = i_grad + 1 + + # Imaginary part; remainder + for s in range(1, (N+1)//2): + for t in range(N): + grad_term = (-wavelengthbar/FOV*self._sqrtQ[s][t] * + 2.0*np.sin(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N)) + grad_term = so.Wrapped_Gradient(grad_term) + grad_term_x = -grad_term[1] + grad_term_y = -grad_term[0] + + cge_term = (EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y) + chisq_grad_epsilon[i_grad] = np.sum(dchisq_dIa * rF**2 * cge_term) + i_grad = i_grad + 1 + + # Gradient of the chi^2 regularization term for the epsilon screen + chisq_epsilon_grad = self.alpha_phi_next * 2.0*EpsilonList/((N*N-1)/2.0) + + # Chain rule term for change of variables + if 'log' in self.transform_next: + regterm *= imvec + chisq_grad_im *= imvec + + out = np.concatenate(((regterm + chisq_grad_im), (chisq_grad_epsilon + chisq_epsilon_grad))) + return out + + def plotcur_scattering(self, minvec): + """Plot current stochastic optics image/screen + """ + if self._show_updates: + if self._nit % self._update_interval == 0: + N = self.prior_next.xdim + + imvec = minvec[:N**2] + EpsilonList = minvec[N**2:] + if 'log' in self.transform_next: + imvec = np.exp(imvec) + + IM = ehtim.image.Image(imvec.reshape(N, N), self.prior_next.psize, + self.prior_next.ra, self.prior_next.dec, + self.prior_next.pa, rf=self.obs_next.rf, + source=self.prior_next.source, mjd=self.prior_next.mjd) + + # The scattered image vector + screen = so.MakeEpsilonScreenFromList(EpsilonList, N) + scatt_im = self.scattering_model.Scatter(IM, Epsilon_Screen=screen, + ea_ker=self._ea_ker, sqrtQ=self._sqrtQ, + Linearized_Approximation=True) + scatt_im = scatt_im.imvec + + # Calculate the chi^2 using the scattered image + datterm = 0. + chi2_term_dict = self.make_chisq_dict(scatt_im) + for dname in sorted(self.dat_term_next.keys()): + datterm += self.dat_term_next[dname] * (chi2_term_dict[dname] - 1.) + + # Calculate the entropy using the unscattered image + regterm = 0 + reg_term_dict = self.make_reg_dict(imvec) + for regname in sorted(self.reg_term_next.keys()): + regterm += self.reg_term_next[regname] * reg_term_dict[regname] + + # Scattering screen regularization term + chisq_epsilon = sum(EpsilonList*EpsilonList)/((N*N-1.0)/2.0) + # regterm_scattering = self.alpha_phi_next * (chisq_epsilon - 1.0) + + outstr = "i: %d " % self._nit + + for dname in sorted(self.dat_term_next.keys()): + outstr += "%s : %0.2f " % (dname, chi2_term_dict[dname]) + for regname in sorted(self.reg_term_next.keys()): + outstr += "%s : %0.2f " % (regname, reg_term_dict[regname]) + outstr += "Epsilon chi^2 : %0.2f " % (chisq_epsilon) + outstr += "Max |Epsilon| : %0.2f " % (max(abs(EpsilonList))) + print(outstr) + + self._nit += 1 + + def make_image_I_stochastic_optics(self, grads=True, **kwargs): + """Reconstructs an image of total flux density + using the stochastic optics scattering mitigation technique. + + Uses the scattering model in Imager.scattering_model. + If none has been specified, defaults to standard model for Sgr A*. + Returns the estimated unscattered image. + + Args: + grads (bool): Flag for whether or not to use analytic gradients. + show_updates (bool): Flag for whether or not to show updates + Returns: + out (Image): The estimated *unscattered* image. + """ + + N = self.prior_next.xdim + + # Checks and initialize + self.check_params() + self.check_limits() + self.init_imager() + self.init_imager_scattering() + self._nit = 0 + + # Print stats + self._show_updates = kwargs.get('show_updates', True) + self._update_interval = kwargs.get('update_interval', 1) + self.plotcur_scattering(self._xinit) + + # Minimize + optdict = {'maxiter': self.maxit_next, 'ftol': self.stop_next, 'maxcor': NHIST} + tstart = time.time() + if grads: + res = opt.minimize(self.objfunc_scattering, self._xinit, method='L-BFGS-B', + jac=self.objgrad_scattering, options=optdict, + callback=self.plotcur_scattering) + else: + res = opt.minimize(self.objfunc_scattering, self._xinit, method='L-BFGS-B', + options=optdict, callback=self.plotcur_scattering) + tstop = time.time() + + # Format output + out = res.x[:N**2] + if 'log' in self.transform_next: + out = np.exp(out) + if np.any(np.invert(self._embed_mask)): + raise Exception("Embedding is not currently implemented!") + out = imutils.embed(out, self._embed_mask) + + outim = ehtim.image.Image(out.reshape(N, N), self.prior_next.psize, + self.prior_next.ra, self.prior_next.dec, self.prior_next.pa, + rf=self.prior_next.rf, source=self.prior_next.source, + mjd=self.prior_next.mjd, pulse=self.prior_next.pulse) + outep = res.x[N**2:] + screen = so.MakeEpsilonScreenFromList(outep, N) + outscatt = self.scattering_model.Scatter(outim, + Epsilon_Screen=screen, + ea_ker=self._ea_ker, sqrtQ=self._sqrtQ, + Linearized_Approximation=True) + + # Preserving image complex polarization fractions + if len(self.prior_next.qvec): + qvec = self.prior_next.qvec * out / self.prior_next.imvec + uvec = self.prior_next.uvec * out / self.prior_next.imvec + outim.add_qu(qvec.reshape(N, N), + uvec.reshape(N, N)) + + # Print stats + print("time: %f s" % (tstop - tstart)) + print("J: %f" % res.fun) + print(res.message) + + # Append to history + logstr = str(self.nruns) + ": make_image_I_stochastic_optics()" + self._append_image_history(outim, logstr) + self._out_list_epsilon.append(res.x[N**2:]) + self._out_list_scattered.append(outscatt) + + self.nruns += 1 + + # Return Image object + return outim + + def _append_image_history(self, outim, logstr): + self.logstr += (logstr + "\n") + self._obs_list.append(self.obslist_next) + self._init_list.append(self.init_next) + self._prior_list.append(self.prior_next) + self._debias_list.append(self.debias_next) + self._weighting_list.append(self.weighting_next) + self._systematic_noise_list.append(self.systematic_noise_next) + self._systematic_cphase_noise_list.append(self.systematic_cphase_noise_next) + self._snrcut_list.append(self.snrcut_next) + self._flux_list.append(self.flux_next) + self._pflux_list.append(self.pflux_next) + self._vflux_list.append(self.vflux_next) + self._pol_list.append(self.pol_next) + self._clipfloor_list.append(self.clipfloor_next) + self._maxset_list.append(self.clipfloor_next) + self._maxit_list.append(self.maxit_next) + self._stop_list.append(self.stop_next) + self._transform_list.append(self.transform_next) + self._reg_term_list.append(self.reg_term_next) + self._dat_term_list.append(self.dat_term_next) + self._alpha_phi_list.append(self.alpha_phi_next) + + self._out_list.append(outim) + return diff --git a/imaging/__init__.py b/imaging/__init__.py new file mode 100644 index 00000000..9cd8c613 --- /dev/null +++ b/imaging/__init__.py @@ -0,0 +1,9 @@ +""" +.. module:: ehtim.imaging + :platform: Unix + :synopsis: EHT Imaging Utilities: imaging functions + +.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu) + +""" +from ..const_def import * diff --git a/imaging/clean.py b/imaging/clean.py new file mode 100644 index 00000000..5ed177ae --- /dev/null +++ b/imaging/clean.py @@ -0,0 +1,1241 @@ +# clean.py +# Clean-like imagers +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import +from builtins import range + +import string +import time +import numpy as np +import scipy.optimize as opt +import scipy.ndimage as nd +import scipy.ndimage.filters as filt +import matplotlib.pyplot as plt +try: + from pynfft.nfft import NFFT +except ImportError: + pass + #print("Warning: No NFFT installed!") +import numpy.polynomial.polynomial as p + +import ehtim.image as image + +from ehtim.const_def import * +from ehtim.observing.obs_helpers import * +from ehtim.imaging.imager_utils import * + +################################################################################################## +# Constants & Definitions +################################################################################################## + + +NHIST = 50 # number of steps to store for hessian approx +MAXIT = 100. + +DATATERMS = ['vis', 'bs', 'amp', 'cphase', 'camp', 'logcamp'] +REGULARIZERS = ['gs', 'tv', 'tv2','l1', 'patch', 'simple', 'compact', 'compact2'] + +NFFT_KERSIZE_DEFAULT = 20 +GRIDDER_P_RAD_DEFAULT = 2 +GRIDDER_CONV_FUNC_DEFAULT = 'gaussian' +FFT_PAD_DEFAULT = 2 +FFT_INTERP_DEFAULT = 3 + +nit = 0 # global variable to track the iteration number in the plotting callback + +################################################################################################## +# Imagers +################################################################################################## +def plot_i(Image, nit, chi2, fig=1, cmap='afmhot'): + """Plot the total intensity image at each iteration + """ + + plt.ion() + plt.figure(fig) + plt.pause(0.00001) + plt.clf() + + plt.imshow(Image.imvec.reshape(Image.ydim,Image.xdim), cmap=plt.get_cmap(cmap), interpolation='gaussian') + xticks = ticks(Image.xdim, Image.psize/RADPERAS/1e-6) + yticks = ticks(Image.ydim, Image.psize/RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + plt.title("step: %i $\chi^2$: %f " % (nit, chi2), fontsize=20) + +def dd_clean_vis(Obsdata, InitIm, niter=1, clipfloor=-1, ttype="direct", loop_gain=1, method='min_chisq', weighting='uniform', + fft_pad_factor=FFT_PAD_DEFAULT, p_rad=NFFT_KERSIZE_DEFAULT, show_updates=False): + + # limit imager range to prior values > clipfloor + embed_mask = InitIm.imvec >= clipfloor + + # get data + vis = Obsdata.data['vis'] + sigma = Obsdata.data['sigma'] + uv = np.hstack((Obsdata.data['u'].reshape(-1,1), Obsdata.data['v'].reshape(-1,1))) + + # necessary nfft infos + npad = int(fft_pad_factor * np.max((InitIm.xdim, InitIm.ydim))) + nfft_info = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv) + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + # Weights + weights_nat = 1./(sigma**2) + + if weighting=='uniform': + weights = np.ones(len(weights_nat)) + elif weighting=='natural': + weights=weights_nat + else: + raise Exception("weighting must be 'uniform' or 'natural'") + weights_norm = np.sum(weights) + + # Coordinate matrix + coord = InitIm.psize * np.array([[[x,y] for x in np.arange(InitIm.xdim//2,-InitIm.xdim//2,-1)] + for y in np.arange(InitIm.ydim//2,-InitIm.ydim//2,-1)]) + coord = coord.reshape(InitIm.ydim*InitIm.xdim, 2) + coord = coord[embed_mask] + + # Initial imvec and visibilities + # TODO currently always initialized to zero!! + OutputIm = InitIm.copy() + DeltasIm = InitIm.copy() + ChisqIm = InitIm.copy() + + res = Obsdata.res() + beamparams = Obsdata.fit_beam() + + imvec_init = 0*InitIm.imvec[embed_mask] + vis_init = np.zeros(len(vis),dtype='complex128') + + imvec_current = imvec_init + vis_current = vis_init + + chisq_init = np.sum(weights*np.abs(vis-vis_init)**2) + rchisq_init = np.sum(weights_nat*np.abs(vis-vis_init)**2)/(2*len(weights_nat)) + chisq_current = chisq_init + rchisq_current = rchisq_init + + # clean loop + print("\n") + for it in range(niter): + + resid_current = vis - vis_current + + plan.f = weights * resid_current + plan.adjoint() + out = np.real((plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim)) + deltas_all = out/ weights_norm + deltas = deltas_all[embed_mask] + + chisq_map_all = chisq_current - (deltas_all**2)*weights_norm + chisq_map = chisq_map_all[embed_mask] + + # Visibility space clean + if method=='min_chisq': + component_loc_idx = np.argmin(chisq_map) + + # Image space clean + elif method=='max_delta': + component_loc_idx = np.argmax(deltas) + + else: + raise Exception("method should be 'min_chisq' or 'max_delta'!") + + # display images of delta and chisq + if show_updates: + DeltasIm.imvec = deltas_all + plot_i(DeltasIm, it, chisq_current,fig=0, cmap='afmhot') + + ChisqIm.imvec = -chisq_map_all + plot_i(ChisqIm, it, chisq_current,fig=1, cmap='cool') + + # clean component location + component_loc_x = coord[component_loc_idx][0] + component_loc_y = coord[component_loc_idx][1] + component_strength = loop_gain*deltas[component_loc_idx] + + # update vis and imvec + imvec_current[component_loc_idx] += component_strength + + #TODO how to incorporate pulse function? + vis_current += component_strength*np.exp(2*np.pi*1j*(uv[:,0]*component_loc_x + uv[:,1]*component_loc_y)) + + # update chi^2 and output image + chisq_current = np.sum(weights*np.abs(vis-vis_current)**2) + rchisq_current = np.sum(weights_nat*np.abs(vis-vis_current)**2)/(2*len(weights_nat)) + + print(it+1,component_strength, chisq_current, rchisq_current, component_loc_x/RADPERUAS, component_loc_y/RADPERUAS) + + OutputIm.imvec = imvec_current + if show_updates: + OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False) + OutputImBlur = OutputIm.blur_gauss(beamparams) + plot_i(OutputImBlur, it, rchisq_current, fig=2) + + OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False) + return OutputIm + + +#solve full 5th order polynomial +def dd_clean_bispec_full(Obsdata, InitIm, niter=1, clipfloor=-1, loop_gain=.1, + weighting='uniform', bscount="min",show_updates=True, + fft_pad_factor=FFT_PAD_DEFAULT, p_rad=NFFT_KERSIZE_DEFAULT): + + + # limit imager range to prior values > clipfloor + embed_mask = InitIm.imvec >= clipfloor + + # get data + biarr = Obsdata.bispectra(mode="all", count=bscount) + uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1))) + uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1))) + uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1))) + bs = biarr['bispec'] + sigma = biarr['sigmab'] + + # necessary nfft infos + npad = int(fft_pad_factor * np.max((InitIm.xdim, InitIm.ydim))) + nfft_info1 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1) + nfft_info2 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2) + nfft_info3 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3) + + nfft_info11 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv1) + nfft_info22 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv2) + nfft_info33 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv3) + + nfft_info12 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1-uv2) + nfft_info23 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2-uv3) + nfft_info31 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3-uv1) + + # TODO do we use pulse factors? + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + plan11 = nfft_info11.plan + pulsefac11 = nfft_info11.pulsefac + plan22 = nfft_info22.plan + pulsefac22 = nfft_info22.pulsefac + plan33 = nfft_info33.plan + pulsefac33 = nfft_info33.pulsefac + + plan12 = nfft_info12.plan + pulsefac12 = nfft_info12.pulsefac + plan23 = nfft_info23.plan + pulsefac23 = nfft_info23.pulsefac + plan31 = nfft_info31.plan + pulsefac31 = nfft_info31.pulsefac + + # Weights + weights_nat = 1./(sigma**2) + if weighting=='uniform': + weights = np.ones(len(weights_nat)) + elif weighting=='natural': + weights=weights_nat + else: + raise Exception("weighting must be 'uniform' or 'natural'") + weights_norm = np.sum(weights) + + # Coordinate matrix + # TODO what if the image is odd? + coord = InitIm.psize * np.array([[[x,y] for x in np.arange(InitIm.xdim//2,-InitIm.xdim//2,-1)] + for y in np.arange(InitIm.ydim//2,-InitIm.ydim//2,-1)]) + coord = coord.reshape(InitIm.ydim*InitIm.xdim, 2) + coord = coord[embed_mask] + + # Initial imvec and visibilities + # TODO currently initialized to zero!! + OutputIm = InitIm.copy() + DeltasIm = InitIm.copy() + ChisqIm = InitIm.copy() + + res = Obsdata.res() + beamparams = Obsdata.fit_beam() + + imvec_init = 0*InitIm.imvec[embed_mask] + vis1_init = np.zeros(len(bs), dtype='complex128') + vis2_init = np.zeros(len(bs), dtype='complex128') + vis3_init = np.zeros(len(bs), dtype='complex128') + bs_init = vis1_init*vis2_init*vis3_init + chisq_init = np.sum(weights*np.abs(bs - bs_init)**2) + rchisq_init = np.sum(weights_nat*np.abs(bs - bs_init)**2)/(2*len(weights_nat)) + + imvec_current = imvec_init + vis1_current = vis1_init + vis2_current = vis2_init + vis3_current = vis3_init + bs_current = bs_init + chisq_current = chisq_init + rchisq_current = rchisq_init + + # clean loop + print("\n") + for it in range(niter): + t = time.time() + # compute delta at each location + resid_current = bs - bs_current + vis12_current = vis1_current*vis2_current + vis13_current = vis1_current*vis3_current + vis23_current = vis2_current*vis3_current + + # center the first component automatically + # since initial image is empty, must go to higher order in delta in solution + # TODO generalize to non-empty initial image! + if it==0: + + A = np.sum(weights*np.real(resid_current)) + B = np.sum(weights) + + component_strength = np.cbrt(A/B) + component_loc_idx = (InitIm.ydim//2)*InitIm.xdim + InitIm.xdim//2 #TODO is this right for odd images??+/- 1?? + + else: + # First calculate P (1st order) + plan1.f = weights * resid_current * vis23_current.conj() + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * resid_current * vis13_current.conj() + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * resid_current * vis12_current.conj() + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + P = -2 * (out1 + out2 + out3) + + # Then calculate Q (2nd order) + plan12.f = weights * vis13_current*vis23_current.conj() + plan12.adjoint() + out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim)) + plan23.f = weights * vis12_current*vis13_current.conj() + plan23.adjoint() + out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim)) + plan31.f = weights * vis23_current*vis12_current.conj() + plan31.adjoint() + out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim)) + + plan1.f = weights * resid_current.conj() * vis1_current + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * resid_current.conj() * vis2_current + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * resid_current.conj() * vis3_current + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + Q0 = np.sum(weights*(np.abs(vis12_current)**2 + np.abs(vis23_current)**2 + np.abs(vis13_current)**2)) + Q1 = 2 * (out12 + out23 + out31) + Q2 = -2 * (out1 + out2 + out3) + Q = 2*(Q0 + Q1 + Q2) + + #Calculate R (3rd order) + plan11.f = weights * vis1_current.conj() * vis23_current + plan11.adjoint() + out11 = np.real((plan11.f_hat.copy().T).reshape(nfft_info11.xdim*nfft_info11.ydim)) + plan22.f = weights * vis2_current.conj() * vis13_current + plan22.adjoint() + out22 = np.real((plan22.f_hat.copy().T).reshape(nfft_info22.xdim*nfft_info22.ydim)) + plan33.f = weights * vis3_current.conj() * vis12_current + plan33.adjoint() + out33 = np.real((plan33.f_hat.copy().T).reshape(nfft_info33.xdim*nfft_info33.ydim)) + + plan1.f = weights * vis1_current * (np.abs(vis2_current)**2 + np.abs(vis3_current)**2) + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * vis2_current * (np.abs(vis1_current)**2 + np.abs(vis3_current)**2) + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * vis3_current * (np.abs(vis1_current)**2 + np.abs(vis2_current)**2) + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + R0 = -np.sum(weights*np.real(resid_current)) + R1 = np.real(out11 + out22 + out33) + R2 = np.real(out1 + out2 + out3) + R = 6*(R0 + R1 + R2) + + # Now find S (4th order) + plan12.f = weights * vis1_current*vis2_current.conj() + plan12.adjoint() + out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim)) + plan23.f = weights * vis2_current*vis3_current.conj() + plan23.adjoint() + out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim)) + plan31.f = weights * vis3_current*vis1_current.conj() + plan31.adjoint() + out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim)) + + plan1.f = weights * vis23_current.conj() + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * vis13_current.conj() + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * vis12_current.conj() + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + S0 = np.sum(weights*(np.abs(vis1_current)**2 + np.abs(vis2_current)**2 + np.abs(vis3_current)**2)) + S1 = 2 * (out12 + out23 + out31) + S2 = 2 * (out1 + out2 + out3) + S = 4*(S0 + S1 + S2) + + + # T (5th order) + plan1.f = weights * vis1_current + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * vis2_current + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * vis3_current + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + T = 10*(out1 + out2 + out3) + + # Finally U (6th order) + U = 6*weights_norm * np.ones(T.shape) + + # Find Component + deltas = np.zeros(len(P)) + chisq_map = np.zeros(len(P)) + for i in range(len(P)): + polynomial_params = np.array([P[i], Q[i], R[i], S[i], T[i], U[i]]) + allroots = p.polyroots(polynomial_params) + + # test roots to see which minimizes chi^2 + newchisq = chisq_current + delta = 0 + for j in range(len(allroots)): + root = allroots[j] + if np.imag(root)!=0: continue + + trialchisq = chisq_current + P[i]*root + 0.5*Q[i]*root**2 + (1./3.)*R[i]*root**3 + 0.25*S[i]*root**4 + 0.2*T[i]*root**5 + (1./6.)*U[i]*root**6 + if trialchisq < newchisq: + delta = root + newchisq = trialchisq + + deltas[i] = delta + chisq_map[i] = newchisq + + #plot deltas and chi^2 map + if show_updates: + DeltasIm.imvec = deltas + plot_i(DeltasIm, it, chisq_current,fig=0, cmap='afmhot') + + ChisqIm.imvec = -chisq_map + plot_i(ChisqIm, it, chisq_current,fig=1, cmap='cool') + + + component_loc_idx = np.argmin(chisq_map[embed_mask]) + component_strength = loop_gain*(deltas[embed_mask])[component_loc_idx] + + # clean component location + component_loc_x = coord[component_loc_idx][0] + component_loc_y = coord[component_loc_idx][1] + + # update imvec, vis, bispec + imvec_current[component_loc_idx] += component_strength + + #TODO how to incorporate pulse function? + vis1_current += component_strength*np.exp(2*np.pi*1j*(uv1[:,0]*component_loc_x + uv1[:,1]*component_loc_y)) + vis2_current += component_strength*np.exp(2*np.pi*1j*(uv2[:,0]*component_loc_x + uv2[:,1]*component_loc_y)) + vis3_current += component_strength*np.exp(2*np.pi*1j*(uv3[:,0]*component_loc_x + uv3[:,1]*component_loc_y)) + bs_current = vis1_current * vis2_current * vis3_current + + # update chi^2 and output image + chisq_current = np.sum(weights*np.abs(bs - bs_current)**2) + rchisq_current = np.sum(weights_nat*np.abs(bs - bs_current)**2)/(2*len(weights_nat)) + + print("it %i: %f (%.2f , %.2f) %.4f" % (it+1, component_strength, component_loc_x/RADPERUAS, component_loc_y/RADPERUAS, chisq_current)) + + OutputIm.imvec = imvec_current + if show_updates: + OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False) + OutputImBlur = OutputIm.blur_gauss(beamparams) + plot_i(OutputImBlur, it, rchisq_current,fig=2) + + OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False) + return OutputIm + + + + +#solve full 5th order polynomial +#weight imaginary term differently +def dd_clean_bispec_imweight(Obsdata, InitIm, niter=1, clipfloor=-1, ttype="direct", loop_gain=.1, loop_gain_init=1, + weighting='uniform', bscount="min", imweight=1, show_updates=True, + fft_pad_factor=FFT_PAD_DEFAULT, p_rad=NFFT_KERSIZE_DEFAULT): + + + imag_weight=imweight + # limit imager range to prior values > clipfloor + embed_mask = InitIm.imvec >= clipfloor + + # get data + biarr = Obsdata.bispectra(mode="all", count=bscount) + uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1))) + uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1))) + uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1))) + bs = biarr['bispec'] + sigma = biarr['sigmab'] + + # necessary nfft infos + npad = int(fft_pad_factor * np.max((InitIm.xdim, InitIm.ydim))) + nfft_info1 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1) + nfft_info2 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2) + nfft_info3 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3) + + nfft_info11 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv1) + nfft_info22 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv2) + nfft_info33 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv3) + + nfft_info12 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1-uv2) + nfft_info23 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2-uv3) + nfft_info31 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3-uv1) + + # TODO do we use pulse factors? + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + plan11 = nfft_info11.plan + pulsefac11 = nfft_info11.pulsefac + plan22 = nfft_info22.plan + pulsefac22 = nfft_info22.pulsefac + plan33 = nfft_info33.plan + pulsefac33 = nfft_info33.pulsefac + + plan12 = nfft_info12.plan + pulsefac12 = nfft_info12.pulsefac + plan23 = nfft_info23.plan + pulsefac23 = nfft_info23.pulsefac + plan31 = nfft_info31.plan + pulsefac31 = nfft_info31.pulsefac + + # Weights + weights_nat = 1./(sigma**2) + if weighting=='uniform': + weights = np.ones(len(weights_nat)) + elif weighting=='natural': + weights=weights_nat + else: + raise Exception("weighting must be 'uniform' or 'natural'") + weights_norm = np.sum(weights) + + # Coordinate matrix + # TODO do we need to make sure this corresponds exactly with what NFFT is doing? + # TODO what if the image is odd? + coord = InitIm.psize * np.array([[[x,y] for x in np.arange(InitIm.xdim//2,-InitIm.xdim//2,-1)] + for y in np.arange(InitIm.ydim//2,-InitIm.ydim//2,-1)]) + coord = coord.reshape(InitIm.ydim*InitIm.xdim, 2) + coord = coord[embed_mask] + + # Initial imvec and visibilities + # TODO currently initialized to zero!! + OutputIm = InitIm.copy() + DeltasIm = InitIm.copy() + ChisqIm = InitIm.copy() + + res = Obsdata.res() + beamparams = Obsdata.fit_beam() + + imvec_init = 0*InitIm.imvec[embed_mask] + + vis1_init = np.zeros(len(bs), dtype='complex128') + vis2_init = np.zeros(len(bs), dtype='complex128') + vis3_init = np.zeros(len(bs), dtype='complex128') + bs_init = vis1_init*vis2_init*vis3_init + chisq_init = np.sum(weights*(np.real(bs - bs_init)**2 + imweight*np.imag(bs-bs_init)**2)) + rchisq_init = np.sum(weights_nat*np.abs(bs - bs_init)**2)/(2*len(weights_nat)) + + imvec_current = imvec_init + vis1_current = vis1_init + vis2_current = vis2_init + vis3_current = vis3_init + bs_current = bs_init + chisq_current = chisq_init + rchisq_current = rchisq_init + + # clean loop + print("\n") + for it in range(niter): + t = time.time() + # compute delta at each location + resid_current = bs - bs_current + vis12_current = vis1_current*vis2_current + vis13_current = vis1_current*vis3_current + vis23_current = vis2_current*vis3_current + + # center the first component automatically + # since initial image is empty, must go to higher order in delta in solution + # TODO generalize to non-empty initial image! + if it==0: + + A = np.sum(weights*np.real(resid_current)) + B = np.sum(weights) + + component_strength = loop_gain_init*np.cbrt(A/B) + component_loc_idx = (InitIm.ydim//2)*InitIm.xdim + InitIm.xdim//2 #TODO is this right for odd images??+/- 1?? + + else: + # First calculate P (1st order) + plan1.f = weights * np.real(resid_current.conj() * vis23_current) + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * np.real(resid_current.conj() * vis13_current) + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * np.real(resid_current.conj() * vis12_current) + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + PRe = -2 * (out1 + out2 + out3) + + plan1.f = weights * np.imag(resid_current.conj() * vis23_current) + plan1.adjoint() + out1 = np.imag((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * np.imag(resid_current.conj() * vis13_current) + plan2.adjoint() + out2 = np.imag((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * np.imag(resid_current.conj() * vis12_current) + plan3.adjoint() + out3 = np.imag((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + PIm = -2 * (out1 + out2 + out3) + + P = PRe + imag_weight*PIm + + # Then calculate Q (2nd order) + plan12.f = weights * np.real(vis13_current.conj()*vis23_current) + plan12.adjoint() + out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim)) + plan23.f = weights * np.real(vis12_current.conj()*vis13_current) + plan23.adjoint() + out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim)) + plan31.f = weights * np.real(vis23_current.conj()*vis12_current) + plan31.adjoint() + out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim)) + + plan1.f = weights * np.real(resid_current * vis1_current.conj()) + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * np.real(resid_current * vis2_current.conj()) + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * np.real(resid_current * vis3_current.conj()) + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + Q0Re = np.sum(weights*(np.real(vis12_current)**2 + np.real(vis23_current)**2 + np.real(vis13_current)**2)) + Q1Re = 2 * (out12 + out23 + out31) + Q2Re = -2*(out1 + out2 + out3) + QRe = 2*(Q0Re + Q1Re + Q2Re) + + plan12.f = weights * np.imag(vis13_current.conj()*vis23_current) + plan12.adjoint() + out12 = np.imag((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim)) + plan23.f = weights * np.imag(vis12_current.conj()*vis13_current) + plan23.adjoint() + out23 = np.imag((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim)) + plan31.f = weights * np.imag(vis23_current.conj()*vis12_current) + plan31.adjoint() + out31 = np.imag((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim)) + + plan1.f = weights * np.imag(resid_current * vis1_current.conj()) + plan1.adjoint() + out1 = np.imag((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * np.imag(resid_current * vis2_current.conj()) + plan2.adjoint() + out2 = np.imag((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * np.imag(resid_current * vis3_current.conj()) + plan3.adjoint() + out3 = np.imag((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + Q0Im = np.sum(weights*(np.imag(vis12_current)**2 + np.imag(vis23_current)**2 + np.imag(vis13_current)**2)) + Q1Im = 2 * (out12 + out23 + out31) + Q2Im = -2*(out1 + out2 + out3) + QIm = 2*(Q0Im + Q1Im + Q2Im) + + Q = QRe + imag_weight*QIm + + #Calculate R (3rd order) + plan11.f = weights * np.real(vis1_current * vis23_current.conj()) + plan11.adjoint() + out11 = np.real((plan11.f_hat.copy().T).reshape(nfft_info11.xdim*nfft_info11.ydim)) + plan22.f = weights * np.real(vis2_current * vis13_current.conj()) + plan22.adjoint() + out22 = np.real((plan22.f_hat.copy().T).reshape(nfft_info22.xdim*nfft_info22.ydim)) + plan33.f = weights * np.real(vis3_current * vis12_current.conj()) + plan33.adjoint() + out33 = np.real((plan33.f_hat.copy().T).reshape(nfft_info33.xdim*nfft_info33.ydim)) + + plan1.f = weights * np.real(vis1_current.conj()) * (np.abs(vis2_current)**2 + np.abs(vis3_current)**2) + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * np.real(vis2_current.conj()) * (np.abs(vis1_current)**2 + np.abs(vis3_current)**2) + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * np.real(vis3_current.conj()) * (np.abs(vis1_current)**2 + np.abs(vis2_current)**2) + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + R0Re = -np.sum(weights*np.real(resid_current)) + R1Re = out11 + out22 + out33 + R2Re = out1 + out2 + out3 + RRe = 6*(R0Re + R1Re + R2Re) + + plan11.f = weights * np.imag(vis1_current * vis23_current.conj()) + plan11.adjoint() + out11 = np.imag((plan11.f_hat.copy().T).reshape(nfft_info11.xdim*nfft_info11.ydim)) + plan22.f = weights * np.imag(vis2_current * vis13_current.conj()) + plan22.adjoint() + out22 = np.imag((plan22.f_hat.copy().T).reshape(nfft_info22.xdim*nfft_info22.ydim)) + plan33.f = weights * np.imag(vis3_current * vis12_current.conj()) + plan33.adjoint() + out33 = np.imag((plan33.f_hat.copy().T).reshape(nfft_info33.xdim*nfft_info33.ydim)) + + plan1.f = weights * np.imag(vis1_current.conj()) * (np.abs(vis2_current)**2 + np.abs(vis3_current)**2) + plan1.adjoint() + out1 = np.imag((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * np.imag(vis2_current.conj()) * (np.abs(vis1_current)**2 + np.abs(vis3_current)**2) + plan2.adjoint() + out2 = np.imag((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * np.imag(vis3_current.conj()) * (np.abs(vis1_current)**2 + np.abs(vis2_current)**2) + plan3.adjoint() + out3 = np.imag((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + R0Im = 0 + R1Im = out11 + out22 + out33 + R2Im = out1 + out2 + out3 + RIm = 6*(R0Im + R1Im + R2Im) + + R = RRe + imag_weight*RIm + + # Now find S (4th order) + plan12.f = weights * np.real(vis1_current.conj()*vis2_current) + plan12.adjoint() + out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim)) + plan23.f = weights * np.real(vis2_current.conj()*vis3_current) + plan23.adjoint() + out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim)) + plan31.f = weights * np.real(vis3_current.conj()*vis1_current) + plan31.adjoint() + out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim)) + + plan1.f = weights * vis23_current.conj() + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * vis13_current.conj() + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * vis12_current.conj() + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + S0Re = np.sum(weights*(np.real(vis1_current)**2 + np.real(vis2_current)**2 + np.real(vis3_current)**2)) + S1Re = 2 * (out12 + out23 + out31) + S2Re = 2 * (out1 + out2 + out3) + SRe = 4*(S0Re + S1Re + S2Re) + + plan12.f = weights * np.imag(vis1_current.conj()*vis2_current) + plan12.adjoint() + out12 = np.imag((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim)) + plan23.f = weights * np.imag(vis2_current.conj()*vis3_current) + plan23.adjoint() + out23 = np.imag((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim)) + plan31.f = weights * np.imag(vis3_current.conj()*vis1_current) + plan31.adjoint() + out31 = np.imag((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim)) + + S0Im = np.sum(weights*(np.imag(vis1_current)**2 + np.imag(vis2_current)**2 + np.imag(vis3_current)**2)) + S1Im = 2 * (out12 + out23 + out31) + S2Im = 0 + SIm = 4*(S0Im + S1Im + S2Im) + + S = SRe + imag_weight*SIm + + # T (5th order) + plan1.f = weights * vis1_current + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights * vis2_current + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights * vis3_current + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + TRe = 10*(out1 + out2 + out3) + TIm = 0. + T = TRe + imag_weight*TIm + + # Finally U (6th order) + URe = 6*weights_norm * np.ones(T.shape) + UIm = 0. + U = URe + imag_weight*UIm + + #Find Component + deltas = np.zeros(len(P)) + chisq_map = np.zeros(len(P)) + for i in range(len(P)): + if embed_mask[i]: + polynomial_params = np.array([P[i], Q[i], R[i], S[i], T[i], U[i]]) + allroots = p.polyroots(polynomial_params) + + # test roots to see which minimizes chi^2 + newchisq = chisq_current + delta = 0 + for j in range(len(allroots)): + root = allroots[j] + if np.imag(root)!=0: continue + + trialchisq = chisq_current + P[i]*root + 0.5*Q[i]*root**2 + (1./3.)*R[i]*root**3 + 0.25*S[i]*root**4 + 0.2*T[i]*root**5 + (1./6.)*U[i]*root**6 + if trialchisq < newchisq: + delta = root + newchisq = trialchisq + + deltas[i] = delta + chisq_map[i] = newchisq + else: + deltas[i]=0 + chisq_map[i]=chisq_current + + #print ("step time %i: %f s" % (it+1, time.time() -t)) + #chisq_map = chisq_current + P*deltas + 0.5*Q*deltas**2 + (1./3.)*R*deltas**3 + 0.25*S*deltas**4 + 0.2*T*deltas**5 + (1./6.)*U*deltas**6 + + #Plot delta and chi^2 map + if show_updates: + DeltasIm.imvec = deltas + plot_i(DeltasIm, it, chisq_current,fig=0, cmap='afmhot') + + ChisqIm.imvec = -chisq_map + plot_i(ChisqIm, it, chisq_current,fig=1, cmap='cool') + + component_loc_idx = np.argmin(chisq_map[embed_mask]) + component_strength = loop_gain*(deltas[embed_mask])[component_loc_idx] + + # clean component location + component_loc_x = coord[component_loc_idx][0] + component_loc_y = coord[component_loc_idx][1] + + # update imvec, vis, bispec + imvec_current[component_loc_idx] += component_strength + + #TODO how to incorporate pulse function? + vis1_current += component_strength*np.exp(2*np.pi*1j*(uv1[:,0]*component_loc_x + uv1[:,1]*component_loc_y)) + vis2_current += component_strength*np.exp(2*np.pi*1j*(uv2[:,0]*component_loc_x + uv2[:,1]*component_loc_y)) + vis3_current += component_strength*np.exp(2*np.pi*1j*(uv3[:,0]*component_loc_x + uv3[:,1]*component_loc_y)) + bs_current = vis1_current * vis2_current * vis3_current + + # update chi^2 and output image + chisq_current = np.sum(weights*np.abs(bs - bs_current)**2) + rchisq_current = np.sum(weights_nat*np.abs(bs - bs_current)**2)/(2*len(weights_nat)) + + print("it %i: %f (%.2f , %.2f) %.4f" % (it+1, component_strength, component_loc_x/RADPERUAS, component_loc_y/RADPERUAS, chisq_current)) + + OutputIm.imvec = imvec_current + if show_updates: + OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False) + OutputImBlur = OutputIm.blur_gauss(beamparams) + plot_i(OutputImBlur, it, rchisq_current,fig=2) + + return OutputIm + + +#amplitude and "closure phase" term +def dd_clean_amp_cphase(Obsdata, InitIm, niter=1, clipfloor=-1, loop_gain=.1, loop_gain_init=1,phaseweight=1, + weighting='uniform', bscount="min",no_neg_comps=False, + fft_pad_factor=FFT_PAD_DEFAULT, p_rad=NFFT_KERSIZE_DEFAULT, show_updates=True): + + + # limit imager range to prior values > clipfloor + embed_mask = InitIm.imvec >= clipfloor + + # get data + amp2 = np.abs(Obsdata.data['vis'])**2 #TODO debias?? + sigma_amp2 = Obsdata.data['sigma']**2 + uv = np.hstack((Obsdata.data['u'].reshape(-1,1), Obsdata.data['v'].reshape(-1,1))) + + biarr = Obsdata.bispectra(mode="all", count=bscount) + uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1))) + uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1))) + uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1))) + bs = biarr['bispec'] + sigma_bs = biarr['sigmab'] + + # necessary nfft infos + npad = int(fft_pad_factor * np.max((InitIm.xdim, InitIm.ydim))) + + nfft_infoA = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv) + nfft_infoB = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, 2*uv) + + nfft_info1 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1) + nfft_info2 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2) + nfft_info3 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3) + + nfft_info11 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv1) + nfft_info22 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv2) + nfft_info33 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, -2*uv3) + + nfft_info12 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv1-uv2) + nfft_info23 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv2-uv3) + nfft_info31 = NFFTInfo(InitIm.xdim, InitIm.ydim, InitIm.psize, InitIm.pulse, npad, p_rad, uv3-uv1) + + # TODO do we use pulse factors? + planA = nfft_infoA.plan + pulsefacA = nfft_infoA.pulsefac + planB = nfft_infoB.plan + pulsefacB = nfft_infoB.pulsefac + + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + plan11 = nfft_info11.plan + pulsefac11 = nfft_info11.pulsefac + plan22 = nfft_info22.plan + pulsefac22 = nfft_info22.pulsefac + plan33 = nfft_info33.plan + pulsefac33 = nfft_info33.pulsefac + + plan12 = nfft_info12.plan + pulsefac12 = nfft_info12.pulsefac + plan23 = nfft_info23.plan + pulsefac23 = nfft_info23.pulsefac + plan31 = nfft_info31.plan + pulsefac31 = nfft_info31.pulsefac + + # Weights + weights_amp2_nat = 1./(sigma_amp2**2) + if weighting=='uniform': + weights_amp2 = np.ones(len(weights_amp2_nat)) * np.median(weights_amp2_nat) #TODO for scaling weights are all given by median natural weight?? + elif weighting=='natural': + weights_amp2=weights_amp2_nat + else: + raise Exception("weighting must be 'uniform' or 'natural'") + weights_amp2 = weights_amp2 / float(len(weights_amp2_nat)) + weights_amp2_norm = np.sum(weights_amp2) + + weights_bs_nat = (np.abs(bs)**2) / (sigma_bs**2) + if weighting=='uniform': + weights_bs = np.ones(len(weights_bs_nat)) * np.median(weights_bs_nat) #TODO for scaling weights are all given by median natural weight?? + elif weighting=='natural': + weights_bs = weights_bs_nat + else: + raise Exception("weighting must be 'uniform' or 'natural'") + weights_bs = weights_bs / np.abs(bs)**2 #weight down by 1/bs^2 only works for uniform?? + weights_bs = weights_bs / float(len(weights_bs_nat)) + weights_bs_norm = np.sum(weights_bs) + + # Coordinate matrix + # TODO do we need to make sure this corresponds exactly with what NFFT is doing? + # TODO what if the image is odd? + coord = InitIm.psize * np.array([[[x,y] for x in np.arange(InitIm.xdim//2,-InitIm.xdim//2,-1)] + for y in np.arange(InitIm.ydim//2,-InitIm.ydim//2,-1)]) + coord = coord.reshape(InitIm.ydim*InitIm.xdim, 2) + coord = coord[embed_mask] + + # Initial imvec and visibilities + # TODO currently initialized to zero!! + OutputIm = InitIm.copy() + DeltasIm = InitIm.copy() + ChisqIm = InitIm.copy() + + res = Obsdata.res() + beamparams = Obsdata.fit_beam() + + imvec_init = 0*InitIm.imvec[embed_mask] + vis_init = np.zeros(len(amp2), dtype='complex128') + chisq_amp2_init = np.sum(weights_amp2*(amp2 - np.abs(vis_init)**2)**2) + rchisq_amp2_init = np.sum(weights_amp2_nat*(amp2 - np.abs(vis_init)**2)**2) + + vis1_init = np.zeros(len(bs), dtype='complex128') + vis2_init = np.zeros(len(bs), dtype='complex128') + vis3_init = np.zeros(len(bs), dtype='complex128') + bs_init = vis1_init*vis2_init*vis3_init + chisq_bs_init = np.sum(weights_bs*np.abs(bs - bs_init)**2) + rchisq_bs_init = np.sum(weights_bs_nat*np.abs(bs - bs_init)**2) + + chisq_init = chisq_amp2_init + phaseweight*chisq_bs_init + + imvec_current = imvec_init.copy() + vis_current = vis_init.copy() + vis1_current = vis1_init.copy() + vis2_current = vis2_init.copy() + vis3_current = vis3_init.copy() + bs_current = bs_init.copy() + chisq_amp2_current = chisq_amp2_init + rchisq_amp2_current = rchisq_amp2_init + chisq_bs_current = chisq_bs_init + rchisq_bs_current = rchisq_bs_init + chisq_current = chisq_init + + # clean loop + print("\n") + for it in range(niter): + t = time.time() + + # center the first component automatically + # since initial image is empty, must go to higher order in delta in solution + # TODO generalize to non-empty initial image! + # BASE INITIAL POINT SOURCE ENTIRELY ON VISIBILITY AMPLITUDES + if it==0: + + A = np.sum(weights_amp2*amp2) + B = weights_amp2_norm + + component_strength = loop_gain_init*np.sqrt(A/B) + component_loc_idx = (InitIm.ydim//2)*InitIm.xdim + InitIm.xdim//2 #TODO is this right for odd images??+/- 1?? + + else: + + #Amplitude part + # First calculate A (1st order) + planA.f = weights_amp2 *(amp2_current-amp2)*vis_current + planA.adjoint() + out = np.real((planA.f_hat.copy().T).reshape(nfft_infoA.xdim*nfft_infoA.ydim)) + A = 4*out + + # Then calculate B (2nd order) + planB.f = weights_amp2 * vis_current*vis_current + planB.adjoint() + out = np.real((planB.f_hat.copy().T).reshape(nfft_infoB.xdim*nfft_infoB.ydim)) + + B0 = np.sum(weights_amp2*(2*amp2_current - amp2)) + B1 = out + B = 4*(B0 + B1) + + #Calculate C (3rd order) + planA.f = weights_amp2 * vis_current + planA.adjoint() + out = np.real((planA.f_hat.copy().T).reshape(nfft_infoA.xdim*nfft_infoA.ydim)) + + C = 12*out + + # Now find D (4th order) + D = 4*weights_amp2_norm * np.ones(C.shape) + + #"Closure Phase" part + resid_current = bs - bs_current + vis12_current = vis1_current*vis2_current + vis13_current = vis1_current*vis3_current + vis23_current = vis2_current*vis3_current + + # First calculate P (1st order) + plan1.f = weights_bs * resid_current * vis23_current.conj() + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights_bs * resid_current * vis13_current.conj() + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights_bs * resid_current * vis12_current.conj() + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + P = -2 * (out1 + out2 + out3) + + # Then calculate Q (2nd order) + plan12.f = weights_bs * vis13_current*vis23_current.conj() + plan12.adjoint() + out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim)) + plan23.f = weights_bs * vis12_current*vis13_current.conj() + plan23.adjoint() + out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim)) + plan31.f = weights_bs * vis23_current*vis12_current.conj() + plan31.adjoint() + out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim)) + + plan1.f = weights_bs * resid_current.conj() * vis1_current + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights_bs * resid_current.conj() * vis2_current + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights_bs * resid_current.conj() * vis3_current + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + Q0 = np.sum(weights_bs*(np.abs(vis12_current)**2 + np.abs(vis23_current)**2 + np.abs(vis13_current)**2)) + Q1 = 2 * (out12 + out23 + out31) + Q2 = -2 * (out1 + out2 + out3) + Q = 2*(Q0 + Q1 + Q2) + + #Calculate R (3rd order) + plan11.f = weights_bs * vis1_current.conj() * vis23_current + plan11.adjoint() + out11 = np.real((plan11.f_hat.copy().T).reshape(nfft_info11.xdim*nfft_info11.ydim)) + plan22.f = weights_bs * vis2_current.conj() * vis13_current + plan22.adjoint() + out22 = np.real((plan22.f_hat.copy().T).reshape(nfft_info22.xdim*nfft_info22.ydim)) + plan33.f = weights_bs * vis3_current.conj() * vis12_current + plan33.adjoint() + out33 = np.real((plan33.f_hat.copy().T).reshape(nfft_info33.xdim*nfft_info33.ydim)) + + plan1.f = weights_bs * vis1_current * (np.abs(vis2_current)**2 + np.abs(vis3_current)**2) + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights_bs * vis2_current * (np.abs(vis1_current)**2 + np.abs(vis3_current)**2) + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights_bs * vis3_current * (np.abs(vis1_current)**2 + np.abs(vis2_current)**2) + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + R0 = -np.sum(weights_bs*np.real(resid_current)) + R1 = np.real(out11 + out22 + out33) + R2 = np.real(out1 + out2 + out3) + R = 6*(R0 + R1 + R2) + + # Now find S (4th order) + plan12.f = weights_bs * vis1_current*vis2_current.conj() + plan12.adjoint() + out12 = np.real((plan12.f_hat.copy().T).reshape(nfft_info12.xdim*nfft_info12.ydim)) + plan23.f = weights_bs * vis2_current*vis3_current.conj() + plan23.adjoint() + out23 = np.real((plan23.f_hat.copy().T).reshape(nfft_info23.xdim*nfft_info23.ydim)) + plan31.f = weights_bs * vis3_current*vis1_current.conj() + plan31.adjoint() + out31 = np.real((plan31.f_hat.copy().T).reshape(nfft_info31.xdim*nfft_info31.ydim)) + + plan1.f = weights_bs * vis23_current.conj() + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights_bs * vis13_current.conj() + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights_bs * vis12_current.conj() + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + S0 = np.sum(weights_bs*(np.abs(vis1_current)**2 + np.abs(vis2_current)**2 + np.abs(vis3_current)**2)) + S1 = 2 * (out12 + out23 + out31) + S2 = 2 * (out1 + out2 + out3) + S = 4*(S0 + S1 + S2) + + # T (5th order) + plan1.f = weights_bs * vis1_current + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + plan2.f = weights_bs * vis2_current + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + plan3.f = weights_bs * vis3_current + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + T = 10*(out1 + out2 + out3) + + # Finally U (6th order) + U = 6*weights_bs_norm * np.ones(T.shape) + + # Find Component based on minimizing chi^2 + deltas = np.zeros(len(A)) + chisq_map = np.zeros(len(A)) + for i in range(len(P)): + if embed_mask[i]: + coeffs = np.array([A[i] + phaseweight*P[i], + B[i] + phaseweight*Q[i], + C[i] + phaseweight*R[i], + D[i] + phaseweight*S[i], + phaseweight*T[i], + phaseweight*U[i] + ]) + allroots = p.polyroots(coeffs) + + # test roots to see which minimizes chi^2 + newchisq = chisq_current + delta = 0 + for j in range(len(allroots)): + root = allroots[j] + if np.imag(root)!=0: continue + if (no_neg_comps and root<0):continue + trialchisq = chisq_current + coeffs[0]*root + 0.5*coeffs[1]*root**2 + (1./3.)*coeffs[2]*root**3 + 0.25*coeffs[3]*root**4 + 0.2*coeffs[4]*root**5 + (1./6.)*coeffs[5]*root**6 + if trialchisq < newchisq: + delta = root + newchisq = trialchisq + + else: + delta = 0. + newchisq = chisq_current + deltas[i] = delta + chisq_map[i] = newchisq + + #plot deltas and chi^2 map + if show_updates: + DeltasIm.imvec = deltas + plot_i(DeltasIm, it, chisq_current,fig=0, cmap='afmhot') + ChisqIm.imvec = -chisq_map + plot_i(ChisqIm, it, chisq_current,fig=1, cmap='cool') + + #chisq_map = chisq_current + P*deltas + 0.5*Q*deltas**2 + (1./3.)*R*deltas**3 + 0.25*S*deltas**4 + 0.2*T*deltas**5 + (1./6.)*U*deltas**6 + component_loc_idx = np.argmin(chisq_map[embed_mask]) + component_strength = loop_gain*(deltas[embed_mask])[component_loc_idx] + + # clean component location + component_loc_x = coord[component_loc_idx][0] + component_loc_y = coord[component_loc_idx][1] + + # update imvec, vis, bispec + imvec_current[component_loc_idx] += component_strength + + #TODO how to incorporate pulse function? + vis_current += component_strength*np.exp(2*np.pi*1j*(uv[:,0]*component_loc_x + uv[:,1]*component_loc_y)) + amp2_current = np.abs(vis_current)**2 + + vis1_current += component_strength*np.exp(2*np.pi*1j*(uv1[:,0]*component_loc_x + uv1[:,1]*component_loc_y)) + vis2_current += component_strength*np.exp(2*np.pi*1j*(uv2[:,0]*component_loc_x + uv2[:,1]*component_loc_y)) + vis3_current += component_strength*np.exp(2*np.pi*1j*(uv3[:,0]*component_loc_x + uv3[:,1]*component_loc_y)) + bs_current = vis1_current * vis2_current * vis3_current + + # update chi^2 and output image + chisq_amp2_current = np.sum(weights_amp2*(amp2 - amp2_current)**2) + rchisq_amp2_current = np.sum(weights_amp2_nat*(amp2 - amp2_current)**2) + + chisq_bs_current = np.sum(weights_bs*np.abs(bs - bs_current)**2) + rchisq_bs_current = np.sum(weights_bs_nat*np.abs(bs - bs_current)**2) + + chisq_current = chisq_amp2_current + phaseweight*chisq_bs_current + + print("it %i| %.4e (%.1f , %.1f) | %.4e %.4e | %.4e" % (it+1, component_strength, component_loc_x/RADPERUAS, component_loc_y/RADPERUAS, chisq_amp2_current, chisq_bs_current, chisq_current)) + + OutputIm.imvec = imvec_current + if show_updates: + OutputIm.imvec = embed(OutputIm.imvec, embed_mask, clipfloor=0., randomfloor=False) + OutputImBlur = OutputIm.blur_gauss(beamparams) + plot_i(OutputImBlur, it, chisq_current, fig=2) + + return OutputIm diff --git a/imaging/dynamical_imaging.py b/imaging/dynamical_imaging.py new file mode 100644 index 00000000..a19d917f --- /dev/null +++ b/imaging/dynamical_imaging.py @@ -0,0 +1,2586 @@ +# dynamical_imaging.py +# imaging movies with interferometric data +# +# Copyright (C) 2018 Michael Johnson +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +# Note: This library is still under very active development and is likely to change considerably +# Contact Michael Johnson (mjohnson@cfa.harvard.edu) with any questions +# The methods/techniques used in this are described in http://adsabs.harvard.edu/abs/2017ApJ...850..172J + +import time +import numpy as np + +import scipy.optimize as opt +import scipy.ndimage.filters as filt +import scipy.signal + +import matplotlib.pyplot as plt + +import itertools as it + +from ehtim.const_def import * #Note: C is m/s rather than cm/s. +from ehtim.observing.obs_helpers import * +import ehtim.obsdata as obsdata +import ehtim.image as image +import ehtim.movie as movie +from ehtim.imaging.imager_utils import * + +import ehtim.scattering as so + +from multiprocessing import Pool +from functools import partial + +#imports from the blazarFileDownloader +import calendar +import requests +import os +from html.parser import HTMLParser +#from HTMLParser import HTMLParser + +Fast_Convolve = True # This option will not wrap the convolution around the image + +# These parameters are only global to allow parallelizing the chi^2 calculation without huge memory overhead. It would be nice to do this locally, using the parallel array capabilities. +# Fourier matrices: +A1_List = [None,] +A2_List = [None,] +A3_List = [None,] +# Data used: +data1_List = [None,] +data2_List = [None,] +data3_List = [None,] +# Standard deviation of data used: +sigma1_List = [None,] +sigma2_List = [None,] +sigma3_List = [None,] + +################################################################################################## +# Constants +################################################################################################## +#NHIST = 25 # number of steps to store for hessian approx +nit = 0 # global variable to track the iteration number in the plotting callback + +################################################################################################## +# Tools for AGN images +################################################################################################## + +def get_KLS(im1, im2, shift = [0,0], blur_size_uas=100, dynamic_range=200): + # Symmetrized Kullback-Liebler Divergence, with optional blurring and max dynamic range + ep = np.max(im1.imvec)/dynamic_range + A = im1.blur_circ(blur_size_uas*RADPERUAS).imvec.reshape((im1.ydim,im1.xdim)) + ep + B = im2.blur_circ(blur_size_uas*RADPERUAS).imvec.reshape((im2.ydim,im2.xdim)) + ep + + B = np.roll(B, shift, (0,1)) + return np.sum( (B - A)*np.log(B/A) ) + +def get_core_position(im, blur_size_uas=100): + #Estimate the core position (i.e., the brightest region) + #Convolve the input image with the beam, then find the brightest pixel + im_blur = im.blur_circ(blur_size_uas*RADPERUAS).imvec + return np.array(np.unravel_index(im_blur.argmax(),(im.ydim,im.xdim))) + +def center_core(im, blur_size_uas=100): + im_rotate = im.copy() + core_pos = get_core_position(im,blur_size_uas) + center = np.array((int((im.ydim-1)/2),int((im.xdim-1)/2))) + print ("Rotating By",center - core_pos) + im_rotate.imvec = np.roll(im.imvec.reshape((im.ydim,im.xdim)), center - core_pos, (0,1)).flatten() + if len(im.qvec): + im_rotate.qvec = np.roll(im.qvec.reshape((im.ydim,im.xdim)), center - core_pos, (0,1)).flatten() + if len(im.uvec): + im_rotate.uvec = np.roll(im.uvec.reshape((im.ydim,im.xdim)), center - core_pos, (0,1)).flatten() + if len(im.vvec): + im_rotate.vvec = np.roll(im.vvec.reshape((im.ydim,im.xdim)), center - core_pos, (0,1)).flatten() + + return im_rotate + +def align_left(im,min_frac=0.1,opposite_frac_thresh=0.05): + #Aligns the core at the middle, assuming that there is no appreciable flux to the left of the core + im_rotate = im.copy() + center = np.array((int((im.ydim-1)/2),int((im.xdim-1)/2))) + projected_flux = np.sum(im.imvec.reshape((im.ydim,im.xdim)),axis=0) + thresh = np.max(projected_flux)*min_frac + #opposite_thresh = np.max(projected_flux[-int(opposite_frac_thresh*im.xdim):]) + + for j in range(im.xdim): + if projected_flux[j] > thresh: #or projected_flux[j] > opposite_thresh + break + + im_rotate.imvec = np.roll(im.imvec.reshape((im.ydim,im.xdim)), center[0] - j, (1)).flatten() + if len(im.qvec): + im_rotate.qvec = np.roll(im.qvec.reshape((im.ydim,im.xdim)), center[0] - j, (1)).flatten() + if len(im.uvec): + im_rotate.uvec = np.roll(im.uvec.reshape((im.ydim,im.xdim)), center[0] - j, (1)).flatten() + if len(im.vvec): + im_rotate.vvec = np.roll(im.vvec.reshape((im.ydim,im.xdim)), center[0] - j, (1)).flatten() + + return im_rotate + + +################################################################################################## +# Movie Export Tools +################################################################################################## +def export_multipanel_movie(im_List_Set, out='movie.mp4', fps=10, dpi=120, scale='linear', dynamic_range=1000.0, pad_factor=1, verbose=False, xlim = None, ylim = None, titles = [], size=8.0): + # Example: di.export_multipanel_movie([im_List,im_List_2],scale='log',xlim=[1000,-1000],ylim=[-3000,500],dynamic_range=[1000,5000], titles = ['43 GHz (BU)','15 GHz (MOJAVE)']) + import matplotlib + matplotlib.use('agg') + import matplotlib.pyplot as plt + import matplotlib.animation as animation + + fig = plt.figure() + + mjd_step = (im_List_Set[0][0].mjd - im_List_Set[0][-1].mjd)/len(im_List_Set[0]) + + #if len(im_List_Set)%2 == 1: + # verticalalignment = 'bottom' + #else: + # verticalalignment = 'top' + + N_set = len(im_List_Set) + extent = [np.array((1,-1,-1,1))]*N_set + maxi = np.zeros(N_set) + plt_im = [None,]*N_set + + if type(dynamic_range) == float or type(dynamic_range) == int: + dynamic_range = np.zeros(N_set) + dynamic_range + + for j in range(N_set): + extent[j] = im_List_Set[j][0].psize/RADPERUAS*im_List_Set[j][0].xdim*np.array((1,-1,-1,1)) / 2. + maxi[j] = np.max(np.concatenate([im.imvec for im in im_List_Set[j]])) + + def im_data(i_set, n): + n_data = (n-n%pad_factor)//pad_factor + if scale == 'linear': + return im_List_Set[i_set][n_data].imvec.reshape((im_List_Set[i_set][n_data].ydim,im_List_Set[i_set][n_data].xdim)) + else: + return np.log(im_List_Set[i_set][n_data].imvec.reshape((im_List_Set[i_set][n_data].ydim,im_List_Set[i_set][n_data].xdim)) + maxi[i_set]/dynamic_range[i_set]) + + for j in range(N_set): + ax = plt.subplot(1, N_set, j+1) + plt_im[j] = plt.imshow(im_data(j, 0), extent=extent[j], cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + + if xlim != None: + ax.set_xlim(xlim) + if ylim != None: + ax.set_ylim(ylim) + + if scale == 'linear': + plt_im[j].set_clim([0,maxi[j]]) + else: + plt_im[j].set_clim([np.log(maxi[j]/dynamic_range[j]),np.log(maxi[j])]) + + if j == 0: + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + + if len(titles) > 0: + ax.set_title(titles[j]) + + fig.set_size_inches([size,size/len(im_List_Set)]) + plt.tight_layout() + + def update_img(n): + if verbose: + print ("processing frame {0} of {1}".format(n, len(im_List)*pad_factor)) + for j in range(N_set): + plt_im[j].set_data(im_data(j, n)) + + if mjd_step > 0.1: + fig.suptitle('MJD: ' + str(im_List_Set[0][int((n-n%pad_factor)//pad_factor)].mjd), verticalalignment = verticalalignment) + else: + time = im_List_Set[0][int((n-n%pad_factor)//pad_factor)].time + time_str = ("%d:%02d.%02d" % (int(time), (time*60) % 60, (time*3600) % 60)) + fig.suptitle(time_str) + + return plt_im + + ani = animation.FuncAnimation(fig,update_img,len(im_List_Set[0])*pad_factor,interval=1e3/fps) + writer = animation.writers['ffmpeg'](fps=fps, bitrate=1e6) + ani.save(out,writer=writer,dpi=dpi) + +def export_movie(im_List, out='movie.mp4', fps=10, dpi=120, scale='linear', cbar_unit = 'Jy', gamma=0.5, dynamic_range=1000.0, pad_factor=1, verbose=False): + import matplotlib + matplotlib.use('agg') + import matplotlib.pyplot as plt + import matplotlib.animation as animation + + mjd_range = im_List[-1].mjd - im_List[0].mjd + + fig = plt.figure() + + extent = im_List[0].psize/RADPERUAS*im_List[0].xdim*np.array((1,-1,-1,1)) / 2. + maxi = np.max(np.concatenate([im.imvec for im in im_List])) + +# TODO: fix this +# if cbar_unit == 'mJy': +# imvec = imvec * 1.e3 +# qvec = qvec * 1.e3 +# uvec = uvec * 1.e3 +# elif cbar_unit == '$\mu$Jy': +# imvec = imvec * 1.e6 +# qvec = qvec * 1.e6 +# uvec = uvec * 1.e6 + + unit = cbar_unit + '/pixel' + + if scale=='log': + unit = 'log(' + cbar_unit + '/pixel)' + + if scale=='gamma': + unit = '(' + cbar_unit + '/pixel)^gamma' + + def im_data(n): + n_data = (n-n%pad_factor)//pad_factor + if scale == 'linear': + return im_List[n_data].imvec.reshape((im_List[n_data].ydim,im_List[n_data].xdim)) + elif scale == 'log': + return np.log(im_List[n_data].imvec.reshape((im_List[n_data].ydim,im_List[n_data].xdim)) + maxi/dynamic_range) + elif scale == 'gamma': + return (im_List[n_data].imvec.reshape((im_List[n_data].ydim,im_List[n_data].xdim)) + maxi/dynamic_range)**(gamma) + + plt_im = plt.imshow(im_data(0), extent=extent, cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + if scale == 'linear': + plt_im.set_clim([0,maxi]) + elif scale == 'log': + plt_im.set_clim([np.log(maxi/dynamic_range),np.log(maxi)]) + elif scale == 'gamma': + plt_im.set_clim([(maxi/dynamic_range)**gamma,(maxi)**(gamma)]) + + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + + fig.set_size_inches([5,5]) + plt.tight_layout() + + def update_img(n): + if verbose: + print ("processing frame {0} of {1}".format(n, len(im_List)*pad_factor)) + plt_im.set_data(im_data(n)) + if mjd_range != 0: + fig.suptitle('MJD: ' + str(im_List[int((n-n%pad_factor)//pad_factor)].mjd)) + else: + time = im_List[int((n-n%pad_factor)//pad_factor)].time + time_str = ("%d:%02d.%02d" % (int(time), (time*60) % 60, (time*3600) % 60)) + fig.suptitle(time_str) + + return plt_im + + ani = animation.FuncAnimation(fig,update_img,len(im_List)*pad_factor,interval=1e3/fps) + writer = animation.writers['ffmpeg'](fps=fps, bitrate=1e6) + ani.save(out,writer=writer,dpi=dpi) + + +################################################################################################## +# Convenience Functions for Data Processing +################################################################################################## + +def split_obs(obs): + """Split single observation into multiple observation files, one per scan + """ + + print ("Splitting Observation File into " + str(len(obs.tlist())) + " scans") + return [ obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, tdata, obs.tarr, source=obs.source, mjd=obs.mjd, ampcal=obs.ampcal, phasecal=obs.phasecal) for tdata in obs.tlist() ] + +def merge_obs(obs_List): + """Merge a list of observations into a single observation file + """ + + if len(set([obs.ra for obs in obs_List])) > 1 or len(set([obs.dec for obs in obs_List])) > 1 or len(set([obs.rf for obs in obs_List])) > 1 or len(set([obs.bw for obs in obs_List])) > 1 or len(set([obs.source for obs in obs_List])) > 1: + print ("All observations must have the same parameters!") + return + + #The important things to merge are the mjd and the data + data_merge = np.hstack([obs.data for obs in obs_List]) + + return obsdata.Obsdata(obs_List[0].ra, obs_List[0].dec, obs_List[0].rf, obs_List[0].bw, data_merge, obs_List[0].tarr, polrep=obs_List[0].polrep, source=obs_List[0].source, mjd=obs_List[0].mjd, ampcal=obs_List[0].ampcal, phasecal=obs_List[0].phasecal) + +def average_im_list(im_List): + """Return the average of a list of images + """ + avg_im = im_List[0].copy() + avg_im.imvec = np.mean([im.imvec for im in im_List],axis=0) + if len(im_List[0].qvec): + avg_im.qvec = np.mean([im.qvec for im in im_List],axis=0) + if len(im_List[0].uvec): + avg_im.uvec = np.mean([im.uvec for im in im_List],axis=0) + if len(im_List[0].vvec): + avg_im.vvec = np.mean([im.vvec for im in im_List],axis=0) + + return avg_im + +def blur_im_list(im_List, fwhm_x, fwhm_t): + """Apply a gaussian filter to a list of images, with fwhm_x in radians and fwhm_t in frames. Currently only for Stokes I. + + Args: + fwhm_x (float): circular beam size for spatial blurring in radians + fwhm_t (float): temporal blurring in frames + Returns: + (Image): output image list + """ + + # Blur Stokes I + sigma_x = fwhm_x / im_List[0].psize / (2. * np.sqrt(2. * np.log(2.))) + sigma_t = fwhm_t / (2. * np.sqrt(2. * np.log(2.))) + + arr = np.array([im.imvec.reshape(im.ydim, im.xdim) for im in im_List]) + arr = filt.gaussian_filter(arr, (sigma_t, sigma_x, sigma_x)) + + ret = [] + for j in range(len(im_List)): + ret.append(image.Image(arr[j], im_List[0].psize, im_List[0].ra, im_List[0].dec, rf=im_List[0].rf, source=im_List[0].source, mjd=im_List[j].mjd)) + + return ret + +################################################################################################## +# Convenience Functions for Analytical Work +################################################################################################## + +def Wrapped_Convolve(sig,ker): + if np.sum(ker) == 0.0: + return sig + + N = sig.shape[0] + + if Fast_Convolve == False: + return scipy.signal.fftconvolve(np.pad(sig,((N, N), (N, N)), 'wrap'), np.pad(ker,((N, N), (N, N)), 'constant'),mode='same')[N:(2*N),N:(2*N)] + else: + return scipy.signal.fftconvolve(sig, ker,mode='same') + +def Wrapped_Gradient(M): + G = np.gradient(np.pad(M,((1, 1), (1, 1)), 'wrap')) + Gx = G[0][1:-1,1:-1] + Gy = G[1][1:-1,1:-1] + return (Gx, Gy) + +def Wrapped_Gradient_Reorder(M): + G = np.gradient(np.pad(M,((1, 1), (1, 1)), 'wrap')) + Gx = G[0][1:-1,1:-1] + Gy = G[1][1:-1,1:-1] + return np.transpose(np.array([Gx, Gy]),axes=[1,2,0]) + +def Wrapped_Divergence( vectorfield ): + Gx = np.gradient(np.pad(vectorfield[:,:,0],((1, 1), (1, 1)), 'wrap'), axis=0)[1:-1,1:-1] + Gy = np.gradient(np.pad(vectorfield[:,:,1],((1, 1), (1, 1)), 'wrap'), axis=1)[1:-1,1:-1] + return Gx+Gy + +def Wrapped_Weighted_Divergence( weight, M ): #(weight \cdot \nabla) M + grad = Wrapped_Gradient(M) + return weight[:,:,0]*grad[0] + weight[:,:,1]*grad[1] + +################################################################################################## +# Dynamic Regularizers and their Gradients +################################################################################################## + +#RdF Regularizer (continuity of total flux density from frame to frame) +def RdF_clip(Frame_List, embed_mask_List): + F_List = [np.sum(Frame_List[j].ravel()[embed_mask_List[j]]) for j in range(len(Frame_List))] + return np.sum(np.diff(F_List)**2) + +def RdF_gradient_clip(Frame_List, embed_mask_List): + N_frame = Frame_List.shape[0] + F_List = [np.sum(Frame_List[j].ravel()[embed_mask_List[j]]) for j in range(len(Frame_List))] + F_grad_List = [1.0 + 0.0*np.copy(Frame_List[j].ravel()[embed_mask_List[j]]) for j in range(len(Frame_List))] + grad = np.copy(F_grad_List)*0.0 + + for i in range(1,N_frame): + grad[i] = grad[i] + 2.0*(F_List[i] - F_List[i-1])*F_grad_List[i] + + for i in range(N_frame-1): + grad[i] = grad[i] + 2.0*(F_List[i] - F_List[i+1])*F_grad_List[i] + + return np.concatenate([grad[i]*(Frame_List[i].ravel()[embed_mask_List[i]]) for i in range(N_frame)]) + +#RdS Regularizer (continuity of entropy from frame to frame) +def RdS(Frame_List, Prior_List, embed_mask_List, entropy="simple", norm_reg=True, **kwargs): + S_List = [static_regularizer(np.array([Frame_List[j]]), np.array([Prior_List[j]]), np.array([embed_mask_List[j]]), Prior_List[0].total_flux(), Prior_List[0].psize, entropy=entropy, norm_reg=norm_reg, **kwargs) for j in range(len(Frame_List))] + return np.sum(np.diff(S_List)**2) + +def RdS_gradient(Frame_List, Prior_List, embed_mask_List, entropy="simple", norm_reg=True, **kwargs): + #The Jacobian_Factor is already part of the entropy gradient that this function calls + N_frame = Frame_List.shape[0] + S_List = [static_regularizer(np.array([Frame_List[j]]), np.array([Prior_List[j]]), np.array([embed_mask_List[j]]), Prior_List[0].total_flux(), Prior_List[0].psize, entropy=entropy, norm_reg=norm_reg, **kwargs) for j in range(len(Frame_List))] + S_grad_List = np.array([static_regularizer_gradient(np.array([Frame_List[j]]), np.array([Prior_List[j]]), np.array([embed_mask_List[j]]), Prior_List[0].total_flux(), Prior_List[0].psize, entropy=entropy, norm_reg=norm_reg, **kwargs) for j in range(len(Frame_List))]) + + grad = np.copy(S_grad_List)*0.0 + + for i in range(1,N_frame): + grad[i] = grad[i] + 2.0*(S_List[i] - S_List[i-1])*S_grad_List[i] + + for i in range(N_frame-1): + grad[i] = grad[i] + 2.0*(S_List[i] - S_List[i+1])*S_grad_List[i] + + return np.concatenate([grad[i] for i in range(N_frame)]) + + +######## Rdt, RdI, and Rflow master functions ######## +def Rdt(Frames, ker, metric='SymKL', p=2.0, **kwargs): + if metric == 'KL': + return Rdt_KL(Frames, ker) + elif metric == 'SymKL': + return Rdt_SymKL(Frames, ker) + elif metric == 'D2': + return Rdt_Dp(Frames, ker, p=2.0) + elif metric == 'Dp': + return Rdt_Dp(Frames, ker, p=p) + else: + return 0.0 + +def Rdt_gradient(Frames, ker, metric='SymKL', p=2.0, **kwargs): + if metric == 'KL': + return Rdt_KL_gradient(Frames, ker) + elif metric == 'SymKL': + return Rdt_SymKL_gradient(Frames, ker) + elif metric == 'D2': + return Rdt_Dp_gradient(Frames, ker, p=2.0) + elif metric == 'Dp': + return Rdt_Dp_gradient(Frames, ker, p=p) + else: + return 0.0 + +def RdI(Frames, metric='SymKL', p=2.0, **kwargs): + if metric == 'KL': + return RdI_KL(Frames) + elif metric == 'SymKL': + return RdI_SymKL(Frames) + elif metric == 'D2': + return RdI_Dp(Frames, p=2.0) + elif metric == 'Dp': + return RdI_Dp(Frames, p=p) + else: + return 0.0 + +def RdI_gradient(Frames, metric='SymKL', p=2.0, **kwargs): + if metric == 'KL': + return RdI_KL_gradient(Frames) + elif metric == 'SymKL': + return RdI_SymKL_gradient(Frames) + elif metric == 'D2': + return RdI_Dp_gradient(Frames, p=2.0) + elif metric == 'Dp': + return RdI_Dp_gradient(Frames, p=p) + else: + return 0.0 + +def Rflow(Frames, Flow, metric='D2', p=2.0, **kwargs): + if metric == 'KL': + return Rflow_KL(Frames, Flow) + elif metric == 'SymKL': + return Rflow_SymKL(Frames, Flow) + elif metric == 'D2': + return Rflow_D2(Frames, Flow) + elif metric == 'Dp': + return Rflow_Dp(Frames, Flow, p=p) + else: + return 0.0 + +def Rflow_gradient_I(Frames, Flow, metric='D2', p=2.0, **kwargs): + if metric == 'KL': + return Rflow_KL_gradient_I(Frames, Flow) + elif metric == 'SymKL': + return Rflow_SymKL_gradient_I(Frames, Flow) + elif metric == 'D2': + return Rflow_D2_gradient_I(Frames, Flow) + elif metric == 'Dp': + return Rflow_Dp_gradient_I(Frames, Flow, p=p) + else: + return 0.0 + +def Rflow_gradient_m(Frames, Flow, metric='D2', p=2.0, **kwargs): + if metric == 'KL': + return Rflow_KL_gradient_m(Frames, Flow) + elif metric == 'SymKL': + return Rflow_SymKL_gradient_m(Frames, Flow) + elif metric == 'D2': + return Rflow_Dp_gradient_m(Frames, Flow) + elif metric == 'Dp': + return Rflow_Dp_gradient_m(Frames, Flow, p=p) + else: + return 0.0 + +#################################### + +#Rdt Regularizer with relative entropy (Kullback-Leibler Divergence) +def Rdt_KL(Frames, ker): + ep=1e-10 + N_frame = Frames.shape[0] + Blur_Frames = np.array([Wrapped_Convolve(f, ker) for f in Frames]) + + R = 0.0 + for i in range(1,N_frame): + R += np.sum( Blur_Frames[i]*np.log((Blur_Frames[i]+ep)/(Blur_Frames[i-1]+ep)) ) + + return R/N_frame + +def Rdt_KL_gradient(Frames, ker): + #The Jacobian_Factor accounts for the frames being written as log(frame) in the imaging algorithm + ep=1e-10 + N_frame = Frames.shape[0] + Blur_Frames = np.array([Wrapped_Convolve(f, ker) for f in Frames]) + + grad = np.copy(Frames)*0.0 + + for i in range(1,len(Frames)): + grad[i] = grad[i] + np.log((Blur_Frames[i]+ep)/(Blur_Frames[i-1]+ep)) + 1.0 + + for i in range(len(Frames)-1): + grad[i] = grad[i] - (Blur_Frames[i+1]+ep)/(Blur_Frames[i]+ep) + + return np.array([Wrapped_Convolve(grad[i],ker)/N_frame*Frames[i] for i in range(N_frame)]).flatten() + +#Rdt Regularizer with symmetrized relative entropy +def Rdt_SymKL(Frames, ker): + ep=1e-10 + N_frame = Frames.shape[0] + Blur_Frames = np.array([Wrapped_Convolve(f, ker) for f in Frames]) + + R = 0.0 + for i in range(1,N_frame): + R += np.sum( (Blur_Frames[i] - Blur_Frames[i-1])*np.log((Blur_Frames[i]+ep)/(Blur_Frames[i-1]+ep)) ) + + return 0.5*R/N_frame + +def Rdt_SymKL_gradient(Frames, ker): + #The Jacobian_Factor accounts for the frames being written as log(frame) in the imaging algorithm + ep=1e-10 + N_frame = Frames.shape[0] + Blur_Frames = np.array([Wrapped_Convolve(f, ker) for f in Frames]) + + grad = np.copy(Frames)*0.0 + + for i in range(1,len(Frames)): + grad[i] = grad[i] + 1.0 - (Blur_Frames[i-1]+ep)/(Blur_Frames[i]+ep) + np.log((Blur_Frames[i]+ep)/(Blur_Frames[i-1]+ep)) + + for i in range(len(Frames)-1): + grad[i] = grad[i] + 1.0 - (Blur_Frames[i+1]+ep)/(Blur_Frames[i]+ep) - np.log((Blur_Frames[i+1]+ep)/(Blur_Frames[i]+ep)) + + return 0.5*np.array([Wrapped_Convolve(grad[i],ker)/N_frame*Frames[i] for i in range(N_frame)]).flatten() + +#Rdt Regularizer with MSE (or l_p norm) +def Rdt_Dp(Frames, ker, p=2.0): + N_frame = Frames.shape[0] + Blur_Frames = np.array([Wrapped_Convolve(f, ker) for f in Frames]) + return np.sum(np.abs(np.diff(Blur_Frames,axis=0))**p)/N_frame + +def Rdt_Dp_gradient(Frames, ker, p=2.0): + N_frame = Frames.shape[0] + + grad = np.copy(Frames)*0.0 + + if p==2.0: + for i in range(1,len(Frames)): + grad[i] = grad[i] + 2.0*(Frames[i] - Frames[i-1]) + + for i in range(len(Frames)-1): + grad[i] = grad[i] + 2.0*(Frames[i] - Frames[i+1]) + + return np.array([Wrapped_Convolve(Wrapped_Convolve(grad[i],ker),ker)/N_frame*Frames[i] for i in range(len(Frames))]).flatten() + else: + for i in range(1,len(Frames)): + grad_temp = Wrapped_Convolve(Frames[i] - Frames[i-1],ker) + grad[i] = grad[i] + p*(np.abs(grad_temp)**(p-1.0)*np.sign(grad_temp)) + + for i in range(len(Frames)-1): + grad_temp = Wrapped_Convolve(Frames[i] - Frames[i+1],ker) + grad[i] = grad[i] + p*(np.abs(grad_temp)**(p-1.0)*np.sign(grad_temp)) + + return np.array([Wrapped_Convolve(grad[i],ker)/N_frame*Frames[i] for i in range(N_frame)]).flatten() + +#RdI Regularizer +def RdI_KL(Frames): + N_frame = Frames.shape[0] + ep=1e-10 + avg_Image = np.mean(Frames,axis=0) + + return np.sum(Frames * np.log((Frames + ep)/(avg_Image + ep)))/N_frame + +def RdI_KL_gradient(Frames): + N_frame = Frames.shape[0] + ep=1e-10 + avg_Image = np.mean(Frames,axis=0) + return np.concatenate([(np.log((Frames[i]+ep)/(avg_Image+ep))/N_frame*Frames[i]).ravel() for i in range(N_frame)]) + +def RdI_SymKL(Frames): + N_frame = Frames.shape[0] + ep=1e-10 + avg_Image = np.mean(Frames,axis=0) + + return np.sum(0.5 * (Frames - avg_Image) * np.log((Frames + ep)/(avg_Image + ep)))/N_frame + +def RdI_SymKL_gradient(Frames): + N_frame = Frames.shape[0] + ep=1e-10 + avg_Image = np.mean(Frames,axis=0) + term2 = 1.0/N_frame * np.sum(np.log((Frames + ep)/(avg_Image + ep)), axis=0) + + return np.concatenate([(0.5*( (Frames[i]-avg_Image)/(Frames[i]+ep) + np.log((Frames[i]+ep)/(avg_Image+ep)) - term2)/N_frame*Frames[i]).ravel() for i in range(N_frame)]) + +def RdI_D2(Frames): + N_frame = Frames.shape[0] + avg_Image = np.mean(Frames,axis=0) + return np.sum((Frames - avg_Image)**2)/N_frame + +def RdI_D2_gradient(Frames): + N_frame = Frames.shape[0] + avg_Image = np.mean(Frames,axis=0) + return np.concatenate([(2.0*(Frames[i] - avg_Image)/N_frame*Frames[i]).ravel() for i in range(Frames.shape[0])]) + +def RdI_Dp(Frames, p=2.0): + N_frame = Frames.shape[0] + avg_Image = np.mean(Frames,axis=0) + return np.sum(np.abs(Frames - avg_Image)**p)/N_frame + +def RdI_Dp_gradient(Frames, p=2.0): + N_frame = Frames.shape[0] + avg_Image = np.mean(Frames,axis=0) + term2 = -p/N_frame * np.sum([np.abs(Frames[i] - avg_Image)**(p-1.0)*np.sign(Frames[i] - avg_Image) for i in range(N_frame)],axis=0) + return np.concatenate([((p*np.abs(Frames[i] - avg_Image)**(p-1.0)*np.sign(Frames[i] - avg_Image) + term2)/N_frame*Frames[i]).ravel() for i in range(N_frame)]) + +#Rflow Regularizer +def Rflow_D2(Frames, Flow): + N_frame = Frames.shape[0] + val = 0.0 + + for j in range(len(Frames)-1): + #dI_dt = -Wrapped_Divergence( Frames[j,:,:,None]*Flow ) #this is not the same as the expanded version (for discrete derivatives) + dI_dt = -(Wrapped_Weighted_Divergence(Flow,Frames[j]) + Frames[j]*Wrapped_Divergence(Flow)) + val = val + np.sum( (Frames[j+1] - (Frames[j] + dI_dt))**2 ) + + return val/N_frame + +def Rflow_D2_gradient_I(Frames, Flow): + N_frame = Frames.shape[0] + grad = 0.0*np.copy(Frames) + + for j in range(1,len(Frames)): + #dI_dt = -Wrapped_Divergence( Frames[j-1,:,:,None]*Flow ) + dI_dt = -(Wrapped_Weighted_Divergence(Flow,Frames[j-1]) + Frames[j-1]*Wrapped_Divergence(Flow)) + deltaI = Frames[j] - (Frames[j-1] + dI_dt) + grad[j] = grad[j] + 2.0*deltaI + + for j in range(0,len(Frames)-1): + #dI_dt = -Wrapped_Divergence( Frames[j,:,:,None]*Flow ) + dI_dt = -(Wrapped_Weighted_Divergence(Flow,Frames[j]) + Frames[j]*Wrapped_Divergence(Flow)) + deltaI = Frames[j+1] - (Frames[j] + dI_dt) + grad[j] = grad[j] - 2.0*(deltaI + Wrapped_Weighted_Divergence(Flow, deltaI)) + + for j in range(len(Frames)): + grad[j] = grad[j]*Frames[j] + + return np.concatenate([g.flatten() for g in grad])/N_frame + +def Rflow_D2_gradient_m(Frames, Flow): + N_frame = Frames.shape[0] + grad = 0.0*np.copy(Flow) + + for j in range(len(Frames)-1): + #dI_dt = -Wrapped_Divergence( Frames[j,:,:,None]*Flow ) + dI_dt = -(Wrapped_Weighted_Divergence(Flow,Frames[j]) + Frames[j]*Wrapped_Divergence(Flow)) + deltaI = Frames[j+1] - (Frames[j] + dI_dt) + grad = grad + 2.0 * deltaI[:,:,None] * Wrapped_Gradient_Reorder( Frames[j] ) + grad = grad - 2.0 * Wrapped_Gradient_Reorder( deltaI*Frames[j] ) + + return np.array(grad).flatten()/N_frame + +def Rflow_D2_gradient_m_alt(Frames, Flow): #not sure if this is correct + N_frame = Frames.shape[0] + grad = 0.0*np.copy(Flow) + + for j in range(len(Frames)-1): + #dI_dt = -Wrapped_Divergence( Frames[j,:,:,None]*Flow ) + dI_dt = -(Wrapped_Weighted_Divergence(Flow,Frames[j]) + Frames[j]*Wrapped_Divergence(Flow)) + deltaI = Frames[j+1] - (Frames[j] + dI_dt) + grad = -2.0 * Frames[j][:,:,None] * Wrapped_Gradient_Reorder( deltaI ) + + return np.array(grad).flatten()/N_frame + +#### Helper functions for the flow #### + +def squared_gradient_flow(flow): + """Total squared gradient of flow""" + + return np.sum(np.array(Wrapped_Gradient(flow[:,:,0]))**2 + np.array(Wrapped_Gradient(flow[:,:,1]))**2) + +def squared_gradient_flow_grad(flow): + """Total squared gradient of flow gradient wrt flow""" + + grad_x = -2.0*Wrapped_Divergence(Wrapped_Gradient_Reorder(flow[:,:,0])) + grad_y = -2.0*Wrapped_Divergence(Wrapped_Gradient_Reorder(flow[:,:,1])) + + return np.transpose([grad_x.ravel(),grad_y.ravel()]).ravel() + +###### Static Regularizer Master Functions ####### + +def static_regularizer(Frame_List, Prior_List, embed_mask_List, flux, psize, stype="simple", norm_reg=True, **kwargs): + N_frame = Frame_List.shape[0] + xdim = int(len(Frame_List[0].ravel())**0.5) + + s = np.sum( regularizer(Frame_List[i].ravel()[embed_mask_List[i]], Prior_List[i].ravel()[embed_mask_List[i]], embed_mask_List[i], flux=flux, xdim=xdim, ydim=xdim, psize=psize, stype=stype, norm_reg=norm_reg, **kwargs) for i in range(N_frame)) + + return s/N_frame + +def static_regularizer_gradient(Frame_List, Prior_List, embed_mask_List, flux, psize, stype="simple", norm_reg=True, **kwargs): + # Note: this function includes Jacobian factor to account for the frames being written as log(frame) + N_frame = Frame_List.shape[0] + xdim = int(len(Frame_List[0].ravel())**0.5) + + s = np.concatenate([regularizergrad((Frame_List[i].ravel())[embed_mask_List[i]], Prior_List[i].ravel()[embed_mask_List[i]], embed_mask_List[i], flux=flux, xdim=xdim, ydim=xdim, psize=psize, stype=stype, norm_reg=norm_reg, **kwargs)*(Frame_List[i].ravel())[embed_mask_List[i]] for i in range(N_frame)]) + + return s/N_frame + +################################################################################################## +# Other Regularization Functions +################################################################################################## + +def centroid(Frame_List, coord): + return np.sum(np.sum(im.ravel() * coord[:,0])**2 + np.sum(im.ravel() * coord[:,1])**2 for im in Frame_List)/len(Frame_List) + +def centroid_gradient(Frame_List, coord): #Includes Jacobian factor to account for the frames being written as log(frame) + return 2.0 * np.concatenate([(np.sum(im.ravel() * coord[:,0])*coord[:,0] + np.sum(im.ravel() * coord[:,1])*coord[:,1])*im.ravel() for im in Frame_List])/len(Frame_List) + +def movie_flux_constraint(Frame_List, flux_List): + # This is the mean squared *fractional* difference in image total flux density + # Negative means ignore + norm = float(np.sum([f > 0.0 for f in flux_List])) + return np.sum([(np.sum(Frame_List[j]) - flux_List[j])**2/flux_List[j]**2/norm*(flux_List[j] >= 0.0) for j in range(len(Frame_List))]) + +def movie_flux_constraint_grad(Frame_List, flux_List): #Includes Jacobian factor to account for the frames being written as log(frame) + norm = float(np.sum([f > 0.0 for f in flux_List])) + return np.concatenate([2.0*(np.sum(Frame_List[j]) - flux_List[j])/flux_List[j]**2/norm*Frame_List[j].ravel()*(flux_List[j] >= 0.0) for j in range(len(Frame_List))]) + +################################################################################################## +# chi^2 estimation routines +################################################################################################## + + +def get_chisq(i, imvec_embed, d1, d2, d3, ttype, mask): + global A1_List, A2_List, A3_List, data1_List, data2_List, data3_List, sigma1_List, sigma2_List, sigma3_List + chisq1 = chisq2 = chisq3 = 1.0 + + if d1 != False and len(data1_List[i])>0: + + chisq1 = chisq(imvec_embed, A1_List[i], data1_List[i], sigma1_List[i], d1, ttype=ttype, mask=mask) + + if d2 != False and len(data2_List[i])>0: + chisq2 = chisq(imvec_embed, A2_List[i], data2_List[i], sigma2_List[i], d2, ttype=ttype, mask=mask) + + if d3 != False and len(data3_List[i])>0: + chisq3 = chisq(imvec_embed, A3_List[i], data3_List[i], sigma3_List[i], d3, ttype=ttype, mask=mask) + + return [chisq1, chisq2, chisq3] + +def get_chisq_wrap(args): + return get_chisq(*args) + + +def get_chisqgrad(i, imvec_embed, d1, d2, d3, ttype, mask): + global A1_List, A2_List, A3_List, data1_List, data2_List, data3_List, sigma1_List, sigma2_List, sigma3_List + chisqgrad1 = chisqgrad2 = chisqgrad3 = 0.0*imvec_embed + + if d1 != False and len(data1_List[i])>0: + + chisqgrad1 = chisqgrad(imvec_embed, A1_List[i], data1_List[i], sigma1_List[i], d1, ttype=ttype, mask=mask) #This *does not* include the Jacobian factor + + if d2 != False and len(data2_List[i])>0: + chisqgrad2 = chisqgrad(imvec_embed, A2_List[i], data2_List[i], sigma2_List[i], d2, ttype=ttype, mask=mask) #This *does not* include the Jacobian factor + + if d3 != False and len(data3_List[i])>0: + chisqgrad3 = chisqgrad(imvec_embed, A3_List[i], data3_List[i], sigma3_List[i], d3, ttype=ttype, mask=mask) #This *does not* include the Jacobian factor + + return [chisqgrad1, chisqgrad2, chisqgrad3] + +def get_chisqgrad_wrap(args): + return get_chisqgrad(*args) + +################################################################################################## +# Imagers +################################################################################################## + + +def dynamical_imaging_minimal(Obsdata_List, InitIm_List, Prior, flux_List = [], +d1='vis', d2=False, d3=False, +alpha_d1=10, alpha_d2=10, alpha_d3=10, +systematic_noise1=0.0, systematic_noise2=0.0, systematic_noise3=0.0, +entropy1="tv2", entropy2="l1", +alpha_s1=1.0, alpha_s2=1.0, norm_reg=True, alpha_A=1.0, +R_dt ={'alpha':0.0, 'metric':'SymKL', 'p':2.0}, +maxit=200, J_factor = 0.001, stop=1.0e-10, ipynb=False, refresh_interval = 1000, +minimizer_method = 'L-BFGS-B', NHIST = 25, update_interval = 1, clipfloor=0., +ttype = 'nfft', fft_pad_factor=2): + + global A1_List, A2_List, A3_List, data1_List, data2_List, data3_List, sigma1_List, sigma2_List, sigma3_List + + N_frame = len(Obsdata_List) + N_pixel = Prior.xdim #pixel dimension + + # Determine the appropriate final resolution + all_res = [] + for obs in Obsdata_List: + if len(obs.data) > 0: + all_res.append(obs.res()) + + beam_size = np.min(np.array(all_res)) + print("Maximal Resolution:",beam_size) + + # Find an observation with data + for j in range(N_frame): + if len(Obsdata_List[j].data) > 0: + first_obs = Obsdata_List[j] + break + + # Catch problem if uvrange < largest baseline + if 1./Prior.psize < np.max(first_obs.unpack(['uvdist'])['uvdist']): + raise Exception("pixel spacing is larger than smallest spatial wavelength!") + + if alpha_flux > 0.0 and len(flux_List) != N_frame: + raise Exception("Number of elements in the list of total flux densities does not match the number of frames!") + + # Make the blurring kernel for R_dt + # Note: There are odd problems when sigma_dt is too small. I can't figure out why it causes the convolution to crash. + # However, having sigma_dt -> 0 is not a problem in theory. So we'll just set the kernel to be zero in that case and then ignore it later in the convolution. + + B_dt = np.zeros((Prior.ydim,Prior.xdim)) + + embed_mask_List = [Prior.imvec > clipfloor for j in range(N_frame)] + embed_mask_All = np.array(embed_mask_List).flatten() + + embed_totals = [np.sum(embed_mask) for embed_mask in embed_mask_List] + + logprior_List = [None,] * N_frame + loginit_List = [None,] * N_frame + + nprior_embed_List = [None,] * N_frame + nprior_List = [None,] * N_frame + + ninit_embed_List = [InitIm_List[i].imvec for i in range(N_frame)] + ninit_List = [ninit_embed_List[i][embed_mask_List[i]] for i in range(N_frame)] + + print ("Calculating lists/matrices for chi-squared terms...") + A1_List = [None for _ in range(N_frame)] + A2_List = [None for _ in range(N_frame)] + A3_List = [None for _ in range(N_frame)] + data1_List = [[] for _ in range(N_frame)] + data2_List = [[] for _ in range(N_frame)] + data3_List = [[] for _ in range(N_frame)] + sigma1_List = [None for _ in range(N_frame)] + sigma2_List = [None for _ in range(N_frame)] + sigma3_List = [None for _ in range(N_frame)] + + # Get data and Fourier matrices for the data terms + for i in range(N_frame): + pixel_max = np.max(InitIm_List[i].imvec) + prior_flux_rescale = 1.0 + if len(flux_List) > 0: + prior_flux_rescale = flux_List[i]/Prior.total_flux() + + nprior_embed_List[i] = Prior.imvec * prior_flux_rescale + + nprior_List[i] = nprior_embed_List[i][embed_mask_List[i]] + logprior_List[i] = np.log(nprior_List[i]) + loginit_List[i] = np.log(ninit_List[i] + pixel_max/Target_Dynamic_Range/1.e6) #add the dynamic range floor here + + if len(Obsdata_List[i].data) == 0: #This allows the algorithm to create frames for periods with no data + continue + + (data1_List[i], sigma1_List[i], A1_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d1, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise1) + (data2_List[i], sigma2_List[i], A2_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d2, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise2) + (data3_List[i], sigma3_List[i], A3_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d3, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise3) + + # Define the objective function and gradient + def objfunc(x): + # Frames is a list of the *unscattered* frames + Frames = np.zeros((N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + s1 = s2 = 0.0 + + if alpha_s1 != 0.0: + s1 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s1 + if alpha_s2 != 0.0: + s2 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s2 + + s_dynamic = 0.0 + + if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames, B_dt, **R_dt)*R_dt['alpha'] + + chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)]) + + chisq = ((np.sum(chisq[:,0])/N_frame - 1.0)*alpha_d1 + + (np.sum(chisq[:,1])/N_frame - 1.0)*alpha_d2 + + (np.sum(chisq[:,2])/N_frame - 1.0)*alpha_d3) + + return (s1 + s2 + s_dynamic + chisq)*J_factor + + def objgrad(x): + Frames = np.zeros((N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + s1 = s2 = 0.0 + + if alpha_s1 != 0.0: + s1 = static_regularizer_gradient(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s1 + if alpha_s2 != 0.0: + s2 = static_regularizer_gradient(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s2 + + s_dynamic_grad = 0.0 + if R_dt['alpha'] != 0.0: s_dynamic_grad += Rdt_gradient(Frames, B_dt, **R_dt)*R_dt['alpha'] + + + chisq_grad = np.array([get_chisqgrad(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)]) + + # Now add the Jacobian factor and concatenate + for j in range(N_frame): + chisq_grad[j,0] = chisq_grad[j,0]*Frames[j].ravel()[embed_mask_List[j]] + chisq_grad[j,1] = chisq_grad[j,1]*Frames[j].ravel()[embed_mask_List[j]] + chisq_grad[j,2] = chisq_grad[j,2]*Frames[j].ravel()[embed_mask_List[j]] + + chisq_grad = (np.concatenate([embed(chisq_grad[i,0], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d1 + + np.concatenate([embed(chisq_grad[i,1], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d2 + + np.concatenate([embed(chisq_grad[i,2], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d3) + + return (np.concatenate((s1 + s2 + (s_dynamic_grad + chisq_grad)[embed_mask_All]))*J_factor) + + # Plotting function for each iteration + global nit + nit = 0 + def plotcur(x, final=False): + global nit + nit += 1 + + if nit%update_interval == 0 or final == True: + print ("iteration %d" % nit) + + Frames = np.zeros((N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + s1 = s2 = 0.0 + + if alpha_s1 != 0.0: + + s1 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), + Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s1 + if alpha_s2 != 0.0: + s2 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), + Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_s2 + + s_dynamic = 0.0 + + if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames, B_dt, **R_dt)*R_dt['alpha'] + + chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]], + d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)]) + + chisq1_List = chisq[:,0] + chisq2_List = chisq[:,1] + chisq3_List = chisq[:,2] + chisq1 = np.sum(chisq1_List)/N_frame + chisq2 = np.sum(chisq2_List)/N_frame + chisq3 = np.sum(chisq3_List)/N_frame + chisq1_max = np.max(chisq1_List) + chisq2_max = np.max(chisq2_List) + chisq3_max = np.max(chisq3_List) + if d1 != False: print ("chi2_1: %f" % chisq1) + if d2 != False: print ("chi2_2: %f" % chisq2) + if d3 != False: print ("chi2_3: %f" % chisq3) + if d1 != False: print ("weighted chi2_1: %f" % (chisq1 * alpha_d1)) + if d2 != False: print ("weighted chi2_2: %f" % (chisq2 * alpha_d2)) + if d3 != False: print ("weighted chi2_3: %f" % (chisq3 * alpha_d3)) + if d1 != False: print ("Max Frame chi2_1: %f" % chisq1_max) + if d2 != False: print ("Max Frame chi2_2: %f" % chisq2_max) + if d3 != False: print ("Max Frame chi2_3: %f" % chisq3_max) + + if final == True: + if d1 != False: print ("All chisq1:",chisq1_List) + if d2 != False: print ("All chisq2:",chisq2_List) + if d3 != False: print ("All chisq3:",chisq3_List) + + if s1 != 0.0: print ("weighted s1: %f" % (s1)) + if s2 != 0.0: print ("weighted s2: %f" % (s2)) + print ("weighted s_dynamic: %f" % (s_dynamic)) + + if nit%refresh_interval == 0: + print ("Plotting Functionality Temporarily Disabled...") + + loginit = np.hstack(loginit_List).flatten() + x0 = loginit + + print ("Total Pixel #: ",(N_pixel*N_pixel*N_frame)) + print ("Clipped Pixel #: ",(len(loginit))) + + print ("Initial Values:") + plotcur(x0) + + # Minimize + optdict = {'maxiter':maxit, 'ftol':stop, 'maxcor':NHIST, 'gtol': 1e-10} # minimizer params + tstart = time.time() + res = opt.minimize(objfunc, x0, method=minimizer_method, jac=objgrad, options=optdict, callback=plotcur) + tstop = time.time() + + Frames = np.zeros((N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(res.x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + #Impose the prior mask in linear space for the output + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + plotcur(res.x, final=True) + + # Print stats + print ("time: %f s" % (tstop - tstart)) + print ("J: %f" % res.fun) + print (res.message) + + outim = [image.Image(Frames[i].reshape(Prior.ydim, Prior.xdim), Prior.psize, + Prior.ra, Prior.dec, rf=Obsdata_List[i].rf, source=Prior.source, + mjd=Prior.mjd, pulse=Prior.pulse) for i in range(N_frame)] + + return outim + + +def dynamical_imaging(obs_input, init_ims, Prior, Flow_Init = None, flux_List = None, +d1='vis', d2=False, d3=False, +alpha_d1=10, alpha_d2=10, alpha_d3=10, +systematic_noise1=0.0, systematic_noise2=0.0, systematic_noise3=0.0, +entropy1="tv2", entropy2="l1", +alpha_s1=1.0, alpha_s2=1.0, norm_reg=True, alpha_A=1.0, +R_dI ={'alpha':0.0, 'metric':'SymKL', 'p':2.0}, +R_dt ={'alpha':0.0, 'metric':'SymKL', 'sigma_dt':0.0, 'p':2.0}, +R_flow={'alpha':0.0, 'metric':'SymKL', 'p':2.0, 'alpha_flow_tv':50.0}, +alpha_centroid=0.0, alpha_flux=0.0, alpha_dF=0.0, alpha_dS1=0.0, alpha_dS2=0.0, #other regularizers +stochastic_optics=False, scattering_model=False, alpha_phi = 1.e4, #options for scattering +Target_Dynamic_Range = 10000.0, +maxit=200, J_factor = 0.001, stop=1.0e-10, ipynb=False, refresh_interval = 1000, +minimizer_method = 'L-BFGS-B', NHIST = 25, update_interval = 1, clipfloor=0., processes = -1, +recalculate_chisqdata = True, ttype = 'nfft', fft_pad_factor=2, **kwargs): + + """Run dynamical imaging. + + Args: + obs_input (List or Obsdata): Observation. Form can be either: + 1. List of Obsdata objects, one per reconstructed frame. Some can have empty data arrays. + 2. Single Obsdata object. + init_ims (List or Movie): List of initial images. List can be either: + 1. Each an Image object, one per reconstructed frame. + 2. A Movie object, where the frames will be used + Prior (Image): The Image object with the prior image + Flow_Init: Optional initialization for imaging with R_flow + flux_List (List): Optional specification of the total flux density for each frame + d1 (str): The first data term; options are 'vis', 'bs', 'amp', 'cphase', 'camp', 'logcamp' + d2 (str): The second data term; options are 'vis', 'bs', 'amp', 'cphase', 'camp', 'logcamp' + d3 (str): The third data term; options are 'vis', 'bs', 'amp', 'cphase', 'camp', 'logcamp' + systematic_noise1 (float): Systematic noise on the first data term, as a fraction of the visibility amplitude + systematic_noise2 (float): Systematic noise on the second data term, as a fraction of the visibility amplitude + systematic_noise3 (float): Systematic noise on the third data term, as a fraction of the visibility amplitude + entropy1 (str): The first regularizer; options are 'simple', 'gs', 'tv', 'tv2', 'l1', 'patch','compact','compact2','rgauss' + entropy2 (str): The second regularizer; options are 'simple', 'gs', 'tv', 'tv2','l1', 'patch','compact','compact2','rgauss' + alpha_d1 (float): The first data term weighting + alpha_d2 (float): The second data term weighting + alpha_s1 (float): The first regularizer term weighting + alpha_s2 (float): The second regularizer term weighting + alpha_flux (float): The weighting for the total flux constraint + alpha_centroid (float): The weighting for the center of mass constraint + alpha_dF (float): The weighting for temporal continuity of the total flux density. + alpha_dS1 (float): The weighting for temporal continuity of entropy1. + alpha_dS2 (float): The weighting for temporal continuity of entropy2. + + maxit (int): Maximum number of minimizer iterations + stop (float): The convergence criterion + minimizer_method (str): Minimizer method (e.g., 'L-BFGS-B' or 'CG') + update_interval (int): Print convergence status every update_interval steps + norm_reg (bool): If True, normalizes regularizer terms + ttype (str): The Fourier transform type; options are 'fast', 'direct', 'nfft' + + stochastic_optics (bool): If True, stochastic optics imaging is used. + scattering_model (ScatteringModel): Optional specification of the ScatteringModel object. + alpha_phi (float): Weighting for screen phase regularization in stochastic optics. +minimizer_method = 'L-BFGS-B', update_interval = 1 + + Returns: + List or Dictionary: A list of Image objects, one per frame, unless a flow or stochastic optics is used in which case it returns a dictionary {'Frames', 'Flow', 'EpsilonList' }. + """ + + global A1_List, A2_List, A3_List, data1_List, data2_List, data3_List, sigma1_List, sigma2_List, sigma3_List + + # Make a list of frames if a movie is passed + if type(init_ims) == list: + InitIm_List = init_ims + else: + InitIm_List = init_ims.im_list() + + # Make a list of observations if a single Obsdata object is passed + if type(obs_input) == list: + Obsdata_List = obs_input + else: + # Create one obsdata object for every frame + Obsdata_List = [obs_input.copy() for _ in InitIm_List] + # Populate each frame; for now just gather observations by nearest frame + tlist = obs_input.tlist() + obs_mjds = [obs_input.mjd + o['time'][0]/24.0 for o in tlist] + frame_mjds = [im.mjd + im.time/24.0 for im in InitIm_List] + idx_list = [np.argmin(np.abs(frame_mjds - obs_mjds[j])) for j in range(len(obs_mjds))] + + c = 0 + for j in range(len(frame_mjds)): + Obsdata_List[j].mjd = InitIm_List[j].mjd + try: + Obsdata_List[j].data = np.concatenate(tlist[[x == j for x in idx_list]]) + except: + Obsdata_List[j].data = [] + c = c + 1 + pass + if c > 0: + print("%d/%d frames have no data"%(c,len(frame_mjds))) + + N_frame = len(Obsdata_List) + N_pixel = Prior.xdim #pixel dimension + + # Determine the appropriate final resolution + all_res = [] + for obs in Obsdata_List: + if len(obs.data) > 0: + all_res.append(obs.res()) + + beam_size = np.min(np.array(all_res)) + print("Maximal Resolution:",beam_size) + + # Find an observation with data + for j in range(N_frame): + if len(Obsdata_List[j].data) > 0: + first_obs = Obsdata_List[j] + break + + # Catch problem if uvrange < largest baseline + if 1./Prior.psize < np.max(first_obs.unpack(['uvdist'])['uvdist']): + raise Exception("pixel spacing is larger than smallest spatial wavelength!") + + if alpha_flux > 0.0 and len(flux_List) != N_frame: + raise Exception("Number of elements in the list of total flux densities does not match the number of frames!") + + # If using stochastic optics, do some preliminary calculations + if stochastic_optics == True: + # Doesn't yet work with clipping + clipfloor = -1.0 + + if scattering_model == False: + print("No scattering model specified. Assuming the default scattering for Sgr A*.") + scattering_model = so.ScatteringModel() + + # First some preliminary definitions + N = InitIm_List[0].xdim + FOV = InitIm_List[0].psize * N * scattering_model.observer_screen_distance #Field of view, in cm, at the scattering screen + + # The ensemble-average convolution kernel and its gradients + wavelength_List = np.array([C/obs.rf*100.0 for obs in Obsdata_List]) #Observing wavelength for each frame [cm] + wavelengthbar_List = wavelength_List/(2.0*np.pi) #lambda/(2pi) [cm] + rF_List = [scattering_model.rF(wavelength) for wavelength in wavelength_List] + + print("Computing the Ensemble-Average Kernel for Each Frame...") + ea_ker = [scattering_model.Ensemble_Average_Kernel(InitIm_List[0], wavelength_cm = wavelength_List[j]) for j in range(N_frame)] + ea_ker_gradient = [so.Wrapped_Gradient(ea_ker[j]/(FOV/N)) for j in range(N_frame)] + ea_ker_gradient_x = [-ea_ker_gradient[j][1] for j in range(N_frame)] + ea_ker_gradient_y = [-ea_ker_gradient[j][0] for j in range(N_frame)] + + # The power spectrum (note: rotation is not currently implemented; the gradients would need to be modified slightly) + sqrtQ = np.real(scattering_model.sqrtQ_Matrix(InitIm_List[0],t_hr=0.0)) + + # Make the blurring kernel for R_dt + # Note: There are odd problems when sigma_dt is too small. I can't figure out why it causes the convolution to crash. + # However, having sigma_dt -> 0 is not a problem in theory. So we'll just set the kernel to be zero in that case and then ignore it later in the convolution. + + if R_dt['sigma_dt'] > 0.0: + B_dt = np.abs(np.array([[np.exp(-1.0*(float(i)**2+float(j)**2)/(2.*R_dt['sigma_dt']**2)) + for i in np.linspace((Prior.xdim-1)/2., -(Prior.xdim-1)/2., num=Prior.xdim)] + for j in np.linspace((Prior.ydim-1)/2., -(Prior.ydim-1)/2., num=Prior.ydim)])) + if np.max(B_dt) == 0.0 or np.sum(B_dt) == 0.0: + raise Exception("Error with the blurring kernel!") + B_dt = B_dt / np.sum(B_dt) # normalize to be flux preserving + else: + B_dt = np.zeros((Prior.ydim,Prior.xdim)) + + embed_mask_List = [Prior.imvec > clipfloor for j in range(N_frame)] + embed_mask_All = np.array(embed_mask_List).flatten() + + embed_totals = [np.sum(embed_mask) for embed_mask in embed_mask_List] + if len(set(embed_totals)) > 1 and R_flow['alpha'] > 0.0: + print ("If a flow is used, then each frame must have the same prior!") + return + + logprior_List = [None for _ in range(N_frame)] + loginit_List = [None for _ in range(N_frame)] + + nprior_embed_List = [None for _ in range(N_frame)] + nprior_List = [None for _ in range(N_frame)] + + ninit_embed_List = [InitIm_List[i].imvec for i in range(N_frame)] + ninit_List = [ninit_embed_List[i][embed_mask_List[i]] for i in range(N_frame)] + + if (recalculate_chisqdata == True and ttype == 'direct') or ttype != 'direct': + print ("Calculating lists/matrices for chi-squared terms...") + A1_List = [None for _ in range(N_frame)] + A2_List = [None for _ in range(N_frame)] + A3_List = [None for _ in range(N_frame)] + data1_List = [[] for _ in range(N_frame)] + data2_List = [[] for _ in range(N_frame)] + data3_List = [[] for _ in range(N_frame)] + sigma1_List = [None for _ in range(N_frame)] + sigma2_List = [None for _ in range(N_frame)] + sigma3_List = [None for _ in range(N_frame)] + + # Get data and Fourier matrices for the data terms + for i in range(N_frame): + pixel_max = np.max(InitIm_List[i].imvec) + prior_flux_rescale = 1.0 + if not flux_List: + pass + else: + prior_flux_rescale = flux_List[i]/Prior.total_flux() + + nprior_embed_List[i] = Prior.imvec * prior_flux_rescale + + nprior_List[i] = nprior_embed_List[i][embed_mask_List[i]] + logprior_List[i] = np.log(nprior_List[i]) + loginit_List[i] = np.log(ninit_List[i] + pixel_max/Target_Dynamic_Range/1.e6) #add the dynamic range floor here + + if len(Obsdata_List[i].data) == 0: #This allows the algorithm to create frames for periods with no data + continue + + if (recalculate_chisqdata == True and ttype == 'direct') or ttype != 'direct': + # Try to create the chisqdata. These can throw errors, for instance when no closure quantities exist in a frame with data + try: + (data1_List[i], sigma1_List[i], A1_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d1, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise1) + except: + pass + try: + (data2_List[i], sigma2_List[i], A2_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d2, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise2) + except: + pass + try: + (data3_List[i], sigma3_List[i], A3_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d3, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise3) + except: + pass + + # Coordinate matrix for COM constraint + coord = np.array([[[x,y] for x in np.linspace(Prior.xdim/2,-Prior.xdim/2,Prior.xdim)] + for y in np.linspace(Prior.ydim/2,-Prior.ydim/2,Prior.ydim)]) + coord = coord.reshape(Prior.ydim*Prior.xdim, 2) + + # Make the pool for parallel processing + if processes > 0: + print("Using Multiprocessing") + pool = Pool(processes=processes) + elif processes == 0: + processes = int(cpu_count()) + print("Using Multiprocessing with %d Processes" % processes) + pool = Pool(processes=processes) + else: + print("Not Using Multiprocessing") + + # Define the objective function and gradient + def objfunc(x): + # Frames is a list of the *unscattered* frames + Frames = np.zeros((N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + if R_flow['alpha'] != 0.0: + cur_len = np.sum(embed_mask_List[0]) #assumes all the priors have the same embedding + Flow_x = embed(x[init_i:(init_i+2*cur_len-1):2], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Flow_y = embed(x[(init_i+1):(init_i+2*cur_len):2], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Flow = np.transpose([Flow_x.ravel(),Flow_y.ravel()]).reshape((N_pixel, N_pixel,2)) + init_i += 2*cur_len + + if stochastic_optics == True: + EpsilonList = x[init_i:(init_i + N**2-1)] + im_List = [image.Image(Frames[j], Prior.psize, Prior.ra, Prior.dec, rf=Obsdata_List[j].rf, source=Prior.source, mjd=Prior.mjd) for j in range(N_frame)] + #the list of scattered image vectors + scatt_im_List = [scattering_model.Scatter(im_List[j], Epsilon_Screen=so.MakeEpsilonScreenFromList(EpsilonList, N), ea_ker = ea_ker[j], sqrtQ=sqrtQ, Linearized_Approximation=True).imvec for j in range(N_frame)] + init_i += len(EpsilonList) + + s1 = s2 = 0.0 + + if alpha_s1 != 0.0: + s1 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s1 + if alpha_s2 != 0.0: + s2 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s2 + + s_dynamic = cm = flux = s_dS = s_dF = 0.0 + + if R_dI['alpha'] != 0.0: s_dynamic += RdI(Frames, **R_dI)*R_dI['alpha'] + if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames, B_dt, **R_dt)*R_dt['alpha'] + + if alpha_dS1 != 0.0: s_dS += RdS(Frames, nprior_embed_List, embed_mask_List, entropy1, norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_dS1 + if alpha_dS2 != 0.0: s_dS += RdS(Frames, nprior_embed_List, embed_mask_List, entropy2, norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_dS2 + + if alpha_dF != 0.0: s_dF += RdF_clip(Frames, embed_mask_List)*alpha_dF + + if alpha_centroid != 0.0: cm = centroid(Frames, coord) * alpha_centroid + + if alpha_flux > 0.0: + flux = alpha_flux * movie_flux_constraint(Frames, flux_List) + + if stochastic_optics == False: + if processes > 0: + chisq = np.array(pool.map(get_chisq_wrap, [[j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)])) + else: + chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)]) + else: + if processes > 0: + chisq = np.array(pool.map(get_chisq_wrap, [[j, scatt_im_List[j][embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)])) + else: + chisq = np.array([get_chisq(j, scatt_im_List[j][embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)]) + + chisq = ((np.sum(chisq[:,0])/N_frame - 1.0)*alpha_d1 + + (np.sum(chisq[:,1])/N_frame - 1.0)*alpha_d2 + + (np.sum(chisq[:,2])/N_frame - 1.0)*alpha_d3) + + if R_flow['alpha'] != 0.0: + flow_tv = squared_gradient_flow(Flow) + s_dynamic += flow_tv*R_flow['alpha_flow_tv'] + s_dynamic += Rflow(Frames, Flow, **R_flow)*R_flow['alpha'] + + # Scattering screen regularization term + regterm_scattering = 0.0 + if stochastic_optics == True: + chisq_epsilon = sum(EpsilonList*EpsilonList)/((N*N-1.0)/2.0) + regterm_scattering = alpha_phi * (chisq_epsilon - 1.0) + + return (s1 + s2 + s_dF + s_dS + s_dynamic + chisq + cm + flux + regterm_scattering)*J_factor + + def objgrad(x): + Frames = np.zeros((N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + if R_flow['alpha'] != 0.0: + cur_len = np.sum(embed_mask_List[0]) #assumes all the priors have the same embedding + Flow_x = embed(x[init_i:(init_i+2*cur_len-1):2], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Flow_y = embed(x[(init_i+1):(init_i+2*cur_len):2], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Flow = np.transpose([Flow_x.ravel(),Flow_y.ravel()]).reshape((N_pixel, N_pixel,2)) + init_i += 2*cur_len + + if stochastic_optics == True: + EpsilonList = x[init_i:(init_i + N**2-1)] + Epsilon_Screen = so.MakeEpsilonScreenFromList(EpsilonList, N) + im_List = [image.Image(Frames[j], Prior.psize, Prior.ra, Prior.dec, rf=Obsdata_List[j].rf, source=Prior.source, mjd=Prior.mjd) for j in range(N_frame)] + scatt_im_List = [scattering_model.Scatter(im_List[j], Epsilon_Screen=so.MakeEpsilonScreenFromList(EpsilonList, N), ea_ker = ea_ker[j], sqrtQ=sqrtQ, Linearized_Approximation=True).imvec for j in range(N_frame)] #the list of scattered image vectors + Epsilon_Screen = so.MakeEpsilonScreenFromList(EpsilonList, N) + init_i += len(EpsilonList) + + s1 = s2 = 0.0 + + if alpha_s1 != 0.0: + s1 = static_regularizer_gradient(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s1 + if alpha_s2 != 0.0: + s2 = static_regularizer_gradient(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s2 + + s_dynamic_grad = cm_grad = flux_grad = s_dS = s_dF = 0.0 + if R_dI['alpha'] != 0.0: s_dynamic_grad += RdI_gradient(Frames,**R_dI)*R_dI['alpha'] + if R_dt['alpha'] != 0.0: s_dynamic_grad += Rdt_gradient(Frames, B_dt, **R_dt)*R_dt['alpha'] + + if alpha_dS1 != 0.0: s_dS += RdS_gradient(Frames, nprior_embed_List, embed_mask_List, entropy1, norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_dS1 + if alpha_dS2 != 0.0: s_dS += RdS_gradient(Frames, nprior_embed_List, embed_mask_List, entropy2, norm_reg, beam_size=beam_size, alpha_A=alpha_A)*alpha_dS2 + + if alpha_dF != 0.0: s_dF += RdF_gradient_clip(Frames, embed_mask_List)*alpha_dF + + if alpha_centroid != 0.0: cm_grad = centroid_gradient(Frames, coord) * alpha_centroid + + if alpha_flux > 0.0: + flux_grad = alpha_flux * movie_flux_constraint_grad(Frames, flux_List) + + dchisq_dIa_List = [] + + + # Michael -- can we do something about this + if stochastic_optics == False: + if processes > 0: + chisq_grad = np.array(pool.map(get_chisqgrad_wrap, [[j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)])) + else: + chisq_grad = np.array([get_chisqgrad(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)]) + else: + if processes > 0: + chisq_grad = np.array(pool.map(get_chisqgrad_wrap, [[j, scatt_im_List[j][embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)])) + else: + chisq_grad = np.array([get_chisqgrad(j, scatt_im_List[j][embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)]) + + # Now, the chi^2 gradient must be modified so that it corresponds to the gradient wrt the unscattered image + for j in range(N_frame): + rF = rF_List[j] + phi = scattering_model.MakePhaseScreen(Epsilon_Screen, im_List[0], obs_frequency_Hz=im_List[j].rf,sqrtQ_init=sqrtQ).imvec.reshape((N, N)) + phi_Gradient = so.Wrapped_Gradient(phi/(FOV/N)) + phi_Gradient_x = -phi_Gradient[1] + phi_Gradient_y = -phi_Gradient[0] + dchisq_dIa_List.append( ((chisq_grad[j,0]*alpha_d1 + chisq_grad[j,1]*alpha_d2 + chisq_grad[j,2]*alpha_d3)/N_frame).reshape((N,N)) ) + + dchisq_dIa = chisq_grad[j,0].reshape((N,N)) + gx = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_x[j][::-1,::-1], phi_Gradient_x * (dchisq_dIa))).flatten() + gy = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_y[j][::-1,::-1], phi_Gradient_y * (dchisq_dIa))).flatten() + chisq_grad[j,0] = so.Wrapped_Convolve(ea_ker[j][::-1,::-1], (dchisq_dIa)).flatten() + gx + gy + + dchisq_dIa = chisq_grad[j,1].reshape((N,N)) + gx = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_x[j][::-1,::-1], phi_Gradient_x * (dchisq_dIa))).flatten() + gy = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_y[j][::-1,::-1], phi_Gradient_y * (dchisq_dIa))).flatten() + chisq_grad[j,1] = so.Wrapped_Convolve(ea_ker[j][::-1,::-1], (dchisq_dIa)).flatten() + gx + gy + + dchisq_dIa = chisq_grad[j,2].reshape((N,N)) + gx = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_x[j][::-1,::-1], phi_Gradient_x * (dchisq_dIa))).flatten() + gy = (rF**2.0 * so.Wrapped_Convolve(ea_ker_gradient_y[j][::-1,::-1], phi_Gradient_y * (dchisq_dIa))).flatten() + chisq_grad[j,2] = so.Wrapped_Convolve(ea_ker[j][::-1,::-1], (dchisq_dIa)).flatten() + gx + gy + + # Now add the Jacobian factor and concatenate + for j in range(N_frame): + chisq_grad[j,0] = chisq_grad[j,0]*Frames[j].ravel()[embed_mask_List[j]] + chisq_grad[j,1] = chisq_grad[j,1]*Frames[j].ravel()[embed_mask_List[j]] + chisq_grad[j,2] = chisq_grad[j,2]*Frames[j].ravel()[embed_mask_List[j]] + + chisq_grad = (np.concatenate([embed(chisq_grad[i,0], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d1 + + np.concatenate([embed(chisq_grad[i,1], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d2 + + np.concatenate([embed(chisq_grad[i,2], embed_mask_List[i]) for i in range(N_frame)])/N_frame*alpha_d3) + + # Gradient of the data chi^2 wrt to the epsilon screen -- this is the really difficult one + chisq_grad_epsilon = np.array([]) + if stochastic_optics == True: + #Preliminary Definitions + chisq_grad_epsilon = np.zeros(N**2-1) + ell_mat = np.zeros((N,N)) + m_mat = np.zeros((N,N)) + for ell in range(0, N): + for m in range(0, N): + ell_mat[ell,m] = ell + m_mat[ell,m] = m + + for j in range(N_frame): + rF = rF_List[j] + dchisq_dIa = dchisq_dIa_List[j] + EA_Image = scattering_model.Ensemble_Average_Blur(im_List[j], ker = ea_ker[j]) + EA_Gradient = so.Wrapped_Gradient((EA_Image.imvec/(FOV/N)).reshape(N, N)) + #The gradient signs don't actually matter, but let's make them match intuition (i.e., right to left, bottom to top) + EA_Gradient_x = -EA_Gradient[1] + EA_Gradient_y = -EA_Gradient[0] + + i_grad = 0 + #Real part; top row + for t in range(1, (N+1)//2): + s=0 + grad_term = so.Wrapped_Gradient(wavelengthbar_List[j]/FOV*sqrtQ[s][t]*2.0*np.cos(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N)) + grad_term_x = -grad_term[1] + grad_term_y = -grad_term[0] + chisq_grad_epsilon[i_grad] += np.sum( dchisq_dIa * rF**2 * ( EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y ) ) + i_grad = i_grad + 1 + + #Real part; remainder + for s in range(1,(N+1)//2): + for t in range(N): + grad_term = so.Wrapped_Gradient(wavelengthbar_List[j]/FOV*sqrtQ[s][t]*2.0*np.cos(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N)) + grad_term_x = -grad_term[1] + grad_term_y = -grad_term[0] + chisq_grad_epsilon[i_grad] += np.sum( dchisq_dIa * rF**2 * ( EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y ) ) + i_grad = i_grad + 1 + + #Imaginary part; top row + for t in range(1, (N+1)//2): + s=0 + grad_term = so.Wrapped_Gradient(-wavelengthbar_List[j]/FOV*sqrtQ[s][t]*2.0*np.sin(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N)) + grad_term_x = -grad_term[1] + grad_term_y = -grad_term[0] + chisq_grad_epsilon[i_grad] += np.sum( dchisq_dIa * rF**2 * ( EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y ) ) + i_grad = i_grad + 1 + + #Imaginary part; remainder + for s in range(1,(N+1)//2): + for t in range(N): + grad_term = so.Wrapped_Gradient(-wavelengthbar_List[j]/FOV*sqrtQ[s][t]*2.0*np.sin(2.0*np.pi/N*(ell_mat*s + m_mat*t))/(FOV/N)) + grad_term_x = -grad_term[1] + grad_term_y = -grad_term[0] + chisq_grad_epsilon[i_grad] += np.sum( dchisq_dIa * rF**2 * ( EA_Gradient_x * grad_term_x + EA_Gradient_y * grad_term_y ) ) + i_grad = i_grad + 1 + + # Gradients related to the flow + flow_grad = np.array([]) + if R_flow['alpha'] != 0.0: + cur_len = np.sum(embed_mask_List[0]) + Flow_x = embed(x[init_i:(init_i+2*cur_len-1):2], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Flow_y = embed(x[(init_i+1):(init_i+2*cur_len):2], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Flow = np.transpose([Flow_x.ravel(),Flow_y.ravel()]).reshape((N_pixel, N_pixel,2)) + flow_tv_grad = squared_gradient_flow_grad(Flow) + s_dynamic_grad_Frames = s_dynamic_grad_Flow = 0.0 + s_dynamic_grad_Frames = Rflow_gradient_I(Frames, Flow, R_flow) + s_dynamic_grad_Flow = Rflow_gradient_m(Frames, Flow, R_flow) + + s_dynamic_grad += s_dynamic_grad_Frames*R_flow['alpha'] + flow_grad = s_dynamic_grad_Flow*R_flow['alpha'] + flow_tv_grad*R_flow['alpha_flow_tv'] + # now handle the embedding + flow_grad = np.transpose([flow_grad[::2][embed_mask_List[0]], flow_grad[1::2][embed_mask_List[0]]]).ravel() + + # Gradient of the chi^2 regularization term for the epsilon screen + chisq_epsilon_grad = np.array([]) + if stochastic_optics == True: + chisq_epsilon_grad = alpha_phi * 2.0*EpsilonList/((N*N-1)/2.0) + + return (np.concatenate((s1 + s2 + s_dF + s_dS + (s_dynamic_grad + chisq_grad + cm_grad + cm_grad + flux_grad)[embed_mask_All], flow_grad, chisq_grad_epsilon + chisq_epsilon_grad))*J_factor) + + # Plotting function for each iteration + global nit + nit = 0 + def plotcur(x, final=False): + global nit + nit += 1 + + if nit%update_interval == 0 or final == True: + print ("iteration %d" % nit) + + Frames = np.zeros((N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + if R_flow['alpha'] != 0.0: + cur_len = np.sum(embed_mask_List[0]) #assumes all the priors have the same embedding + Flow_x = embed(x[init_i:(init_i+2*cur_len-1):2], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Flow_y = embed(x[(init_i+1):(init_i+2*cur_len):2], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Flow = np.transpose([Flow_x.ravel(),Flow_y.ravel()]).reshape((N_pixel, N_pixel,2)) + init_i += 2*cur_len + + if stochastic_optics == True: + EpsilonList = x[init_i:(init_i + N**2-1)] + im_List = [image.Image(Frames[j], Prior.psize, Prior.ra, Prior.dec, rf=Obsdata_List[j].rf, source=Prior.source, mjd=Prior.mjd) for j in range(N_frame)] + + scatt_im_List = [scattering_model.Scatter(im_List[j], Epsilon_Screen=so.MakeEpsilonScreenFromList(EpsilonList, N), ea_ker = ea_ker[j], sqrtQ=sqrtQ, Linearized_Approximation=True).imvec + for j in range(N_frame)] #the list of scattered image vectors + + s1 = s2 = 0.0 + + if alpha_s1 != 0.0: + + s1 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), + Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s1 + if alpha_s2 != 0.0: + s2 = static_regularizer(Frames, nprior_embed_List, embed_mask_List, Prior.total_flux(), + Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_s2 + + s_dynamic = cm = s_dS = s_dF = 0.0 + + if R_dI['alpha'] != 0.0: s_dynamic += RdI(Frames, **R_dI)*R_dI['alpha'] + if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames, B_dt, **R_dt)*R_dt['alpha'] + + if alpha_dS1 != 0.0: s_dS += RdS(Frames, nprior_embed_List, embed_mask_List, entropy1, norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_dS1 + if alpha_dS2 != 0.0: s_dS += RdS(Frames, nprior_embed_List, embed_mask_List, entropy2, norm_reg, beam_size=beam_size, alpha_A=alpha_A, **kwargs)*alpha_dS2 + + if alpha_dF != 0.0: s_dF += RdF_clip(Frames, embed_mask_List)*alpha_dF + + if alpha_centroid != 0.0: cm = centroid(Frames, coord) * alpha_centroid + + if stochastic_optics == False: + if processes > 0: + + chisq = np.array(pool.map(get_chisq_wrap, [[j, Frames[j].ravel()[embed_mask_List[j]], + d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)])) + else: + chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]], + d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)]) + else: + if processes > 0: + chisq = np.array(pool.map(get_chisq_wrap, [[j, scatt_im_List[j][embed_mask_List[j]], + d1, d2, d3, ttype, embed_mask_List[j]] for j in range(N_frame)])) + else: + chisq = np.array([get_chisq(j, scatt_im_List[j][embed_mask_List[j]], + d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame)]) + + chisq1_List = chisq[:,0] + chisq2_List = chisq[:,1] + chisq3_List = chisq[:,2] + chisq1 = np.sum(chisq1_List)/N_frame + chisq2 = np.sum(chisq2_List)/N_frame + chisq3 = np.sum(chisq3_List)/N_frame + chisq1_max = np.max(chisq1_List) + chisq2_max = np.max(chisq2_List) + chisq3_max = np.max(chisq3_List) + if d1 != False: print ("chi2_1: %f" % chisq1) + if d2 != False: print ("chi2_2: %f" % chisq2) + if d3 != False: print ("chi2_3: %f" % chisq3) + if d1 != False: print ("weighted chi2_1: %f" % (chisq1 * alpha_d1)) + if d2 != False: print ("weighted chi2_2: %f" % (chisq2 * alpha_d2)) + if d3 != False: print ("weighted chi2_3: %f" % (chisq3 * alpha_d3)) + if d1 != False: print ("Max Frame chi2_1: %f" % chisq1_max) + if d2 != False: print ("Max Frame chi2_2: %f" % chisq2_max) + if d3 != False: print ("Max Frame chi2_3: %f" % chisq3_max) + + if final == True: + if d1 != False: print ("All chisq1:",chisq1_List) + if d2 != False: print ("All chisq2:",chisq2_List) + if d3 != False: print ("All chisq3:",chisq3_List) + + # Now deal with the a flow, if necessary + if R_flow['alpha'] != 0.0: + flow_tv = squared_gradient_flow(Flow) + s_dynamic += flow_tv*R_flow['alpha_flow_tv'] + print ("Weighted Flow TV: %f" % (flow_tv*R_flow['alpha_flow_tv'])) + s_dynamic += Rflow(Frames, Flow, **R_flow)*R_flow['alpha'] + print ("Weighted R_Flow: %f" % (Rflow(Frames, Flow, **R_flow)*R_flow['alpha'])) + + if s1 != 0.0: print ("weighted s1: %f" % (s1)) + if s2 != 0.0: print ("weighted s2: %f" % (s2)) + if s_dF != 0.0: print ("weighted s_dF: %f" % (s_dF)) + if s_dS != 0.0: print ("weighted s_dS: %f" % (s_dS)) + print ("weighted s_dynamic: %f" % (s_dynamic)) + if alpha_centroid > 0.0: print ("weighted COM: %f" % cm) + + if alpha_flux > 0.0: + print ("weighted flux constraint: %f" % (alpha_flux * movie_flux_constraint(Frames, flux_List))) + + if stochastic_optics == True: + chisq_epsilon = sum(EpsilonList*EpsilonList)/((N*N-1.0)/2.0) + regterm_scattering = alpha_phi * (chisq_epsilon - 1.0) + print("Epsilon chi^2 : %0.2f " % (chisq_epsilon)) + print("Weighted Epsilon chi^2 : %0.2f " % (regterm_scattering)) + print("Max |Epsilon| : %0.2f " % (max(abs(EpsilonList)))) + + if nit%refresh_interval == 0: + print ("Plotting Functionality Temporarily Disabled...") + + loginit = np.hstack(loginit_List).flatten() + if R_flow['alpha'] == 0.0: + x0 = loginit + else: + Flow_Init_embed = np.transpose([Flow_Init.ravel()[::2][embed_mask_List[0]],Flow_Init.ravel()[1::2][embed_mask_List[0]]]).ravel() + x0 = np.concatenate( (loginit, Flow_Init_embed) ) + + if stochastic_optics == True: + x0 = np.concatenate((x0,np.zeros(N**2-1))) + + + print ("Total Pixel #: ",(N_pixel*N_pixel*N_frame)) + print ("Clipped Pixel #: ",(len(loginit))) + + print ("Initial Values:") + plotcur(x0) + + # Minimize + optdict = {'maxiter':maxit, 'ftol':stop, 'maxcor':NHIST, 'gtol': 1e-10} # minimizer params + tstart = time.time() + res = opt.minimize(objfunc, x0, method=minimizer_method, jac=objgrad, options=optdict, callback=plotcur) + tstop = time.time() + + Frames = np.zeros((N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(res.x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + #Impose the prior mask in linear space for the output + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + Flow = EpsilonList = False + + if R_flow['alpha'] != 0.0: + print ("Collecting Flow...") + cur_len = np.sum(embed_mask_List[0]) + Flow_x = embed(res.x[init_i:(init_i+2*cur_len-1):2], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Flow_y = embed(res.x[(init_i+1):(init_i+2*cur_len):2], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Flow = np.transpose([Flow_x.ravel(),Flow_y.ravel()]).reshape((N_pixel, N_pixel,2)) + init_i += 2*cur_len + + if stochastic_optics == True: + EpsilonList = res.x[init_i:(init_i + N**2-1)] + init_i += len(EpsilonList) + + + plotcur(res.x, final=True) + + # Print stats + print ("time: %f s" % (tstop - tstart)) + print ("J: %f" % res.fun) + print (res.message) + + #Note: the global variables are *not* released to avoid recalculation + if processes != -1: + pool.close() + + #Return Frames + outim = [image.Image(Frames[i].reshape(Prior.ydim, Prior.xdim), Prior.psize, + Prior.ra, Prior.dec, rf=Obsdata_List[i].rf, source=Prior.source, + mjd=InitIm_List[i].mjd, time=InitIm_List[i].time, pulse=Prior.pulse) for i in range(N_frame)] + + if type(init_ims) == list: + pass + else: + outim = movie.merge_im_list(outim) + + if R_flow['alpha'] == 0.0 and stochastic_optics == False: + return outim + else: + return {'Movie':outim, 'Frames':outim, 'Flow':Flow, 'EpsilonList':EpsilonList } + + +def multifreq_dynamical_imaging(Obsdata_Multifreq_List, InitIm_Multifreq_List, Prior, flux_Multifreq_List = [], +d1='vis', d2=False, d3=False, +alpha_d1=10, alpha_d2=10, alpha_d3=10, +systematic_noise1=0.0, systematic_noise2=0.0, systematic_noise3=0.0, +entropy1="tv2", entropy2="l1", +alpha_s1=1.0, alpha_s2=1.0, norm_reg=True, alpha_A=1.0, +R_dI ={'alpha':0.0, 'metric':'SymKL', 'p':2.0}, +R_dt ={'alpha':0.0, 'metric':'SymKL', 'sigma_dt':0.0, 'p':2.0}, +R_dt_multifreq ={'alpha':0.0, 'metric':'SymKL', 'sigma_dt':0.0, 'p':2.0}, +alpha_centroid=0.0, alpha_flux=0.0, alpha_dF=0.0, alpha_dS1=0.0, alpha_dS2=0.0, #other regularizers +Target_Dynamic_Range = 10000.0, +maxit=200, J_factor = 0.001, stop=1.0e-10, ipynb=False, refresh_interval = 1000, minimizer_method = 'L-BFGS-B', NHIST = 25, update_interval = 1, clipfloor=0., recalculate_chisqdata = True, ttype = 'nfft', fft_pad_factor=2): + """Run dynamic imager + Uses I = exp(I') change of variables. + Obsdata_List is a list of Obsdata objects, InitIm_List is a list of Image objects, and Prior is an Image object. + Returns list of Image objects, one per frame (unless a flow or stochastic optics is used) + ttype = 'direct' or 'fast' or 'nfft' + """ + + global A1_List, A2_List, A3_List, data1_List, data2_List, data3_List, sigma1_List, sigma2_List, sigma3_List + + N_freq = len(Obsdata_Multifreq_List) + N_frame = len(Obsdata_Multifreq_List[0]) + N_pixel = Prior.xdim #pixel dimension + + # Flatten the input lists + flux_List = [x for y in flux_Multifreq_List for x in y] + InitIm_List = [x for y in InitIm_Multifreq_List for x in y] + Obsdata_List = [x for y in Obsdata_Multifreq_List for x in y] + + # Determine the appropriate final resolution + all_res = [[] for j in range(N_freq)] + for j in range(len(Obsdata_Multifreq_List)): + for obs in Obsdata_Multifreq_List[j]: + if len(obs.data) > 0: + all_res[j].append(obs.res()) + + # Determine the beam size for each frequency + beam_size = [np.min(all_res[j]) for j in range(N_freq)] + print("Maximal Resolutions:",beam_size) + + if alpha_flux > 0.0 and len(flux_Multifreq_List[0]) != N_frame: + raise Exception("Number of elements in the list of total flux densities does not match the number of frames!") + + # Make the blurring kernel for R_dt + # Note: There are odd problems when sigma_dt is too small. I can't figure out why it causes the convolution to crash. + # However, having sigma_dt -> 0 is not a problem in theory. So we'll just set the kernel to be zero in that case and then ignore it later in the convolution. + + if R_dt['sigma_dt'] > 0.0: + B_dt = np.abs(np.array([[np.exp(-1.0*(float(i)**2+float(j)**2)/(2.*R_dt['sigma_dt']**2)) + for i in np.linspace((Prior.xdim-1)/2., -(Prior.xdim-1)/2., num=Prior.xdim)] + for j in np.linspace((Prior.ydim-1)/2., -(Prior.ydim-1)/2., num=Prior.ydim)])) + if np.max(B_dt) == 0.0 or np.sum(B_dt) == 0.0: + raise Exception("Error with the blurring kernel!") + B_dt = B_dt / np.sum(B_dt) # normalize to be flux preserving + else: + B_dt = np.zeros((Prior.ydim,Prior.xdim)) + + if R_dt_multifreq['sigma_dt'] > 0.0: + B_dt_multifreq = np.abs(np.array([[np.exp(-1.0*(float(i)**2+float(j)**2)/(2.*R_dt_multifreq['sigma_dt']**2)) + for i in np.linspace((Prior.xdim-1)/2., -(Prior.xdim-1)/2., num=Prior.xdim)] + for j in np.linspace((Prior.ydim-1)/2., -(Prior.ydim-1)/2., num=Prior.ydim)])) + if np.max(B_dt_multifreq) == 0.0 or np.sum(B_dt_multifreq) == 0.0: + raise Exception("Error with the blurring kernel!") + B_dt_multifreq = B_dt_multifreq / np.sum(B_dt_multifreq) # normalize to be flux preserving + else: + B_dt_multifreq = np.zeros((Prior.ydim,Prior.xdim)) + + embed_mask_List = [Prior.imvec > clipfloor for j in range(N_freq*N_frame)] + embed_mask_All = np.array(embed_mask_List).flatten() + + embed_totals = [np.sum(embed_mask) for embed_mask in embed_mask_List] + + logprior_List = [None,] * (N_freq * N_frame) + loginit_List = [None,] * (N_freq * N_frame) + + nprior_embed_List = [None,] * (N_freq * N_frame) + nprior_List = [None,] * (N_freq * N_frame) + + ninit_embed_List = [InitIm_List[i].imvec for i in range(N_freq * N_frame)] + ninit_List = [ninit_embed_List[i][embed_mask_List[i]] for i in range(N_freq * N_frame)] + + if (recalculate_chisqdata == True and ttype == 'direct') or ttype != 'direct': + print ("Calculating lists/matrices for chi-squared terms...") + A1_List = [None,] * N_freq * N_frame + A2_List = [None,] * N_freq * N_frame + A3_List = [None,] * N_freq * N_frame + data1_List = [[],] * N_freq * N_frame + data2_List = [[],] * N_freq * N_frame + data3_List = [[],] * N_freq * N_frame + sigma1_List = [None,] * N_freq * N_frame + sigma2_List = [None,] * N_freq * N_frame + sigma3_List = [None,] * N_freq * N_frame + + # Get data and Fourier matrices for the data terms + for i in range(N_frame*N_freq): + pixel_max = np.max(InitIm_List[i].imvec) + prior_flux_rescale = 1.0 + if len(flux_List) > 0: + prior_flux_rescale = flux_List[i]/Prior.total_flux() + + nprior_embed_List[i] = Prior.imvec * prior_flux_rescale + + nprior_List[i] = nprior_embed_List[i][embed_mask_List[i]] + logprior_List[i] = np.log(nprior_List[i]) + loginit_List[i] = np.log(ninit_List[i] + pixel_max/Target_Dynamic_Range/1.e6) #add the dynamic range floor here + + if len(Obsdata_List[i].data) == 0: #This allows the algorithm to create frames for periods with no data + continue + + if (recalculate_chisqdata == True and ttype == 'direct') or ttype != 'direct': + (data1_List[i], sigma1_List[i], A1_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d1, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise1) + (data2_List[i], sigma2_List[i], A2_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d2, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise2) + (data3_List[i], sigma3_List[i], A3_List[i]) = chisqdata(Obsdata_List[i], Prior, embed_mask_List[i], d3, ttype=ttype, fft_pad_factor=fft_pad_factor, systematic_noise=systematic_noise3) + + # Coordinate matrix for COM constraint + coord = np.array([[[x,y] for x in np.linspace(Prior.xdim/2,-Prior.xdim/2,Prior.xdim)] + for y in np.linspace(Prior.ydim/2,-Prior.ydim/2,Prior.ydim)]) + coord = coord.reshape(Prior.ydim*Prior.xdim, 2) + + # Define the objective function and gradient + def objfunc(x): + Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_freq*N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + s1 = s2 = s_multifreq = s_dynamic = cm = flux = s_dS = s_dF = 0.0 + + # Multifrequency part + if R_dt_multifreq['alpha'] != 0.0: + for j in range(N_frame): + s_multifreq += Rdt(Frames[j::N_frame], B_dt_multifreq, **R_dt_multifreq)*R_dt_multifreq['alpha'] + + # Individual frequencies + for j in range(N_freq): + i1 = j*N_frame + i2 = (j+1)*N_frame + + if alpha_s1 != 0.0: + s1 += static_regularizer(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s1 + if alpha_s2 != 0.0: + s2 += static_regularizer(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s2 + + if R_dI['alpha'] != 0.0: s_dynamic += RdI(Frames[i1:i2], **R_dI)*R_dI['alpha'] + if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames[i1:i2], B_dt, **R_dt)*R_dt['alpha'] + + if alpha_dS1 != 0.0: s_dS += RdS(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy1, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS1 + if alpha_dS2 != 0.0: s_dS += RdS(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy2, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS2 + + if alpha_dF != 0.0: s_dF += RdF_clip(Frames[i1:i2], embed_mask_List[i1:i2])*alpha_dF + + if alpha_centroid != 0.0: cm = centroid(Frames, coord) * alpha_centroid + + if alpha_flux > 0.0: + flux = alpha_flux * movie_flux_constraint(Frames, flux_List) + + chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_frame*N_freq)]) + + chisq = ((np.sum(chisq[:,0])/(N_freq*N_frame) - 1.0)*alpha_d1 + + (np.sum(chisq[:,1])/(N_freq*N_frame) - 1.0)*alpha_d2 + + (np.sum(chisq[:,2])/(N_freq*N_frame) - 1.0)*alpha_d3) + + return (s1 + s2 + s_dF + s_dS + s_multifreq + s_dynamic + chisq + cm + flux)*J_factor + + def objgrad(x): + Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_freq*N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + s1 = s2 = s_dS = s_dF = np.zeros((N_freq*N_frame*cur_len)) + s_dynamic_grad = cm_grad = flux_grad = np.zeros((N_freq*N_frame*N_pixel*N_pixel)) + s_multifreq = 0.0 + + # Multifrequency part + if R_dt_multifreq['alpha'] != 0.0: + s_multifreq = np.zeros((N_freq*N_frame, N_pixel*N_pixel)) + for j in range(N_frame): + s_multifreq[j::N_frame] += Rdt_gradient(Frames[j::N_frame], B_dt_multifreq, **R_dt_multifreq).reshape((N_freq,N_pixel*N_pixel))*R_dt_multifreq['alpha'] + s_multifreq = s_multifreq.reshape(N_freq*N_frame*N_pixel*N_pixel) + + # Individual frequencies + for j in range(N_freq): + i1 = j*N_frame + i2 = (j+1)*N_frame + f1 = j*N_frame*N_pixel*N_pixel + f2 = (j+1)*N_frame*N_pixel*N_pixel + mf1 = j*N_frame*cur_len # Note: This assumes that all priors have the same number of masked pixels! + mf2 = (j+1)*N_frame*cur_len + + + if alpha_s1 != 0.0: + s1[mf1:mf2] = static_regularizer_gradient(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s1 + if alpha_s2 != 0.0: + s2[mf1:mf2] = static_regularizer_gradient(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s2 + + if R_dI['alpha'] != 0.0: s_dynamic_grad[f1:f2] += RdI_gradient(Frames[i1:i2],**R_dI)*R_dI['alpha'] + if R_dt['alpha'] != 0.0: s_dynamic_grad[f1:f2] += Rdt_gradient(Frames[i1:i2], B_dt, **R_dt)*R_dt['alpha'] + + if alpha_dS1 != 0.0: s_dS[mf1:mf2] += RdS_gradient(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy1, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS1 + if alpha_dS2 != 0.0: s_dS[mf1:mf2] += RdS_gradient(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy2, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS2 + + if alpha_dF != 0.0: s_dF[mf1:mf2] += RdF_gradient_clip(Frames[i1:i2], embed_mask_List[i1:i2])*alpha_dF + + if alpha_centroid != 0.0: cm_grad = centroid_gradient(Frames, coord) * alpha_centroid + + if alpha_flux > 0.0: + flux_grad = alpha_flux * movie_flux_constraint_grad(Frames, flux_List) + + chisq_grad = np.array([get_chisqgrad(j, Frames[j].ravel()[embed_mask_List[j]], d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_freq*N_frame)]) + + # Now add the Jacobian factor and concatenate + for j in range(N_freq*N_frame): + chisq_grad[j,0] = chisq_grad[j,0]*Frames[j].ravel()[embed_mask_List[j]] + chisq_grad[j,1] = chisq_grad[j,1]*Frames[j].ravel()[embed_mask_List[j]] + chisq_grad[j,2] = chisq_grad[j,2]*Frames[j].ravel()[embed_mask_List[j]] + + chisq_grad = (np.concatenate([embed(chisq_grad[i,0], embed_mask_List[i]) for i in range(N_freq*N_frame)])/(N_freq*N_frame)*alpha_d1 + + np.concatenate([embed(chisq_grad[i,1], embed_mask_List[i]) for i in range(N_freq*N_frame)])/(N_freq*N_frame)*alpha_d2 + + np.concatenate([embed(chisq_grad[i,2], embed_mask_List[i]) for i in range(N_freq*N_frame)])/(N_freq*N_frame)*alpha_d3) + +# print((s1.shape, s2.shape, s_dF.shape, s_dS.shape)) +# print((s_multifreq.shape, s_dynamic_grad.shape, chisq_grad.shape, cm_grad.shape, flux_grad.shape)) +# print((s_multifreq[embed_mask_All].shape)) + + return ((s1 + s2 + s_dF + s_dS + (s_multifreq + s_dynamic_grad + chisq_grad + cm_grad + flux_grad)[embed_mask_All])*J_factor) + + # Plotting function for each iteration + global nit + nit = 0 + def plotcur(x, final=False): + global nit + nit += 1 + + if nit%update_interval == 0 or final == True: + print ("iteration %d" % nit) + + Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_freq*N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + s1 = s2 = s_multifreq = s_dynamic = cm = s_dS = s_dF = 0.0 + + # Multifrequency part + if R_dt_multifreq['alpha'] != 0.0: + for j in range(N_frame): + s_multifreq += Rdt(Frames[j::N_frame], B_dt_multifreq, **R_dt_multifreq)*R_dt_multifreq['alpha'] + + # Individual frequencies + for j in range(N_freq): + i1 = j*N_frame + i2 = (j+1)*N_frame + + if alpha_s1 != 0.0: + s1 += static_regularizer(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy1, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s1 + if alpha_s2 != 0.0: + s2 += static_regularizer(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], Prior.total_flux(), Prior.psize, entropy2, norm_reg=norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_s2 + + if R_dI['alpha'] != 0.0: s_dynamic += RdI(Frames[i1:i2], **R_dI)*R_dI['alpha'] + if R_dt['alpha'] != 0.0: s_dynamic += Rdt(Frames[i1:i2], B_dt, **R_dt)*R_dt['alpha'] + + if alpha_dS1 != 0.0: s_dS += RdS(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy1, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS1 + if alpha_dS2 != 0.0: s_dS += RdS(Frames[i1:i2], nprior_embed_List[i1:i2], embed_mask_List[i1:i2], entropy2, norm_reg, beam_size=beam_size[j], alpha_A=alpha_A)*alpha_dS2 + + if alpha_dF != 0.0: s_dF += RdF_clip(Frames[i1:i2], embed_mask_List[i1:i2])*alpha_dF + + if alpha_centroid != 0.0: cm = centroid(Frames, coord) * alpha_centroid + + chisq = np.array([get_chisq(j, Frames[j].ravel()[embed_mask_List[j]], + d1, d2, d3, ttype, embed_mask_List[j]) for j in range(N_freq*N_frame)]) + + chisq1_List = chisq[:,0] + chisq2_List = chisq[:,1] + chisq3_List = chisq[:,2] + chisq1 = np.sum(chisq1_List)/len(chisq1_List) + chisq2 = np.sum(chisq2_List)/len(chisq1_List) + chisq3 = np.sum(chisq3_List)/len(chisq1_List) + chisq1_max = np.max(chisq1_List) + chisq2_max = np.max(chisq2_List) + chisq3_max = np.max(chisq3_List) + if d1 != False: print ("chi2_1: %f" % chisq1) + if d2 != False: print ("chi2_2: %f" % chisq2) + if d3 != False: print ("chi2_3: %f" % chisq3) + if d1 != False: print ("weighted chi2_1: %f" % (chisq1 * alpha_d1)) + if d2 != False: print ("weighted chi2_2: %f" % (chisq2 * alpha_d2)) + if d3 != False: print ("weighted chi2_3: %f" % (chisq3 * alpha_d3)) + if d1 != False: print ("Max Frame chi2_1: %f" % chisq1_max) + if d2 != False: print ("Max Frame chi2_2: %f" % chisq2_max) + if d3 != False: print ("Max Frame chi2_3: %f" % chisq3_max) + + if final == True: + if d1 != False: print ("All chisq1:",chisq1_List) + if d2 != False: print ("All chisq2:",chisq2_List) + if d3 != False: print ("All chisq3:",chisq3_List) + + if s1 != 0.0: print ("weighted s1: %f" % (s1)) + if s2 != 0.0: print ("weighted s2: %f" % (s2)) + if s_dF != 0.0: print ("weighted s_dF: %f" % (s_dF)) + if s_dS != 0.0: print ("weighted s_dS: %f" % (s_dS)) + print ("weighted s_dynamic: %f" % (s_dynamic)) + print ("weighted s_multifreq: %f" % (s_multifreq)) + if alpha_centroid > 0.0: print ("weighted COM: %f" % cm) + + if alpha_flux > 0.0: + print ("weighted flux constraint: %f" % (alpha_flux * movie_flux_constraint(Frames, flux_List))) + + if nit%refresh_interval == 0: + print ("Plotting Functionality Temporarily Disabled...") + + loginit = np.hstack(loginit_List).flatten() + + x0 = loginit + + print ("Total Pixel #: ",(N_pixel*N_pixel*N_frame*N_freq)) + print ("Clipped Pixel #: ",(len(loginit))) + + print ("Initial Values:") + plotcur(x0) + + # Minimize + optdict = {'maxiter':maxit, 'ftol':stop, 'maxcor':NHIST, 'gtol': 1e-10} # minimizer params + tstart = time.time() + res = opt.minimize(objfunc, x0, method=minimizer_method, jac=objgrad, options=optdict, callback=plotcur) + tstop = time.time() + + Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel)) + log_Frames = np.zeros((N_freq*N_frame, N_pixel, N_pixel)) + + init_i = 0 + for i in range(N_freq*N_frame): + cur_len = np.sum(embed_mask_List[i]) + log_Frames[i] = embed(res.x[init_i:(init_i+cur_len)], embed_mask_List[i]).reshape((N_pixel, N_pixel)) + #Impose the prior mask in linear space for the output + Frames[i] = np.exp(log_Frames[i])*(embed_mask_List[i].reshape((N_pixel, N_pixel))) + init_i += cur_len + + plotcur(res.x, final=True) + + # Print stats + print ("time: %f s" % (tstop - tstart)) + print ("J: %f" % res.fun) + print (res.message) + + #Return Frames + outim = [[image.Image(Frames[i + j*N_frame].reshape(Prior.ydim, Prior.xdim), Prior.psize, + Prior.ra, Prior.dec, rf=Obsdata_Multifreq_List[j][i].rf, source=Prior.source, + mjd=Prior.mjd, pulse=Prior.pulse) for i in range(N_frame)] for j in range(N_freq)] + + return outim + + + + + + +################################################################################################## +# Plotting Functions +################################################################################################## + +def plot_im_List_Set(im_List_List, plot_log_amplitude=False, ipynb=False): + plt.ion() + plt.clf() + + Prior = im_List_List[0][0] + + xnum = len(im_List_List[0]) + ynum = len(im_List_List) + + for i in range(xnum*ynum): + plt.subplot(ynum, xnum, i+1) + im = im_List_List[(i-i%xnum)//xnum][i%xnum] + if plot_log_amplitude == False: + plt.imshow(im.imvec.reshape(im.ydim,im.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + else: + plt.imshow(np.log(im.imvec.reshape(im.ydim,im.xdim)), cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + xticks = ticks(im.xdim, im.psize/RADPERAS/1e-6) + yticks = ticks(im.ydim, im.psize/RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + if i == 0: + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + else: + plt.xlabel('') + plt.ylabel('') + plt.title('') + + plt.draw() + +def plot_im_List(im_List, plot_log_amplitude=False, ipynb=False): + + plt.ion() + plt.clf() + + Prior = im_List[0] + + for i in range(len(im_List)): + plt.subplot(1, len(im_List), i+1) + if plot_log_amplitude == False: + plt.imshow(im_List[i].imvec.reshape(Prior.ydim,Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + else: + plt.imshow(np.log(im_List[i].imvec.reshape(Prior.ydim,Prior.xdim)), cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6) + yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + if i == 0: + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + else: + plt.xlabel('') + plt.ylabel('') + plt.title('') + + + plt.draw() + +def plot_i_dynamic(im_List, Prior, nit, chi2, s, s_dynamic, ipynb=False): + + plt.ion() + plt.clf() + + for i in range(len(im_List)): + plt.subplot(1, len(im_List), i+1) + plt.imshow(im_List[i].reshape(Prior.ydim,Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6) + yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + if i == 0: + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + plt.title("step: %i $\chi^2$: %f $s$: %f $s_{t}$: %f" % (nit, chi2, s, s_dynamic), fontsize=20) + else: + plt.xlabel('') + plt.ylabel('') + plt.title('') + + + plt.draw() + +################################################################################################## +#BU blazar CLEAN file loading functions +################################################################################################## + +class MOJAVEHTMLParser(HTMLParser): + #standard overriding of the python HTMLParser to suit the format of the BU blazar library + dates = [] + mod_files = [] + uv_files = [] + + #checks for CLEAN files linked on the page by looking for file ending + def handle_starttag(self, tag, attrs): + for attr in attrs: + if len(attr) == 2: + attr1s = attr[1].strip().split('.') + if (len(attr1s) >1): + fileType = attr1s[1] + if fileType == 'icn': + self.mod_files.append(str(attr[1]).strip()) + if fileType == 'uvf': + self.uv_files.append(str(attr[1]).strip()) + + #extracts date information and standardizes format to DD + month string + YYYY = DDMONYYYY + def handle_data(self,data): + ds = data.split(' ') + lds = len(ds) + if ds[lds-1].isdigit(): + if lds == 3 or lds == 4: + day = str(ds[lds-3]) + month = str(ds[lds-2]) + year = str(ds[lds-1]) + if len(day) == 1: + day = '0' + day + month = month[0:3] + monthNum = str(list(calendar.month_abbr).index(month)) + if len(monthNum) == 1: + monthNum = '0' + monthNum + newDate = year +monthNum + day + self.dates.append(str(newDate)) + +class BlazarHTMLParser(HTMLParser): + #standard overriding of the python HTMLParser to suit the format of the BU blazar library + dates = [] + mod_files = [] + uv_files = [] + + #checks for CLEAN files linked on the page by looking for file ending + def handle_starttag(self, tag, attrs): + for attr in attrs: + if len(attr) == 2: + attr1s = attr[1].strip().split('.') + if (len(attr1s) >1): + fileType = attr1s[1] + if fileType == 'MOD' or fileType=='mod': + self.mod_files.append(str(attr[1]).strip()) + if fileType == 'UVP' or fileType=='uvp': + self.uv_files.append(str(attr[1]).strip()) + + #extracts date information and standardizes format to DD + month string + YYYY = DDMONYYYY + def handle_data(self,data): + ds = data.split(' ') + lds = len(ds) + if ds[lds-1].isdigit(): + if lds == 3 or lds == 4: + day = str(ds[lds-3]) + month = str(ds[lds-2]) + year = str(ds[lds-1]) + if len(day) == 1: + day = '0' + day + month = month[0:3] + monthNum = str(list(calendar.month_abbr).index(month)) + if len(monthNum) == 1: + monthNum = '0' + monthNum + newDate = year +monthNum + day + self.dates.append(str(newDate)) + +def generateMOJAVEdates(url, sourceName, path = './'): + #Creates a list of observation dates and .mod filenames for a particular source in the MOJAVE library + #Returns the filename of the output file + #url is the URL of the MOJAVE library page of files for a particular source + r = requests.get(url) + parser = MOJAVEHTMLParser() + parser.feed(r.text) + outputFileName = sourceName+'_dates_and_CLEAN_filenames.txt' + outputFile = open(outputFileName, 'w') + outputFile.write("#Observation dates, UV files, and CLEAN models obtained from " + url + '\n') + for i in range(len(parser.dates)): + outputFile.write(parser.dates[i]+','+parser.uv_files[i]+','+parser.mod_files[i]+'\n') + outputFile.close() + return outputFileName + +def generateCLEANdates(url, sourceName, path = './'): + #Creates a list of observation dates and .mod filenames for a particular source in the BU blazar library + #Returns the filename of the output file + #url is the URL of the BU blazar library page of files for a particular source + r = requests.get(url) + parser = BlazarHTMLParser() + parser.feed(r.text) + outputFileName = sourceName+'_dates_and_CLEAN_filenames.txt' + outputFile = open(outputFileName, 'w') + outputFile.write("#Observation dates, UV files, and CLEAN models obtained from " + url + '\n') + for i in range(len(parser.dates)): + outputFile.write(parser.dates[i]+','+parser.uv_files[i]+','+parser.mod_files[i]+'\n') + outputFile.close() + return outputFileName + +def sourceNameFromURL(url): + #returns a string containing the BU url designation for a particular source, i.e. "3c454" for 3c454.3 + urls = url.split('/') + lurls = len(urls) + sourcehtmls = urls[lurls-1].split('.') + sn = sourcehtmls[0] + return sn + +def downloadMOJAVEfiles(url, path = './'): + #Downloads data and image from MOJAVE library + #url is the URL of the MOJAVE page of files for a particular source + + sn = sourceNameFromURL(url) + #Find the base directory in the BU library from the url. This is a really hacky way to do this. + #source name in BU library + lsn = len(sn) + baseurl = url[:-(5+lsn)] + + #Make a new directory for the downloaded files + CLEANpath = path + '/' + sn + '_CLEAN_images' + if not os.path.exists(CLEANpath): + os.mkdir(CLEANpath) + + print ("Downloading CLEAN images to " + CLEANpath) + + UVpath = path + '/' + sn+"_uvf_files" + if not os.path.exists(UVpath): + os.mkdir(UVpath) + + print ("Downloading uvfits files to " + UVpath) + + #Generate the bookkeeping file and iterate on it, downloading files and saving them in a better format + #Files are saved in a new directory named after the source i.e. "3c454_CLEAN_files" + guideFileName = generateCLEANdates(url, sn, path=path) + observations = np.loadtxt(guideFileName, dtype = str, delimiter = ',', skiprows = 1) + for obs in observations: + date = obs[0] + UVbuFileName = obs[1] + CLEANbuFileName = obs[2] + UVurl = baseurl+UVbuFileName + CLEANurl = baseurl + CLEANbuFileName + # If it doesn't already exist, download the UV file + if os.path.isfile(UVpath+'/'+date+'_'+ sn+".uvf") == False: + print ("Downloading " + (UVpath+'/'+date+'_'+ sn+".uvf")) + response = requests.get(UVurl, stream=True) + with open(UVpath+'/'+date+'_'+ sn+".uvf",'wb') as handle: + handle.write(response.raw.read()) + else: + print ("Already Downloaded " + (UVpath+'/'+date+'_'+ sn+".uvf")) + + # If it doesn't already exist, download the CLEAN file + if os.path.isfile(CLEANpath+'/'+date+'_'+sn+".icn.fits.gz") == False: + print ("Downloading " + (CLEANpath+'/'+date+'_'+sn+".icn.fits.gz")) + response = requests.get(CLEANurl, stream=True) + with open(CLEANpath+'/'+date+'_'+sn+".icn.fits.gz",'wb') as handle: + for chunk in response.iter_content(chunk_size = 128): + handle.write(chunk) + else: + print ("Already Downloaded " + (CLEANpath+'/'+date+'_'+sn+".icn.fits.gz")) + +def downloadCLEANfiles(url, path = './'): + #Downloads all CLEAN files from a single source in the BU blazar library + #url is the URL of the BU blazar library page of files for a particular source + + sn = sourceNameFromURL(url) + #Find the base directory in the BU library from the url. This is a really hacky way to do this. + #source name in BU library + lsn = len(sn) + baseurl = url[:-(5+lsn)] + + #Make a new directory for the downloaded files + CLEANpath = path + '/' + sn + '_CLEAN_files' + if not os.path.exists(CLEANpath): + os.mkdir(CLEANpath) + + print ("Downloading CLEAN files to " + CLEANpath) + + UVpath = path + '/' + sn+"_UVP.gz_files" + if not os.path.exists(UVpath): + os.mkdir(UVpath) + + print ("Downloading uvfits files to " + UVpath) + + #Generate the bookkeeping file and iterate on it, downloading files and saving them in a better format + #Files are saved in a new directory named after the source i.e. "3c454_CLEAN_files" + guideFileName = generateCLEANdates(url, sn, path=path) + observations = np.loadtxt(guideFileName, dtype = str, delimiter = ',', skiprows = 1) + for obs in observations: + date = obs[0] + UVbuFileName = obs[1] + CLEANbuFileName = obs[2] + UVurl = baseurl+UVbuFileName + CLEANurl = baseurl + CLEANbuFileName + # If it doesn't already exist, download the UV file + if os.path.isfile(UVpath+'/'+date+'_'+ sn+"_UV.UVP.gz") == False: + print ("Downloading " + (UVpath+'/'+date+'_'+ sn+"_UV.UVP.gz")) + response = requests.get(UVurl, stream=True) + with open(UVpath+'/'+date+'_'+ sn+"_UV.UVP.gz",'wb') as handle: + handle.write(response.raw.read()) + else: + print ("Already Downloaded " + (UVpath+'/'+date+'_'+ sn+"_UV.UVP.gz")) + + # If it doesn't already exist, download the CLEAN file + if os.path.isfile(CLEANpath+'/'+date+'_'+sn+"_CLEAN.mod") == False: + print ("Downloading " + (CLEANpath+'/'+date+'_'+sn+"_CLEAN.mod")) + response = requests.get(CLEANurl, stream=True) + with open(CLEANpath+'/'+date+'_'+sn+"_CLEAN.mod",'wb') as handle: + for chunk in response.iter_content(chunk_size = 128): + handle.write(chunk) + else: + print ("Already Downloaded " + (CLEANpath+'/'+date+'_'+sn+"_CLEAN.mod")) + +def minDeltaMJD(inputMJD, im_List): + #returns the image whose MJD most closely matches the inputMJD + index = 0 + minDelta = 10000 + for i in range(len(im_List)): + oldDelta = minDelta + minDelta = min(minDelta, abs(inputMJD - im_List[i].mjd)) + if not(minDelta == oldDelta): + index = i + #return im_List[index].copy() + return index + +def ReadCLEAN(nameF, reference_obs, npix, fov=0, beamPar=(0,0,0.)): +#This should be able to load CLEAN Model data +#such as given here https://www.bu.edu/blazars/VLBA_GLAST/3c454.html +#nameF - name of the CLEAN Model file to load (3columns: Flux in Jy, r in mas, theta in deg) +#npix - number of pixels in one dimension +#fov - field of view (radians) +#beamPar - parameters of Gaussian beam, same as beamparams in image.blur_gauss + + #read data + #first remove multiple models in a single file + linesF = open(nameF).readlines() + DelTrig = 0 + TrigString = '!' + linesMax = 0 + for cou in range(len(linesF)): + if (linesF[cou].find(TrigString) != -1)*(cou>8): + open(nameF, 'w').writelines(linesF[:cou]) + break + + #skip headline + TableMOD = np.genfromtxt(nameF, skip_header=4) + ScaleR = 1. + FluxConst = 1. + Flux = FluxConst*TableMOD[:,0] + xPS = ScaleR*TableMOD[:,1]*np.cos(np.pi/2.-(np.pi/180.)*TableMOD[:,2])*(1.e3)*RADPERUAS #to radians + yPS = ScaleR*TableMOD[:,1]*np.sin(np.pi/2.-(np.pi/180.)*TableMOD[:,2])*(1.e3)*RADPERUAS #to radians + NumbPoints = np.shape(yPS)[0] + + #set image parameters + if fov==0: + MaxR = np.amax(TableMOD[:,1]) #in mas + fov = 1.*MaxR*(1.e3)*RADPERUAS + + image0 = np.zeros((int(npix),int(npix))) + im = image.Image(image0, fov/npix, 0., 0., rf=86e9) + + beamMaj = beamPar[0] + if beamMaj==0: + beamMaj = 4.*fov/npix + + beamMin = beamPar[1] + if beamMin==0: + beamMin = 4.*fov/npix + + beamTh = beamPar[2] + + sigma_maj = beamMaj / (2. * np.sqrt(2. * np.log(2.))) + sigma_min = beamMin / (2. * np.sqrt(2. * np.log(2.))) + cth = np.cos(beamTh) + sth = np.sin(beamTh) + xfov = im.xdim * im.psize + yfov = im.ydim * im.psize + xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0 + ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0 + + gauss = image0 + for couP in range(NumbPoints): + x = xPS[couP] + y = yPS[couP] + xM, yM = np.meshgrid(xlist, ylist) + gaussNew = np.exp(-((yM-y)*cth + (xM-x)*sth)**2/(2*sigma_maj**2) - ((xM-x)*cth - (yM-y)*sth)**2/(2.*sigma_min**2)) + gauss = gauss + gaussNew*Flux[couP] + + gauss /= (2.0*np.pi*sigma_maj*sigma_min)/(fov/npix)**2 #Normalize the Gaussian + gauss = (gauss > 0.)*gauss + 1.e-10 + + imageCLEAN = image.Image(gauss, fov/npix, reference_obs.ra, reference_obs.dec, rf=reference_obs.rf) + imageCLEAN.mjd = reference_obs.mjd + return imageCLEAN + +def Cont(imG): +#This is meant to create plots similar to the ones from +#https://www.bu.edu/blazars/VLBA_GLAST/3c454.html +#for the visual comparison + + import matplotlib.pyplot as plt + plt.figure() + Z = np.reshape(imG.imvec,(imG.xdim,imG.ydim)) + pov = imG.xdim*imG.psize + pov_mas = pov/(RADPERUAS*1.e3) + Zmax = np.amax(Z) + print(Zmax) + + levels = np.array((-0.00125*Zmax,0.00125*Zmax,0.0025*Zmax, 0.005*Zmax, 0.01*Zmax, + 0.02*Zmax, 0.04*Zmax, 0.08*Zmax, 0.16*Zmax, 0.32*Zmax, 0.64*Zmax)) + CS = plt.contour(Z, levels, + origin='lower', + linewidths=2, + extent=(-pov_mas/2., pov_mas/2., -pov_mas/2., pov_mas/2.)) + plt.show() + + + +def ReadSeriesImages(pathCLEAN, Obs, npix,fov,beamPar, obsNumbList): + + listCLEAN = os.listdir(pathCLEAN) + listCLEAN = sorted(listCLEAN) + listCLEAN = list( listCLEAN[i] for i in obsNumbList ) + + im_List = [None]*len(listCLEAN) + for cou in range(len(listCLEAN)): + nameF = pathCLEAN+listCLEAN[cou] + print(nameF) + im_List[cou] = ReadCLEAN(nameF,Obs[cou], npix, fov, beamPar) + + return im_List + +def SaveSeriesImages(pathCLEAN, im_List, sourceName, outputDirectory='default'): + #saves a list of images returned by ReadSeriesImages according to the naming convention in pathCLEAN and the source name + if outputDirectory == 'default': + outputDirectory = sourceName+'_READ_CLEAN_files' + if not os.path.exists(outputDirectory): + os.mkdir(outputDirectory) + outputSuffix = '_'+sourceName+'_READ_CLEAN.txt' + + listCLEAN = os.listdir(pathCLEAN) + listCLEAN = sorted(listCLEAN) + datesCLEAN = [filename.split('_')[0] for filename in listCLEAN] + for i in range(len(datesCLEAN)): + outputName = outputDirectory+'/'+datesCLEAN[i] + outputSuffix + im_List[i].save_txt(outputName) + +def LoadSeriesImages(sourceName): + #loads a list of images saved according to the BU Blazar library convention + dirName = sourceName+'_READ_CLEAN_files' + fileList = os.listdir(dirName) + filenameList = [dirName+'/'+filename for filename in fileList] + im_List = [image.load_txt(filename) for filename in filenameList] + return im_List diff --git a/imaging/imager_utils.py b/imaging/imager_utils.py new file mode 100644 index 00000000..5f565c07 --- /dev/null +++ b/imaging/imager_utils.py @@ -0,0 +1,4397 @@ +# imager_utils.py +# General imager functions for total intensity VLBI data +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import time +import numpy as np +import scipy.optimize as opt +import matplotlib.pyplot as plt + +import ehtim.image as image +import ehtim.observing.obs_helpers as obsh +import ehtim.const_def as ehc + + +################################################################################################## +# Constants & Definitions +################################################################################################## + +NORM_REGULARIZER = False # ANDREW TODO change this default in the future + +MAXLS = 100 # maximum number of line searches in L-BFGS-B +NHIST = 100 # number of steps to store for hessian approx +MAXIT = 100 # maximum number of iterations +STOP = 1.e-8 # convergence criterion + +DATATERMS = ['vis', 'bs', 'amp', 'cphase', 'cphase_diag', + 'camp', 'logcamp', 'logcamp_diag', 'logamp'] +REGULARIZERS = ['gs', 'tv', 'tvlog','tv2', 'tv2log','l1w', 'lA', 'patch', 'simple', 'compact', 'compact2', 'rgauss'] + +nit = 0 # global variable to track the iteration number in the plotting callback + +################################################################################################## +# Total Intensity Imager +################################################################################################## + + +def imager_func(Obsdata, InitIm, Prior, flux, + d1='vis', d2=False, d3=False, + alpha_d1=100, alpha_d2=100, alpha_d3=100, + s1='simple', s2=False, s3=False, + alpha_s1=1, alpha_s2=1, alpha_s3=1, + alpha_flux=500, alpha_cm=500, + **kwargs): + """Run a general interferometric imager. Only works directly on the image's primary polarization. + + Args: + Obsdata (Obsdata): The Obsdata object with VLBI data + InitIm (Image): The Image object with the initial image for the minimization + Prior (Image): The Image object with the prior image + flux (float): The total flux of the output image in Jy + + d1 (str): The first data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag' + d2 (str): The second data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag' + d3 (str): The third data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag' + + s1 (str): The first regularizer; options are 'simple', 'gs', 'tv', 'tv2', 'l1', 'patch','compact','compact2','rgauss' + s2 (str): The second regularizer; options are 'simple', 'gs', 'tv', 'tv2','l1', 'patch','compact','compact2','rgauss' + s3 (str): The third regularizer; options are 'simple', 'gs', 'tv', 'tv2','l1', 'patch','compact','compact2','rgauss' + + alpha_d1 (float): The first data term weighting + alpha_d2 (float): The second data term weighting + alpha_d2 (float): The third data term weighting + + alpha_s1 (float): The first regularizer term weighting + alpha_s2 (float): The second regularizer term weighting + alpha_s3 (float): The third regularizer term weighting + + alpha_flux (float): The weighting for the total flux constraint + alpha_cm (float): The weighting for the center of mass constraint + + maxit (int): Maximum number of minimizer iterations + stop (float): The convergence criterion + clipfloor (float): The Jy/pixel level above which prior image pixels are varied + + grads (bool): If True, analytic gradients are used + logim (bool): If True, uses I = exp(I') change of variables + norm_reg (bool): If True, normalizes regularizer terms + norm_init (bool): If True, normalizes initial image to given total flux + show_updates (bool): If True, displays the progress of the minimizer + + weighting (str): 'natural' or 'uniform' + debias (bool): if True then apply debiasing to amplitudes/closure amplitudes + systematic_noise (float): a fractional systematic noise tolerance to add to thermal sigmas + snrcut (float): a snr cutoff for including data in the chi^2 sum + beam_size (float): beam size in radians for normalizing the regularizers + + maxset (bool): if True, use maximal set instead of minimal for closure quantities + systematic_cphase_noise (float): a value in degrees to add to the closure phase sigmas + cp_uv_min (float): flag baselines shorter than this before forming closure quantities + + ttype (str): The Fourier transform type; options are 'fast', 'direct', 'nfft' + fft_pad_factor (float): The FFT will pre-pad the image by this factor x the original size + order (int): Interpolation order for sampling the FFT + conv_func (str): The convolving function for gridding; options are 'gaussian', 'pill', and 'cubic' + p_rad (int): The pixel radius for the convolving function in gridding for FFTs + + Returns: + Image: Image object with result + """ + + # some kwarg default values + maxit = kwargs.get('maxit', MAXIT) + stop = kwargs.get('stop', STOP) + clipfloor = kwargs.get('clipfloor', 0) + ttype = kwargs.get('ttype', 'direct') + + grads = kwargs.get('grads', True) + logim = kwargs.get('logim', True) + norm_init = kwargs.get('norm_init', False) + show_updates = kwargs.get('show_updates', True) + + beam_size = kwargs.get('beam_size', Obsdata.res()) + kwargs['beam_size'] = beam_size + + # Make sure data and regularizer options are ok + if not d1 and not d2: + raise Exception("Must have at least one data term!") + if not s1 and not s2: + raise Exception("Must have at least one regularizer term!") + if (not ((d1 in DATATERMS) or d1 is False)) or (not ((d2 in DATATERMS) or d2 is False)): + raise Exception("Invalid data term: valid data terms are: " + ' '.join(DATATERMS)) + if (not ((s1 in REGULARIZERS) or s1 is False)) or (not ((s2 in REGULARIZERS) or s2 is False)): + raise Exception("Invalid regularizer: valid regularizers are: " + ' '.join(REGULARIZERS)) + if (Prior.psize != InitIm.psize) or (Prior.xdim != InitIm.xdim) or (Prior.ydim != InitIm.ydim): + raise Exception("Initial image does not match dimensions of the prior image!") + if (InitIm.polrep != Prior.polrep): + raise Exception( + "Initial image pol. representation does not match pol. representation of the prior image!") + if (logim and Prior.pol_prim in ['Q', 'U', 'V']): + raise Exception( + "Cannot image Stokes Q,U,or V with log image transformation! Set logim=False in imager_func") + + pol = Prior.pol_prim + print("Generating %s image..." % pol) + + # Catch scale and dimension problems + imsize = np.max([Prior.xdim, Prior.ydim]) * Prior.psize + uvmax = 1.0/Prior.psize + uvmin = 1.0/imsize + uvdists = Obsdata.unpack('uvdist')['uvdist'] + maxbl = np.max(uvdists) + minbl = np.max(uvdists[uvdists > 0]) + maxamp = np.max(np.abs(Obsdata.unpack('amp')['amp'])) + + if uvmax < maxbl: + print("Warning! Pixel Spacing is larger than smallest spatial wavelength!") + if uvmin > minbl: + print("Warning! Field of View is smaller than largest nonzero spatial wavelength!") + if flux > 1.2*maxamp: + print("Warning! Specified flux is > 120% of maximum visibility amplitude!") + if flux < .8*maxamp: + print("Warning! Specified flux is < 80% of maximum visibility amplitude!") + + # Define embedding mask + embed_mask = Prior.imvec > clipfloor + + # Normalize prior image to total flux and limit imager range to prior values > clipfloor + if (not norm_init): + nprior = Prior.imvec[embed_mask] + ninit = InitIm.imvec[embed_mask] + else: + nprior = (flux * Prior.imvec / np.sum((Prior.imvec)[embed_mask]))[embed_mask] + ninit = (flux * InitIm.imvec / np.sum((InitIm.imvec)[embed_mask]))[embed_mask] + + if len(nprior) == 0: + raise Exception("clipfloor too large: all prior pixels have been clipped!") + + # Get data and fourier matrices for the data terms + (data1, sigma1, A1) = chisqdata(Obsdata, Prior, embed_mask, d1, pol=pol, **kwargs) + (data2, sigma2, A2) = chisqdata(Obsdata, Prior, embed_mask, d2, pol=pol, **kwargs) + (data3, sigma3, A3) = chisqdata(Obsdata, Prior, embed_mask, d3, pol=pol, **kwargs) + + # Define the chi^2 and chi^2 gradient + def chisq1(imvec): + return chisq(imvec, A1, data1, sigma1, d1, ttype=ttype, mask=embed_mask) + + def chisq1grad(imvec): + c = chisqgrad(imvec, A1, data1, sigma1, d1, ttype=ttype, mask=embed_mask) + return c + + def chisq2(imvec): + return chisq(imvec, A2, data2, sigma2, d2, ttype=ttype, mask=embed_mask) + + def chisq2grad(imvec): + c = chisqgrad(imvec, A2, data2, sigma2, d2, ttype=ttype, mask=embed_mask) + return c + + def chisq3(imvec): + return chisq(imvec, A3, data3, sigma3, d3, ttype=ttype, mask=embed_mask) + + def chisq3grad(imvec): + c = chisqgrad(imvec, A3, data3, sigma3, d3, ttype=ttype, mask=embed_mask) + return c + + # Define the regularizer and regularizer gradient + def reg1(imvec): + return regularizer(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s1, **kwargs) + + def reg1grad(imvec): + return regularizergrad(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s1, **kwargs) + + def reg2(imvec): + return regularizer(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s2, **kwargs) + + def reg2grad(imvec): + return regularizergrad(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s2, **kwargs) + + def reg3(imvec): + return regularizer(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s3, **kwargs) + + def reg3grad(imvec): + return regularizergrad(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, s3, **kwargs) + + # Define constraint functions + def flux_constraint(imvec): + return regularizer(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, "flux", **kwargs) + + def flux_constraint_grad(imvec): + return regularizergrad(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, "flux", **kwargs) + + def cm_constraint(imvec): + return regularizer(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, "cm", **kwargs) + + def cm_constraint_grad(imvec): + return regularizergrad(imvec, nprior, embed_mask, flux, Prior.xdim, Prior.ydim, Prior.psize, "cm", **kwargs) + + # Define the objective function and gradient + def objfunc(imvec): + if logim: + imvec = np.exp(imvec) + + datterm = alpha_d1 * (chisq1(imvec) - 1) + alpha_d2 * \ + (chisq2(imvec) - 1) + alpha_d3 * (chisq3(imvec) - 1) + regterm = alpha_s1 * reg1(imvec) + alpha_s2 * reg2(imvec) + alpha_s3 * reg3(imvec) + conterm = alpha_flux * flux_constraint(imvec) + alpha_cm * cm_constraint(imvec) + + return datterm + regterm + conterm + + def objgrad(imvec): + if logim: + imvec = np.exp(imvec) + + datterm = alpha_d1 * chisq1grad(imvec) + alpha_d2 * \ + chisq2grad(imvec) + alpha_d3 * chisq3grad(imvec) + regterm = alpha_s1 * reg1grad(imvec) + alpha_s2 * \ + reg2grad(imvec) + alpha_s3 * reg3grad(imvec) + conterm = alpha_flux * flux_constraint_grad(imvec) + alpha_cm * cm_constraint_grad(imvec) + + grad = datterm + regterm + conterm + + # chain rule term for change of variables + if logim: + grad *= imvec + + return grad + + # Define plotting function for each iteration + global nit + nit = 0 + + def plotcur(im_step): + global nit + if logim: + im_step = np.exp(im_step) + if show_updates: + chi2_1 = chisq1(im_step) + chi2_2 = chisq2(im_step) + chi2_3 = chisq3(im_step) + s_1 = reg1(im_step) + s_2 = reg2(im_step) + s_3 = reg3(im_step) + if np.any(np.invert(embed_mask)): + im_step = embed(im_step, embed_mask) + plot_i(im_step, Prior, nit, {d1: chi2_1, d2: chi2_2, d3: chi2_3}, pol=pol) + print("i: %d chi2_1: %0.2f chi2_2: %0.2f chi2_3: %0.2f s_1: %0.2f s_2: %0.2f s_3: %0.2f" % ( + nit, chi2_1, chi2_2, chi2_3, s_1, s_2, s_3)) + nit += 1 + + # Generate and the initial image + if logim: + xinit = np.log(ninit) + else: + xinit = ninit + + # Print stats + print("Initial S_1: %f S_2: %f S_3: %f" % (reg1(ninit), reg2(ninit), reg3(ninit))) + print("Initial Chi^2_1: %f Chi^2_2: %f Chi^2_3: %f" % + (chisq1(ninit), chisq2(ninit), chisq3(ninit))) + print("Initial Objective Function: %f" % (objfunc(xinit))) + + if d1 in DATATERMS: + print("Total Data 1: ", (len(data1))) + if d2 in DATATERMS: + print("Total Data 2: ", (len(data2))) + if d3 in DATATERMS: + print("Total Data 3: ", (len(data3))) + + print("Total Pixel #: ", (len(Prior.imvec))) + print("Clipped Pixel #: ", (len(ninit))) + print() + plotcur(xinit) + + # Minimize + optdict = {'maxiter': maxit, 'ftol': stop, 'maxcor': NHIST, + 'gtol': stop, 'maxls': MAXLS} # minimizer dict params + tstart = time.time() + if grads: + res = opt.minimize(objfunc, xinit, method='L-BFGS-B', jac=objgrad, + options=optdict, callback=plotcur) + else: + res = opt.minimize(objfunc, xinit, method='L-BFGS-B', + options=optdict, callback=plotcur) + + tstop = time.time() + + # Format output + out = res.x + if logim: + out = np.exp(res.x) + if np.any(np.invert(embed_mask)): + out = embed(out, embed_mask) + + outim = image.Image(out.reshape(Prior.ydim, Prior.xdim), + Prior.psize, Prior.ra, Prior.dec, + rf=Prior.rf, source=Prior.source, + polrep=Prior.polrep, pol_prim=pol, + mjd=Prior.mjd, time=Prior.time, pulse=Prior.pulse) + + # copy over other polarizations + outim.copy_pol_images(InitIm) + + # Print stats + print("time: %f s" % (tstop - tstart)) + print("J: %f" % res.fun) + print("Final Chi^2_1: %f Chi^2_2: %f Chi^2_3: %f" % + (chisq1(out[embed_mask]), chisq2(out[embed_mask]), chisq3(out[embed_mask]))) + print(res.message) + + # Return Image object + return outim + +################################################################################################## +# Wrapper Functions +################################################################################################## + + +def chisq(imvec, A, data, sigma, dtype, ttype='direct', mask=None): + """return the chi^2 for the appropriate dtype + """ + + if mask is None: + mask = [] + chisq = 1 + if dtype not in DATATERMS: + return chisq + + if ttype not in ['fast', 'direct', 'nfft']: + raise Exception("Possible ttype values are 'fast', 'direct'!, 'nfft!'") + + if ttype == 'direct': + if dtype == 'vis': + chisq = chisq_vis(imvec, A, data, sigma) + elif dtype == 'amp': + chisq = chisq_amp(imvec, A, data, sigma) + elif dtype == 'logamp': + chisq = chisq_logamp(imvec, A, data, sigma) + elif dtype == 'bs': + chisq = chisq_bs(imvec, A, data, sigma) + elif dtype == 'cphase': + chisq = chisq_cphase(imvec, A, data, sigma) + elif dtype == 'cphase_diag': + chisq = chisq_cphase_diag(imvec, A, data, sigma) + elif dtype == 'camp': + chisq = chisq_camp(imvec, A, data, sigma) + elif dtype == 'logcamp': + chisq = chisq_logcamp(imvec, A, data, sigma) + elif dtype == 'logcamp_diag': + chisq = chisq_logcamp_diag(imvec, A, data, sigma) + + elif ttype == 'fast': + if len(mask) > 0 and np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + + if dtype not in ['cphase_diag', 'logcamp_diag']: + vis_arr = obsh.fft_imvec(imvec, A[0]) + + if dtype == 'vis': + chisq = chisq_vis_fft(vis_arr, A, data, sigma) + elif dtype == 'amp': + chisq = chisq_amp_fft(vis_arr, A, data, sigma) + elif dtype == 'logamp': + chisq = chisq_logamp_fft(vis_arr, A, data, sigma) + elif dtype == 'bs': + chisq = chisq_bs_fft(vis_arr, A, data, sigma) + elif dtype == 'cphase': + chisq = chisq_cphase_fft(vis_arr, A, data, sigma) + elif dtype == 'cphase_diag': + chisq = chisq_cphase_diag_fft(imvec, A, data, sigma) + elif dtype == 'camp': + chisq = chisq_camp_fft(vis_arr, A, data, sigma) + elif dtype == 'logcamp': + chisq = chisq_logcamp_fft(vis_arr, A, data, sigma) + elif dtype == 'logcamp_diag': + chisq = chisq_logcamp_diag_fft(imvec, A, data, sigma) + + elif ttype == 'nfft': + if len(mask) > 0 and np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + + if dtype == 'vis': + chisq = chisq_vis_nfft(imvec, A, data, sigma) + elif dtype == 'amp': + chisq = chisq_amp_nfft(imvec, A, data, sigma) + elif dtype == 'logamp': + chisq = chisq_logamp_nfft(imvec, A, data, sigma) + elif dtype == 'bs': + chisq = chisq_bs_nfft(imvec, A, data, sigma) + elif dtype == 'cphase': + chisq = chisq_cphase_nfft(imvec, A, data, sigma) + elif dtype == 'cphase_diag': + chisq = chisq_cphase_diag_nfft(imvec, A, data, sigma) + elif dtype == 'camp': + chisq = chisq_camp_nfft(imvec, A, data, sigma) + elif dtype == 'logcamp': + chisq = chisq_logcamp_nfft(imvec, A, data, sigma) + elif dtype == 'logcamp_diag': + chisq = chisq_logcamp_diag_nfft(imvec, A, data, sigma) + + return chisq + + +def chisqgrad(imvec, A, data, sigma, dtype, ttype='direct', mask=None): + """return the chi^2 gradient for the appropriate dtype + """ + + if mask is None: + mask = [] + chisqgrad = np.zeros(len(imvec)) + if dtype not in DATATERMS: + return chisqgrad + + if ttype not in ['fast', 'direct', 'nfft']: + raise Exception("Possible ttype values are 'fast', 'direct', 'nfft'!") + + if ttype == 'direct': + if dtype == 'vis': + chisqgrad = chisqgrad_vis(imvec, A, data, sigma) + elif dtype == 'amp': + chisqgrad = chisqgrad_amp(imvec, A, data, sigma) + elif dtype == 'logamp': + chisqgrad = chisqgrad_logamp(imvec, A, data, sigma) + elif dtype == 'bs': + chisqgrad = chisqgrad_bs(imvec, A, data, sigma) + elif dtype == 'cphase': + chisqgrad = chisqgrad_cphase(imvec, A, data, sigma) + elif dtype == 'cphase_diag': + chisqgrad = chisqgrad_cphase_diag(imvec, A, data, sigma) + elif dtype == 'camp': + chisqgrad = chisqgrad_camp(imvec, A, data, sigma) + elif dtype == 'logcamp': + chisqgrad = chisqgrad_logcamp(imvec, A, data, sigma) + elif dtype == 'logcamp_diag': + chisqgrad = chisqgrad_logcamp_diag(imvec, A, data, sigma) + + elif ttype == 'fast': + if len(mask) > 0 and np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + + if dtype not in ['cphase_diag', 'logcamp_diag']: + vis_arr = obsh.fft_imvec(imvec, A[0]) + + if dtype == 'vis': + chisqgrad = chisqgrad_vis_fft(vis_arr, A, data, sigma) + elif dtype == 'amp': + chisqgrad = chisqgrad_amp_fft(vis_arr, A, data, sigma) + elif dtype == 'logamp': + chisqgrad = chisqgrad_logamp_fft(vis_arr, A, data, sigma) + elif dtype == 'bs': + chisqgrad = chisqgrad_bs_fft(vis_arr, A, data, sigma) + elif dtype == 'cphase': + chisqgrad = chisqgrad_cphase_fft(vis_arr, A, data, sigma) + elif dtype == 'cphase_diag': + chisqgrad = chisqgrad_cphase_diag_fft(imvec, A, data, sigma) + elif dtype == 'camp': + chisqgrad = chisqgrad_camp_fft(vis_arr, A, data, sigma) + elif dtype == 'logcamp': + chisqgrad = chisqgrad_logcamp_fft(vis_arr, A, data, sigma) + elif dtype == 'logcamp_diag': + chisqgrad = chisqgrad_logcamp_diag_fft(imvec, A, data, sigma) + + if len(mask) > 0 and np.any(np.invert(mask)): + chisqgrad = chisqgrad[mask] + + elif ttype == 'nfft': + if len(mask) > 0 and np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + + if dtype == 'vis': + chisqgrad = chisqgrad_vis_nfft(imvec, A, data, sigma) + elif dtype == 'amp': + chisqgrad = chisqgrad_amp_nfft(imvec, A, data, sigma) + elif dtype == 'logamp': + chisqgrad = chisqgrad_logamp_nfft(imvec, A, data, sigma) + elif dtype == 'bs': + chisqgrad = chisqgrad_bs_nfft(imvec, A, data, sigma) + elif dtype == 'cphase': + chisqgrad = chisqgrad_cphase_nfft(imvec, A, data, sigma) + elif dtype == 'cphase_diag': + chisqgrad = chisqgrad_cphase_diag_nfft(imvec, A, data, sigma) + elif dtype == 'camp': + chisqgrad = chisqgrad_camp_nfft(imvec, A, data, sigma) + elif dtype == 'logcamp': + chisqgrad = chisqgrad_logcamp_nfft(imvec, A, data, sigma) + elif dtype == 'logcamp_diag': + chisqgrad = chisqgrad_logcamp_diag_nfft(imvec, A, data, sigma) + + if len(mask) > 0 and np.any(np.invert(mask)): + chisqgrad = chisqgrad[mask] + + return chisqgrad + + +def regularizer(imvec, nprior, mask, flux, xdim, ydim, psize, stype, **kwargs): + """return the regularizer value + """ + + norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER) + beam_size = kwargs.get('beam_size', psize) + alpha_A = kwargs.get('alpha_A', 1.0) + epsilon = kwargs.get('epsilon_tv', 0.) + + if stype == "flux": + s = -sflux(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "cm": + s = -scm(imvec, xdim, ydim, psize, flux, mask, norm_reg=norm_reg, beam_size=beam_size) + elif stype == "simple": + s = -ssimple(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "l1": + s = -sl1(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "l1w": + s = -sl1w(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "lA": + s = -slA(imvec, nprior, psize, flux, beam_size, alpha_A, norm_reg) + elif stype == "gs": + s = -sgs(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "patch": + s = -spatch(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "tv": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + s = -stv(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg,beam_size=beam_size, epsilon=epsilon) + elif stype == "tvlog": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, clipfloor=epsilon, randomfloor=True) + npix = xdim*ydim + logvec = np.log(imvec) + logflux = npix*np.abs(np.log(flux/npix)) + s = -stv(logvec, xdim, ydim, psize, logflux, norm_reg=norm_reg,beam_size=beam_size, epsilon=epsilon) + elif stype == "tv2": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + s = -stv2(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg, beam_size=beam_size) + elif stype == "tv2log": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + npix = xdim*ydim + logvec = np.log(imvec) + logflux = npix*np.abs(np.log(flux/npix)) + s = -stv2(logvec, xdim, ydim, psize, logflux, norm_reg=norm_reg, beam_size=beam_size) + elif stype == "compact": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + s = -scompact(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg) + elif stype == "compact2": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + s = -scompact2(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg) + elif stype == "rgauss": + # additional key words for gaussian regularizer + major = kwargs.get('major', 1.0) + minor = kwargs.get('minor', 1.0) + PA = kwargs.get('PA', 1.0) + + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + s = -sgauss(imvec, xdim, ydim, psize, major=major, minor=minor, PA=PA) + else: + s = 0 + + return s + + +def regularizergrad(imvec, nprior, mask, flux, xdim, ydim, psize, stype, **kwargs): + """return the regularizer gradient + """ + + norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER) + beam_size = kwargs.get('beam_size', psize) + alpha_A = kwargs.get('alpha_A', 1.0) + epsilon = kwargs.get('epsilon_tv', 0.) + + if stype == "flux": + s = -sfluxgrad(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "cm": + s = -scmgrad(imvec, xdim, ydim, psize, flux, mask, norm_reg=norm_reg, beam_size=beam_size) + elif stype == "simple": + s = -ssimplegrad(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "l1": + s = -sl1grad(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "l1w": + s = -sl1wgrad(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "lA": + s = -slAgrad(imvec, nprior, psize, flux, beam_size, alpha_A, norm_reg) + elif stype == "gs": + s = -sgsgrad(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "patch": + s = -spatchgrad(imvec, nprior, flux, norm_reg=norm_reg) + elif stype == "tv": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + s = -stvgrad(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg, + beam_size=beam_size, epsilon=epsilon) + s = s[mask] + elif stype == "tvlog": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, clipfloor=epsilon, randomfloor=True) + npix = xdim*ydim + logvec = np.log(imvec) + logflux = npix*np.abs(np.log(flux/npix)) + s = -stvgrad(logvec, xdim, ydim, psize, logflux, norm_reg=norm_reg,beam_size=beam_size, epsilon=epsilon) + s = s / imvec + s = s[mask] + elif stype == "tv2": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + s = -stv2grad(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg, beam_size=beam_size) + s = s[mask] + elif stype == "tv2log": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + npix = xdim*ydim + logvec = np.log(imvec) + logflux = npix*np.abs(np.log(flux/npix)) + s = -stv2grad(logvec, xdim, ydim, psize, logflux, norm_reg=norm_reg, beam_size=beam_size) + s = s / imvec + s = s[mask] + elif stype == "compact": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + s = -scompactgrad(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg) + s = s[mask] + elif stype == "compact2": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + s = -scompact2grad(imvec, xdim, ydim, psize, flux, norm_reg=norm_reg) + s = s[mask] + elif stype == "rgauss": + # additional key words for gaussian regularizer + major = kwargs.get('major', 1.0) + minor = kwargs.get('minor', 1.0) + PA = kwargs.get('PA', 1.0) + + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=True) + s = -sgauss_grad(imvec, xdim, ydim, psize, major, minor, PA) + s = s[mask] + else: + s = np.zeros(len(imvec)) + + return s + + +def chisqdata(Obsdata, Prior, mask, dtype, pol='I', **kwargs): + """Return the data, sigma, and matrices for the appropriate dtype + """ + + ttype = kwargs.get('ttype', 'direct') + (data, sigma, A) = (False, False, False) + if ttype not in ['fast', 'direct', 'nfft']: + raise Exception("Possible ttype values are 'fast', 'direct', 'nfft'!") + + if ttype == 'direct': + if dtype == 'vis': + (data, sigma, A) = chisqdata_vis(Obsdata, Prior, mask, pol=pol, **kwargs) + elif dtype == 'amp' or dtype == 'logamp': + (data, sigma, A) = chisqdata_amp(Obsdata, Prior, mask, pol=pol, **kwargs) + elif dtype == 'bs': + (data, sigma, A) = chisqdata_bs(Obsdata, Prior, mask, pol=pol, **kwargs) + elif dtype == 'cphase': + (data, sigma, A) = chisqdata_cphase(Obsdata, Prior, mask, pol=pol, **kwargs) + elif dtype == 'cphase_diag': + (data, sigma, A) = chisqdata_cphase_diag(Obsdata, Prior, mask, pol=pol, **kwargs) + elif dtype == 'camp': + (data, sigma, A) = chisqdata_camp(Obsdata, Prior, mask, pol=pol, **kwargs) + elif dtype == 'logcamp': + (data, sigma, A) = chisqdata_logcamp(Obsdata, Prior, mask, pol=pol, **kwargs) + elif dtype == 'logcamp_diag': + (data, sigma, A) = chisqdata_logcamp_diag(Obsdata, Prior, mask, pol=pol, **kwargs) + + elif ttype == 'fast': + if dtype == 'vis': + (data, sigma, A) = chisqdata_vis_fft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'amp' or dtype == 'logamp': + (data, sigma, A) = chisqdata_amp_fft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'bs': + (data, sigma, A) = chisqdata_bs_fft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'cphase': + (data, sigma, A) = chisqdata_cphase_fft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'cphase_diag': + (data, sigma, A) = chisqdata_cphase_diag_fft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'camp': + (data, sigma, A) = chisqdata_camp_fft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'logcamp': + (data, sigma, A) = chisqdata_logcamp_fft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'logcamp_diag': + (data, sigma, A) = chisqdata_logcamp_diag_fft(Obsdata, Prior, pol=pol, **kwargs) + + elif ttype == 'nfft': + if dtype == 'vis': + (data, sigma, A) = chisqdata_vis_nfft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'amp' or dtype == 'logamp': + (data, sigma, A) = chisqdata_amp_nfft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'bs': + (data, sigma, A) = chisqdata_bs_nfft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'cphase': + (data, sigma, A) = chisqdata_cphase_nfft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'cphase_diag': + (data, sigma, A) = chisqdata_cphase_diag_nfft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'camp': + (data, sigma, A) = chisqdata_camp_nfft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'logcamp': + (data, sigma, A) = chisqdata_logcamp_nfft(Obsdata, Prior, pol=pol, **kwargs) + elif dtype == 'logcamp_diag': + (data, sigma, A) = chisqdata_logcamp_diag_nfft(Obsdata, Prior, pol=pol, **kwargs) + + return (data, sigma, A) + + +################################################################################################## +# DFT Chi-squared and Gradient Functions +################################################################################################## + +def chisq_vis(imvec, Amatrix, vis, sigma): + """Visibility chi-squared""" + + samples = np.dot(Amatrix, imvec) + chisq = np.sum(np.abs((samples-vis)/sigma)**2)/(2*len(vis)) + return chisq + +def chisqgrad_vis(imvec, Amatrix, vis, sigma): + """The gradient of the visibility chi-squared""" + + samples = np.dot(Amatrix, imvec) + wdiff = (vis - samples)/(sigma**2) + + out = -np.real(np.dot(Amatrix.conj().T, wdiff))/len(vis) + return out + + +def chisq_amp(imvec, A, amp, sigma): + """Visibility Amplitudes (normalized) chi-squared""" + + amp_samples = np.abs(np.dot(A, imvec)) + return np.sum(np.abs((amp - amp_samples)/sigma)**2)/len(amp) + + +def chisqgrad_amp(imvec, A, amp, sigma): + """The gradient of the amplitude chi-squared""" + + i1 = np.dot(A, imvec) + amp_samples = np.abs(i1) + + pp = ((amp - amp_samples) * amp_samples) / (sigma**2) / i1 + out = (-2.0/len(amp)) * np.real(np.dot(pp, A)) + return out + + +def chisq_bs(imvec, Amatrices, bis, sigma): + """Bispectrum chi-squared""" + + bisamples = (np.dot(Amatrices[0], imvec) * + np.dot(Amatrices[1], imvec) * + np.dot(Amatrices[2], imvec)) + chisq = np.sum(np.abs(((bis - bisamples)/sigma))**2)/(2.*len(bis)) + return chisq + + +def chisqgrad_bs(imvec, Amatrices, bis, sigma): + """The gradient of the bispectrum chi-squared""" + + bisamples = (np.dot(Amatrices[0], imvec) * + np.dot(Amatrices[1], imvec) * + np.dot(Amatrices[2], imvec)) + + wdiff = ((bis - bisamples).conj())/(sigma**2) + pt1 = wdiff * np.dot(Amatrices[1], imvec) * np.dot(Amatrices[2], imvec) + pt2 = wdiff * np.dot(Amatrices[0], imvec) * np.dot(Amatrices[2], imvec) + pt3 = wdiff * np.dot(Amatrices[0], imvec) * np.dot(Amatrices[1], imvec) + out = (np.dot(pt1, Amatrices[0]) + + np.dot(pt2, Amatrices[1]) + + np.dot(pt3, Amatrices[2])) + + out = -np.real(out) / len(bis) + return out + + +def chisq_cphase(imvec, Amatrices, clphase, sigma): + """Closure Phases (normalized) chi-squared""" + clphase = clphase * ehc.DEGREE + sigma = sigma * ehc.DEGREE + + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + clphase_samples = np.angle(i1 * i2 * i3) + + chisq = (2.0/len(clphase)) * np.sum((1.0 - np.cos(clphase-clphase_samples))/(sigma**2)) + return chisq + + +def chisqgrad_cphase(imvec, Amatrices, clphase, sigma): + """The gradient of the closure phase chi-squared""" + clphase = clphase * ehc.DEGREE + sigma = sigma * ehc.DEGREE + + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + clphase_samples = np.angle(i1 * i2 * i3) + + pref = np.sin(clphase - clphase_samples)/(sigma**2) + pt1 = pref/i1 + pt2 = pref/i2 + pt3 = pref/i3 + out = np.dot(pt1, Amatrices[0]) + np.dot(pt2, Amatrices[1]) + np.dot(pt3, Amatrices[2]) + out = (-2.0/len(clphase)) * np.imag(out) + return out + + +def chisq_cphase_diag(imvec, Amatrices, clphase_diag, sigma): + """Diagonalized closure phases (normalized) chi-squared""" + clphase_diag = np.concatenate(clphase_diag) * ehc.DEGREE + sigma = np.concatenate(sigma) * ehc.DEGREE + + A3_diag = Amatrices[0] + tform_mats = Amatrices[1] + + clphase_diag_samples = [] + for iA, A3 in enumerate(A3_diag): + clphase_samples = np.angle(np.dot(A3[0], imvec) * + np.dot(A3[1], imvec) * + np.dot(A3[2], imvec)) + clphase_diag_samples.append(np.dot(tform_mats[iA], clphase_samples)) + clphase_diag_samples = np.concatenate(clphase_diag_samples) + + chisq = np.sum((1.0 - np.cos(clphase_diag-clphase_diag_samples))/(sigma**2)) + chisq *= (2.0/len(clphase_diag)) + return chisq + + +def chisqgrad_cphase_diag(imvec, Amatrices, clphase_diag, sigma): + """The gradient of the diagonalized closure phase chi-squared""" + clphase_diag = clphase_diag * ehc.DEGREE + sigma = sigma * ehc.DEGREE + + A3_diag = Amatrices[0] + tform_mats = Amatrices[1] + + deriv = np.zeros_like(imvec) + for iA, A3 in enumerate(A3_diag): + + i1 = np.dot(A3[0], imvec) + i2 = np.dot(A3[1], imvec) + i3 = np.dot(A3[2], imvec) + clphase_samples = np.angle(i1 * i2 * i3) + clphase_diag_samples = np.dot(tform_mats[iA], clphase_samples) + + clphase_diag_measured = clphase_diag[iA] + clphase_diag_sigma = sigma[iA] + + term1 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples) / + (clphase_diag_sigma**2.0)), (tform_mats[iA]/i1)), A3[0]) + term2 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples) / + (clphase_diag_sigma**2.0)), (tform_mats[iA]/i2)), A3[1]) + term3 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples) / + (clphase_diag_sigma**2.0)), (tform_mats[iA]/i3)), A3[2]) + deriv += -2.0*np.imag(term1 + term2 + term3) + + deriv *= 1.0/np.float(len(np.concatenate(clphase_diag))) + + return deriv + + +def chisq_camp(imvec, Amatrices, clamp, sigma): + """Closure Amplitudes (normalized) chi-squared""" + + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + i4 = np.dot(Amatrices[3], imvec) + clamp_samples = np.abs((i1 * i2)/(i3 * i4)) + + chisq = np.sum(np.abs((clamp - clamp_samples)/sigma)**2)/len(clamp) + return chisq + + +def chisqgrad_camp(imvec, Amatrices, clamp, sigma): + """The gradient of the closure amplitude chi-squared""" + + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + i4 = np.dot(Amatrices[3], imvec) + clamp_samples = np.abs((i1 * i2)/(i3 * i4)) + + pp = ((clamp - clamp_samples) * clamp_samples)/(sigma**2) + pt1 = pp/i1 + pt2 = pp/i2 + pt3 = -pp/i3 + pt4 = -pp/i4 + out = (np.dot(pt1, Amatrices[0]) + + np.dot(pt2, Amatrices[1]) + + np.dot(pt3, Amatrices[2]) + + np.dot(pt4, Amatrices[3])) + out *= (-2.0/len(clamp)) * np.real(out) + return out + + +def chisq_logcamp(imvec, Amatrices, log_clamp, sigma): + """Log Closure Amplitudes (normalized) chi-squared""" + + a1 = np.abs(np.dot(Amatrices[0], imvec)) + a2 = np.abs(np.dot(Amatrices[1], imvec)) + a3 = np.abs(np.dot(Amatrices[2], imvec)) + a4 = np.abs(np.dot(Amatrices[3], imvec)) + + samples = np.log(a1) + np.log(a2) - np.log(a3) - np.log(a4) + chisq = np.sum(np.abs((log_clamp - samples)/sigma)**2) / (len(log_clamp)) + return chisq + + +def chisqgrad_logcamp(imvec, Amatrices, log_clamp, sigma): + """The gradient of the Log closure amplitude chi-squared""" + + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + i4 = np.dot(Amatrices[3], imvec) + log_clamp_samples = (np.log(np.abs(i1)) + + np.log(np.abs(i2)) - + np.log(np.abs(i3)) - + np.log(np.abs(i4))) + + pp = (log_clamp - log_clamp_samples) / (sigma**2) + pt1 = pp / i1 + pt2 = pp / i2 + pt3 = -pp / i3 + pt4 = -pp / i4 + out = (np.dot(pt1, Amatrices[0]) + + np.dot(pt2, Amatrices[1]) + + np.dot(pt3, Amatrices[2]) + + np.dot(pt4, Amatrices[3])) + out = (-2.0/len(log_clamp)) * np.real(out) + return out + + +def chisq_logcamp_diag(imvec, Amatrices, log_clamp_diag, sigma): + """Diagonalized log closure amplitudes (normalized) chi-squared""" + + log_clamp_diag = np.concatenate(log_clamp_diag) + sigma = np.concatenate(sigma) + + A4_diag = Amatrices[0] + tform_mats = Amatrices[1] + + log_clamp_diag_samples = [] + for iA, A4 in enumerate(A4_diag): + + a1 = np.abs(np.dot(A4[0], imvec)) + a2 = np.abs(np.dot(A4[1], imvec)) + a3 = np.abs(np.dot(A4[2], imvec)) + a4 = np.abs(np.dot(A4[3], imvec)) + + log_clamp_samples = np.log(a1) + np.log(a2) - np.log(a3) - np.log(a4) + log_clamp_diag_samples.append(np.dot(tform_mats[iA], log_clamp_samples)) + + log_clamp_diag_samples = np.concatenate(log_clamp_diag_samples) + + chisq = np.sum(np.abs((log_clamp_diag - log_clamp_diag_samples)/sigma)**2) + chisq /= (len(log_clamp_diag)) + + return chisq + + +def chisqgrad_logcamp_diag(imvec, Amatrices, log_clamp_diag, sigma): + """The gradient of the diagonalized log closure amplitude chi-squared""" + + A4_diag = Amatrices[0] + tform_mats = Amatrices[1] + + deriv = np.zeros_like(imvec) + for iA, A4 in enumerate(A4_diag): + + i1 = np.dot(A4[0], imvec) + i2 = np.dot(A4[1], imvec) + i3 = np.dot(A4[2], imvec) + i4 = np.dot(A4[3], imvec) + log_clamp_samples = np.log(np.abs(i1)) + np.log(np.abs(i2)) - \ + np.log(np.abs(i3)) - np.log(np.abs(i4)) + log_clamp_diag_samples = np.dot(tform_mats[iA], log_clamp_samples) + + log_clamp_diag_measured = log_clamp_diag[iA] + log_clamp_diag_sigma = sigma[iA] + + term1 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples) / + (log_clamp_diag_sigma**2.0)), (tform_mats[iA]/i1)), A4[0]) + term2 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples) / + (log_clamp_diag_sigma**2.0)), (tform_mats[iA]/i2)), A4[1]) + term3 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples) / + (log_clamp_diag_sigma**2.0)), (tform_mats[iA]/i3)), A4[2]) + term4 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples) / + (log_clamp_diag_sigma**2.0)), (tform_mats[iA]/i4)), A4[3]) + deriv += -2.0*np.real(term1 + term2 - term3 - term4) + + deriv *= 1.0/np.float(len(np.concatenate(log_clamp_diag))) + + return deriv + + +def chisq_logamp(imvec, A, amp, sigma): + """Log Visibility Amplitudes (normalized) chi-squared""" + + # to lowest order the variance on the logarithm of a quantity x is + # sigma^2_log(x) = sigma^2/x^2 + logsigma = sigma / amp + + amp_samples = np.abs(np.dot(A, imvec)) + chisq = np.sum(np.abs((np.log(amp) - np.log(amp_samples))/logsigma)**2)/len(amp) + return chisq + +def chisqgrad_logamp(imvec, A, amp, sigma): + """The gradient of the Log amplitude chi-squared""" + + # to lowest order the variance on the logarithm of a quantity x is + # sigma^2_log(x) = sigma^2/x^2 + logsigma = sigma / amp + + i1 = np.dot(A, imvec) + amp_samples = np.abs(i1) + + pp = ((np.log(amp) - np.log(amp_samples))) / (logsigma**2) / i1 + out = (-2.0/len(amp)) * np.real(np.dot(pp, A)) + return out + +################################################################################################## +# FFT Chi-squared and Gradient Functions +################################################################################################## + + +def chisq_vis_fft(vis_arr, A, vis, sigma): + """Visibility chi-squared from fft + """ + + im_info, sampler_info_list, gridder_info_list = A + samples = obsh.sampler(vis_arr, sampler_info_list, sample_type="vis") + + chisq = np.sum(np.abs((samples-vis)/sigma)**2)/(2*len(vis)) + + return chisq + + +def chisqgrad_vis_fft(vis_arr, A, vis, sigma): + """The gradient of the visibility chi-squared from fft + """ + + im_info, sampler_info_list, gridder_info_list = A + + # samples and gradient FT + pulsefac = sampler_info_list[0].pulsefac + samples = obsh.sampler(vis_arr, sampler_info_list, sample_type="vis") + wdiff_vec = (-1.0/len(vis)*(vis - samples)/(sigma**2)) * pulsefac.conj() + + # Setup and perform the inverse FFT + wdiff_arr = obsh.gridder([wdiff_vec], gridder_info_list) + grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff_arr))) + grad_arr = grad_arr * (im_info.npad * im_info.npad) + + # extract relevant cells and flatten + # TODO or is x<-->y?? + out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2, + im_info.padvaly1:-im_info.padvaly2].flatten()) + + return out + + +def chisq_amp_fft(vis_arr, A, amp, sigma): + """Visibility amplitude chi-squared from fft + """ + + im_info, sampler_info_list, gridder_info_list = A + amp_samples = np.abs(obsh.sampler(vis_arr, sampler_info_list, sample_type="vis")) + chisq = np.sum(np.abs((amp_samples-amp)/sigma)**2)/(len(amp)) + return chisq + + +def chisqgrad_amp_fft(vis_arr, A, amp, sigma): + """The gradient of the amplitude chi-kernesquared + """ + + im_info, sampler_info_list, gridder_info_list = A + + # samples + samples = obsh.sampler(vis_arr, sampler_info_list, sample_type="vis") + amp_samples = np.abs(samples) + + # gradient FT + pulsefac = sampler_info_list[0].pulsefac + wdiff_vec = (-2.0/len(amp)*((amp - amp_samples) * amp_samples) / + (sigma**2) / samples.conj()) * pulsefac.conj() + + # Setup and perform the inverse FFT + wdiff_arr = obsh.gridder([wdiff_vec], gridder_info_list) + grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff_arr))) + grad_arr = grad_arr * (im_info.npad * im_info.npad) + + # extract relevent cells and flatten + # TODO or is x<-->y?? + out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2, + im_info.padvaly1:-im_info.padvaly2].flatten()) + + return out + + +def chisq_bs_fft(vis_arr, A, bis, sigma): + """Bispectrum chi-squared from fft""" + + im_info, sampler_info_list, gridder_info_list = A + bisamples = obsh.sampler(vis_arr, sampler_info_list, sample_type="bs") + + return np.sum(np.abs(((bis - bisamples)/sigma))**2)/(2.*len(bis)) + + +def chisqgrad_bs_fft(vis_arr, A, bis, sigma): + """The gradient of the amplitude chi-squared + """ + im_info, sampler_info_list, gridder_info_list = A + + v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis") + v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis") + v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis") + bisamples = v1*v2*v3 + + wdiff = -1.0/len(bis)*(bis - bisamples)/(sigma**2) + + pt1 = wdiff * (v2 * v3).conj() * sampler_info_list[0].pulsefac.conj() + pt2 = wdiff * (v1 * v3).conj() * sampler_info_list[1].pulsefac.conj() + pt3 = wdiff * (v1 * v2).conj() * sampler_info_list[2].pulsefac.conj() + + # Setup and perform the inverse FFT + wdiff = obsh.gridder([pt1, pt2, pt3], gridder_info_list) + grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff))) + grad_arr = grad_arr * (im_info.npad * im_info.npad) + + # extract relevant cells and flatten + # TODO or is x<-->y?? + out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2, + im_info.padvaly1:-im_info.padvaly2].flatten()) + return out + + +def chisq_cphase_fft(vis_arr, A, clphase, sigma): + """Closure Phases (normalized) chi-squared from fft + """ + + clphase = clphase * ehc.DEGREE + sigma = sigma * ehc.DEGREE + + im_info, sampler_info_list, gridder_info_list = A + clphase_samples = np.angle(obsh.sampler(vis_arr, sampler_info_list, sample_type="bs")) + + chisq = (2.0/len(clphase)) * np.sum((1.0 - np.cos(clphase-clphase_samples))/(sigma**2)) + return chisq + + +def chisqgrad_cphase_fft(vis_arr, A, clphase, sigma): + """The gradient of the closure phase chi-squared from fft""" + + clphase = clphase * ehc.DEGREE + sigma = sigma * ehc.DEGREE + im_info, sampler_info_list, gridder_info_list = A + + # sample visibilities and closure phases + v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis") + v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis") + v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis") + clphase_samples = np.angle(v1*v2*v3) + + pref = (2.0/len(clphase)) * np.sin(clphase - clphase_samples)/(sigma**2) + pt1 = pref/v1.conj() * sampler_info_list[0].pulsefac.conj() + pt2 = pref/v2.conj() * sampler_info_list[1].pulsefac.conj() + pt3 = pref/v3.conj() * sampler_info_list[2].pulsefac.conj() + + # Setup and perform the inverse FFT + wdiff = obsh.gridder([pt1, pt2, pt3], gridder_info_list) + grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff))) + grad_arr = grad_arr * (im_info.npad * im_info.npad) + + # extract relevant cells and flatten + # TODO or is x<-->y?? + out = np.imag(grad_arr[im_info.padvalx1:-im_info.padvalx2, + im_info.padvaly1:-im_info.padvaly2].flatten()) + + return out + + +def chisq_cphase_diag_fft(imvec, A, clphase_diag, sigma): + """Diagonalized closure phases (normalized) chi-squared from fft + """ + + clphase_diag = np.concatenate(clphase_diag) * ehc.DEGREE + sigma = np.concatenate(sigma) * ehc.DEGREE + + A3 = A[0] + tform_mats = A[1] + + im_info, sampler_info_list, gridder_info_list = A3 + vis_arr = obsh.fft_imvec(imvec, A3[0]) + clphase_samples = np.angle(obsh.sampler(vis_arr, sampler_info_list, sample_type="bs")) + + count = 0 + clphase_diag_samples = [] + for tform_mat in tform_mats: + clphase_samples_here = clphase_samples[count:count+len(tform_mat)] + clphase_diag_samples.append(np.dot(tform_mat, clphase_samples_here)) + count += len(tform_mat) + + clphase_diag_samples = np.concatenate(clphase_diag_samples) + + chisq = np.sum((1.0 - np.cos(clphase_diag-clphase_diag_samples))/(sigma**2)) + chisq *= (2.0/len(clphase_diag)) + return chisq + + +def chisqgrad_cphase_diag_fft(imvec, A, clphase_diag, sigma): + """The gradient of the closure phase chi-squared from fft""" + + clphase_diag = np.concatenate(clphase_diag) * ehc.DEGREE + sigma = np.concatenate(sigma) * ehc.DEGREE + + A3 = A[0] + tform_mats = A[1] + + im_info, sampler_info_list, gridder_info_list = A3 + vis_arr = obsh.fft_imvec(imvec, A3[0]) + + # sample visibilities and closure phases + v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis") + v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis") + v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis") + clphase_samples = np.angle(v1*v2*v3) + + # gradient vec stuff + count = 0 + pref = np.zeros_like(clphase_samples) + for tform_mat in tform_mats: + + clphase_diag_samples = np.dot(tform_mat, clphase_samples[count:count+len(tform_mat)]) + clphase_diag_measured = clphase_diag[count:count+len(tform_mat)] + clphase_diag_sigma = sigma[count:count+len(tform_mat)] + + for j in range(len(clphase_diag_measured)): + pref[count:count+len(tform_mat)] += 2.0 * tform_mat[j, :] * np.sin( + clphase_diag_measured[j] - clphase_diag_samples[j])/(clphase_diag_sigma[j]**2) + + count += len(tform_mat) + + pt1 = pref/v1.conj() * sampler_info_list[0].pulsefac.conj() + pt2 = pref/v2.conj() * sampler_info_list[1].pulsefac.conj() + pt3 = pref/v3.conj() * sampler_info_list[2].pulsefac.conj() + + # Setup and perform the inverse FFT + wdiff = obsh.gridder([pt1, pt2, pt3], gridder_info_list) + grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff))) + grad_arr = grad_arr * (im_info.npad * im_info.npad) + + # extract relevant cells and flatten + deriv = np.imag(grad_arr[im_info.padvalx1:-im_info.padvalx2, + im_info.padvaly1:-im_info.padvaly2].flatten()) + deriv *= 1.0/np.float(len(clphase_diag)) + + return deriv + + +def chisq_camp_fft(vis_arr, A, clamp, sigma): + """Closure Amplitudes (normalized) chi-squared from fft + """ + + im_info, sampler_info_list, gridder_info_list = A + clamp_samples = obsh.sampler(vis_arr, sampler_info_list, sample_type="camp") + chisq = np.sum(np.abs((clamp - clamp_samples)/sigma)**2)/len(clamp) + return chisq + + +def chisqgrad_camp_fft(vis_arr, A, clamp, sigma): + """The gradient of the closure amplitude chi-squared from fft + """ + + im_info, sampler_info_list, gridder_info_list = A + + # sampled visibility and closure amplitudes + v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis") + v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis") + v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis") + v4 = obsh.sampler(vis_arr, [sampler_info_list[3]], sample_type="vis") + clamp_samples = np.abs((v1 * v2)/(v3 * v4)) + + # gradient components + pp = (-2.0/len(clamp)) * ((clamp - clamp_samples) * clamp_samples)/(sigma**2) + pt1 = pp/v1.conj() * sampler_info_list[0].pulsefac.conj() + pt2 = pp/v2.conj() * sampler_info_list[1].pulsefac.conj() + pt3 = -pp/v3.conj() * sampler_info_list[2].pulsefac.conj() + pt4 = -pp/v4.conj() * sampler_info_list[3].pulsefac.conj() + + # Setup and perform the inverse FFT + wdiff = obsh.gridder([pt1, pt2, pt3, pt4], gridder_info_list) + grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff))) + grad_arr = grad_arr * (im_info.npad * im_info.npad) + + # extract relevant cells and flatten + # TODO or is x<-->y?? + out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2, + im_info.padvaly1:-im_info.padvaly2].flatten()) + + return out + + +def chisq_logcamp_fft(vis_arr, A, log_clamp, sigma): + """Closure Amplitudes (normalized) chi-squared from fft + """ + + im_info, sampler_info_list, gridder_info_list = A + log_clamp_samples = np.log(obsh.sampler(vis_arr, sampler_info_list, sample_type='camp')) + + chisq = np.sum(np.abs((log_clamp - log_clamp_samples)/sigma)**2) / (len(log_clamp)) + + return chisq + + +def chisqgrad_logcamp_fft(vis_arr, A, log_clamp, sigma): + """The gradient of the closure amplitude chi-squared from fft + """ + + im_info, sampler_info_list, gridder_info_list = A + + # sampled visibility and closure amplitudes + v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis") + v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis") + v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis") + v4 = obsh.sampler(vis_arr, [sampler_info_list[3]], sample_type="vis") + + log_clamp_samples = np.log(np.abs((v1 * v2)/(v3 * v4))) + + # gradient components + pp = (-2.0/len(log_clamp)) * (log_clamp - log_clamp_samples) / (sigma**2) + pt1 = pp / v1.conj() * sampler_info_list[0].pulsefac.conj() + pt2 = pp / v2.conj() * sampler_info_list[1].pulsefac.conj() + pt3 = -pp / v3.conj() * sampler_info_list[2].pulsefac.conj() + pt4 = -pp / v4.conj() * sampler_info_list[3].pulsefac.conj() + + # Setup and perform the inverse FFT + wdiff = obsh.gridder([pt1, pt2, pt3, pt4], gridder_info_list) + grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff))) + grad_arr = grad_arr * (im_info.npad * im_info.npad) + + # extract relevant cells and flatten + # TODO or is x<-->y?? + out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2, + im_info.padvaly1:-im_info.padvaly2].flatten()) + + return out + + +def chisq_logcamp_diag_fft(imvec, A, log_clamp_diag, sigma): + """Diagonalized log closure amplitudes (normalized) chi-squared from fft + """ + + log_clamp_diag = np.concatenate(log_clamp_diag) + sigma = np.concatenate(sigma) + + A4 = A[0] + tform_mats = A[1] + + im_info, sampler_info_list, gridder_info_list = A4 + vis_arr = obsh.fft_imvec(imvec, A4[0]) + log_clamp_samples = np.log(obsh.sampler(vis_arr, sampler_info_list, sample_type='camp')) + + count = 0 + log_clamp_diag_samples = [] + for tform_mat in tform_mats: + log_clamp_samples_here = log_clamp_samples[count:count+len(tform_mat)] + log_clamp_diag_samples.append(np.dot(tform_mat, log_clamp_samples_here)) + count += len(tform_mat) + log_clamp_diag_samples = np.concatenate(log_clamp_diag_samples) + + chisq = np.sum(np.abs((log_clamp_diag - log_clamp_diag_samples)/sigma)**2) + chisq /= (len(log_clamp_diag)) + return chisq + + +def chisqgrad_logcamp_diag_fft(imvec, A, log_clamp_diag, sigma): + """The gradient of the diagonalized log closure amplitude chi-squared from fft + """ + + log_clamp_diag = np.concatenate(log_clamp_diag) + sigma = np.concatenate(sigma) + + A4 = A[0] + tform_mats = A[1] + + im_info, sampler_info_list, gridder_info_list = A4 + vis_arr = obsh.fft_imvec(imvec, A4[0]) + + # sampled visibility and closure amplitudes + v1 = obsh.sampler(vis_arr, [sampler_info_list[0]], sample_type="vis") + v2 = obsh.sampler(vis_arr, [sampler_info_list[1]], sample_type="vis") + v3 = obsh.sampler(vis_arr, [sampler_info_list[2]], sample_type="vis") + v4 = obsh.sampler(vis_arr, [sampler_info_list[3]], sample_type="vis") + log_clamp_samples = np.log(np.abs((v1 * v2)/(v3 * v4))) + + # gradient vec stuff + count = 0 + pref = np.zeros_like(log_clamp_samples) + for tform_mat in tform_mats: + + log_clamp_diag_samples = np.dot(tform_mat, log_clamp_samples[count:count+len(tform_mat)]) + log_clamp_diag_measured = log_clamp_diag[count:count+len(tform_mat)] + log_clamp_diag_sigma = sigma[count:count+len(tform_mat)] + + for j in range(len(log_clamp_diag_measured)): + pref[count:count+len(tform_mat)] += -2.0 * tform_mat[j, :] * \ + (log_clamp_diag_measured[j] - log_clamp_diag_samples[j]) / \ + (log_clamp_diag_sigma[j]**2) + + count += len(tform_mat) + + pt1 = pref / v1.conj() * sampler_info_list[0].pulsefac.conj() + pt2 = pref / v2.conj() * sampler_info_list[1].pulsefac.conj() + pt3 = -pref / v3.conj() * sampler_info_list[2].pulsefac.conj() + pt4 = -pref / v4.conj() * sampler_info_list[3].pulsefac.conj() + + # Setup and perform the inverse FFT + wdiff = obsh.gridder([pt1, pt2, pt3, pt4], gridder_info_list) + grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff))) + grad_arr = grad_arr * (im_info.npad * im_info.npad) + + # extract relevant cells and flatten + deriv = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2, + im_info.padvaly1:-im_info.padvaly2].flatten()) + deriv *= 1.0/np.float(len(log_clamp_diag)) + + return deriv + + +def chisq_logamp_fft(vis_arr, A, amp, sigma): + """Visibility amplitude chi-squared from fft + """ + + # to lowest order the variance on the logarithm of a quantity x is + # sigma^2_log(x) = sigma^2/x^2 + logsigma = sigma / amp + + im_info, sampler_info_list, gridder_info_list = A + amp_samples = np.abs(obsh.sampler(vis_arr, sampler_info_list, sample_type="vis")) + chisq = np.sum(np.abs((np.log(amp_samples)-np.log(amp))/logsigma)**2)/(len(amp)) + return chisq + + +def chisqgrad_logamp_fft(vis_arr, A, amp, sigma): + """The gradient of the amplitude chi-kernesquared + """ + + im_info, sampler_info_list, gridder_info_list = A + + # samples + samples = obsh.sampler(vis_arr, sampler_info_list, sample_type="vis") + amp_samples = np.abs(samples) + + # gradient FT + logsigma = sigma / amp + pulsefac = sampler_info_list[0].pulsefac + wdiff_vec = (-2.0/len(amp)*((np.log(amp) - np.log(amp_samples))) / + (logsigma**2) / samples.conj()) * pulsefac.conj() + + # Setup and perform the inverse FFT + wdiff_arr = obsh.gridder([wdiff_vec], gridder_info_list) + grad_arr = np.fft.ifftshift(np.fft.ifft2(np.fft.fftshift(wdiff_arr))) + grad_arr = grad_arr * (im_info.npad * im_info.npad) + + # extract relevent cells and flatten + # TODO or is x<-->y?? + out = np.real(grad_arr[im_info.padvalx1:-im_info.padvalx2, + im_info.padvaly1:-im_info.padvaly2].flatten()) + + return out + +################################################################################################## +# NFFT Chi-squared and Gradient Functions +################################################################################################## + + +def chisq_vis_nfft(imvec, A, vis, sigma): + """Visibility chi-squared from nfft + """ + + # get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + # compute uniform --> nonuniform transform + plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T + plan.trafo() + samples = plan.f.copy()*pulsefac + + # compute chi^2 + chisq = np.sum(np.abs((samples-vis)/sigma)**2)/(2*len(vis)) + + return chisq + + +def chisqgrad_vis_nfft(imvec, A, vis, sigma): + """The gradient of the visibility chi-squared from nfft + """ + + # get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + # compute uniform --> nonuniform transform + plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T + plan.trafo() + samples = plan.f.copy()*pulsefac + + # gradient vec for adjoint FT + wdiff_vec = (-1.0/len(vis)*(vis - samples)/(sigma**2)) * pulsefac.conj() + plan.f = wdiff_vec + plan.adjoint() + grad = np.real((plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim)) + + return grad + + +def chisq_amp_nfft(imvec, A, amp, sigma): + """Visibility amplitude chi-squared from nfft + """ + # get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + # compute uniform --> nonuniform transform + plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T + plan.trafo() + samples = plan.f.copy()*pulsefac + + # compute chi^2 + amp_samples = np.abs(samples) + chisq = np.sum(np.abs((amp_samples-amp)/sigma)**2)/(len(amp)) + + return chisq + + +def chisqgrad_amp_nfft(imvec, A, amp, sigma): + """The gradient of the amplitude chi-squared from nfft + """ + + # get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + # compute uniform --> nonuniform transform + plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T + plan.trafo() + samples = plan.f.copy()*pulsefac + amp_samples = np.abs(samples) + + # gradient vec for adjoint FT + wdiff_vec = (-2.0/len(amp)*((amp - amp_samples) * samples) / + (sigma**2) / amp_samples) * pulsefac.conj() + plan.f = wdiff_vec + plan.adjoint() + out = np.real((plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim)) + + return out + + +def chisq_bs_nfft(imvec, A, bis, sigma): + """Bispectrum chi-squared from fft""" + + # get nfft objects + nfft_info1 = A[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + samples1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + samples2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + samples3 = plan3.f.copy()*pulsefac3 + + # compute chi^2 + bisamples = samples1*samples2*samples3 + chisq = np.sum(np.abs(((bis - bisamples)/sigma))**2)/(2.*len(bis)) + return chisq + + +def chisqgrad_bs_nfft(imvec, A, bis, sigma): + """The gradient of the amplitude chi-squared from the nfft + """ + # get nfft objects + nfft_info1 = A[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + v1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + v2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + v3 = plan3.f.copy()*pulsefac3 + + # gradient vec for adjoint FT + bisamples = v1*v2*v3 + wdiff = -1.0/len(bis)*(bis - bisamples)/(sigma**2) + pt1 = wdiff * (v2 * v3).conj() * pulsefac1.conj() + pt2 = wdiff * (v1 * v3).conj() * pulsefac2.conj() + pt3 = wdiff * (v1 * v2).conj() * pulsefac3.conj() + + # Setup and perform the inverse FFT + plan1.f = pt1 + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + + plan2.f = pt2 + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + + plan3.f = pt3 + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + out = out1 + out2 + out3 + return out + + +def chisq_cphase_nfft(imvec, A, clphase, sigma): + """Closure Phases (normalized) chi-squared from nfft + """ + + clphase = clphase * ehc.DEGREE + sigma = sigma * ehc.DEGREE + + # get nfft objects + nfft_info1 = A[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + samples1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + samples2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + samples3 = plan3.f.copy()*pulsefac3 + + # compute chi^2 + clphase_samples = np.angle(samples1*samples2*samples3) + chisq = (2.0/len(clphase)) * np.sum((1.0 - np.cos(clphase-clphase_samples))/(sigma**2)) + + return chisq + + +def chisqgrad_cphase_nfft(imvec, A, clphase, sigma): + """The gradient of the closure phase chi-squared from nfft""" + + clphase = clphase * ehc.DEGREE + sigma = sigma * ehc.DEGREE + + # get nfft objects + nfft_info1 = A[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + v1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + v2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + v3 = plan3.f.copy()*pulsefac3 + + # gradient vec for adjoint FT + clphase_samples = np.angle(v1*v2*v3) + pref = (2.0/len(clphase)) * np.sin(clphase - clphase_samples)/(sigma**2) + pt1 = pref/v1.conj() * pulsefac1.conj() + pt2 = pref/v2.conj() * pulsefac2.conj() + pt3 = pref/v3.conj() * pulsefac3.conj() + + # Setup and perform the inverse FFT + plan1.f = pt1 + plan1.adjoint() + out1 = np.imag((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + + plan2.f = pt2 + plan2.adjoint() + out2 = np.imag((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + + plan3.f = pt3 + plan3.adjoint() + out3 = np.imag((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + out = out1 + out2 + out3 + return out + + +def chisq_cphase_diag_nfft(imvec, A, clphase_diag, sigma): + """Diagonalized closure phases (normalized) chi-squared from nfft + """ + + clphase_diag = np.concatenate(clphase_diag) * ehc.DEGREE + sigma = np.concatenate(sigma) * ehc.DEGREE + + A3 = A[0] + tform_mats = A[1] + + # get nfft objects + nfft_info1 = A3[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A3[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A3[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + samples1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + samples2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + samples3 = plan3.f.copy()*pulsefac3 + + clphase_samples = np.angle(samples1*samples2*samples3) + + count = 0 + clphase_diag_samples = [] + for tform_mat in tform_mats: + clphase_samples_here = clphase_samples[count:count+len(tform_mat)] + clphase_diag_samples.append(np.dot(tform_mat, clphase_samples_here)) + count += len(tform_mat) + + clphase_diag_samples = np.concatenate(clphase_diag_samples) + + # compute chi^2 + chisq = (2.0/len(clphase_diag)) * \ + np.sum((1.0 - np.cos(clphase_diag-clphase_diag_samples))/(sigma**2)) + + return chisq + + +def chisqgrad_cphase_diag_nfft(imvec, A, clphase_diag, sigma): + """The gradient of the diagonalized closure phase chi-squared from nfft""" + + clphase_diag = np.concatenate(clphase_diag) * ehc.DEGREE + sigma = np.concatenate(sigma) * ehc.DEGREE + + A3 = A[0] + tform_mats = A[1] + + # get nfft objects + nfft_info1 = A3[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A3[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A3[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + v1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + v2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + v3 = plan3.f.copy()*pulsefac3 + + clphase_samples = np.angle(v1*v2*v3) + + # gradient vec for adjoint FT + count = 0 + pref = np.zeros_like(clphase_samples) + for tform_mat in tform_mats: + + clphase_diag_samples = np.dot(tform_mat, clphase_samples[count:count+len(tform_mat)]) + clphase_diag_measured = clphase_diag[count:count+len(tform_mat)] + clphase_diag_sigma = sigma[count:count+len(tform_mat)] + + for j in range(len(clphase_diag_measured)): + pref[count:count+len(tform_mat)] += 2.0 * tform_mat[j, :] * np.sin( + clphase_diag_measured[j] - clphase_diag_samples[j])/(clphase_diag_sigma[j]**2) + + count += len(tform_mat) + + pt1 = pref/v1.conj() * pulsefac1.conj() + pt2 = pref/v2.conj() * pulsefac2.conj() + pt3 = pref/v3.conj() * pulsefac3.conj() + + # Setup and perform the inverse FFT + plan1.f = pt1 + plan1.adjoint() + out1 = np.imag((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + + plan2.f = pt2 + plan2.adjoint() + out2 = np.imag((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + + plan3.f = pt3 + plan3.adjoint() + out3 = np.imag((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + deriv = out1 + out2 + out3 + deriv *= 1.0/np.float(len(clphase_diag)) + + return deriv + + +def chisq_camp_nfft(imvec, A, clamp, sigma): + """Closure Amplitudes (normalized) chi-squared from fft + """ + + # get nfft objects + nfft_info1 = A[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + nfft_info4 = A[3] + plan4 = nfft_info4.plan + pulsefac4 = nfft_info4.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + samples1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + samples2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + samples3 = plan3.f.copy()*pulsefac3 + + plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T + plan4.trafo() + samples4 = plan4.f.copy()*pulsefac4 + + # compute chi^2 + clamp_samples = np.abs((samples1*samples2)/(samples3*samples4)) + chisq = np.sum(np.abs((clamp - clamp_samples)/sigma)**2)/len(clamp) + return chisq + + +def chisqgrad_camp_nfft(imvec, A, clamp, sigma): + """The gradient of the closure amplitude chi-squared from fft + """ + + # get nfft objects + nfft_info1 = A[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + nfft_info4 = A[3] + plan4 = nfft_info4.plan + pulsefac4 = nfft_info4.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + v1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + v2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + v3 = plan3.f.copy()*pulsefac3 + + plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T + plan4.trafo() + v4 = plan4.f.copy()*pulsefac4 + + # gradient vec for adjoint FT + clamp_samples = np.abs((v1 * v2)/(v3 * v4)) + + pp = (-2.0/len(clamp)) * ((clamp - clamp_samples) * clamp_samples)/(sigma**2) + pt1 = pp/v1.conj() * pulsefac1.conj() + pt2 = pp/v2.conj() * pulsefac2.conj() + pt3 = -pp/v3.conj() * pulsefac3.conj() + pt4 = -pp/v4.conj() * pulsefac4.conj() + + # Setup and perform the inverse FFT + plan1.f = pt1 + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + + plan2.f = pt2 + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + + plan3.f = pt3 + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + plan4.f = pt4 + plan4.adjoint() + out4 = np.real((plan4.f_hat.copy().T).reshape(nfft_info4.xdim*nfft_info4.ydim)) + + out = out1 + out2 + out3 + out4 + return out + + +def chisq_logcamp_nfft(imvec, A, log_clamp, sigma): + """Log Closure Amplitudes (normalized) chi-squared from fft + """ + + # get nfft objects + nfft_info1 = A[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + nfft_info4 = A[3] + plan4 = nfft_info4.plan + pulsefac4 = nfft_info4.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + samples1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + samples2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + samples3 = plan3.f.copy()*pulsefac3 + + plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T + plan4.trafo() + samples4 = plan4.f.copy()*pulsefac4 + + # compute chi^2 + log_clamp_samples = (np.log(np.abs(samples1)) + np.log(np.abs(samples2)) - + np.log(np.abs(samples3)) - np.log(np.abs(samples4))) + chisq = np.sum(np.abs((log_clamp - log_clamp_samples)/sigma)**2) / (len(log_clamp)) + return chisq + + +def chisqgrad_logcamp_nfft(imvec, A, log_clamp, sigma): + """The gradient of the log closure amplitude chi-squared from fft + """ + + # get nfft objects + nfft_info1 = A[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + nfft_info4 = A[3] + plan4 = nfft_info4.plan + pulsefac4 = nfft_info4.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + v1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + v2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + v3 = plan3.f.copy()*pulsefac3 + + plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T + plan4.trafo() + v4 = plan4.f.copy()*pulsefac4 + + # gradient vec for adjoint FT + log_clamp_samples = np.log(np.abs((v1 * v2)/(v3 * v4))) + + pp = (-2.0/len(log_clamp)) * (log_clamp - log_clamp_samples) / (sigma**2) + pt1 = pp / v1.conj() * pulsefac1.conj() + pt2 = pp / v2.conj() * pulsefac2.conj() + pt3 = -pp / v3.conj() * pulsefac3.conj() + pt4 = -pp / v4.conj() * pulsefac4.conj() + + # Setup and perform the inverse FFT + plan1.f = pt1 + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + + plan2.f = pt2 + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + + plan3.f = pt3 + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + plan4.f = pt4 + plan4.adjoint() + out4 = np.real((plan4.f_hat.copy().T).reshape(nfft_info4.xdim*nfft_info4.ydim)) + + out = out1 + out2 + out3 + out4 + return out + + +def chisq_logcamp_diag_nfft(imvec, A, log_clamp_diag, sigma): + """Diagonalized log closure amplitudes (normalized) chi-squared from nfft + """ + + log_clamp_diag = np.concatenate(log_clamp_diag) + sigma = np.concatenate(sigma) + + A4 = A[0] + tform_mats = A[1] + + # get nfft objects + nfft_info1 = A4[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A4[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A4[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + nfft_info4 = A4[3] + plan4 = nfft_info4.plan + pulsefac4 = nfft_info4.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + samples1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + samples2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + samples3 = plan3.f.copy()*pulsefac3 + + plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T + plan4.trafo() + samples4 = plan4.f.copy()*pulsefac4 + + log_clamp_samples = (np.log(np.abs(samples1)) + np.log(np.abs(samples2)) - + np.log(np.abs(samples3)) - np.log(np.abs(samples4))) + + count = 0 + log_clamp_diag_samples = [] + for tform_mat in tform_mats: + log_clamp_samples_here = log_clamp_samples[count:count+len(tform_mat)] + log_clamp_diag_samples.append(np.dot(tform_mat, log_clamp_samples_here)) + count += len(tform_mat) + + log_clamp_diag_samples = np.concatenate(log_clamp_diag_samples) + + # compute chi^2 + chisq = np.sum(np.abs((log_clamp_diag - log_clamp_diag_samples)/sigma)**2) / \ + (len(log_clamp_diag)) + + return chisq + + +def chisqgrad_logcamp_diag_nfft(imvec, A, log_clamp_diag, sigma): + """The gradient of the diagonalized log closure amplitude chi-squared from fft + """ + + log_clamp_diag = np.concatenate(log_clamp_diag) + sigma = np.concatenate(sigma) + + A4 = A[0] + tform_mats = A[1] + + # get nfft objects + nfft_info1 = A4[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A4[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A4[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + nfft_info4 = A4[3] + plan4 = nfft_info4.plan + pulsefac4 = nfft_info4.pulsefac + + # compute uniform --> nonuniform transforms + plan1.f_hat = imvec.copy().reshape((nfft_info1.ydim, nfft_info1.xdim)).T + plan1.trafo() + v1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = imvec.copy().reshape((nfft_info2.ydim, nfft_info2.xdim)).T + plan2.trafo() + v2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = imvec.copy().reshape((nfft_info3.ydim, nfft_info3.xdim)).T + plan3.trafo() + v3 = plan3.f.copy()*pulsefac3 + + plan4.f_hat = imvec.copy().reshape((nfft_info4.ydim, nfft_info4.xdim)).T + plan4.trafo() + v4 = plan4.f.copy()*pulsefac4 + + log_clamp_samples = np.log(np.abs((v1 * v2)/(v3 * v4))) + + # gradient vec for adjoint FT + count = 0 + pp = np.zeros_like(log_clamp_samples) + for tform_mat in tform_mats: + + log_clamp_diag_samples = np.dot(tform_mat, log_clamp_samples[count:count+len(tform_mat)]) + log_clamp_diag_measured = log_clamp_diag[count:count+len(tform_mat)] + log_clamp_diag_sigma = sigma[count:count+len(tform_mat)] + + for j in range(len(log_clamp_diag_measured)): + pp[count:count+len(tform_mat)] += -2.0 * tform_mat[j, :] * \ + (log_clamp_diag_measured[j] - log_clamp_diag_samples[j]) / \ + (log_clamp_diag_sigma[j]**2) + + count += len(tform_mat) + + pt1 = pp / v1.conj() * pulsefac1.conj() + pt2 = pp / v2.conj() * pulsefac2.conj() + pt3 = -pp / v3.conj() * pulsefac3.conj() + pt4 = -pp / v4.conj() * pulsefac4.conj() + + # Setup and perform the inverse FFT + plan1.f = pt1 + plan1.adjoint() + out1 = np.real((plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim)) + + plan2.f = pt2 + plan2.adjoint() + out2 = np.real((plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim)) + + plan3.f = pt3 + plan3.adjoint() + out3 = np.real((plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim)) + + plan4.f = pt4 + plan4.adjoint() + out4 = np.real((plan4.f_hat.copy().T).reshape(nfft_info4.xdim*nfft_info4.ydim)) + + deriv = out1 + out2 + out3 + out4 + deriv *= 1.0/np.float(len(log_clamp_diag)) + + return deriv + + +def chisq_logamp_nfft(imvec, A, amp, sigma): + """Visibility log amplitude chi-squared from nfft + """ + + # get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + # compute uniform --> nonuniform transform + plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T + plan.trafo() + samples = plan.f.copy()*pulsefac + + # to lowest order the variance on the logarithm of a quantity x is + # sigma^2_log(x) = sigma^2/x^2 + logsigma = sigma / amp + + # compute chi^2 + amp_samples = np.abs(samples) + chisq = np.sum(np.abs((np.log(amp_samples)-np.log(amp))/logsigma)**2)/(len(amp)) + + return chisq + + +def chisqgrad_logamp_nfft(imvec, A, amp, sigma): + """The gradient of the log amplitude chi-squared from nfft + """ + + # get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + # compute uniform --> nonuniform transform + plan.f_hat = imvec.copy().reshape((nfft_info.ydim, nfft_info.xdim)).T + plan.trafo() + samples = plan.f.copy()*pulsefac + amp_samples = np.abs(samples) + + # gradient vec for adjoint FT + logsigma = sigma / amp + wdiff_vec = (-2.0/len(amp)*((np.log(amp) - np.log(amp_samples))) / + (logsigma**2) / samples.conj()) * pulsefac.conj() + plan.f = wdiff_vec + plan.adjoint() + out = np.real((plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim)) + + return out + + +################################################################################################## +# Regularizer and Gradient Functions +################################################################################################## + +def sflux(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER): + """Total flux constraint + """ + if norm_reg: + norm = flux**2 + else: + norm = 1 + + out = -(np.sum(imvec) - flux)**2 + return out/norm + + +def sfluxgrad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER): + """Total flux constraint gradient + """ + if norm_reg: + norm = flux**2 + else: + norm = 1 + + out = -2*(np.sum(imvec) - flux)*np.ones(len(imvec)) + return out / norm + + +def scm(imvec, nx, ny, psize, flux, embed_mask, norm_reg=NORM_REGULARIZER, beam_size=None): + """Center-of-mass constraint + """ + if beam_size is None: + beam_size = psize + if norm_reg: + norm = beam_size**2 * flux**2 + else: + norm = 1 + + xx, yy = np.meshgrid(range(nx//2, -nx//2, -1), range(ny//2, -ny//2, -1)) + xx = psize*xx.flatten()[embed_mask] + yy = psize*yy.flatten()[embed_mask] + + out = -(np.sum(imvec*xx)**2 + np.sum(imvec*yy)**2) + return out/norm + + +def scmgrad(imvec, nx, ny, psize, flux, embed_mask, norm_reg=NORM_REGULARIZER, beam_size=None): + """Center-of-mass constraint gradient + """ + if beam_size is None: + beam_size = psize + if norm_reg: + norm = beam_size**2 * flux**2 + else: + norm = 1 + + xx, yy = np.meshgrid(range(nx//2, -nx//2, -1), range(ny//2, -ny//2, -1)) + xx = psize*xx.flatten()[embed_mask] + yy = psize*yy.flatten()[embed_mask] + + out = -2*(np.sum(imvec*xx)*xx + np.sum(imvec*yy)*yy) + return out/norm + + +def ssimple(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER): + """Simple entropy + """ + if norm_reg: + norm = flux + else: + norm = 1 + + entropy = -np.sum(imvec*np.log(imvec/priorvec)) + return entropy/norm + + +def ssimplegrad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER): + """Simple entropy gradient + """ + if norm_reg: + norm = flux + else: + norm = 1 + + entropygrad = -np.log(imvec/priorvec) - 1 + return entropygrad/norm + + +def sl1(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER): + """L1 norm regularizer + """ + if norm_reg: + norm = flux + else: + norm = 1 + + # l1 = -np.sum(np.abs(imvec - priorvec)) + l1 = -np.sum(np.abs(imvec)) + return l1/norm + + +def sl1grad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER): + """L1 norm gradient + """ + if norm_reg: + norm = flux + else: + norm = 1 + + # l1grad = -np.sign(imvec - priorvec) + l1grad = -np.sign(imvec) + return l1grad/norm + + +def sl1w(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER, epsilon=ehc.EP): + """Weighted L1 norm regularizer a la SMILI + """ + + if norm_reg: + norm = 1 # should be ok? + # This is SMILI normalization + # norm = np.sum((np.sqrt(priorvec**2 + epsilon) + epsilon)/np.sqrt(priorvec**2 + epsilon)) + else: + norm = 1 + + num = np.sqrt(imvec**2 + epsilon) + denom = np.sqrt(priorvec**2 + epsilon) + epsilon + + l1w = -np.sum(num/denom) + return l1w/norm + + +def sl1wgrad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER, epsilon=ehc.EP): + """Weighted L1 norm gradient + """ + if norm_reg: + norm = 1 # should be ok? + # This is SMILI normalization + # norm = np.sum((np.sqrt(priorvec**2 + epsilon) + epsilon)/np.sqrt(priorvec**2 + epsilon)) + else: + norm = 1 + + num = imvec / np.sqrt(imvec**2 + epsilon) + denom = np.sqrt(priorvec**2 + epsilon) + epsilon + + l1wgrad = - num / denom + return l1wgrad/norm + + +def fA(imvec, I_ref=1.0, alpha_A=1.0): + """Function to take imvec to itself in the limit alpha_A -> 0 + and to a binary representation in the limit alpha_A -> infinity + """ + return 2.0/np.pi * (1.0 + alpha_A)/alpha_A * np.arctan(np.pi*alpha_A/2.0*np.abs(imvec)/I_ref) + + +def fAgrad(imvec, I_ref=1.0, alpha_A=1.0): + """Function to take imvec to itself in the limit alpha_A -> 0 + and to a binary representation in the limit alpha_A -> infinity + """ + return (1.0 + alpha_A) / (I_ref * (1.0 + (np.pi*alpha_A/2.0*imvec/I_ref)**2)) + + +def slA(imvec, priorvec, psize, flux, beam_size=None, alpha_A=1.0, norm_reg=NORM_REGULARIZER): + """l_A regularizer + """ + + # The appropriate I_ref is something like the total flux divided by the # of pixels per beam + if beam_size is None: + beam_size = psize + I_ref = flux + + if norm_reg: + norm_l1 = 1.0 # as alpha_A ->0 + norm_l0 = (beam_size/psize)**2 # as alpha_A ->\infty + weight_l1 = 1.0/(1.0 + alpha_A) + weight_l0 = alpha_A + norm = (norm_l1 * weight_l1 + norm_l0 * weight_l0)/(weight_l0 + weight_l1) + else: + norm = 1 + + return -np.sum(fA(imvec, I_ref, alpha_A))/norm + + +def slAgrad(imvec, priorvec, psize, flux, beam_size=None, alpha_A=1.0, norm_reg=NORM_REGULARIZER): + """l_A gradient + """ + + # The appropriate I_ref is something like the total flux divided by the # of pixels per beam + if beam_size is None: + beam_size = psize + I_ref = flux + + if norm_reg: + norm_l1 = 1.0 # as alpha_A ->0 + norm_l0 = (beam_size/psize)**2 # as alpha_A ->\infty + weight_l1 = 1.0/(1.0 + alpha_A) + weight_l0 = alpha_A + norm = (norm_l1 * weight_l1 + norm_l0 * weight_l0)/(weight_l0 + weight_l1) + else: + norm = 1 + + return -fAgrad(imvec, I_ref, alpha_A)/norm + + +def sgs(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER): + """Gull-skilling entropy + """ + if norm_reg: + norm = flux + else: + norm = 1 + + entropy = np.sum(imvec - priorvec - imvec*np.log(imvec/priorvec)) + return entropy/norm + + +def sgsgrad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER): + """Gull-Skilling gradient + """ + if norm_reg: + norm = flux + else: + norm = 1 + + entropygrad = -np.log(imvec/priorvec) + return entropygrad/norm + +# TODO: epsilon is 0 by default for backwards compatibilitys +def stv(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None, epsilon=0.): + """Total variation regularizer + """ + if beam_size is None: + beam_size = psize + if norm_reg: + norm = flux*psize / beam_size + else: + norm = 1 + + im = imvec.reshape(ny, nx) + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + out = -np.sum(np.sqrt(np.abs(im_l1 - im)**2 + np.abs(im_l2 - im)**2 + epsilon)) + return out/norm + +# TODO: epsilon is 0 by default for backwards compatibility +def stvgrad(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None, epsilon=0.): + """Total variation gradient + """ + if beam_size is None: + beam_size = psize + if norm_reg: + norm = flux*psize / beam_size + else: + norm = 1 + + im = imvec.reshape(ny, nx) + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1] + im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1] + + # rotate images + im_r1l2 = np.roll(np.roll(impad, 1, axis=0), -1, axis=1)[1:ny+1, 1:nx+1] + im_l1r2 = np.roll(np.roll(impad, -1, axis=0), 1, axis=1)[1:ny+1, 1:nx+1] + + # add together terms and return + g1 = (2*im - im_l1 - im_l2) / np.sqrt((im - im_l1)**2 + (im - im_l2)**2 + epsilon) + g2 = (im - im_r1) / np.sqrt((im - im_r1)**2 + (im_r1l2 - im_r1)**2 + epsilon) + g3 = (im - im_r2) / np.sqrt((im - im_r2)**2 + (im_l1r2 - im_r2)**2 + epsilon) + + # mask the first row column gradient terms that don't exist + mask1 = np.zeros(im.shape) + mask2 = np.zeros(im.shape) + mask1[0, :] = 1 + mask2[:, 0] = 1 + g2[mask1.astype(bool)] = 0 + g3[mask2.astype(bool)] = 0 + + # add terms together and return + out = -(g1 + g2 + g3).flatten() + return out/norm + + +def stv2(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None): + """Squared Total variation regularizer + """ + if beam_size is None: + beam_size = psize + if norm_reg: + norm = psize**4 * flux**2 / beam_size**4 + else: + norm = 1 + + im = imvec.reshape(ny, nx) + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + out = -np.sum((im_l1 - im)**2 + (im_l2 - im)**2) + return out/norm + + +def stv2grad(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None): + """Squared Total variation gradient + """ + if beam_size is None: + beam_size = psize + if norm_reg: + norm = psize**4 * flux**2 / beam_size**4 + else: + norm = 1 + + im = imvec.reshape(ny, nx) + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1] + im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1] + + g1 = (2*im - im_l1 - im_l2) + g2 = (im - im_r1) + g3 = (im - im_r2) + + # mask the first row column gradient terms that don't exist + mask1 = np.zeros(im.shape) + mask2 = np.zeros(im.shape) + mask1[0, :] = 1 + mask2[:, 0] = 1 + g2[mask1.astype(bool)] = 0 + g3[mask2.astype(bool)] = 0 + + # add together terms and return + out = -2*(g1 + g2 + g3).flatten() + return out/norm + + +def spatch(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER): + """Patch prior regularizer + """ + if norm_reg: + norm = flux**2 + else: + norm = 1 + + out = -0.5*np.sum((imvec - priorvec) ** 2) + return out/norm + + +def spatchgrad(imvec, priorvec, flux, norm_reg=NORM_REGULARIZER): + """Patch prior gradient + """ + if norm_reg: + norm = flux**2 + else: + norm = 1 + + out = -(imvec - priorvec) + return out/norm + +# TODO FIGURE OUT NORMALIZATIONS FOR COMPACT 1 & 2 REGULARIZERS + + +def scompact(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None): + """I r^2 source size regularizer + """ + + if beam_size is None: + beam_size = psize + if norm_reg: + norm = flux * (beam_size**2) + else: + norm = 1 + + im = imvec.reshape(ny, nx) + + xx, yy = np.meshgrid(range(nx), range(ny)) + xx = xx - (nx-1)/2.0 + yy = yy - (ny-1)/2.0 + xxpsize = xx * psize + yypsize = yy * psize + + x0 = np.sum(np.sum(im * xxpsize))/flux + y0 = np.sum(np.sum(im * yypsize))/flux + + out = -np.sum(np.sum(im * ((xxpsize - x0)**2 + (yypsize - y0)**2))) + return out/norm + + +def scompactgrad(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None): + """Gradient for I r^2 source size regularizer + """ + + if beam_size is None: + beam_size = psize + if norm_reg: + norm = flux * beam_size**2 + else: + norm = 1 + + im = imvec.reshape(ny, nx) + + xx, yy = np.meshgrid(range(nx), range(ny)) + xx = xx - (nx-1)/2.0 + yy = yy - (ny-1)/2.0 + xxpsize = xx * psize + yypsize = yy * psize + + x0 = np.sum(np.sum(im * xxpsize))/flux + y0 = np.sum(np.sum(im * yypsize))/flux + + term1 = np.sum(np.sum(im * ((xxpsize - x0)))) + term2 = np.sum(np.sum(im * ((yypsize - y0)))) + + grad = -2*xxpsize*term1 - 2*yypsize*term2 + (xxpsize - x0)**2 + (yypsize - y0)**2 + + return -grad.reshape(-1)/norm + + +def scompact2(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None): + """I^2r^2 source size regularizer + """ + + if beam_size is None: + beam_size = psize + if norm_reg: + norm = flux**2 * beam_size**2 + else: + norm = 1 + + im = imvec.reshape(ny, nx) + + xx, yy = np.meshgrid(range(nx), range(ny)) + xx = xx - (nx-1)/2.0 + yy = yy - (ny-1)/2.0 + xxpsize = xx * psize + yypsize = yy * psize + + out = -np.sum(np.sum(im**2 * (xxpsize**2 + yypsize**2))) + return out/norm + + +def scompact2grad(imvec, nx, ny, psize, flux, norm_reg=NORM_REGULARIZER, beam_size=None): + """Gradient for I^2r^2 source size regularizer + """ + + if beam_size is None: + beam_size = psize + if norm_reg: + norm = flux**2 * beam_size**2 + else: + norm = 1 + + im = imvec.reshape(ny, nx) + + xx, yy = np.meshgrid(range(nx), range(ny)) + xx = xx - (nx-1)/2.0 + yy = yy - (ny-1)/2.0 + xxpsize = xx * psize + yypsize = yy * psize + + grad = -2*im*(xxpsize**2 + yypsize**2) + + return grad.reshape(-1)/norm + + +def sgauss(imvec, xdim, ydim, psize, major, minor, PA): + """Gaussian source size regularizer + """ + + # major, minor and PA are all in radians + phi = PA + + # eigenvalues of covariance matrix + lambda1 = minor**2./(8.*np.log(2.)) + lambda2 = major**2./(8.*np.log(2.)) + + # now compute covariance matrix elements from user inputs + sigxx_prime = lambda1*(np.cos(phi)**2.) + lambda2*(np.sin(phi)**2.) + sigyy_prime = lambda1*(np.sin(phi)**2.) + lambda2*(np.cos(phi)**2.) + sigxy_prime = (lambda2 - lambda1)*np.cos(phi)*np.sin(phi) + + # we get the dimensions and image vector + im = imvec.reshape(xdim, ydim) + xlist, ylist = np.meshgrid(range(xdim), range(ydim)) + xlist = xlist - (xdim-1)/2.0 + ylist = ylist - (ydim-1)/2.0 + + xx = xlist * psize + yy = ylist * psize + + # the centroid parameters + x0 = np.sum(xx*im) / np.sum(im) + y0 = np.sum(yy*im) / np.sum(im) + + # we calculate the elements of the covariance matrix + sigxx = (np.sum((xx - x0)**2.*im)/np.sum(im)) + sigyy = (np.sum((yy - y0)**2.*im)/np.sum(im)) + sigxy = (np.sum((xx - x0)*(yy - y0)*im)/np.sum(im)) + + # We calculate the regularizer #this line was CHANGED + rgauss = -((sigxx - sigxx_prime)**2. + (sigyy - sigyy_prime)**2. + 2*(sigxy - sigxy_prime)**2.) + # normalization will need to be redone, right now requires alpha~1000 + rgauss = rgauss/(major**2. * minor**2.) + return rgauss + + +def sgauss_grad(imvec, xdim, ydim, psize, major, minor, PA): + """Gradient for Gaussian source size regularizer + """ + + # major, minor and PA are all in radians + phi = PA + + # computing eigenvalues of the covariance matrix + lambda1 = (minor**2.)/(8.*np.log(2.)) + lambda2 = (major**2.)/(8.*np.log(2.)) + + # now compute covariance matrix elements from user inputs + + sigxx_prime = lambda1*(np.cos(phi)**2.) + lambda2*(np.sin(phi)**2.) + sigyy_prime = lambda1*(np.sin(phi)**2.) + lambda2*(np.cos(phi)**2.) + sigxy_prime = (lambda2 - lambda1)*np.cos(phi)*np.sin(phi) + + # we get the dimensions and image vector + im = imvec.reshape(xdim, ydim) + xlist, ylist = np.meshgrid(range(xdim), range(ydim)) + xlist = xlist - (xdim-1)/2.0 + ylist = ylist - (ydim-1)/2.0 + + xx = xlist * psize + yy = ylist * psize + + # the centroid parameters + x0 = np.sum(xx*im) / np.sum(im) + y0 = np.sum(yy*im) / np.sum(im) + + # we calculate the elements of the covariance matrix of the image + sigxx = (np.sum((xx - x0)**2.*im)/np.sum(im)) + sigyy = (np.sum((yy - y0)**2.*im)/np.sum(im)) + sigxy = (np.sum((xx - x0)*(yy - y0)*im)/np.sum(im)) + + # now we compute the gradients of all quantities + # gradient of centroid + dx0 = (xx - x0) / np.sum(im) + dy0 = (yy - y0) / np.sum(im) + + # gradients of covariance matrix elements + dxx = (((xx - x0)**2. - 2.*(xx - x0)*dx0*im) - sigxx) / np.sum(im) + + dyy = (((yy - y0)**2. - 2.*(yy - y0)*dx0*im) - sigyy) / np.sum(im) + + dxy = (((xx - x0)*(yy - y0) - (yy - y0)*dx0*im - (xx - x0)*dy0*im) - sigxy) / np.sum(im) + + # gradient of the regularizer #this line was CHANGED + drgauss = (2.*(sigxx - sigxx_prime)*dxx + + 2.*(sigyy - sigyy_prime)*dyy + + 4.*(sigxy - sigxy_prime)*dxy) + + # normalization will need to be redone, right now requires alpha~1000 + drgauss = drgauss/(major**2. * minor**2.) + + return -drgauss.reshape(-1) + + +################################################################################################## +# Chi^2 Data functions +################################################################################################## +def apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol): + """apply systematic noise to VISIBILITIES or AMPLITUDES + data_arr should have fields 't1','t2','u','v','vis','amp','sigma' + + returns: (uv, vis, amp, sigma) + """ + + vtype = ehc.vis_poldict[pol] + atype = ehc.amp_poldict[pol] + etype = ehc.sig_poldict[pol] + + t1 = data_arr['t1'] + t2 = data_arr['t2'] + + sigma = data_arr[etype] + amp = data_arr[atype] + try: + vis = data_arr[vtype] + except ValueError: + vis = amp.astype('c16') + + snrmask = np.abs(amp/sigma) >= snrcut + + if type(systematic_noise) is dict: + sys_level = np.zeros(len(t1)) + for i in range(len(t1)): + if t1[i] in systematic_noise.keys(): + t1sys = systematic_noise[t1[i]] + else: + t1sys = 0. + if t2[i] in systematic_noise.keys(): + t2sys = systematic_noise[t2[i]] + else: + t2sys = 0. + + if t1sys < 0 or t2sys < 0: + sys_level[i] = -1 + else: + sys_level[i] = np.sqrt(t1sys**2 + t2sys**2) + else: + sys_level = np.sqrt(2)*systematic_noise*np.ones(len(t1)) + + mask = sys_level >= 0. + mask = snrmask * mask + + sigma = np.linalg.norm([sigma, sys_level*np.abs(amp)], axis=0)[mask] + vis = vis[mask] + amp = amp[mask] + uv = np.hstack((data_arr['u'].reshape(-1, 1), data_arr['v'].reshape(-1, 1)))[mask] + return (uv, vis, amp, sigma) + + +def chisqdata_vis(Obsdata, Prior, mask, pol='I', **kwargs): + """Return the data, sigmas, and fourier matrix for visibilities + """ + + # unpack keyword args + systematic_noise = kwargs.get('systematic_noise', 0.) + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + + # unpack data + vtype = ehc.vis_poldict[pol] + atype = ehc.amp_poldict[pol] + etype = ehc.sig_poldict[pol] + data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias) + (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # make fourier matrix + A = obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv, pulse=Prior.pulse, mask=mask) + + return (vis, sigma, A) + + +def chisqdata_amp(Obsdata, Prior, mask, pol='I', **kwargs): + """Return the data, sigmas, and fourier matrix for visibility amplitudes + """ + + # unpack keyword args + systematic_noise = kwargs.get('systematic_noise', 0.) + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + + # unpack data + vtype = ehc.vis_poldict[pol] + atype = ehc.amp_poldict[pol] + etype = ehc.sig_poldict[pol] + if (Obsdata.amp is None) or (len(Obsdata.amp) == 0) or pol != 'I': + data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias) + + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed amplitude table in amplitude chi^2!") + if not type(Obsdata.amp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed amplitude table is not a numpy rec array!") + data_arr = Obsdata.amp + + # apply systematic noise and SNR cut + # TODO -- after pre-computed?? + (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # make fourier matrix + A = obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv, pulse=Prior.pulse, mask=mask) + + return (amp, sigma, A) + + +def chisqdata_bs(Obsdata, Prior, mask, pol='I', **kwargs): + """return the data, sigmas, and fourier matrices for bispectra + """ + + # unpack keyword args + # systematic_noise = kwargs.get('systematic_noise',0.) + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + weighting = kwargs.get('weighting', 'natural') + + # unpack data + vtype = ehc.vis_poldict[pol] + if (Obsdata.bispec is None) or (len(Obsdata.bispec) == 0) or pol != 'I': + biarr = Obsdata.bispectra(mode="all", vtype=vtype, count=count, snrcut=snrcut) + + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed bispectrum table in cphase chi^2!") + if not type(Obsdata.bispec) in [np.ndarray, np.recarray]: + raise Exception("pre-computed bispectrum table is not a numpy rec array!") + biarr = Obsdata.bispec + # reduce to a minimal set + if count != 'max': + biarr = obsh.reduce_tri_minimal(Obsdata, biarr) + + uv1 = np.hstack((biarr['u1'].reshape(-1, 1), biarr['v1'].reshape(-1, 1))) + uv2 = np.hstack((biarr['u2'].reshape(-1, 1), biarr['v2'].reshape(-1, 1))) + uv3 = np.hstack((biarr['u3'].reshape(-1, 1), biarr['v3'].reshape(-1, 1))) + bi = biarr['bispec'] + sigma = biarr['sigmab'] + + # add systematic noise + # sigma = np.linalg.norm([biarr['sigmab'], systematic_noise*np.abs(biarr['bispec'])], axis=0) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # make fourier matrices + A3 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask) + ) + + return (bi, sigma, A3) + + +def chisqdata_cphase(Obsdata, Prior, mask, pol='I', **kwargs): + """Return the data, sigmas, and fourier matrices for closure phases + """ + + # unpack keyword args + maxset = kwargs.get('maxset', False) + uv_min = kwargs.get('cp_uv_min', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + systematic_cphase_noise = kwargs.get('systematic_cphase_noise', 0.) + weighting = kwargs.get('weighting', 'natural') + + # unpack data + vtype = ehc.vis_poldict[pol] + if (Obsdata.cphase is None) or (len(Obsdata.cphase) == 0) or pol != 'I': + clphasearr = Obsdata.c_phases(mode="all", vtype=vtype, + count=count, uv_min=uv_min, snrcut=snrcut) + else: # TODO precomputed with not Stokes I + print("Using pre-computed cphase table in cphase chi^2!") + if not type(Obsdata.cphase) in [np.ndarray, np.recarray]: + raise Exception("pre-computed closure phase table is not a numpy rec array!") + clphasearr = Obsdata.cphase + # reduce to a minimal set + if count != 'max': + clphasearr = obsh.reduce_tri_minimal(Obsdata, clphasearr) + + uv1 = np.hstack((clphasearr['u1'].reshape(-1, 1), clphasearr['v1'].reshape(-1, 1))) + uv2 = np.hstack((clphasearr['u2'].reshape(-1, 1), clphasearr['v2'].reshape(-1, 1))) + uv3 = np.hstack((clphasearr['u3'].reshape(-1, 1), clphasearr['v3'].reshape(-1, 1))) + clphase = clphasearr['cphase'] + sigma = clphasearr['sigmacp'] + + # add systematic cphase noise (in DEGREES) + sigma = np.linalg.norm([sigma, systematic_cphase_noise*np.ones(len(sigma))], axis=0) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # make fourier matrices + A3 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask) + ) + return (clphase, sigma, A3) + + +def chisqdata_cphase_diag(Obsdata, Prior, mask, pol='I', **kwargs): + """Return the data, sigmas, and fourier matrices for diagonalized closure phases + """ + + # unpack keyword args + maxset = kwargs.get('maxset', False) + uv_min = kwargs.get('cp_uv_min', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + + # unpack data + vtype = ehc.vis_poldict[pol] + clphasearr = Obsdata.c_phases_diag(vtype=vtype, count=count, snrcut=snrcut, uv_min=uv_min) + + # loop over timestamps + clphase_diag = [] + sigma_diag = [] + A3_diag = [] + tform_mats = [] + for ic, cl in enumerate(clphasearr): + + # get diagonalized closure phases and errors + clphase_diag.append(cl[0]['cphase']) + sigma_diag.append(cl[0]['sigmacp']) + + # get uv arrays + u1 = cl[2][:, 0].astype('float') + v1 = cl[3][:, 0].astype('float') + uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1))) + + u2 = cl[2][:, 1].astype('float') + v2 = cl[3][:, 1].astype('float') + uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1))) + + u3 = cl[2][:, 2].astype('float') + v3 = cl[3][:, 2].astype('float') + uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1))) + + # compute Fourier matrices + A3 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask) + ) + A3_diag.append(A3) + + # get transformation matrix for this timestamp + tform_mats.append(cl[4].astype('float')) + + # combine Fourier and transformation matrices into tuple for outputting + Amatrices = (np.array(A3_diag), np.array(tform_mats)) + + return (np.array(clphase_diag), np.array(sigma_diag), Amatrices) + + +def chisqdata_camp(Obsdata, Prior, mask, pol='I', **kwargs): + """Return the data, sigmas, and fourier matrices for closure amplitudes + """ + # unpack keyword args + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + + # unpack data & mask low snr points + vtype = ehc.vis_poldict[pol] + if (Obsdata.camp is None) or (len(Obsdata.camp) == 0) or pol != 'I': + clamparr = Obsdata.c_amplitudes(mode='all', count=count, + vtype=vtype, ctype='camp', debias=debias, snrcut=snrcut) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed closure amplitude table in closure amplitude chi^2!") + if not type(Obsdata.camp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed closure amplitude table is not a numpy rec array!") + clamparr = Obsdata.camp + # reduce to a minimal set + if count != 'max': + clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='camp') + + uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1))) + uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1))) + uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1))) + uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1))) + clamp = clamparr['camp'] + sigma = clamparr['sigmaca'] + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # make fourier matrices + A4 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv4, pulse=Prior.pulse, mask=mask) + ) + + return (clamp, sigma, A4) + + +def chisqdata_logcamp(Obsdata, Prior, mask, pol='I', **kwargs): + """Return the data, sigmas, and fourier matrices for log closure amplitudes + """ + # unpack keyword args + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + + # unpack data & mask low snr points + vtype = ehc.vis_poldict[pol] + if (Obsdata.logcamp is None) or (len(Obsdata.logcamp) == 0) or pol != 'I': + clamparr = Obsdata.c_amplitudes(mode='all', count=count, + vtype=vtype, ctype='logcamp', debias=debias, snrcut=snrcut) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed log closure amplitude table in log closure amplitude chi^2!") + if not type(Obsdata.logcamp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed log closure amplitude table is not a numpy rec array!") + clamparr = Obsdata.logcamp + # reduce to a minimal set + if count != 'max': + clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='logcamp') + + uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1))) + uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1))) + uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1))) + uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1))) + clamp = clamparr['camp'] + sigma = clamparr['sigmaca'] + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # make fourier matrices + A4 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv4, pulse=Prior.pulse, mask=mask) + ) + + return (clamp, sigma, A4) + + +def chisqdata_logcamp_diag(Obsdata, Prior, mask, pol='I', **kwargs): + """Return the data, sigmas, and fourier matrices for diagonalized log closure amplitudes + """ + # unpack keyword args + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + + # unpack data & mask low snr points + vtype = ehc.vis_poldict[pol] + clamparr = Obsdata.c_log_amplitudes_diag(vtype=vtype, count=count, debias=debias, snrcut=snrcut) + + # loop over timestamps + clamp_diag = [] + sigma_diag = [] + A4_diag = [] + tform_mats = [] + for ic, cl in enumerate(clamparr): + + # get diagonalized log closure amplitudes and errors + clamp_diag.append(cl[0]['camp']) + sigma_diag.append(cl[0]['sigmaca']) + + # get uv arrays + u1 = cl[2][:, 0].astype('float') + v1 = cl[3][:, 0].astype('float') + uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1))) + + u2 = cl[2][:, 1].astype('float') + v2 = cl[3][:, 1].astype('float') + uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1))) + + u3 = cl[2][:, 2].astype('float') + v3 = cl[3][:, 2].astype('float') + uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1))) + + u4 = cl[2][:, 3].astype('float') + v4 = cl[3][:, 3].astype('float') + uv4 = np.hstack((u4.reshape(-1, 1), v4.reshape(-1, 1))) + + # compute Fourier matrices + A4 = (obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask), + obsh.ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv4, pulse=Prior.pulse, mask=mask) + ) + A4_diag.append(A4) + + # get transformation matrix for this timestamp + tform_mats.append(cl[4].astype('float')) + + # combine Fourier and transformation matrices into tuple for outputting + Amatrices = (np.array(A4_diag), np.array(tform_mats)) + + return (np.array(clamp_diag), np.array(sigma_diag), Amatrices) + +################################################################################################## +# FFT Chi^2 Data functions +################################################################################################## + + +def chisqdata_vis_fft(Obsdata, Prior, pol='I', **kwargs): + """Return the data, sigmas, uv points, and FFT info for visibilities + """ + + # unpack keyword args + systematic_noise = kwargs.get('systematic_noise', 0.) + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + atype = ehc.amp_poldict[pol] + etype = ehc.sig_poldict[pol] + data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias) + (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # prepare image and fft info objects + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse) + gs_info = obsh.make_gridder_and_sampler_info( + im_info, uv, conv_func=conv_func, p_rad=p_rad, order=order) + + sampler_info_list = [gs_info[0]] + gridder_info_list = [gs_info[1]] + A = (im_info, sampler_info_list, gridder_info_list) + + return (vis, sigma, A) + + +def chisqdata_amp_fft(Obsdata, Prior, pol='I', **kwargs): + """Return the data, sigmas, uv points, and FFT info for visibility amplitudes + """ + + # unpack keyword args + systematic_noise = kwargs.get('systematic_noise', 0.) + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + atype = ehc.amp_poldict[pol] + etype = ehc.sig_poldict[pol] + if (Obsdata.amp is None) or (len(Obsdata.amp) == 0) or pol != 'I': + data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed amplitude table in amplitude chi^2!") + if not type(Obsdata.amp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed amplitude table is not a numpy rec array!") + data_arr = Obsdata.amp + + # apply systematic noise + (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # prepare image and fft info objects + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse) + gs_info = obsh.make_gridder_and_sampler_info(im_info, uv, + conv_func=conv_func, p_rad=p_rad, order=order) + + sampler_info_list = [gs_info[0]] + gridder_info_list = [gs_info[1]] + A = (im_info, sampler_info_list, gridder_info_list) + + return (amp, sigma, A) + + +def chisqdata_bs_fft(Obsdata, Prior, pol='I', **kwargs): + """Return the data, sigmas, uv points, and FFT info for bispectra + """ + + # unpack keyword args + # systematic_noise = kwargs.get('systematic_noise',0.) + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + weighting = kwargs.get('weighting', 'natural') + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + if (Obsdata.bispec is None) or (len(Obsdata.bispec) == 0) or pol != 'I': + biarr = Obsdata.bispectra(mode="all", vtype=vtype, count=count, snrcut=snrcut) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed bispectrum table in cphase chi^2!") + if not type(Obsdata.bispec) in [np.ndarray, np.recarray]: + raise Exception("pre-computed bispectrum table is not a numpy rec array!") + biarr = Obsdata.bispec + # reduce to a minimal set + if count != 'max': + biarr = obsh.reduce_tri_minimal(Obsdata, biarr) + + uv1 = np.hstack((biarr['u1'].reshape(-1, 1), biarr['v1'].reshape(-1, 1))) + uv2 = np.hstack((biarr['u2'].reshape(-1, 1), biarr['v2'].reshape(-1, 1))) + uv3 = np.hstack((biarr['u3'].reshape(-1, 1), biarr['v3'].reshape(-1, 1))) + bi = biarr['bispec'] + sigma = biarr['sigmab'] + + # add systematic noise + # sigma = np.linalg.norm([biarr['sigmab'], systematic_noise*np.abs(biarr['bispec'])], axis=0) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # prepare image and fft info objects + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse) + gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3, + conv_func=conv_func, p_rad=p_rad, order=order) + + sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0]] + gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1]] + A = (im_info, sampler_info_list, gridder_info_list) + + return (bi, sigma, A) + + +def chisqdata_cphase_fft(Obsdata, Prior, pol='I', **kwargs): + """Return the data, sigmas, uv points, and FFT info for closure phases + """ + # unpack keyword args + maxset = kwargs.get('maxset', False) + uv_min = kwargs.get('cp_uv_min', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + weighting = kwargs.get('weighting', 'natural') + systematic_cphase_noise = kwargs.get('systematic_cphase_noise', 0.) + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + if (Obsdata.cphase is None) or (len(Obsdata.cphase) == 0) or pol != 'I': + clphasearr = Obsdata.c_phases(mode="all", vtype=vtype, + count=count, uv_min=uv_min, snrcut=snrcut) + else: # TODO precomputed with not Stokes I + print("Using pre-computed cphase table in cphase chi^2!") + if not type(Obsdata.cphase) in [np.ndarray, np.recarray]: + raise Exception("pre-computed closure phase table is not a numpy rec array!") + clphasearr = Obsdata.cphase + # reduce to a minimal set + if count != 'max': + clphasearr = obsh.reduce_tri_minimal(Obsdata, clphasearr) + + uv1 = np.hstack((clphasearr['u1'].reshape(-1, 1), clphasearr['v1'].reshape(-1, 1))) + uv2 = np.hstack((clphasearr['u2'].reshape(-1, 1), clphasearr['v2'].reshape(-1, 1))) + uv3 = np.hstack((clphasearr['u3'].reshape(-1, 1), clphasearr['v3'].reshape(-1, 1))) + clphase = clphasearr['cphase'] + sigma = clphasearr['sigmacp'] + + # add systematic cphase noise (in DEGREES) + sigma = np.linalg.norm([sigma, systematic_cphase_noise*np.ones(len(sigma))], axis=0) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # prepare image and fft info objects + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse) + gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3, + conv_func=conv_func, p_rad=p_rad, order=order) + + sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0]] + gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1]] + A = (im_info, sampler_info_list, gridder_info_list) + + return (clphase, sigma, A) + + +def chisqdata_cphase_diag_fft(Obsdata, Prior, pol='I', **kwargs): + """Return the data, sigmas, uv points, and FFT info for diagonalized closure phases + """ + # unpack keyword args + maxset = kwargs.get('maxset', False) + uv_min = kwargs.get('cp_uv_min', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + clphasearr = Obsdata.c_phases_diag(vtype=vtype, count=count, snrcut=snrcut, uv_min=uv_min) + + # loop over timestamps + clphase_diag = [] + sigma_diag = [] + tform_mats = [] + u1 = [] + v1 = [] + u2 = [] + v2 = [] + u3 = [] + v3 = [] + for ic, cl in enumerate(clphasearr): + + # get diagonalized closure phases and errors + clphase_diag.append(cl[0]['cphase']) + sigma_diag.append(cl[0]['sigmacp']) + + # get u and v values + u1.append(cl[2][:, 0].astype('float')) + v1.append(cl[3][:, 0].astype('float')) + u2.append(cl[2][:, 1].astype('float')) + v2.append(cl[3][:, 1].astype('float')) + u3.append(cl[2][:, 2].astype('float')) + v3.append(cl[3][:, 2].astype('float')) + + # get transformation matrix for this timestamp + tform_mats.append(cl[4].astype('float')) + + # fix formatting of arrays + u1 = np.concatenate(u1) + v1 = np.concatenate(v1) + u2 = np.concatenate(u2) + v2 = np.concatenate(v2) + u3 = np.concatenate(u3) + v3 = np.concatenate(v3) + + # get uv arrays + uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1))) + uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1))) + uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1))) + + # prepare image and fft info objects + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse) + gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3, + conv_func=conv_func, p_rad=p_rad, order=order) + + sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0]] + gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1]] + A3 = (im_info, sampler_info_list, gridder_info_list) + + # combine Fourier and transformation matrices into tuple for outputting + Amatrices = (A3, np.array(tform_mats)) + + return (np.array(clphase_diag), np.array(sigma_diag), Amatrices) + + +def chisqdata_camp_fft(Obsdata, Prior, pol='I', **kwargs): + """Return the data, sigmas, uv points, and FFT info for closure amplitudes + """ + + # unpack keyword args + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + if (Obsdata.camp is None) or (len(Obsdata.camp) == 0) or pol != 'I': + clamparr = Obsdata.c_amplitudes(mode='all', count=count, + vtype=vtype, ctype='camp', debias=debias, snrcut=snrcut) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed closure amplitude table in closure amplitude chi^2!") + if not type(Obsdata.camp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed closure amplitude table is not a numpy rec array!") + clamparr = Obsdata.camp + # reduce to a minimal set + if count != 'max': + clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='camp') + + uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1))) + uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1))) + uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1))) + uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1))) + clamp = clamparr['camp'] + sigma = clamparr['sigmaca'] + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # prepare image and fft info objects + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse) + gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info4 = obsh.make_gridder_and_sampler_info(im_info, uv4, + conv_func=conv_func, p_rad=p_rad, order=order) + + sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0], gs_info4[0]] + gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1], gs_info4[1]] + A = (im_info, sampler_info_list, gridder_info_list) + + return (clamp, sigma, A) + + +def chisqdata_logcamp_fft(Obsdata, Prior, pol='I', **kwargs): + """Return the data, sigmas, uv points, and FFT info for log closure amplitudes + """ + # unpack keyword args + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + if (Obsdata.logcamp is None) or (len(Obsdata.logcamp) == 0) or pol != 'I': + clamparr = Obsdata.c_amplitudes(mode='all', count=count, + vtype=vtype, ctype='logcamp', debias=debias, snrcut=snrcut) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed log closure amplitude table in log closure amplitude chi^2!") + if not type(Obsdata.logcamp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed log closure amplitude table is not a numpy rec array!") + clamparr = Obsdata.logcamp + # reduce to a minimal set + if count != 'max': + clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='logcamp') + + uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1))) + uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1))) + uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1))) + uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1))) + clamp = clamparr['camp'] + sigma = clamparr['sigmaca'] + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # prepare image and fft info objects + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse) + gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info4 = obsh.make_gridder_and_sampler_info(im_info, uv4, + conv_func=conv_func, p_rad=p_rad, order=order) + + sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0], gs_info4[0]] + gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1], gs_info4[1]] + A = (im_info, sampler_info_list, gridder_info_list) + + return (clamp, sigma, A) + + +def chisqdata_logcamp_diag_fft(Obsdata, Prior, pol='I', **kwargs): + """Return the data, sigmas, uv points, and FFT info for diagonalized log closure amplitudes + """ + # unpack keyword args + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + conv_func = kwargs.get('conv_func', ehc.GRIDDER_CONV_FUNC_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + order = kwargs.get('order', ehc.FFT_INTERP_DEFAULT) + + # unpack data & mask low snr points + vtype = ehc.vis_poldict[pol] + clamparr = Obsdata.c_log_amplitudes_diag(vtype=vtype, count=count, + debias=debias, snrcut=snrcut) + + # loop over timestamps + clamp_diag = [] + sigma_diag = [] + tform_mats = [] + u1 = [] + v1 = [] + u2 = [] + v2 = [] + u3 = [] + v3 = [] + u4 = [] + v4 = [] + for ic, cl in enumerate(clamparr): + + # get diagonalized log closure amplitudes and errors + clamp_diag.append(cl[0]['camp']) + sigma_diag.append(cl[0]['sigmaca']) + + # get u and v values + u1.append(cl[2][:, 0].astype('float')) + v1.append(cl[3][:, 0].astype('float')) + u2.append(cl[2][:, 1].astype('float')) + v2.append(cl[3][:, 1].astype('float')) + u3.append(cl[2][:, 2].astype('float')) + v3.append(cl[3][:, 2].astype('float')) + u4.append(cl[2][:, 3].astype('float')) + v4.append(cl[3][:, 3].astype('float')) + + # get transformation matrix for this timestamp + tform_mats.append(cl[4].astype('float')) + + # fix formatting of arrays + u1 = np.concatenate(u1) + v1 = np.concatenate(v1) + u2 = np.concatenate(u2) + v2 = np.concatenate(v2) + u3 = np.concatenate(u3) + v3 = np.concatenate(v3) + u4 = np.concatenate(u4) + v4 = np.concatenate(v4) + + # get uv arrays + uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1))) + uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1))) + uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1))) + uv4 = np.hstack((u4.reshape(-1, 1), v4.reshape(-1, 1))) + + # prepare image and fft info objects + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + im_info = obsh.ImInfo(Prior.xdim, Prior.ydim, npad, Prior.psize, Prior.pulse) + gs_info1 = obsh.make_gridder_and_sampler_info(im_info, uv1, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info2 = obsh.make_gridder_and_sampler_info(im_info, uv2, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info3 = obsh.make_gridder_and_sampler_info(im_info, uv3, + conv_func=conv_func, p_rad=p_rad, order=order) + gs_info4 = obsh.make_gridder_and_sampler_info(im_info, uv4, + conv_func=conv_func, p_rad=p_rad, order=order) + + sampler_info_list = [gs_info1[0], gs_info2[0], gs_info3[0], gs_info4[0]] + gridder_info_list = [gs_info1[1], gs_info2[1], gs_info3[1], gs_info4[1]] + A = (im_info, sampler_info_list, gridder_info_list) + + # combine Fourier and transformation matrices into tuple for outputting + Amatrices = (A, np.array(tform_mats)) + + return (np.array(clamp_diag), np.array(sigma_diag), Amatrices) + +################################################################################################## +# NFFT Chi^2 Data functions +################################################################################################## + + +def chisqdata_vis_nfft(Obsdata, Prior, pol='I', **kwargs): + """Return the visibilities, sigmas, uv points, and nfft info + """ + if (Prior.xdim % 2 or Prior.ydim % 2): + raise Exception("NFFT doesn't work with odd image dimensions!") + + # unpack keyword args + systematic_noise = kwargs.get('systematic_noise', 0.) + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + atype = ehc.amp_poldict[pol] + etype = ehc.sig_poldict[pol] + data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias) + (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv) + A = [A1] + + return (vis, sigma, A) + + +def chisqdata_amp_nfft(Obsdata, Prior, pol='I', **kwargs): + """Return the amplitudes, sigmas, uv points, and nfft info + """ + if (Prior.xdim % 2 or Prior.ydim % 2): + raise Exception("NFFT doesn't work with odd image dimensions!") + + # unpack keyword args + systematic_noise = kwargs.get('systematic_noise', 0.) + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + atype = ehc.amp_poldict[pol] + etype = ehc.sig_poldict[pol] + if (Obsdata.amp is None) or (len(Obsdata.amp) == 0) or pol != 'I': + data_arr = Obsdata.unpack(['t1', 't2', 'u', 'v', vtype, atype, etype], debias=debias) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed amplitude table in amplitude chi^2!") + if not type(Obsdata.amp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed amplitude table is not a numpy rec array!") + data_arr = Obsdata.amp + + # apply systematic noise + (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv) + A = [A1] + + return (amp, sigma, A) + + +def chisqdata_bs_nfft(Obsdata, Prior, pol='I', **kwargs): + """Return the bispectra, sigmas, uv points, and nfft info + """ + if (Prior.xdim % 2 or Prior.ydim % 2): + raise Exception("NFFT doesn't work with odd image dimensions!") + + # unpack keyword args + # systematic_noise = kwargs.get('systematic_noise',0.) + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + weighting = kwargs.get('weighting', 'natural') + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + if (Obsdata.bispec is None) or (len(Obsdata.bispec) == 0) or pol != 'I': + biarr = Obsdata.bispectra(mode="all", vtype=vtype, count=count, snrcut=snrcut) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed bispectrum table in cphase chi^2!") + if not type(Obsdata.bispec) in [np.ndarray, np.recarray]: + raise Exception("pre-computed bispectrum table is not a numpy rec array!") + biarr = Obsdata.bispec + # reduce to a minimal set + if count != 'max': + biarr = obsh.reduce_tri_minimal(Obsdata, biarr) + + uv1 = np.hstack((biarr['u1'].reshape(-1, 1), biarr['v1'].reshape(-1, 1))) + uv2 = np.hstack((biarr['u2'].reshape(-1, 1), biarr['v2'].reshape(-1, 1))) + uv3 = np.hstack((biarr['u3'].reshape(-1, 1), biarr['v3'].reshape(-1, 1))) + bi = biarr['bispec'] + sigma = biarr['sigmab'] + + # add systematic noise + # sigma = np.linalg.norm([biarr['sigmab'], systematic_noise*np.abs(biarr['bispec'])], axis=0) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1) + A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2) + A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3) + A = [A1, A2, A3] + + return (bi, sigma, A) + + +def chisqdata_cphase_nfft(Obsdata, Prior, pol='I', **kwargs): + """Return the closure phases, sigmas, uv points, and nfft info + """ + if (Prior.xdim % 2 or Prior.ydim % 2): + raise Exception("NFFT doesn't work with odd image dimensions!") + + # unpack keyword args + maxset = kwargs.get('maxset', False) + uv_min = kwargs.get('cp_uv_min', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + weighting = kwargs.get('weighting', 'natural') + systematic_cphase_noise = kwargs.get('systematic_cphase_noise', 0.) + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + if (Obsdata.cphase is None) or (len(Obsdata.cphase) == 0) or pol != 'I': + clphasearr = Obsdata.c_phases(mode="all", vtype=vtype, + count=count, uv_min=uv_min, snrcut=snrcut) + else: # TODO precomputed with not Stokes I + print("Using pre-computed cphase table in cphase chi^2!") + if not type(Obsdata.cphase) in [np.ndarray, np.recarray]: + raise Exception("pre-computed closure phase table is not a numpy rec array!") + clphasearr = Obsdata.cphase + # reduce to a minimal set + if count != 'max': + clphasearr = obsh.reduce_tri_minimal(Obsdata, clphasearr) + + uv1 = np.hstack((clphasearr['u1'].reshape(-1, 1), clphasearr['v1'].reshape(-1, 1))) + uv2 = np.hstack((clphasearr['u2'].reshape(-1, 1), clphasearr['v2'].reshape(-1, 1))) + uv3 = np.hstack((clphasearr['u3'].reshape(-1, 1), clphasearr['v3'].reshape(-1, 1))) + clphase = clphasearr['cphase'] + sigma = clphasearr['sigmacp'] + + # add systematic cphase noise (in DEGREES) + sigma = np.linalg.norm([sigma, systematic_cphase_noise*np.ones(len(sigma))], axis=0) + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1) + A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2) + A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3) + A = [A1, A2, A3] + + return (clphase, sigma, A) + + +def chisqdata_cphase_diag_nfft(Obsdata, Prior, pol='I', **kwargs): + """Return the diagonalized closure phases, sigmas, uv points, and nfft info + """ + if (Prior.xdim % 2 or Prior.ydim % 2): + raise Exception("NFFT doesn't work with odd image dimensions!") + + # unpack keyword args + maxset = kwargs.get('maxset', False) + uv_min = kwargs.get('cp_uv_min', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + clphasearr = Obsdata.c_phases_diag(vtype=vtype, count=count, snrcut=snrcut, uv_min=uv_min) + + # loop over timestamps + clphase_diag = [] + sigma_diag = [] + tform_mats = [] + u1 = [] + v1 = [] + u2 = [] + v2 = [] + u3 = [] + v3 = [] + for ic, cl in enumerate(clphasearr): + + # get diagonalized closure phases and errors + clphase_diag.append(cl[0]['cphase']) + sigma_diag.append(cl[0]['sigmacp']) + + # get u and v values + u1.append(cl[2][:, 0].astype('float')) + v1.append(cl[3][:, 0].astype('float')) + u2.append(cl[2][:, 1].astype('float')) + v2.append(cl[3][:, 1].astype('float')) + u3.append(cl[2][:, 2].astype('float')) + v3.append(cl[3][:, 2].astype('float')) + + # get transformation matrix for this timestamp + tform_mats.append(cl[4].astype('float')) + + # fix formatting of arrays + u1 = np.concatenate(u1) + v1 = np.concatenate(v1) + u2 = np.concatenate(u2) + v2 = np.concatenate(v2) + u3 = np.concatenate(u3) + v3 = np.concatenate(v3) + + # get uv arrays + uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1))) + uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1))) + uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1))) + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1) + A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2) + A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3) + A = [A1, A2, A3] + + # combine Fourier and transformation matrices into tuple for outputting + Amatrices = (A, np.array(tform_mats)) + + return (np.array(clphase_diag), np.array(sigma_diag), Amatrices) + + +def chisqdata_camp_nfft(Obsdata, Prior, pol='I', **kwargs): + """Return the closure amplitudes, sigmas, uv points, and nfft info + """ + if (Prior.xdim % 2 or Prior.ydim % 2): + raise Exception("NFFT doesn't work with odd image dimensions!") + + # unpack keyword args + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + if (Obsdata.camp is None) or (len(Obsdata.camp) == 0) or pol != 'I': + clamparr = Obsdata.c_amplitudes(mode='all', count=count, + vtype=vtype, ctype='camp', debias=debias, snrcut=snrcut) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed closure amplitude table in closure amplitude chi^2!") + if not type(Obsdata.camp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed closure amplitude table is not a numpy rec array!") + clamparr = Obsdata.camp + # reduce to a minimal set + if count != 'max': + clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='camp') + + uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1))) + uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1))) + uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1))) + uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1))) + clamp = clamparr['camp'] + sigma = clamparr['sigmaca'] + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1) + A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2) + A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3) + A4 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv4) + A = [A1, A2, A3, A4] + + return (clamp, sigma, A) + + +def chisqdata_logcamp_nfft(Obsdata, Prior, pol='I', **kwargs): + """Return the log closure amplitudes, sigmas, uv points, and nfft info + """ + if (Prior.xdim % 2 or Prior.ydim % 2): + raise Exception("NFFT doesn't work with odd image dimensions!") + + # unpack keyword args + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + weighting = kwargs.get('weighting', 'natural') + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + + # unpack data + vtype = ehc.vis_poldict[pol] + if (Obsdata.logcamp is None) or (len(Obsdata.logcamp) == 0) or pol != 'I': + clamparr = Obsdata.c_amplitudes(mode='all', count=count, + vtype=vtype, ctype='logcamp', debias=debias, snrcut=snrcut) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed log closure amplitude table in log closure amplitude chi^2!") + if not type(Obsdata.logcamp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed log closure amplitude table is not a numpy rec array!") + clamparr = Obsdata.logcamp + # reduce to a minimal set + if count != 'max': + clamparr = obsh.reduce_quad_minimal(Obsdata, clamparr, ctype='logcamp') + + uv1 = np.hstack((clamparr['u1'].reshape(-1, 1), clamparr['v1'].reshape(-1, 1))) + uv2 = np.hstack((clamparr['u2'].reshape(-1, 1), clamparr['v2'].reshape(-1, 1))) + uv3 = np.hstack((clamparr['u3'].reshape(-1, 1), clamparr['v3'].reshape(-1, 1))) + uv4 = np.hstack((clamparr['u4'].reshape(-1, 1), clamparr['v4'].reshape(-1, 1))) + clamp = clamparr['camp'] + sigma = clamparr['sigmaca'] + + # data weighting + if weighting == 'uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1) + A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2) + A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3) + A4 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv4) + A = [A1, A2, A3, A4] + + return (clamp, sigma, A) + + +def chisqdata_logcamp_diag_nfft(Obsdata, Prior, pol='I', **kwargs): + """Return the diagonalized log closure amplitudes, sigmas, uv points, and nfft info + """ + if (Prior.xdim % 2 or Prior.ydim % 2): + raise Exception("NFFT doesn't work with odd image dimensions!") + + # unpack keyword args + maxset = kwargs.get('maxset', False) + if maxset: + count = 'max' + else: + count = 'min' + + snrcut = kwargs.get('snrcut', 0.) + debias = kwargs.get('debias', True) + fft_pad_factor = kwargs.get('fft_pad_factor', ehc.FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', ehc.GRIDDER_P_RAD_DEFAULT) + + # unpack data & mask low snr points + vtype = ehc.vis_poldict[pol] + clamparr = Obsdata.c_log_amplitudes_diag(vtype=vtype, count=count, debias=debias, snrcut=snrcut) + + # loop over timestamps + clamp_diag = [] + sigma_diag = [] + tform_mats = [] + u1 = [] + v1 = [] + u2 = [] + v2 = [] + u3 = [] + v3 = [] + u4 = [] + v4 = [] + for ic, cl in enumerate(clamparr): + + # get diagonalized log closure amplitudes and errors + clamp_diag.append(cl[0]['camp']) + sigma_diag.append(cl[0]['sigmaca']) + + # get u and v values + u1.append(cl[2][:, 0].astype('float')) + v1.append(cl[3][:, 0].astype('float')) + u2.append(cl[2][:, 1].astype('float')) + v2.append(cl[3][:, 1].astype('float')) + u3.append(cl[2][:, 2].astype('float')) + v3.append(cl[3][:, 2].astype('float')) + u4.append(cl[2][:, 3].astype('float')) + v4.append(cl[3][:, 3].astype('float')) + + # get transformation matrix for this timestamp + tform_mats.append(cl[4].astype('float')) + + # fix formatting of arrays + u1 = np.concatenate(u1) + v1 = np.concatenate(v1) + u2 = np.concatenate(u2) + v2 = np.concatenate(v2) + u3 = np.concatenate(u3) + v3 = np.concatenate(v3) + u4 = np.concatenate(u4) + v4 = np.concatenate(v4) + + # get uv arrays + uv1 = np.hstack((u1.reshape(-1, 1), v1.reshape(-1, 1))) + uv2 = np.hstack((u2.reshape(-1, 1), v2.reshape(-1, 1))) + uv3 = np.hstack((u3.reshape(-1, 1), v3.reshape(-1, 1))) + uv4 = np.hstack((u4.reshape(-1, 1), v4.reshape(-1, 1))) + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1) + A2 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2) + A3 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3) + A4 = obsh.NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv4) + A = [A1, A2, A3, A4] + + # combine Fourier and transformation matrices into tuple for outputting + Amatrices = (A, np.array(tform_mats)) + + return (np.array(clamp_diag), np.array(sigma_diag), Amatrices) + +################################################################################################## +# Restoring ,Embedding, and Plotting Functions +################################################################################################## + + +def plot_i(im, Prior, nit, chi2_dict, **kwargs): + """Plot the total intensity image at each iteration + """ + cmap = kwargs.get('cmap', 'afmhot') + interpolation = kwargs.get('interpolation', 'gaussian') + pol = kwargs.get('pol', '') + scale = kwargs.get('scale', None) + dynamic_range = kwargs.get('dynamic_range', 1.e5) + gamma = kwargs.get('dynamic_range', .5) + + plt.ion() + plt.pause(1.e-6) + plt.clf() + + imarr = im.reshape(Prior.ydim, Prior.xdim) + + if scale == 'log': + if (imarr < 0.0).any(): + print('clipping values less than 0') + imarr[imarr < 0.0] = 0.0 + imarr = np.log(imarr + np.max(imarr)/dynamic_range) + + if scale == 'gamma': + if (imarr < 0.0).any(): + print('clipping values less than 0') + imarr[imarr < 0.0] = 0.0 + imarr = (imarr + np.max(imarr)/dynamic_range)**(gamma) + + plt.imshow(imarr, cmap=plt.get_cmap(cmap), interpolation=interpolation) + xticks = obsh.ticks(Prior.xdim, Prior.psize/ehc.RADPERAS/1e-6) + yticks = obsh.ticks(Prior.ydim, Prior.psize/ehc.RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel(r'Relative RA ($\mu$as)') + plt.ylabel(r'Relative Dec ($\mu$as)') + plotstr = str(pol) + " : step: %i " % nit + for key in chi2_dict.keys(): + plotstr += r"$\chi^2_{%s}$: %0.2f " % (key, chi2_dict[key]) + plt.title(plotstr, fontsize=18) + + +def embed(im, mask, clipfloor=0., randomfloor=False): + """Embeds a 1d image array into the size of boolean embed mask + """ + + out = np.zeros(len(mask)) + + # Here's a much faster version than before + out[mask.nonzero()] = im + + #if clipfloor != 0.0: + if randomfloor: # prevent total variation gradient singularities + out[(mask-1).nonzero()] = clipfloor * \ + np.abs(np.random.normal(size=len((mask-1).nonzero()[0]))) + else: + out[(mask-1).nonzero()] = clipfloor + + return out diff --git a/imaging/linearize_energy.py b/imaging/linearize_energy.py new file mode 100644 index 00000000..1ebb1ef6 --- /dev/null +++ b/imaging/linearize_energy.py @@ -0,0 +1,111 @@ +from __future__ import division + +import numpy as np +import ehtim.image as image + +from ehtim.const_def import * +from ehtim.observing.obs_helpers import * + + +def linearized_bi(x0, A3, bispec, sigs, nPixels, alpha=100, reg="patch"): + + Alin, blin = computeLinTerms_bi(x0, A3, bispec, sigs, nPixels, alpha=alpha, reg=reg) + gradient = 2*np.dot(Alin.T, np.dot(Alin, x0) - blin) + + return -gradient + +def linearizedSol_bs(Obsdata, currImage, Prior, alpha=100, beta=100, reg="patch"): +# note what beta is + + # normalize the prior + # TODO: SHOULD THIS BE DONE?? + zbl = np.nanmax(np.abs(Obsdata.unpack(['vis'])['vis'])) + nprior = zbl * Prior.imvec / np.sum(Prior.imvec) + + if reg == "patch": + linRegTerm, constRegTerm = spatchlingrad(currImage.imvec, nprior) + + # Get bispectra data + biarr = Obsdata.bispectra(mode="all", count="max") + + bispec = biarr['bispec'] + sigs = biarr['sigmab'] + + nans = np.isnan(sigs) + bispec = bispec[nans==False] + sigs = sigs[nans==False] + biarr = biarr[nans==False] + + + uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1))) + uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1))) + uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1))) + + + # Compute the fourier matrices + A3 = (ftmatrix(currImage.psize, currImage.xdim, currImage.ydim, uv1, pulse=currImage.pulse), + ftmatrix(currImage.psize, currImage.xdim, currImage.ydim, uv2, pulse=currImage.pulse), + ftmatrix(currImage.psize, currImage.xdim, currImage.ydim, uv3, pulse=currImage.pulse) + ) + + Alin, blin = computeLinTerms_bi(currImage.imvec, A3, bispec, sigs, currImage.xdim*currImage.ydim, alpha=alpha, reg=reg) + + out = np.linalg.solve(Alin + beta*linRegTerm, blin + beta*constRegTerm); + + return image.Image(out.reshape((currImage.ydim, currImage.xdim)), currImage.psize, currImage.ra, currImage.dec, rf=currImage.rf, source=currImage.source, mjd=currImage.mjd, pulse=currImage.pulse) + + +def computeLinTerms_bi(x0, A3, bispec, sigs, nPixels, alpha=100, reg="patch"): + + + sigmaR = sigmaI = 1.0/(sigs**2) + + rA = np.real(A3); + iA = np.imag(A3); + + rA1 = rA[0]; + rA2 = rA[1]; + rA3 = rA[2]; + iA1 = iA[0]; + iA2 = iA[1]; + iA3 = iA[2]; + + rA1x0 = np.dot(rA1, x0); + rA2x0 = np.dot(rA2, x0); + rA3x0 = np.dot(rA3, x0); + iA1x0 = np.dot(iA1, x0); + iA2x0 = np.dot(iA2, x0); + iA3x0 = np.dot(iA3, x0); + + fR = (rA1x0*rA2x0*rA3x0 - rA1x0*iA2x0*iA3x0 - iA1x0*rA2x0*iA3x0 - iA1x0*iA2x0*rA3x0); + fI = (rA1x0*rA2x0*iA3x0 + rA1x0*iA2x0*rA3x0 + iA1x0*rA2x0*rA3x0 - iA1x0*iA2x0*iA3x0); + + yR = np.real(bispec); + yI = np.imag(bispec); + + #f = np.sum( sigmaR*(fR - yR)**2 + sigmaI*(fI - yI)**2 + 2.0*sigmaRI*(fR - yR)*(fI - yI)) / 2.0; + + # size number of bispectrum x nPixels + dxReal = (rA1 *np.tile(rA2x0*rA3x0, [nPixels, 1]).T + rA2*np.tile(rA1x0*rA3x0, [nPixels, 1]).T + rA3*np.tile(rA2x0*rA1x0, [nPixels, 1]).T \ + -( rA1*np.tile(iA2x0*iA3x0, [nPixels, 1]).T + iA2*np.tile(rA1x0*iA3x0, [nPixels, 1]).T + iA3*np.tile(iA2x0*rA1x0, [nPixels, 1]).T) \ + -( iA1*np.tile(rA2x0*iA3x0, [nPixels, 1]).T + rA2*np.tile(iA1x0*iA3x0, [nPixels, 1]).T + iA3*np.tile(rA2x0*iA1x0, [nPixels, 1]).T ) \ + -( iA1*np.tile(iA2x0*rA3x0, [nPixels, 1]).T + iA2*np.tile(iA1x0*rA3x0, [nPixels, 1]).T + rA3*np.tile(iA2x0*iA1x0, [nPixels, 1]).T ) ); + + # size number of bispectrum x nPixels + dxImag = (rA1*np.tile(rA2x0*iA3x0, [nPixels, 1]).T + rA2*np.tile(rA1x0*iA3x0, [nPixels, 1]).T + iA3*np.tile(rA2x0*rA1x0, [nPixels, 1]).T \ + +( rA1*np.tile(iA2x0*rA3x0, [nPixels, 1]).T + iA2*np.tile(rA1x0*rA3x0, [nPixels, 1]).T + rA3*np.tile(iA2x0*rA1x0, [nPixels, 1]).T) \ + +( iA1*np.tile(rA2x0*rA3x0, [nPixels, 1]).T + rA2*np.tile(iA1x0*rA3x0, [nPixels, 1]).T + rA3*np.tile(rA2x0*iA1x0, [nPixels, 1]).T ) \ + -( iA1*np.tile(iA2x0*iA3x0, [nPixels, 1]).T + iA2*np.tile(iA1x0*iA3x0, [nPixels, 1]).T + iA3*np.tile(iA2x0*iA1x0, [nPixels, 1]).T ) ); + + #size number of bixpectrum x 1 + betaR = fR - np.dot(dxReal,x0); + betaI = fI - np.dot(dxImag,x0); + + blin = np.dot( np.transpose( np.dot( np.diag( alpha*sigmaR ) , dxReal)) , (yR - betaR) ) + np.dot( np.transpose( np.dot( np.diag( alpha*sigmaI ), dxImag )), (yI - betaI) ) ; + Alin = np.dot( np.transpose( np.dot( np.diag( alpha*sigmaR ) , dxReal)) , dxReal ) + np.dot( np.transpose( np.dot( np.diag( alpha*sigmaI ), dxImag )) , dxImag ); + + return (Alin, blin) + +def spatchlingrad(imvec, priorvec): + return (np.diag(np.ones(len(priorvec))), priorvec) + diff --git a/imaging/multifreq_imager_utils.py b/imaging/multifreq_imager_utils.py new file mode 100644 index 00000000..f2ce28b7 --- /dev/null +++ b/imaging/multifreq_imager_utils.py @@ -0,0 +1,278 @@ +# multifreq_imager_utils.py +# imager functions for multifrequency VLBI data +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import +from builtins import range + +import string +import time +import numpy as np +import scipy.optimize as opt +import scipy.ndimage as nd +import scipy.ndimage.filters as filt +import matplotlib.pyplot as plt +try: + from pynfft.nfft import NFFT +except ImportError: + pass + #print("Warning: No NFFT installed! Cannot use nfft functions") + +import ehtim.image as image +from . import linearize_energy as le + +from ehtim.const_def import * +from ehtim.observing.obs_helpers import * +from ehtim.statistics.dataframes import * + +NORM_REGULARIZER = True +EPSILON = 1.e-12 + +################################################################################################## +# Mulitfrequency regularizers +################################################################################################## + +def regularizer_mf(imvec, nprior, mask, flux, xdim, ydim, psize, stype, **kwargs): + """return the regularizer value on spectral index or curvature + """ + + norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER) + beam_size = kwargs.get('beam_size', psize) + + if stype == "l2_alpha" or stype=="l2_beta": + s = -l2_alpha(imvec, nprior, norm_reg=norm_reg) + elif stype == "tv_alpha" or stype=="tv_beta": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=False) + s = -tv_alpha(imvec, xdim, ydim, psize, norm_reg=norm_reg, beam_size=beam_size) + else: + s = 0 + + return s + +def regularizergrad_mf(imvec, nprior, mask, flux, xdim, ydim, psize, stype, **kwargs): + """return the regularizer gradient + """ + + norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER) + beam_size = kwargs.get('beam_size', psize) + + if stype == "l2_alpha" or stype=="l2_beta": + s = -l2_alpha_grad(imvec, nprior) + elif stype == "tv_alpha" or stype=="tv_beta": + if np.any(np.invert(mask)): + imvec = embed(imvec, mask, randomfloor=False) + s = -tv_alpha_grad(imvec, xdim, ydim, psize, norm_reg=norm_reg, beam_size=beam_size) + s = s[mask] + else: + s = 0 + + return s + + +def l2_alpha(imvec, priorvec, norm_reg=NORM_REGULARIZER): + """L2 norm on spectral index w/r/t prior + """ + + if norm_reg: + norm = float(len(imvec)) + else: + norm = 1 + + out = -(np.sum((imvec - priorvec)**2)) + return out/norm + +def l2_alpha_grad(imvec, priorvec, norm_reg=NORM_REGULARIZER): + """L2 norm on spectral index w/r/t prior + """ + + if norm_reg: + norm = float(len(imvec)) + else: + norm = 1 + + out = -2*(np.sum(imvec - priorvec))*np.ones(len(imvec)) + return out/norm + + +def tv_alpha(imvec, nx, ny, psize, norm_reg=NORM_REGULARIZER, beam_size=None): + """Total variation regularizer + """ + if beam_size is None: beam_size = psize + if norm_reg: + norm = len(imvec)*psize / beam_size + else: + norm = 1 + + im = imvec.reshape(ny, nx) + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + out = -np.sum(np.sqrt(np.abs(im_l1 - im)**2 + np.abs(im_l2 - im)**2 + EPSILON)) + + return out/norm + +def tv_alpha_grad(imvec, nx, ny, psize, norm_reg=NORM_REGULARIZER, beam_size=None): + """Total variation gradient + """ + if beam_size is None: beam_size = psize + if norm_reg: + norm = len(imvec)*psize / beam_size + else: + norm = 1 + + im = imvec.reshape(ny,nx) + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1] + im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1] + + #rotate images + im_r1l2 = np.roll(np.roll(impad, 1, axis=0),-1, axis=1)[1:ny+1, 1:nx+1] + im_l1r2 = np.roll(np.roll(impad, -1, axis=0), 1, axis=1)[1:ny+1, 1:nx+1] + + #add together terms and return + g1 = (2*im - im_l1 - im_l2) / np.sqrt((im - im_l1)**2 + (im - im_l2)**2 + EPSILON) + g2 = (im - im_r1) / np.sqrt((im - im_r1)**2 + (im_r1l2 - im_r1)**2 + EPSILON) + g3 = (im - im_r2) / np.sqrt((im - im_r2)**2 + (im_l1r2 - im_r2)**2 + EPSILON) + + #mask the first row column gradient terms that don't exist + mask1 = np.zeros(im.shape) + mask2 = np.zeros(im.shape) + mask1[0,:] = 1 + mask2[:,0] = 1 + g2[mask1.astype(bool)] = 0 + g3[mask2.astype(bool)] = 0 + + # add terms together and return + out= -(g1 + g2 + g3).flatten() + return out/norm + +################################################################################################## + +################################################################################################## +def unpack_mftuple(imvec, inittuple, nimage, mf_solve = (1,1,0)): + """unpack imvec into tuple, + replaces quantities not iterated with their initial values + """ + init0 = inittuple[0] + init1 = inittuple[1] + init2 = inittuple[2] + + imct = 0 + if mf_solve[0] == 0: + im0 = init0 + else: + im0 = imvec[imct*nimage:(imct+1)*nimage] + imct += 1 + + if mf_solve[1] == 0: + im1 = init1 + else: + im1 = imvec[imct*nimage:(imct+1)*nimage] + imct += 1 + + if mf_solve[2] == 0: + im2 = init2 + else: + im2 = imvec[imct*nimage:(imct+1)*nimage] + imct += 1 + return np.array((im0, im1, im2)) + +def pack_mftuple(mftuple, mf_solve = (1,1,0)): + """pack multifreq data into image vector, + ignore quantities not iterated + """ + + vec = np.array([]) + if mf_solve[0] != 0: + vec = np.hstack((vec,mftuple[0])) + if mf_solve[1] != 0: + vec = np.hstack((vec,mftuple[1])) + if mf_solve[2] != 0: + vec = np.hstack((vec,mftuple[2])) + + return vec + +def embed(im, mask, clipfloor=0., randomfloor=False): + """Embeds a 1d image array into the size of boolean embed mask + """ + + out = np.zeros(len(mask)) + + # Here's a much faster version than before + out[mask.nonzero()] = im + + if clipfloor != 0.0: + if randomfloor: # prevent total variation gradient singularities + out[(mask-1).nonzero()] = clipfloor * \ + np.abs(np.random.normal(size=len((mask-1).nonzero()))) + else: + out[(mask-1).nonzero()] = clipfloor + + return out + +def embed_mf(imtuple, mask, clipfloor=0., randomfloor=False): + """Embeds a multifrequency image tuple into the size of boolean embed mask + """ + out0=np.zeros(len(mask)) + out1=np.zeros(len(mask)) + out2=np.zeros(len(mask)) + + # Here's a much faster version than before + out0[mask.nonzero()] = imtuple[0] + out1[mask.nonzero()] = imtuple[1] + out2[mask.nonzero()] = imtuple[2] + + if clipfloor != 0.0: + if randomfloor: # prevent total variation gradient singularities + out0[(mask-1).nonzero()] = clipfloor * np.abs(np.random.normal(size=len((mask-1).nonzero()))) + out1[(mask-1).nonzero()] = clipfloor * np.abs(np.random.normal(size=len((mask-1).nonzero()))) + out2[(mask-1).nonzero()] = clipfloor * np.abs(np.random.normal(size=len((mask-1).nonzero()))) + else: + out0[(mask-1).nonzero()] = clipfloor + out1[(mask-1).nonzero()] = clipfloor + out2[(mask-1).nonzero()] = clipfloor + + return (out0, out1, out2) + +def imvec_at_freq(mftuple, log_freqratio): + """Get the image at a frequency given by ref_freq*e(log_freqratio) + Remember spectral index is defined with a + sign! + """ + imvec_ref_log = np.log(mftuple[0]) + spectral_index = mftuple[1] + curvature = mftuple[2] + + logimvec = imvec_ref_log + spectral_index*log_freqratio + curvature*log_freqratio*log_freqratio + imvec = np.exp(logimvec) + return imvec + +def mf_all_grads_chain(funcgrad, imvec_cur, imvec_ref, log_freqratio): + """Get the gradients of the reference image, spectral index, and curvature + w/r/t the gradient of a function funcgrad to the image given frequency ref_freq*e(log_freqratio) + """ + + dfunc_dI0 = funcgrad * imvec_cur / imvec_ref + dfunc_dalpha = funcgrad * imvec_cur * log_freqratio + dfunc_dbeta = funcgrad * imvec_cur * log_freqratio * log_freqratio + + return np.array((dfunc_dI0, dfunc_dalpha, dfunc_dbeta)) diff --git a/imaging/patch_prior.py b/imaging/patch_prior.py new file mode 100644 index 00000000..d2dd6b1d --- /dev/null +++ b/imaging/patch_prior.py @@ -0,0 +1,150 @@ +# patch_prior.py +# +# Create a "prior" image for vlbi imaging by cleaning the input image. +# Image cleaning is done by breaking the image into patches and cleaning each one +# individually by assigning it a cluster in the input Gaussian mixture model and +# using a weiner filter to denoise each patch. +# These ideas are based on Expected Patch Log Likelihood patch prior work +# +# Code Author: Katie Bouman +# Date: June 1, 2016 + +from __future__ import division +from builtins import map +from builtins import range + +from matplotlib import pyplot as plt +import ehtim.image as image +import scipy.io +import numpy as np + +def patchPrior(im, beta, patchPriorFile='naturalPrior.mat', patchSize=8 ): + + # load data + ldata = scipy.io.loadmat(patchPriorFile) + + # reassign and reshape data + nmodels = ldata['nmodels'].ravel() + nmodels = nmodels[0] + mixweights = ldata['mixweights'].ravel() + covs = np.array(ldata['covs']) + means = np.array(ldata['means']) + + # reshape image + img = np.reshape(im.imvec, (im.ydim, im.xdim) ) + + I1, counts = cleanImage(img, beta, nmodels, covs, mixweights, means, patchSize) + + if not all(counts[0][0] == item for item in np.reshape(counts, (-1)) ): + raise TypeError("The counts are not the same for every pixel in the image") + + I1 = I1/counts[0][0] + out = image.Image(I1, im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse) + + return (out, counts[0][0]) + +def cleanImage(img, beta, nmodels, covs, mixweights, means, patchSize=8): + + # pad images with 0's + validRegion = np.lib.pad( np.ones(img.shape), (patchSize-1, patchSize-1), 'constant', constant_values=(0, 0) ) + cleanIPad = np.lib.pad( img , (patchSize-1, patchSize-1), 'constant', constant_values=(0, 0) ) + + # adjust the dynamic range of image to be in the range 0 to 1 + minCleanI = min( np.reshape(cleanIPad, (-1)) ) + cleanIPad = cleanIPad - minCleanI + maxCleanI = max( np.reshape(cleanIPad, (-1)) ) + cleanIPad = cleanIPad / maxCleanI; + + # extract all overlapping patches from the image + Z = im2col(np.transpose(cleanIPad), patchSize) + + # clean each patch by weiner filtering + meanZ = np.mean(Z,0) + Z = Z - np.tile( meanZ, [patchSize**2, 1] ); + cleanZ = cleanPatches( Z,patchSize,(beta)**(-0.5), nmodels, covs, mixweights, means); + cleanZ = cleanZ + np.tile( meanZ, [patchSize**2, 1] ); + + # join all patches together + mm = validRegion.shape[0] + nn = validRegion.shape[1] + t = np.reshape(list(range(0,mm*nn,1)), (mm, nn) ) + temp = im2col(t, patchSize) + I1 = np.transpose( np.bincount( np.array(list(map(int, np.reshape(temp, (-1)) ))), weights=np.reshape(cleanZ, (-1))) ) + counts = np.transpose( np.bincount( np.array(list(map(int, np.reshape(temp, (-1)) ))), weights=np.reshape(np.ones(cleanZ.shape), (-1))) ) + + # normalize and put back in the original scale + I1 = I1/counts; + I1 = (I1*maxCleanI) + minCleanI; + I1 = I1*counts; + + # set all negative entries to 0 (hacky) + I1[I1<0] = 0; + + # crop out the center valid region + I1 = np.extract(np.reshape(validRegion, (-1)), I1) + counts = np.extract(np.reshape(validRegion, (-1)), counts) + + # reshape + I1 = np.transpose(np.reshape( I1, (img.shape[1], img.shape[0]))); + counts = np.transpose(np.reshape( counts, (img.shape[1], img.shape[0]))); + + return I1, counts + +def im2col(im, patchSize): + + # extract all overlapping patches from the image + M,N = im.shape + col_extent = N - patchSize + 1 + row_extent = M - patchSize + 1 + # Get Starting block indices + start_idx = np.arange(patchSize)[:,None]*N + np.arange(patchSize) + # Get offsetted indices across the height and width of input array + offset_idx = np.arange(row_extent)[:,None]*N + np.arange(col_extent) + # Get all actual indices & index into input array for final output + Z = np.take (im,start_idx.ravel()[:,None] + offset_idx.ravel()) + return Z + + +def cleanPatches(Y, patchSize, noiseSD, nmodels, covs, mixweights, means): + + SigmaNoise = noiseSD**2 * np.eye(patchSize**2); + + #remove DC component + meanY = np.mean(Y,0); + Y = Y - np.tile(meanY, [Y.shape[0], 1] ); + + #calculate assignment probabilities for each mixture component for all patches + PYZ = np.zeros((nmodels,Y.shape[1])); + for i in range (0,nmodels): + PYZ[i,:] = np.log(mixweights[i]) + loggausspdf2(Y, covs[:,:,i] + SigmaNoise); + + + #find the most likely component for each patch + ks = PYZ.argmax(axis = 0) + + # and now perform weiner filtering + Xhat = np.zeros(Y.shape); + for i in range (0,nmodels): + inds = np.array(np.where(ks==i)).ravel() + Xhat[:,inds] = np.dot( covs[:,:,i], np.dot( np.linalg.inv( covs[:,:,i]+SigmaNoise ), Y[:,inds] ) ) + np.dot( SigmaNoise, np.dot( np.linalg.inv(covs[:,:,i]+SigmaNoise), np.transpose(np.tile( np.transpose(means[:,i]), [inds.shape[0], 1] )) )); + + + Xhat = Xhat + np.tile(meanY, [Xhat.shape[0], 1] ) + return Xhat + + +def loggausspdf2(X, sigma): +#log pdf of Gaussian with zero mena +#Based on code written by Mo Chen (mochen@ie.cuhk.edu.hk). March 2009. + d = X.shape[0] + + R = np.linalg.cholesky(sigma).T; + # todo check that sigma is psd + + q = np.sum( ( np.dot( np.linalg.inv(np.transpose(R)) , X ) )**2 , 0); # quadratic term (M distance) + c = d*np.log(2*np.pi)+2*np.sum(np.log( np.diagonal(R) ), 0); # normalization constant + y = -(c+q)/2.0; + + return y + + diff --git a/imaging/pol_imager_utils.py b/imaging/pol_imager_utils.py new file mode 100644 index 00000000..e34f4d6e --- /dev/null +++ b/imaging/pol_imager_utils.py @@ -0,0 +1,2098 @@ +# pol_imager_utils.py +# General imager functions for polarimetric VLBI data +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +#TODO +# FIX NFFTS for m and for p -- offsets? imag? + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import +from builtins import range + +import string +import time +import copy +import numpy as np +import scipy.optimize as opt +import scipy.ndimage as ndF +import scipy.ndimage.filters as filt +import matplotlib.pyplot as plt +try: + from pynfft.nfft import NFFT +except ImportError: + pass + #print("Warning: No NFFT installed! Cannot use nfft functions") +from scipy.special import jv + +import ehtim.image as image +from . import linearize_energy as le + +from ehtim.const_def import * +from ehtim.observing.obs_helpers import * + +TANWIDTH_M = 0.5 +TANWIDTH_V = 1 + +################################################################################################## +# Constants & Definitions +################################################################################################## + +NORM_REGULARIZER = False #ANDREW TODO change this default in the future + +MAXLS = 100 # maximum number of line searches in L-BFGS-B +NHIST = 100 # number of steps to store for hessian approx +MAXIT = 100 # maximum number of iterations +STOP = 1.e-100 # convergence criterion + +DATATERMS_POL = ['pvis', 'm', 'pbs','vvis'] +REGULARIZERS_POL = ['msimple', 'hw', 'ptv','l1v','l2v','vtv','v2tv2','vflux'] + +nit = 0 # global variable to track the iteration number in the plotting callback + +def qimage(iimage, mimage, chiimage): + """Return the Q image from m and chi""" + return iimage * mimage * np.cos(2*chiimage) + +def uimage(iimage, mimage, chiimage): + """Return the U image from m and chi""" + return iimage * mimage * np.sin(2*chiimage) + + +################################################################################################## +# Polarimetric Imager +################################################################################################## +def pol_imager_func(Obsdata, InitIm, Prior, + pol_trans=True, pol_solve = (0,1,1), + d1='pvis', d2=False, + s1='msimple', s2=False, + alpha_d1=100, alpha_d2=100, + alpha_s1=1, alpha_s2=1, + **kwargs): + + """Run a polarimetric imager. + + Args: + Obsdata (Obsdata): The Obsdata object with polarimetric VLBI data + Prior (Image): The Image object with the prior image + InitIm (Image): The Image object with the initial image for the minimization + + pol_trans (bool): polarization scaled to physical values or not + pol_solve (tuple): len 3 tuple, solve for the corresponding pol image when not zero + + d1 (str): The first data term; options are 'pvis', 'm', 'pbs' + d2 (str): The second data term; options are 'pvis', 'm', 'pbs' + s1 (str): The first regularizer; options are 'msimple', 'hw', 'ptv' + s2 (str): The second regularizer; options are 'msimple', 'hw', 'ptv' + alpha_d1 (float): The first data term weighting + alpha_d2 (float): The second data term weighting + alpha_s1 (float): The first regularizer term weighting + alpha_s2 (float): The second regularizer term weighting + + ttype (str): The Fourier transform type; options are 'fast' and 'direct' + fft_pad_factor (float): The FFT will pre-pad the image by this factor x the original size + p_rad (float): The pixel radius for the convolving function in gridding for FFTs + + clipfloor (float): The Stokes I Jy/pixel level above which prior image pixels are varied + grads (bool): If True, analytic gradients are used + norm_reg (bool): If True, normalizes regularizer terms + beam_size (float): beam size in radians for normalizing the regularizers + flux (float): The total flux of the output image in Jy + + maxit (int): Maximum number of minimizer iterations + stop (float): The convergence criterion + show_updates (bool): If True, displays the progress of the minimizer + + Returns: + Image: Image object with result + """ + + # some kwarg default values + maxit = kwargs.get('maxit', MAXIT) + stop = kwargs.get('stop', STOP) + clipfloor = kwargs.get('clipfloor', -1) + ttype = kwargs.get('ttype','direct') + grads = kwargs.get('grads',True) + norm_init = kwargs.get('norm_init',False) + show_updates = kwargs.get('show_updates',True) + + beam_size = kwargs.get('beam_size',Obsdata.res()) + flux = kwargs.get('flux',InitIm.total_flux()) + kwargs['beam_size'] = beam_size + + # Make sure data and regularizer options are ok + if not d1 and not d2: + raise Exception("Must have at least one data term!") + if not s1 and not s2: + raise Exception("Must have at least one regularizer term!") + if (not ((d1 in DATATERMS_POL) or d1==False)) or (not ((d2 in DATATERMS_POL) or d2==False)): + raise Exception("Invalid data term: valid data terms are: " + ' '.join(DATATERMS_POL)) + if (not ((s1 in REGULARIZERS_POL) or s1==False)) or (not ((s2 in REGULARIZERS_POL) or s2==False)): + raise Exception("Invalid regularizer: valid regularizers are: " + ' '.join(REGULARIZERS_POL)) + if (Prior.psize != InitIm.psize) or (Prior.xdim != InitIm.xdim) or (Prior.ydim != InitIm.ydim): + raise Exception("Initial image does not match dimensions of the prior image!") + if (ttype not in ["direct","nfft"]): + raise Exception("FFT no yet implemented in polarimetric imaging -- use NFFT!") + if (not pol_trans): + raise Exception("Only pol_trans==True supported!") + if (len(pol_solve)!=3): + raise Exception("pol_solve tuple must have 3 entries!") + + # Catch scale and dimension problems + imsize = np.max([Prior.xdim, Prior.ydim]) * Prior.psize + uvmax = 1.0/Prior.psize + uvmin = 1.0/imsize + uvdists = Obsdata.unpack('uvdist')['uvdist'] + maxbl = np.max(uvdists) + minbl = np.max(uvdists[uvdists > 0]) + maxamp = np.max(np.abs(Obsdata.unpack('amp')['amp'])) + + if uvmax < maxbl: + print("Warning! Pixel Spacing is larger than smallest spatial wavelength!") + if uvmin > minbl: + print("Warning! Field of View is smaller than largest nonzero spatial wavelength!") + + # convert polrep to stokes + # ANDREW todo -- make more general?? + Prior = Prior.switch_polrep(polrep_out='stokes', pol_prim_out='I') + InitIm = InitIm.switch_polrep(polrep_out='stokes', pol_prim_out='I') + + # embedding mask + embed_mask = Prior.imvec > clipfloor + + # initial Stokes I image + iimage = InitIm.imvec[embed_mask] + nimage = len(iimage) + + # initial pol image + if pol_trans: + if len(InitIm.qvec) and (np.any(InitIm.qvec!=0) or np.any(InitIm.uvec!=0)): #TODO right? or should it be if any=0 + init1 = (np.abs(InitIm.qvec + 1j*InitIm.uvec) / InitIm.imvec)[embed_mask] + init2 = (np.arctan2(InitIm.uvec, InitIm.qvec) / 2.0)[embed_mask] + else: + # !AC TODO get the actual zero baseline pol. frac from the data!?? + print("No polarimetric image in the initial image!") + init1 = 0.2 * (np.ones(len(iimage)) + 1e-2 * np.random.rand(len(iimage))) + init2 = np.zeros(len(iimage)) + 1e-2 * np.random.rand(len(iimage)) + + # Change of variables + inittuple = np.array((iimage, init1, init2)) + xtuple = mcv_r(inittuple) + + # Get data and fourier matrices for the data terms + (data1, sigma1, A1) = polchisqdata(Obsdata, Prior, embed_mask, d1, **kwargs) + (data2, sigma2, A2) = polchisqdata(Obsdata, Prior, embed_mask, d2, **kwargs) + + # Define the chi^2 and chi^2 gradient + def chisq1(imtuple): + return polchisq(imtuple, A1, data1, sigma1, d1, ttype=ttype, + mask=embed_mask, pol_trans=pol_trans) + + def chisq1grad(imtuple): + return polchisqgrad(imtuple, A1, data1, sigma1, d1, ttype=ttype, mask=embed_mask, + pol_trans=pol_trans, pol_solve=pol_solve) + + def chisq2(imtuple): + return polchisq(imtuple, A2, data2, sigma2, d2, ttype=ttype, mask=embed_mask, + pol_trans=pol_trans) + + def chisq2grad(imtuple): + return polchisqgrad(imtuple, A2, data2, sigma2, d2, ttype=ttype, mask=embed_mask, + pol_trans=pol_trans,pol_solve=pol_solve) + + # Define the regularizer and regularizer gradient + def reg1(imtuple): + return polregularizer(imtuple, embed_mask, flux, flux, flux, + Prior.xdim, Prior.ydim, Prior.psize, s1, **kwargs) + + def reg1grad(imtuple): + return polregularizergrad(imtuple, embed_mask, flux, flux, flux, + Prior.xdim, Prior.ydim, Prior.psize, s1, **kwargs) + + def reg2(imtuple): + return polregularizer(imtuple, embed_mask, flux, flux, flux, + Prior.xdim, Prior.ydim, Prior.psize, s2, **kwargs) + + def reg2grad(imtuple): + return polregularizergrad(imtuple, embed_mask, flux, flux, flux, + Prior.xdim, Prior.ydim, Prior.psize, s2, **kwargs) + + + # Define the objective function and gradient + def objfunc(allvec): + # unpack allvec into image tuple + cvtuple = unpack_poltuple(allvec, xtuple, nimage, pol_solve) + + # change of variables + if pol_trans: + imtuple = mcv(cvtuple) + else: + raise Exception() + + datterm = alpha_d1 * (chisq1(imtuple) - 1) + alpha_d2 * (chisq2(imtuple) - 1) + regterm = alpha_s1 * reg1(imtuple) + alpha_s2 * reg2(imtuple) + + return datterm + regterm + + def objgrad(allvec): + # unpack allvec into image tuple + cvtuple = unpack_poltuple(allvec, xtuple, nimage, pol_solve) + + # change of variables + if pol_trans: + imtuple = mcv(cvtuple) + else: + raise Exception() + + datterm = alpha_d1 * chisq1grad(imtuple) + alpha_d2 * chisq2grad(imtuple) + regterm = alpha_s1 * reg1grad(imtuple) + alpha_s2 * reg2grad(imtuple) + gradarr = datterm + regterm + + # chain rule + if pol_trans: + chainarr = mchain(cvtuple) + gradarr = gradarr*chainarr + + # repack grad into single vector + grad = pack_poltuple(gradarr, pol_solve) + + return grad + + # Define plotting function for each iteration + global nit + nit = 0 + def plotcur(im_step): + global nit + cvtuple = unpack_poltuple(im_step, xtuple, nimage, pol_solve) + if pol_trans: + imtuple = mcv(cvtuple) #change of variables + else: + raise Exception() + + if show_updates: + #print( np.max(np.abs((imtuple[2]-xx_static)/xx_static))) + chi2_1 = chisq1(imtuple) + chi2_2 = chisq2(imtuple) + s_1 = reg1(imtuple) + s_2 = reg2(imtuple) + if np.any(np.invert(embed_mask)): + imtuple = embed_pol(imtuple, embed_mask) + plot_m(imtuple, Prior, nit, {d1:chi2_1, d2:chi2_2}) + print("i: %d chi2_1: %0.2f chi2_2: %0.2f s_1: %0.2f s_2: %0.2f" % (nit, chi2_1, chi2_2,s_1,s_2)) + nit += 1 + + # Print stats + print("Initial S_1: %f S_2: %f" % (reg1(inittuple), reg2(inittuple))) + print("Initial Chi^2_1: %f Chi^2_2: %f" % (chisq1(inittuple), chisq2(inittuple))) + if d1 in DATATERMS_POL: + print("Total Data 1: ", (len(data1))) + if d2 in DATATERMS_POL: + print("Total Data 2: ", (len(data2))) + print("Total Pixel #: ", (len(Prior.imvec))) + print("Clipped Pixel #: ", nimage) + print() + + # Plot Initial + xinit = pack_poltuple(xtuple.copy(), pol_solve) + plotcur(xinit) + + # Minimize + optdict = {'maxiter':maxit, 'ftol':stop, 'maxcor':NHIST,'maxls':MAXLS,'gtol':stop,'maxfun':1.e100} # minimizer dict params + tstart = time.time() + if grads: + res = opt.minimize(objfunc, xinit, method='L-BFGS-B', jac=objgrad,callback=plotcur, + options=optdict) + else: + res = opt.minimize(objfunc, xinit, method='L-BFGS-B', + options=optdict, callback=plotcur) + + + + tstop = time.time() + + # Format output + outcv = unpack_poltuple(res.x, xtuple, nimage, pol_solve) + if pol_trans: + outcut = mcv(outcv) #change of variables + else: + outcut = outcv + + if np.any(np.invert(embed_mask)): + out = embed_pol(out, embed_mask) #embed + else: + out = outcut + + iimage = out[0] + qimage = make_q_image(out, pol_trans) + uimage = make_u_image(out, pol_trans) + + outim = image.Image(iimage.reshape(Prior.ydim, Prior.xdim), Prior.psize, + Prior.ra, Prior.dec, rf=Prior.rf, source=Prior.source, + mjd=Prior.mjd, pulse=Prior.pulse) + outim.add_qu(qimage.reshape(Prior.ydim, Prior.xdim), uimage.reshape(Prior.ydim, Prior.xdim)) + + # Print stats + print("time: %f s" % (tstop - tstart)) + print("J: %f" % res.fun) + print("Final Chi^2_1: %f Chi^2_2: %f" % (chisq1(outcut), chisq2(outcut))) + print(res.message) + + # Return Image object + return outim + +################################################################################################## +# Linear Polarimetric image representations and Change of Variables +################################################################################################## +def pack_poltuple(poltuple, pol_solve = (0,1,1)): + """pack polvec into image vector, + ignore quantities not iterated + """ + + vec = np.array([]) + if pol_solve[0] != 0: + vec = np.hstack((vec,poltuple[0])) + if pol_solve[1] != 0: + vec = np.hstack((vec,poltuple[1])) + if pol_solve[2] != 0: + vec = np.hstack((vec,poltuple[2])) + if len(pol_solve)==4 and pol_solve[3] != 0: + vec = np.hstack((vec,poltuple[3])) + + return vec + + +def unpack_poltuple(polvec, inittuple, nimage, pol_solve = (0,1,1)): + """unpack polvec into image tuple, + replaces quantities not iterated with initial values + """ + init0 = inittuple[0] + init1 = inittuple[1] + init2 = inittuple[2] + if len(pol_solve)==4: + init3 = inittuple[3] + + imct = 0 + if pol_solve[0] == 0: + im0 = init0 + else: + im0 = polvec[imct*nimage:(imct+1)*nimage] + imct += 1 + + if pol_solve[1] == 0: + im1 = init1 + else: + im1 = polvec[imct*nimage:(imct+1)*nimage] + imct += 1 + + if pol_solve[2] == 0: + im2 = init2 + else: + im2 = polvec[imct*nimage:(imct+1)*nimage] + imct += 1 + + if len(pol_solve)==4: + if pol_solve[3] == 0: + im3 = init3 + else: + im3 = polvec[imct*nimage:(imct+1)*nimage] + imct += 1 + out = np.array((im0, im1, im2, im3)) + else: + out = np.array((im0, im1, im2)) + + return out + +def make_p_image(imtuple, pol_trans=True): + """construct a polarimetric image P = Q + iU + """ + + if pol_trans: + pimage = imtuple[0] * imtuple[1] * np.exp(2j*imtuple[2]) + else: + pimage = imtuple[1] + 1j*imtuple[2] + + return pimage + +def make_m_image(imtuple, pol_trans=True): + """construct a polarimetric ratrio image abs(P/I) = abs(Q + iU)/I + """ + + if pol_trans: + mimage = imtuple[1] + else: + mimage = np.abs((imtuple[1] + 1j*imtuple[2])/imtuple[0]) + return mimage + +def make_chi_image(imtuple, pol_trans=True): + """construct a polarimetric angle image + """ + + if pol_trans: + chiimage = imtuple[2] + else: + chiimage = 0.5*np.angle((imtuple[1] + 1j*imtuple[2])/imtuple[0]) + + return chiimage + +def make_q_image(imtuple, pol_trans=True): + """construct an image of stokes Q + """ + + if pol_trans: + qimage = imtuple[0] * imtuple[1] * np.cos(2*imtuple[2]) + else: + qimage = imtuple[1] + + return qimage + +def make_u_image(imtuple, pol_trans=True): + """construct an image of stokes U + """ + + if pol_trans: + uimage = imtuple[0] * imtuple[1] * np.sin(2*imtuple[2]) + else: + uimage = imtuple[2] + + return uimage + +def make_v_image(imtuple, pol_trans=True): + """construct an image of stokes V + """ + + if len(imtuple)==4: + if pol_trans: + vimage = imtuple[0] * imtuple[3] + else: + vimage = imtuple[3] + else: + vimage = np.zeros(imtuple[0].shape) + + return vimage + +def make_vfrac_image(imtuple, pol_trans=True): + """construct an image of stokes V + """ + if len(imtuple)==4: + if pol_trans: + vfimage = imtuple[3] + else: + vfimage = imtuple[3]/imtuple[0] + else: + vfimage = np.zeros(imtuple[0].shape) + + return vfimage + + +# these change of variables only apply to polarimetric ratios +# !AC In these pol. changes of variables, might be useful to +# take m -> m/100 by adjusting B (function becomes less steep around m' = 0) + +def mcv(imtuple): + """Change of pol. ratio from range (-inf, inf) to (0,1) + """ + + iimage = imtuple[0] + mimage = imtuple[1] + chiimage = imtuple[2] + + mtrans = 0.5 + np.arctan(mimage/TANWIDTH_M)/np.pi + if len(imtuple)==4: + vfimage = imtuple[3] + vtrans = 2*np.arctan(vfimage/TANWIDTH_V)/np.pi + out = np.array((iimage, mtrans, chiimage, vtrans)) + else: + out = np.array((iimage, mtrans, chiimage)) + + + return out + +def mcv_r(imtuple): + """Change of pol. ratio from range (0,1) to (-inf,inf) + """ + iimage = imtuple[0] + mimage = imtuple[1] + chiimage = imtuple[2] + + mtrans = TANWIDTH_M*np.tan(np.pi*(mimage - 0.5)) + if len(imtuple)==4: + vfimage = imtuple[3] + vtrans = TANWIDTH_V*np.tan(0.5*np.pi*(vfimage)) + out = np.array((iimage, mtrans, chiimage, vtrans)) + else: + out = np.array((iimage, mtrans, chiimage)) + + return out + +def mchain(imtuple): + """The gradient change of variables, dm/dm' + """ + iimage = imtuple[0] + mimage = imtuple[1] + chiimage = imtuple[2] + + ichain = np.ones(len(iimage)) + mmchain = 1 / (TANWIDTH_M*np.pi*(1 + (mimage/TANWIDTH_M)**2)) + chichain = np.ones(len(chiimage)) + + if len(imtuple)==4: + vfimage = imtuple[3] + vchain = 2. / (TANWIDTH_V*np.pi*(1 + (vfimage/TANWIDTH_V)**2)) + out = np.array((ichain, mmchain, chichain, vchain)) + else: + out = np.array((ichain, mmchain, chichain)) + + return np.array(out) + + +################################################################################################## +# Wrapper Functions +################################################################################################## + +def polchisq(imtuple, A, data, sigma, dtype, ttype='direct', mask=[], pol_trans=True): + """return the chi^2 for the appropriate dtype + """ + + chisq = 1 + if not dtype in DATATERMS_POL: + return chisq + if ttype not in ['fast','direct','nfft']: + raise Exception("Possible ttype values are 'fast' and 'direct'!") + + if ttype == 'direct': + # linear + if dtype == 'pvis': + chisq = chisq_p(imtuple, A, data, sigma, pol_trans) + elif dtype == 'm': + chisq = chisq_m(imtuple, A, data, sigma, pol_trans) + elif dtype == 'pbs': + chisq = chisq_pbs(imtuple, A, data, sigma, pol_trans) + + # circular + elif dtype == 'vvis': + chisq = chisq_vvis(imtuple, A, data, sigma, pol_trans) + + elif ttype== 'fast': + raise Exception("FFT not yet implemented in polchisq!") + + elif ttype== 'nfft': + if len(mask)>0 and np.any(np.invert(mask)): + imtuple = embed_pol(imtuple, mask, randomfloor=True) + + # linear + if dtype == 'pvis': + chisq = chisq_p_nfft(imtuple, A, data, sigma, pol_trans) + elif dtype == 'm': + chisq = chisq_m_nfft(imtuple, A, data, sigma, pol_trans) + elif dtype == 'pbs': + chisq = chisq_pbs_nfft(imtuple, A, data, sigma, pol_trans) + + # circular + elif dtype == 'vvis': + chisq = chisq_vvis_nfft(imtuple, A, data, sigma, pol_trans) + + return chisq + +def polchisqgrad(imtuple, A, data, sigma, dtype, ttype='direct', + mask=[], pol_trans=True,pol_solve=(0,1,1)): + + """return the chi^2 gradient for the appropriate dtype + """ + + chisqgrad = np.zeros((3,len(imtuple[0]))) + if not dtype in DATATERMS_POL: + return chisqgrad + if ttype not in ['fast','direct','nfft']: + raise Exception("Possible ttype values are 'fast' and 'direct'!") + + if ttype == 'direct': + # linear + if dtype == 'pvis': + chisqgrad = chisqgrad_p(imtuple, A, data, sigma, pol_trans,pol_solve) + elif dtype == 'm': + chisqgrad = chisqgrad_m(imtuple, A, data, sigma, pol_trans,pol_solve) + elif dtype == 'pbs': + chisqgrad = chisqgrad_pbs(imtuple, A, data, sigma, pol_trans,pol_solve) + + # circular + elif dtype == 'vvis': + chisqgrad = chisqgrad_vvis(imtuple, A, data, sigma, pol_trans,pol_solve) + + elif ttype== 'fast': + raise Exception("FFT not yet implemented in polchisqgrad!") + + elif ttype== 'nfft': + if len(mask)>0 and np.any(np.invert(mask)): + imtuple = embed_pol(imtuple, mask, randomfloor=True) + + # linear + if dtype == 'pvis': + chisqgrad = chisqgrad_p_nfft(imtuple, A, data, sigma, pol_trans,pol_solve) + elif dtype == 'm': + chisqgrad = chisqgrad_m_nfft(imtuple, A, data, sigma, pol_trans,pol_solve) + elif dtype == 'pbs': + chisqgrad = chisqgrad_pbs_nfft(imtuple, A, data, sigma, pol_trans,pol_solve) + + # circular + elif dtype == 'vvis': + chisqgrad = chisqgrad_vvis_nfft(imtuple, A, data, sigma, pol_trans,pol_solve) + + if len(mask)>0 and np.any(np.invert(mask)): + if len(chisqgrad)==4: + chisqgrad = np.array((chisqgrad[0][mask],chisqgrad[1][mask],chisqgrad[2][mask],chisqgrad[3][mask])) + else: + chisqgrad = np.array((chisqgrad[0][mask],chisqgrad[1][mask],chisqgrad[2][mask])) + + return chisqgrad + + +def polregularizer(imtuple, mask, flux, pflux, vflux, xdim, ydim, psize, stype, **kwargs): + + norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER) + pol_trans = kwargs.get('pol_trans', True) + beam_size = kwargs.get('beam_size',1) + + # linear + if stype == "msimple": + reg = -sm(imtuple, flux, pol_trans, norm_reg=norm_reg) + elif stype == "hw": + reg = -shw(imtuple, flux, pol_trans, norm_reg=norm_reg) + elif stype == "ptv": + if np.any(np.invert(mask)): + imtuple = embed_pol(imtuple, mask, randomfloor=True) + reg = -stv_pol(imtuple, flux, xdim, ydim, psize, pol_trans, + norm_reg=norm_reg, beam_size=beam_size) + # circular + elif stype == 'vflux': + reg = -svflux(imtuple, vflux, norm_reg=norm_reg) + elif stype == "l1v": + reg = -sl1v(imtuple, vflux, pol_trans, norm_reg=norm_reg) + elif stype == "l2v": + reg = -sl2v(imtuple, vflux, pol_trans, norm_reg=norm_reg) + elif stype == "vtv": + if np.any(np.invert(mask)): + imtuple = embed_pol(imtuple, mask, randomfloor=True) + reg = -stv_v(imtuple, vflux, xdim, ydim, psize, pol_trans, + norm_reg=norm_reg, beam_size=beam_size) + elif stype == "vtv2": + if np.any(np.invert(mask)): + imtuple = embed_pol(imtuple, mask, randomfloor=True) + reg = -stv2_v(imtuple, vflux, xdim, ydim, psize, pol_trans, + norm_reg=norm_reg, beam_size=beam_size) + else: + reg = 0 + + return reg + +def polregularizergrad(imtuple, mask, flux, pflux, vflux, xdim, ydim, psize, stype, **kwargs): + + norm_reg = kwargs.get('norm_reg', NORM_REGULARIZER) + pol_trans = kwargs.get('pol_trans', True) + pol_solve = kwargs.get('pol_solve', (0,1,1)) + beam_size = kwargs.get('beam_size',1) + + # linear + if stype == "msimple": + reggrad = -smgrad(imtuple, flux, pol_trans, pol_solve, norm_reg=norm_reg) + elif stype == "hw": + reggrad = -shwgrad(imtuple, flux, pol_trans, pol_solve, norm_reg=norm_reg) + elif stype == "ptv": + if np.any(np.invert(mask)): + imtuple = embed_pol(imtuple, mask, randomfloor=True) + reggrad = -stv_pol_grad(imtuple, flux, xdim, ydim, psize, pol_trans, pol_solve, + norm_reg=norm_reg, beam_size=beam_size) + if np.any(np.invert(mask)): + if len(reggrad)==4: + reggrad = np.array((reggrad[0][mask],reggrad[1][mask],reggrad[2][mask],reggrad[3][mask])) + else: + reggrad = np.array((reggrad[0][mask],reggrad[1][mask],reggrad[2][mask])) + + # circular + elif stype == 'vflux': + reggrad = -svfluxgrad(imtuple, vflux, pol_trans, pol_solve=pol_solve, norm_reg=norm_reg) + elif stype == "l1v": + reggrad = -sl1vgrad(imtuple, vflux, pol_trans, pol_solve=pol_solve, norm_reg=norm_reg) + elif stype == "l2v": + reggrad = -sl2vgrad(imtuple, vflux, pol_trans, pol_solve=pol_solve, norm_reg=norm_reg) + elif stype == "vtv": + if np.any(np.invert(mask)): + imtuple = embed_pol(imtuple, mask, randomfloor=True) + reggrad = -stv_v_grad(imtuple, vflux, xdim, ydim, psize, pol_trans, + pol_solve=pol_solve, norm_reg=norm_reg, beam_size=beam_size) + if np.any(np.invert(mask)): + reggrad = np.array((reggrad[0][mask],reggrad[1][mask],reggrad[2][mask],reggrad[3][mask])) + elif stype == "vtv2": + if np.any(np.invert(mask)): + imtuple = embed_pol(imtuple, mask, randomfloor=True) + reggrad = -stv2_v_grad(imtuple, vflux, xdim, ydim, psize, pol_trans, + pol_solve=pol_solve, norm_reg=norm_reg, beam_size=beam_size) + if np.any(np.invert(mask)): + reggrad = np.array((reggrad[0][mask],reggrad[1][mask],reggrad[2][mask],reggrad[3][mask])) + + else: + reggrad = np.zeros((len(imtuple),len(imtuple[0]))) + + return reggrad + + +def polchisqdata(Obsdata, Prior, mask, dtype, **kwargs): + + """Return the data, sigma, and matrices for the appropriate dtype + """ + + ttype=kwargs.get('ttype','direct') + + (data, sigma, A) = (False, False, False) + if ttype not in ['fast','direct','nfft']: + raise Exception("Possible ttype values are 'fast' and 'direct' and 'nfft'!") + if ttype=='direct': + if dtype == 'pvis': + (data, sigma, A) = chisqdata_pvis(Obsdata, Prior, mask) + elif dtype == 'm': + (data, sigma, A) = chisqdata_m(Obsdata, Prior, mask) + elif dtype == 'pbs': + (data, sigma, A) = chisqdata_pbs(Obsdata, Prior, mask) + elif dtype == 'vvis': + (data, sigma, A) = chisqdata_vvis(Obsdata, Prior, mask) + elif ttype=='fast': + raise Exception("FFT not yet implemented in polchisqdata!") + + elif ttype=='nfft': + if dtype == 'pvis': + (data, sigma, A) = chisqdata_pvis_nfft(Obsdata, Prior, mask, **kwargs) + elif dtype == 'm': + (data, sigma, A) = chisqdata_m_nfft(Obsdata, Prior, mask, **kwargs) + elif dtype == 'pbs': + (data, sigma, A) = chisqdata_pbs_nfft(Obsdata, Prior, mask, **kwargs) + elif dtype == 'vvis': + (data, sigma, A) = chisqdata_vvis_nfft(Obsdata, Prior, mask, **kwargs) + + return (data, sigma, A) + + +################################################################################################## +# DFT Chi-squared and Gradient Functions +################################################################################################## + +def chisq_p(imtuple, Amatrix, p, sigmap, pol_trans=True): + """Polarimetric visibility chi-squared + """ + + pimage = make_p_image(imtuple, pol_trans) + psamples = np.dot(Amatrix, pimage) + chisq = np.sum(np.abs((p - psamples))**2/(sigmap**2)) / (2*len(p)) + return chisq + +def chisqgrad_p(imtuple, Amatrix, p, sigmap, pol_trans=True,pol_solve=(0,1,1)): + """Polarimetric visibility chi-squared gradient + """ + + + iimage = imtuple[0] + pimage = make_p_image(imtuple, pol_trans) + psamples = np.dot(Amatrix, pimage) + pdiff = (p - psamples) / (sigmap**2) + zeros = np.zeros(len(iimage)) + + if pol_trans: + + mimage = imtuple[1] + chiimage = imtuple[2] + + if pol_solve[0]!=0: + gradi = -np.real(mimage * np.exp(-2j*chiimage) * np.dot(Amatrix.conj().T, pdiff)) / len(p) + else: + gradi = zeros + + if pol_solve[1]!=0: + gradm = -np.real(iimage * np.exp(-2j*chiimage) * np.dot(Amatrix.conj().T, pdiff)) / len(p) + else: + gradm = zeros + + if pol_solve[2]!=0: + gradchi = -2 * np.imag(pimage.conj() * np.dot(Amatrix.conj().T, pdiff)) / len(p) + else: + gradchi = zeros + + # output tuple can be length 3 or 4 depending on V + if len(imtuple)==4: + gradout = np.array((gradi, gradm, gradchi, zeros)) + else: + gradout = np.array((gradi, gradm, gradchi)) + + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return gradout + +def chisq_m(imtuple, Amatrix, m, sigmam, pol_trans=True): + """Polarimetric ratio chi-squared + """ + + iimage = imtuple[0] + pimage = make_p_image(imtuple, pol_trans) + msamples = np.dot(Amatrix, pimage) / np.dot(Amatrix, iimage) + return np.sum(np.abs((m - msamples))**2/(sigmam**2)) / (2*len(m)) + +def chisqgrad_m(imtuple, Amatrix, m, sigmam, pol_trans=True,pol_solve=(0,1,1)): + """The gradient of the polarimetric ratio chisq + """ + iimage = imtuple[0] + pimage = make_p_image(imtuple, pol_trans) + + isamples = np.dot(Amatrix, iimage) + psamples = np.dot(Amatrix, pimage) + zeros = np.zeros(len(iimage)) + + if pol_trans: + + mimage = imtuple[1] + chiimage = imtuple[2] + + msamples = psamples/isamples + mdiff = (m - msamples) / (isamples.conj() * sigmam**2) + + if pol_solve[0]!=0: + gradi = (-np.real(mimage * np.exp(-2j*chiimage) * np.dot(Amatrix.conj().T, mdiff)) / len(m) + + np.real(np.dot(Amatrix.conj().T, msamples.conj() * mdiff)) / len(m)) + else: + gradi = zeros + + if pol_solve[1]!=0: + gradm = -np.real(iimage*np.exp(-2j*chiimage) * np.dot(Amatrix.conj().T, mdiff)) / len(m) + else: + gradm = zeros + + if pol_solve[2]!=0: + gradchi = -2 * np.imag(pimage.conj() * np.dot(Amatrix.conj().T, mdiff)) / len(m) + else: + gradchi = zeros + + # output tuple can be length 3 or 4 depending on V + if len(imtuple)==4: + gradout = np.array((gradi, gradm, gradchi, zeros)) + else: + gradout = np.array((gradi, gradm, gradchi)) + + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return gradout + +def chisq_pbs(imtuple, Amatrices, bis_p, sigma, pol_trans=True): + """Polarimetric bispectrum chi-squared + """ + + iimage = imtuple[0] + pimage = make_p_image(imtuple, pol_trans) + bisamples_p = np.dot(Amatrices[0], pimage) * np.dot(Amatrices[1], pimage) * np.dot(Amatrices[2], pimage) + chisq = np.sum(np.abs((bis_p - bisamples_p)/sigma)**2) / (2.*len(bis_p)) + return chisq + +def chisqgrad_pbs(imtuple, Amatrices, bis_p, sigma, pol_trans=True,pol_solve=(0,1,1)): + """Polarimetric bispectrum chi-squared gradient + """ + pimage = make_p_image(imtuple, pol_trans) + bisamples_p = np.dot(Amatrices[0], pimage) * np.dot(Amatrices[1], pimage) * np.dot(Amatrices[2], pimage) + + wdiff = ((bis_p - bisamples_p).conj()) / (sigma**2) + pt1 = wdiff * np.dot(Amatrices[1],pimage) * np.dot(Amatrices[2],pimage) + pt2 = wdiff * np.dot(Amatrices[0],pimage) * np.dot(Amatrices[2],pimage) + pt3 = wdiff * np.dot(Amatrices[0],pimage) * np.dot(Amatrices[1],pimage) + ptsum = np.dot(pt1, Amatrices[0]) + np.dot(pt2, Amatrices[1]) + np.dot(pt3, Amatrices[2]) + + if pol_trans: + iimage = imtuple[0] + mimage = imtuple[1] + chiimage = imtuple[2] + + if pol_solve[0]!=0: + gradi = -np.real(ptsum * mimage * np.exp(2j*chiimage)) / len(bis_p) + else: + gradi = zeros + if pol_solve[1]!=0: + gradm = -np.real(ptsum * iimage * np.exp(2j*chiimage)) / len(bis_p) + else: + gradm = zeros + if pol_solve[2]!=0: + gradchi = 2 * np.imag(ptsum * pimage) / len(bis_p) + else: + gradchi = zeros + + # output tuple can be length 3 or 4 depending on V + if len(imtuple)==4: + gradout = np.array((gradi, gradm, gradchi, zeros)) + else: + gradout = np.array((gradi, gradm, gradchi)) + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return gradout + +# stokes v +def chisq_vvis(imtuple, Amatrix, v, sigmav, pol_trans=True): + """V visibility chi-squared + """ + + vimage = make_v_image(imtuple, pol_trans) + vsamples = np.dot(Amatrix, vimage) + chisq = np.sum(np.abs((v - vsamples))**2/(sigmav**2)) / (2*len(v)) + return chisq + +def chisqgrad_vvis(imtuple, Amatrix, v, sigmap, pol_trans=True,pol_solve=(0,1,1)): + """V visibility chi-squared gradient + """ + + + iimage = imtuple[0] + vimage = make_v_image(imtuple, pol_trans) + vsamples = np.dot(Amatrix, vimage) + vdiff = (v - vsamples) / (sigmav**2) + zeros = np.zeros(len(iimage)) + + if pol_trans: + vfimage = imtuple[3] # fractional + if pol_solve[0]!=0: + gradi = -np.real(vfimage * np.dot(Amatrix.conj().T, vdiff)) / len(v) + else: + gradi = zeros + + if pol_solve[3]!=0: + gradv = -np.real(iimage * np.dot(Amatrix.conj().T, vdiff)) / len(v) + else: + gradm = zeros + + + gradout = np.array((gradi, zeros, zeros, gradv)) + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return gradout + +################################################################################################## +# NFFT Chi-squared and Gradient Functions +################################################################################################## +def chisq_p_nfft(imtuple, A, p, sigmap, pol_trans=True): + """P visibility chi-squared + """ + + #get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + #compute uniform --> nonuniform transform + pimage = make_p_image(imtuple, pol_trans) + plan.f_hat = pimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T + plan.trafo() + psamples = plan.f.copy()*pulsefac + + #compute chi^2 + chisq = np.sum(np.abs((p - psamples))**2/(sigmap**2)) / (2*len(p)) + + return chisq + +def chisqgrad_p_nfft(imtuple, A, p, sigmap, pol_trans=True,pol_solve=(0,1,1)): + """P visibility chi-squared gradient + """ + + #get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + #compute uniform --> nonuniform transform + iimage = imtuple[0] + pimage = make_p_image(imtuple, pol_trans) + plan.f_hat = pimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T + plan.trafo() + psamples = plan.f.copy()*pulsefac + + + pdiff_vec = (-1.0/len(p)) * (p - psamples) / (sigmap**2) * pulsefac.conj() + plan.f = pdiff_vec + plan.adjoint() + ppart = (plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim) + + zeros = np.zeros(len(iimage)) + + if pol_trans: + + mimage = imtuple[1] + chiimage = imtuple[2] + + if pol_solve[0]!=0: + + gradi = np.real(mimage * np.exp(-2j*chiimage) * ppart) + else: + gradi = zeros + + if pol_solve[1]!=0: + gradm = np.real(iimage*np.exp(-2j*chiimage) *ppart) + else: + gradm = zeros + + if pol_solve[2]!=0: + gradchi = 2 * np.imag(pimage.conj() * ppart) + else: + gradchi = zeros + + if len(imtuple)==4: + gradout = np.array((gradi, gradm, gradchi, zeros)) + else: + gradout = np.array((gradi, gradm, gradchi)) + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return gradout + + +def chisq_m_nfft(imtuple, A, m, sigmam, pol_trans=True): + """Polarimetric ratio chi-squared + """ + #get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + #compute uniform --> nonuniform transform + iimage = imtuple[0] + plan.f_hat = iimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T + plan.trafo() + isamples = plan.f.copy()*pulsefac + + pimage = make_p_image(imtuple, pol_trans) + plan.f_hat = pimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T + plan.trafo() + psamples = plan.f.copy()*pulsefac + + #compute chi^2 + msamples = psamples/isamples + chisq = np.sum(np.abs((m - msamples))**2/(sigmam**2)) / (2*len(m)) + return chisq + +def chisqgrad_m_nfft(imtuple, A, m, sigmam, pol_trans=True,pol_solve=(0,1,1)): + """Polarimetric ratio chi-squared gradient + """ + iimage = imtuple[0] + pimage = make_p_image(imtuple, pol_trans) + + + #get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + #compute uniform --> nonuniform transform + iimage = imtuple[0] + plan.f_hat = iimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T + plan.trafo() + isamples = plan.f.copy()*pulsefac + + pimage = make_p_image(imtuple, pol_trans) + plan.f_hat = pimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T + plan.trafo() + psamples = plan.f.copy()*pulsefac + + zeros = np.zeros(len(iimage)) + + if pol_trans: + + mimage = imtuple[1] + chiimage = imtuple[2] + + msamples = psamples/isamples + mdiff_vec = (-1./len(m))*(m - msamples) / (isamples.conj() * sigmam**2) * pulsefac.conj() + plan.f = mdiff_vec + plan.adjoint() + mpart = (plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim) + + if pol_solve[0]!=0: #TODO -- not right?? + plan.f = mdiff_vec * msamples.conj() + plan.adjoint() + mpart2 = (plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim) + + gradi = (np.real(mimage * np.exp(-2j*chiimage) * mpart) - np.real(mpart2)) + + else: + gradi = zeros + + if pol_solve[1]!=0: + gradm = np.real(iimage*np.exp(-2j*chiimage) * mpart) + else: + gradm = zeros + + if pol_solve[2]!=0: + gradchi = 2 * np.imag(pimage.conj() * mpart) + else: + gradchi = zeros + + # output tuple can be length 3 or 4 depending on V + if len(imtuple)==4: + gradout = np.array((gradi, gradm, gradchi, zeros)) + else: + gradout = np.array((gradi, gradm, gradchi)) + + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return gradout + + +def chisq_pbs_nfft(imtuple, A, bis_p, sigma, pol_trans=True): + """Polarimetric bispectrum chi-squared + """ + + #get nfft objects + nfft_info1 = A[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + #compute uniform --> nonuniform transforms + pimage = make_p_image(imtuple, pol_trans) + + plan1.f_hat = pimage.copy().reshape((nfft_info1.ydim,nfft_info1.xdim)).T + plan1.trafo() + samples1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = pimage.copy().reshape((nfft_info2.ydim,nfft_info2.xdim)).T + plan2.trafo() + samples2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = pimage.copy().reshape((nfft_info3.ydim,nfft_info3.xdim)).T + plan3.trafo() + samples3 = plan3.f.copy()*pulsefac3 + + #compute chi^2 + bisamples_p = samples1*samples2*samples3 + chisq = np.sum(np.abs((bis_p - bisamples_p)/sigma)**2) / (2.*len(bis_p)) + + return chisq + +def chisqgrad_pbs(imtuple, Amatrices, bis_p, sigma, pol_trans=True,pol_solve=(0,1,1)): + """Polarimetric bispectrum chi-squared gradient + """ + + #get nfft objects + nfft_info1 = A[0] + plan1 = nfft_info1.plan + pulsefac1 = nfft_info1.pulsefac + + nfft_info2 = A[1] + plan2 = nfft_info2.plan + pulsefac2 = nfft_info2.pulsefac + + nfft_info3 = A[2] + plan3 = nfft_info3.plan + pulsefac3 = nfft_info3.pulsefac + + #compute uniform --> nonuniform transforms + pimage = make_p_image(imtuple, pol_trans) + + plan1.f_hat = pimage.copy().reshape((nfft_info1.ydim,nfft_info1.xdim)).T + plan1.trafo() + v1 = plan1.f.copy()*pulsefac1 + + plan2.f_hat = pimage.copy().reshape((nfft_info2.ydim,nfft_info2.xdim)).T + plan2.trafo() + v2 = plan2.f.copy()*pulsefac2 + + plan3.f_hat = pimage.copy().reshape((nfft_info3.ydim,nfft_info3.xdim)).T + plan3.trafo() + v3 = plan3.f.copy()*pulsefac3 + + # gradient vec for adjoint fft + bisamples_p = v1*v2*v3 + wdiff = (-1./len(bis_p)) * ((bis_p - bisamples_p).conj()) / (sigma**2) + pt1 = wdiff * (v2 * v3).conj() * pulsefac1.conj() + pt2 = wdiff * (v1 * v3).conj() * pulsefac2.conj() + pt3 = wdiff * (v1 * v2).conj() * pulsefac3.conj() + + plan1.f = pt1 + plan1.adjoint() + out1 = (plan1.f_hat.copy().T).reshape(nfft_info1.xdim*nfft_info1.ydim) + + plan2.f = pt2 + plan2.adjoint() + out2 = (plan2.f_hat.copy().T).reshape(nfft_info2.xdim*nfft_info2.ydim) + + plan3.f = pt3 + plan3.adjoint() + out3 = (plan3.f_hat.copy().T).reshape(nfft_info3.xdim*nfft_info3.ydim) + + ptsum = out1 + out2 + out3 + + if pol_trans: + iimage = imtuple[0] + mimage = imtuple[1] + chiimage = imtuple[2] + + if pol_solve[0]!=0: + gradi = np.real(ptsum * mimage * np.exp(2j*chiimage)) + else: + gradi = zeros + if pol_solve[1]!=0: + gradm = np.real(ptsum * iimage * np.exp(2j*chiimage)) + else: + gradm = zeros + if pol_solve[2]!=0: + gradchi = -2 * np.imag(ptsum * pimage) + else: + gradchi = zeros + + # output tuple can be length 3 or 4 depending on V + if len(imtuple)==4: + gradout = np.array((gradi, gradm, gradchi, zeros)) + else: + gradout = np.array((gradi, gradm, gradchi)) + + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return gradout + +# stokes v +def chisq_vvis_nfft(imtuple, A, v, sigmav, pol_trans=True): + """V visibility chi-squared + """ + + #get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + #compute uniform --> nonuniform transform + vimage = make_v_image(imtuple, pol_trans) + plan.f_hat = vimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T + plan.trafo() + vsamples = plan.f.copy()*pulsefac + + #compute chi^2 + chisq = np.sum(np.abs((v - vsamples))**2/(sigmav**2)) / (2*len(v)) + + return chisq + +def chisqgrad_vvis_nfft(imtuple, A, v, sigmav, pol_trans=True,pol_solve=(0,1,1)): + """V visibility chi-squared gradient + """ + + #get nfft object + nfft_info = A[0] + plan = nfft_info.plan + pulsefac = nfft_info.pulsefac + + #compute uniform --> nonuniform transform + iimage = imtuple[0] + vimage = make_v_image(imtuple, pol_trans) + plan.f_hat = vimage.copy().reshape((nfft_info.ydim,nfft_info.xdim)).T + plan.trafo() + vsamples = plan.f.copy()*pulsefac + + + vdiff_vec = (-1.0/len(v)) * (v - vsamples) / (sigmav**2) * pulsefac.conj() + plan.f = vdiff_vec + plan.adjoint() + vpart = (plan.f_hat.copy().T).reshape(nfft_info.xdim*nfft_info.ydim) + + zeros = np.zeros(len(iimage)) + + if pol_trans: + + vfimage = imtuple[3] #fractional + + if pol_solve[0]!=0: + gradi = np.real(vfimage*vpart) + else: + gradi = zeros + + if pol_solve[3]!=0: + gradv = np.real(iimage*vpart) + else: + gradv = zeros + + + gradout = np.array((gradi, zeros, zeros, gradv)) + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return gradout + +################################################################################################## +# Polarimetric Entropy and Gradient Functions +################################################################################################## + +def sm(imtuple, flux, pol_trans=True, + norm_reg=NORM_REGULARIZER): + """I log m entropy + """ + if norm_reg: norm = flux + else: norm = 1 + + iimage = imtuple[0] + mimage = make_m_image(imtuple, pol_trans) + S = -np.sum(iimage * np.log(mimage)) + return S/norm + +def smgrad(imtuple, flux, pol_trans=True,pol_solve=(0,1,1), + norm_reg=NORM_REGULARIZER): + """I log m entropy gradient + """ + + if norm_reg: norm = flux + else: norm = 1 + + iimage = imtuple[0] + zeros = np.zeros(len(iimage)) + mimage = make_m_image(imtuple, pol_trans) + + if pol_trans: + + + if pol_solve[0]!=0: + gradi = -np.log(mimage) + else: + gradi = zeros + + if pol_solve[1]!=0: + gradm = -iimage / mimage + else: + gradm = zeros + + gradchi = zeros + + if len(imtuple)==4: + out = np.array((gradi, gradm, gradchi,zeros)) + else: + out = np.array((gradi, gradm, gradchi)) + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return np.array(out)/norm + +def shw(imtuple, flux, pol_trans=True, norm_reg=NORM_REGULARIZER): + """Holdaway-Wardle polarimetric entropy + """ + + if norm_reg: norm = flux + else: norm = 1 + + iimage = imtuple[0] + mimage = make_m_image(imtuple, pol_trans) + S = -np.sum(iimage * (((1+mimage)/2) * np.log((1+mimage)/2) + ((1-mimage)/2) * np.log((1-mimage)/2))) + return S/norm + +def shwgrad(imtuple, flux, pol_trans=True,pol_solve=(0,1,1), + norm_reg=NORM_REGULARIZER): + """Gradient of the Holdaway-Wardle polarimetric entropy + """ + if norm_reg: norm = flux + else: norm = 1 + + iimage = imtuple[0] + zeros = np.zeros(len(iimage)) + mimage = make_m_image(imtuple, pol_trans) + if pol_trans: + + if pol_solve[0]!=0: + gradi = -(((1+mimage)/2) * np.log((1+mimage)/2) + ((1-mimage)/2) * np.log((1-mimage)/2)) + else: + gradi = zeros + + if pol_solve[1]!=0: + gradm = -iimage * np.arctanh(mimage) + else: + gradm = zeros + + gradchi = zeros + + if len(imtuple)==4: + out = np.array((gradi, gradm, gradchi,zeros)) + else: + out = np.array((gradi, gradm, gradchi)) + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return np.array(out)/norm + +def stv_pol(imtuple, flux, nx, ny, psize, pol_trans=True, + norm_reg=NORM_REGULARIZER, beam_size=None): + """Total variation of I*m*exp(2Ichi)""" + + if beam_size is None: beam_size = psize + if norm_reg: norm = flux*psize / beam_size + else: norm = 1 + + pimage = make_p_image(imtuple, pol_trans) + im = pimage.reshape(ny, nx) + + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + S = -np.sum(np.sqrt(np.abs(im_l1 - im)**2 + np.abs(im_l2 - im)**2)) + return S/norm + +def stv_pol_grad(imtuple, flux, nx, ny, psize, pol_trans=True, pol_solve=(0,1,1), + norm_reg=NORM_REGULARIZER, beam_size=None): + """Total variation entropy gradient""" + + if beam_size is None: beam_size = psize + if norm_reg: norm = flux*psize / beam_size + else: norm = 1 + + iimage = imtuple[0] + zeros = np.zeros(len(iimage)) + pimage = make_p_image(imtuple, pol_trans) + + im = pimage.reshape(ny, nx) + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1] + im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1] + im_r1l2 = np.roll(np.roll(impad, 1, axis=0), -1, axis=1)[1:ny+1, 1:nx+1] + im_l1r2 = np.roll(np.roll(impad, 1, axis=0), -1, axis=1)[1:ny+1, 1:nx+1] + + # Denominators + d1 = np.sqrt(np.abs(im_l1 - im)**2 + np.abs(im_l2 - im)**2) + d2 = np.sqrt(np.abs(im_r1 - im)**2 + np.abs(im_r1l2 - im_r1)**2) + d3 = np.sqrt(np.abs(im_r2 - im)**2 + np.abs(im_l1r2 - im_r2)**2) + + if pol_trans: + + # dS/dI Numerators + if pol_solve[0]!=0: + m1 = 2*np.abs(im*im) - np.abs(im*im_l1)*np.cos(2*(np.angle(im_l1) - np.angle(im))) - np.abs(im*im_l2)*np.cos(2*(np.angle(im_l2) - np.angle(im))) + m2 = np.abs(im*im) - np.abs(im*im_r1)*np.cos(2*(np.angle(im) - np.angle(im_r1))) + m3 = np.abs(im*im) - np.abs(im*im_r2)*np.cos(2*(np.angle(im) - np.angle(im_r2))) + igrad = -(1./iimage)*(m1/d1 + m2/d2 + m3/d3).flatten() + else: + igrad = zeros + + # dS/dm numerators + if pol_solve[1]!=0: + m1 = 2*np.abs(im) - np.abs(im_l1)*np.cos(2*(np.angle(im_l1) - np.angle(im))) - np.abs(im_l2)*np.cos(2*(np.angle(im_l2) - np.angle(im))) + m2 = np.abs(im) - np.abs(im_r1)*np.cos(2*(np.angle(im) - np.angle(im_r1))) + m3 = np.abs(im) - np.abs(im_r2)*np.cos(2*(np.angle(im) - np.angle(im_r2))) + mgrad = -iimage*(m1/d1 + m2/d2 + m3/d3).flatten() + else: + mgrad=zeros + + # dS/dchi numerators + if pol_solve[2]!=0: + c1 = -2*np.abs(im*im_l1)*np.sin(2*(np.angle(im_l1) - np.angle(im))) - 2*np.abs(im*im_l2)*np.sin(2*(np.angle(im_l2) - np.angle(im))) + c2 = 2*np.abs(im*im_r1)*np.sin(2*(np.angle(im) - np.angle(im_r1))) + c3 = 2*np.abs(im*im_r2)*np.sin(2*(np.angle(im) - np.angle(im_r2))) + chigrad = -(c1/d1 + c2/d2 + c3/d3).flatten() + else: + chigrad = zeros + + if len(imtuple)==4: + out = np.array((igrad, mgrad, chigrad,zeros)) + else: + out = np.array((igrad, mgrad, chigrad)) + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return out/norm + +# circular polarization +# TODO check!! +def svflux(imtuple, vflux, pol_trans=True, norm_reg=NORM_REGULARIZER): + """Total flux constraint + """ + if norm_reg: norm = np.abs(vflux)**2 + else: norm = 1 + + vimage = make_v_image(imtuple, pol_trans) + + out = -(np.sum(vimage) - vflux)**2 + return out/norm + + +def svfluxgrad(imtuple, vflux, pol_trans=True, pol_solve=(0,0,0,1), norm_reg=NORM_REGULARIZER): + """Total flux constraint gradient + """ + if norm_reg: norm = np.abs(vflux)**2 + else: norm = 1 + + + iimage = imtuple[0] + zeros = np.zeros(len(iimage)) + vimage = make_v_image(imtuple, pol_trans) + grad = -2*(np.sum(vimage) - vflux)*np.ones(len(vimage)) + + + if pol_trans: + + # dS/dI Numerators + if pol_solve[0]!=0: + igrad = (vimage/iimage)*grad + else: + igrad = zeros + + # dS/dv numerators + if pol_solve[3]!=0: + vgrad = iimage*grad + else: + vgrad=zeros + + + out = np.array((igrad, zeros, zeros, vgrad)) + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return out/norm + +def sl1v(imtuple, vflux, pol_trans=True, norm_reg=NORM_REGULARIZER): + """L1 norm regularizer on V + """ + if norm_reg: norm = np.abs(vflux) + else: norm = 1 + + vimage = make_v_image(imtuple, pol_trans) + l1 = -np.sum(np.abs(vimage)) + return l1/norm + + +def sl1vgrad(imtuple, vflux, pol_trans=True, pol_solve=(0,0,0,1), norm_reg=NORM_REGULARIZER): + """L1 norm gradient + """ + if norm_reg: norm = np.abs(vflux) + else: norm = 1 + + iimage = imtuple[0] + zeros = np.zeros(len(iimage)) + vimage = make_v_image(imtuple, pol_trans) + grad = -np.sign(vimage) + + if pol_trans: + + # dS/dI Numerators + if pol_solve[0]!=0: + igrad = (vimage/iimage)*grad + else: + igrad = zeros + + # dS/dv numerators + if pol_solve[3]!=0: + vgrad = iimage*grad + else: + vgrad=zeros + + + out = np.array((igrad, zeros, zeros, vgrad)) + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return out/norm + +def sl2v(imtuple, vflux, pol_trans=True, norm_reg=NORM_REGULARIZER): + """L1 norm regularizer on V + """ + if norm_reg: norm = np.abs(vflux**2) + else: norm = 1 + + iimage = imtuple[0] + vimage = make_v_image(imtuple, pol_trans) + l2 = -np.sum((vimage)**2) + return l2/norm + + +def sl2vgrad(imtuple, vflux, pol_trans=True, pol_solve=(0,0,0,1), norm_reg=NORM_REGULARIZER): + """L2 norm gradient + """ + if norm_reg: norm = np.abs(vflux**2) + else: norm = 1 + + iimage = imtuple[0] + zeros = np.zeros(len(iimage)) + vimage = make_v_image(imtuple, pol_trans) + grad = -2*vimage + + if pol_trans: + + # dS/dI Numerators + if pol_solve[0]!=0: + igrad = (vimage/iimage)*grad + else: + igrad = zeros + + # dS/dv numerators + if pol_solve[3]!=0: + vgrad = iimage*grad + else: + vgrad=zeros + + + out = np.array((igrad, zeros, zeros, vgrad)) + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return out/norm + +def stv_v(imtuple, vflux, nx, ny, psize, pol_trans=True, + norm_reg=NORM_REGULARIZER, beam_size=None, epsilon=0.): + """Total variation of I*vfrac""" + + if beam_size is None: beam_size = psize + if norm_reg: norm = np.abs(vflux)*psize / beam_size + else: norm = 1 + + vimage = make_v_image(imtuple, pol_trans) + im = vimage.reshape(ny, nx) + + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + S = -np.sum(np.sqrt(np.abs(im_l1 - im)**2 + np.abs(im_l2 - im)**2+epsilon)) + return S/norm + +def stv_v_grad(imtuple, vflux, nx, ny, psize, pol_trans=True, pol_solve=(0,0,0,1), + norm_reg=NORM_REGULARIZER, beam_size=None, epsilon=0.): + """Total variation gradient""" + + if beam_size is None: beam_size = psize + if norm_reg: norm = np.abs(vflux)*psize / beam_size + else: norm = 1 + + iimage = imtuple[0] + zeros = np.zeros(len(iimage)) + vimage = make_v_image(imtuple, pol_trans) + + im = vimage.reshape(ny, nx) + + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1] + im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1] + + # rotate images + im_r1l2 = np.roll(np.roll(impad, 1, axis=0), -1, axis=1)[1:ny+1, 1:nx+1] + im_l1r2 = np.roll(np.roll(impad, -1, axis=0), 1, axis=1)[1:ny+1, 1:nx+1] + + # add together terms and return + g1 = (2*im - im_l1 - im_l2) / np.sqrt((im - im_l1)**2 + (im - im_l2)**2 + epsilon) + g2 = (im - im_r1) / np.sqrt((im - im_r1)**2 + (im_r1l2 - im_r1)**2 + epsilon) + g3 = (im - im_r2) / np.sqrt((im - im_r2)**2 + (im_l1r2 - im_r2)**2 + epsilon) + + # mask the first row column gradient terms that don't exist + mask1 = np.zeros(im.shape) + mask2 = np.zeros(im.shape) + mask1[0, :] = 1 + mask2[:, 0] = 1 + g2[mask1.astype(bool)] = 0 + g3[mask2.astype(bool)] = 0 + + # add terms together and return + grad = -(g1 + g2 + g3).flatten() + + if pol_trans: + + # dS/dI Numerators + if pol_solve[0]!=0: + igrad = (vimage/iimage)*grad + else: + igrad = zeros + + # dS/dv numerators + if pol_solve[3]!=0: + vgrad = iimage*grad + else: + vgrad=zeros + + + out = np.array((igrad, zeros, zeros, vgrad)) + + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return out/norm + +def stv2_v(imtuple, vflux, nx, ny, psize, pol_trans=True, + norm_reg=NORM_REGULARIZER, beam_size=None): + """Squared Total variation of I*vfrac + """ + + if beam_size is None: + beam_size = psize + if norm_reg: + norm = psize**4 * np.abs(vflux**2) / beam_size**4 + else: + norm = 1 + + vimage = make_v_image(imtuple, pol_trans) + im = vimage.reshape(ny, nx) + + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + out = -np.sum((im_l1 - im)**2 + (im_l2 - im)**2) + return out/norm + +def stv2_v_grad(imtuple, vflux, nx, ny, psize, pol_trans=True, pol_solve=(0,0,0,1), + norm_reg=NORM_REGULARIZER, beam_size=None): + """Squared Total variation gradient + """ + if beam_size is None: + beam_size = psize + if norm_reg: + norm = psize**4 * np.abs(vflux**2) / beam_size**4 + else: + norm = 1 + + iimage = imtuple[0] + zeros = np.zeros(len(iimage)) + vimage = make_v_image(imtuple, pol_trans) + im = vimage.reshape(ny, nx) + + impad = np.pad(im, 1, mode='constant', constant_values=0) + im_l1 = np.roll(impad, -1, axis=0)[1:ny+1, 1:nx+1] + im_l2 = np.roll(impad, -1, axis=1)[1:ny+1, 1:nx+1] + im_r1 = np.roll(impad, 1, axis=0)[1:ny+1, 1:nx+1] + im_r2 = np.roll(impad, 1, axis=1)[1:ny+1, 1:nx+1] + + g1 = (2*im - im_l1 - im_l2) + g2 = (im - im_r1) + g3 = (im - im_r2) + + # mask the first row column gradient terms that don't exist + mask1 = np.zeros(im.shape) + mask2 = np.zeros(im.shape) + mask1[0, :] = 1 + mask2[:, 0] = 1 + g2[mask1.astype(bool)] = 0 + g3[mask2.astype(bool)] = 0 + + # add together terms and return + grad = -2*(g1 + g2 + g3).flatten() + + if pol_trans: + + # dS/dI Numerators + if pol_solve[0]!=0: + igrad = (vimage/iimage)*grad + else: + igrad = zeros + + # dS/dv numerators + if pol_solve[3]!=0: + vgrad = iimage*grad + else: + vgrad=zeros + + + out = np.array((igrad, zeros, zeros, vgrad)) + + + else: + raise Exception("polarimetric representation %s not added to pol gradient yet!" % pol_trans) + + return out/norm + +################################################################################################## +# Embedding and Chi^2 Data functions +################################################################################################## +def embed_pol(imtuple, mask, clipfloor=0., randomfloor=False): + """Embeds a polarimetric image tuple into the size of boolean embed mask + """ + + out0=np.zeros(len(mask)) + out1=np.zeros(len(mask)) + out2=np.zeros(len(mask)) + if len(imtuple)==4: out3=np.zeros(len(mask)) + + # Here's a much faster version than before + out0[mask.nonzero()] = imtuple[0] + out1[mask.nonzero()] = imtuple[1] + out2[mask.nonzero()] = imtuple[2] + if len(imtuple)==4: + out3[mask.nonzero()] = imtuple[3] + + if clipfloor != 0.0: + out0[(mask-1).nonzero()] = clipfloor + out1[(mask-1).nonzero()] = 0 + out2[(mask-1).nonzero()] = 0 + if len(imtuple)==4: out3[(mask-1).nonzero()] = 0 + if randomfloor: # prevent total variation gradient singularities + out0[(mask-1).nonzero()] *= np.abs(np.random.normal(size=len((mask-1).nonzero()))) + + if len(imtuple)==4: + out = (out0, out1, out2, out3) + else: + out = (out0, out1, out2) + + return out + +def chisqdata_pvis(Obsdata, Prior, mask): + """Return the visibilities, sigmas, and fourier matrix for an observation, prior, mask + """ + + data_arr = Obsdata.unpack(['u','v','pvis','psigma'], conj=True) + uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1))) + vis = data_arr['pvis'] + sigma = data_arr['psigma'] + A = ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv, pulse=Prior.pulse, mask=mask) + + return (vis, sigma, A) + +def chisqdata_pvis_nfft(Obsdata, Prior, mask, **kwargs): + """Return the visibilities, sigmas, and fourier matrix for an observation, prior, mask + """ + + # unpack keyword args + fft_pad_factor = kwargs.get('fft_pad_factor',FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', GRIDDER_P_RAD_DEFAULT) + + # unpack data + data_arr = Obsdata.unpack(['u','v','pvis','psigma'], conj=True) + uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1))) + vis = data_arr['pvis'] + sigma = data_arr['psigma'] + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv) + A = [A1] + + return (vis, sigma, A) + + +def chisqdata_m(Obsdata, Prior, mask): + """Return the pol ratios, sigmas, and fourier matrix for and observation, prior, mask + """ + + mdata = Obsdata.unpack(['u','v','m','msigma'], conj=True) + uv = np.hstack((mdata['u'].reshape(-1,1), mdata['v'].reshape(-1,1))) + m = mdata['m'] + sigmam = mdata['msigma'] + A = ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv, pulse=Prior.pulse, mask=mask) + + return (m, sigmam, A) + +def chisqdata_m_nfft(Obsdata, Prior, mask, **kwargs): + """Return the pol ratios, sigmas, and fourier matrix for an observation, prior, mask + """ + + # unpack keyword args + fft_pad_factor = kwargs.get('fft_pad_factor',FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', GRIDDER_P_RAD_DEFAULT) + + # unpack data + mdata = Obsdata.unpack(['u','v','m','msigma'], conj=True) + uv = np.hstack((mdata['u'].reshape(-1,1), mdata['v'].reshape(-1,1))) + m = mdata['m'] + sigmam = mdata['msigma'] + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv) + A = [A1] + + return (m, sigmam, A) + + +def chisqdata_pbs(Obsdata, Prior, mask): + """return the bispectra, sigmas, and fourier matrices for and observation, prior, mask + """ + + biarr = Obsdata.bispectra(mode="all", vtype='rlvis', count="min") #TODO CONJ?? + uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1))) + uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1))) + uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1))) + bi = biarr['bispec'] + sigma = biarr['sigmab'] + + A3 = (ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv1, pulse=Prior.pulse, mask=mask), + ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv2, pulse=Prior.pulse, mask=mask), + ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv3, pulse=Prior.pulse, mask=mask) + ) + + return (bi, sigma, A3) + +def chisqdata_pbs_nfft(Obsdata, Prior, mask): + """return the bispectra, sigmas, and fourier matrices for and observation, prior, mask + """ + + # unpack keyword args + fft_pad_factor = kwargs.get('fft_pad_factor',FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', GRIDDER_P_RAD_DEFAULT) + + # unpack data + biarr = Obsdata.bispectra(mode="all", vtype='rlvis', count="min") + uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1))) + uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1))) + uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1))) + bi = biarr['bispec'] + sigma = biarr['sigmab'] + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv1) + A2 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv2) + A3 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv3) + A = [A1,A2,A3] + + return (bi, sigma, A) + +def chisqdata_vvis(Obsdata, Prior, mask): + """Return the visibilities, sigmas, and fourier matrix for an observation, prior, mask + """ + + data_arr = Obsdata.unpack(['u','v','vvis','vsigma'], conj=False) + uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1))) + vis = data_arr['vvis'] + sigma = data_arr['vsigma'] + A = ftmatrix(Prior.psize, Prior.xdim, Prior.ydim, uv, pulse=Prior.pulse, mask=mask) + + return (vis, sigma, A) + +def chisqdata_vvis_nfft(Obsdata, Prior, mask, **kwargs): + """Return the visibilities, sigmas, and fourier matrix for an observation, prior, mask + """ + + # unpack keyword args + fft_pad_factor = kwargs.get('fft_pad_factor',FFT_PAD_DEFAULT) + p_rad = kwargs.get('p_rad', GRIDDER_P_RAD_DEFAULT) + + # unpack data + data_arr = Obsdata.unpack(['u','v','vvis','vsigma'], conj=False) + uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1))) + vis = data_arr['vvis'] + sigma = data_arr['vsigma'] + + # get NFFT info + npad = int(fft_pad_factor * np.max((Prior.xdim, Prior.ydim))) + A1 = NFFTInfo(Prior.xdim, Prior.ydim, Prior.psize, Prior.pulse, npad, p_rad, uv) + A = [A1] + + return (vis, sigma, A) + +################################################################################################## +# Plotting +################################################################################################## + +#TODO this only works for pol_trans == "amp_cphase" +def plot_m(imtuple, Prior, nit, chi2_dict, **kwargs): + + cmap = kwargs.get('cmap','afmhot') + interpolation = kwargs.get('interpolation', 'gaussian') + pcut = kwargs.get('pcut', 0.05) + nvec = kwargs.get('nvec', 15) + scale = kwargs.get('scale',None) + dynamic_range = kwargs.get('dynamic_range',1.e5) + gamma = kwargs.get('dynamic_range',.5) + + plt.ion() + plt.pause(1.e-6) + plt.clf() + + # unpack + im = imtuple[0] + mim = imtuple[1] + chiim = imtuple[2] + imarr = im.reshape(Prior.ydim,Prior.xdim) + + if scale=='log': + if (imarr < 0.0).any(): + print('clipping values less than 0') + imarr[imarr<0.0] = 0.0 + imarr = np.log(imarr + np.max(imarr)/dynamic_range) + #unit = 'log(' + cbar_unit[0] + ' per ' + cbar_unit[1] + ')' + + if scale=='gamma': + if (imarr < 0.0).any(): + print('clipping values less than 0') + imarr[imarr<0.0] = 0.0 + imarr = (imarr + np.max(imarr)/dynamic_range)**(gamma) + #unit = '(' + cbar_unit[0] + ' per ' + cbar_unit[1] + ')^gamma' + + # Mask for low flux points + thin = int(round(Prior.xdim/nvec)) + mask = imarr > pcut * np.max(im) + mask2 = mask[::thin, ::thin] + + # Get vectors and ratio from current image + x = np.array([[i for i in range(Prior.xdim)] for j in range(Prior.ydim)])[::thin, ::thin][mask2] + y = np.array([[j for i in range(Prior.xdim)] for j in range(Prior.ydim)])[::thin, ::thin][mask2] + q = qimage(im, mim, chiim) + u = uimage(im, mim, chiim) + a = -np.sin(np.angle(q+1j*u)/2).reshape(Prior.ydim, Prior.xdim)[::thin, ::thin][mask2] + b = np.cos(np.angle(q+1j*u)/2).reshape(Prior.ydim, Prior.xdim)[::thin, ::thin][mask2] + m = (np.abs(q + 1j*u)/im).reshape(Prior.ydim, Prior.xdim) + m[~mask] = 0 + + # Stokes I plot + plt.subplot(121) + plt.imshow(imarr, cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.01*Prior.xdim, units='x', pivot='mid', color='k', angles='uv', scale=1.0/thin) + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.005*Prior.xdim, units='x', pivot='mid', color='w', angles='uv', scale=1.1/thin) + + xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6) + yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + plt.title('Stokes I') + + # Ratio plot + plt.subplot(122) + plt.imshow(m, cmap=plt.get_cmap('winter'), interpolation='gaussian', vmin=0, vmax=1) + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.01*Prior.xdim, units='x', pivot='mid', color='k', angles='uv', scale=1.0/thin) + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.005*Prior.xdim, units='x', pivot='mid', color='w', angles='uv', scale=1.1/thin) + + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + plt.title('m (above %i %% max flux)' % int(pcut*100)) + + # Create title + plotstr = "step: %i " % nit + for key in chi2_dict.keys(): + plotstr += "$\chi^2_{%s}$: %0.2f " % (key, chi2_dict[key]) + plt.suptitle(plotstr, fontsize=18) + + diff --git a/imaging/starwarps.py b/imaging/starwarps.py new file mode 100644 index 00000000..16eb3d45 --- /dev/null +++ b/imaging/starwarps.py @@ -0,0 +1,1713 @@ +# See example_starwarps.py for an example of how to use these methods +# Contact Katie Bouman (klbouman@caltech.edu) for any questions +# +# The methods/techniques used in this, referred to as StarWars, are described in +# "Reconstructing Video from Interferometric Measurements of Time-Varying Sources" +# by Katherine L. Bouman, Michael D. Johnson, Adrian V. Dalca, +# Andrew Chael, Freek Roelofs, Sheperd S. Doeleman, and William T. Freeman + +from __future__ import division +from __future__ import print_function + +import numpy as np +#import ehtim as eh +import ehtim.image as image +import ehtim.observing.pulses +from ehtim.observing.obs_helpers import * +from ehtim.imaging.imager_utils import chisqdata + +import scipy.stats as st +import scipy +import copy +import sys + +import matplotlib.pyplot as plt + +PROPERROR = True + +################################################################################################## + + +def solve_singleImage(mu, Lambda_orig, obs, measurement={'vis':1}, numLinIters=5, mask=[], normalize=False): + + if len(mask): + Lambda = Lambda_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1]) + else: + Lambda = Lambda_orig + mask = np.ones(mu.imvec.shape)>0 + + if list(measurement.keys())==1 and measurement.keys()[0]=='vis': + numLinIters = 1 + + z_List_t_t = mu.copy() + z_List_lin = mu.copy() + + for k in range(0,numLinIters): + meas, idealmeas, F, measCov, valid = getMeasurementTerms(obs, z_List_lin, measurement=measurement, mask=mask, normalize=normalize) + if valid: + z_List_t_t.imvec[mask], P_List_t_t = prodGaussiansLem2(F, measCov, meas, mu.imvec[mask], Lambda) + + if k < numLinIters-1: + z_List_lin = z_List_t_t.copy() + else: + z_List_t_t = mu.copy() + P_List_t_t = copy.deepcopy(Lambda) + + return (z_List_t_t, P_List_t_t, z_List_lin) + + +################################################################################################## + +def forwardUpdates_apxImgs(mu, Lambda_orig, obs_List, A_orig, Q_orig, init_images, measurement={'vis':1}, lightcurve=None, numLinIters=5, interiorPriors=False, mask=[], normalize=False): + ''' + Gaussian image prior: + :param mu: (list - len(mu)=num_time_steps or 1): every element is an image object which contains the mean image + at given timestep. If list length is one mean image is duplicated for all time steps + :param Lambda_orig: (list - len(Lambda_orig)=num_time_steps or 1): original unmasked covariance matrix. + Every element is a 2D numpy array which contains the covariance at a given timestep. + If list length is one, the cov image is duplicated for all time steps. + + Observations: + :param obs_List: list of observations, for each time step + + Dynamical Evolution Model + :param A_orig: original unmasked A matrix - time-invariant mean of warp field for dynamical evolution + :param Q_orig: original unmasked Q matrix - time-invariant covariance matrix of dynamical evolution model, + describing the amount of allowed intensity deviation + + Other Parameters: + :param init_images: option to provide initialization for the forward updates. + If none provided, then use the initialization from StarWarps paper + :param measurement: data products used + :param lightcurve: light curve time seres, needed if imposing a flux constraint + :param numLinIters: number of linearized iterations. We have non-linear measurement function f_{t}(x_{t}), + and we linearize the solution around \tilde{x_{t}} by taking the first order Taylor series expansion of f. + To improve the solution of the forward and backward terms, each step in the forward pass can be + iteratively re-solved and \tilde{x}_{t} can be updated at each iteration. + The values of \tilde{x}_{t} are fixed for the backward pass. + Note that if f is linear, only a single iteration will be enough to converge to the optimal solution. + Thus, if the only measurement is visibility, numLinIters = 1 should be set. + :param interiorPriors: flag for whether to use interior priors + :param mask: to select parts of the image to utilize. default is True for all pixels. + This is because getMeasurementTerms doesn't work with a mask yet. + :param normalize: flag for whether to normalize sigma in getMeasurementTerms + ''' + + # linear case: measurement function is linear and problem is convex + if list(measurement.keys())==1 and measurement.keys()[0]=='vis': + numLinIters = 1 + + # apply mask + if len(mask): + A = A_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1]) + Q = Q_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1]) + Lambda = [] + for t in range(0, len(Lambda_orig)): + Lambda.append( Lambda_orig[t][mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1]) ) + else: + Lambda = Lambda_orig + A = A_orig + Q = Q_orig + + #if measurement=='bispectrum': + # print 'WARNING: check the loglikelihood for non-linear functions' + + # create an image of 0's + zero_im = mu[0].copy() + zero_im.imvec = 0.0*zero_im.imvec + + # intitilize the prediction and update mean and covariances + z_List_t_t = [] # Mean of the hidden (state) image at t given data up to time t + P_List_t_t = [] # Covariance of the hidden (state) image at t given data up to time t + + # initialize z and P lists with to be all zeros + for t in range(0,len(obs_List)): + z_List_t_t.append(zero_im.copy()) + P_List_t_t.append(np.zeros(Lambda[0].shape)) + + # prediction 1 list: prediction at time t given all information up to t-1 + z_List_t_tm1 = copy.deepcopy(z_List_t_t) + P_List_t_tm1 = copy.deepcopy(P_List_t_t) + + # prediction 2 list: deep copy of prediction 1 list (possibly intermediate state variable) + z_star_List_t_tm1 = copy.deepcopy(z_List_t_t) + P_star_List_t_tm1 = copy.deepcopy(P_List_t_t) + # initialize linearization z's + z_List_lin = copy.deepcopy(z_List_t_t) + + loglikelihood_prior = 0.0 + loglikelihood_data = 0.0 + + # for each forward timestep... + for t in range(0,len(obs_List)): + sys.stdout.write('\rForward timestep %i of %i total timesteps...' % (t,len(obs_List))) + sys.stdout.flush() + + print('forward timestep: ' + str(t)) + + # Duplicate mean and covariance if needed + if len(mu) == 1: + mu_t = mu[0] + Lambda_t = Lambda[0] + # get current image's mean and covariance + else: + mu_t = mu[t] + Lambda_t = Lambda[t] + + # use lightcurve data if provided + if lightcurve: + tot = np.sum(mu_t.imvec) + mu_t.imvec = lightcurve[t]/tot * mu_t.imvec + Lambda_t = (lightcurve[t]/tot)**2 * Lambda_t + + + # predict + # Initialization of hidden state mean and covariance for t = 0 + if t==0: + z_star_List_t_tm1[t].imvec = copy.deepcopy(mu_t.imvec) + P_star_List_t_tm1[t] = copy.deepcopy(Lambda_t) + # StarWarps initialization of hidden state mean and covariance + else: + z_List_t_tm1[t].imvec[mask] = np.dot( A, z_List_t_t[t-1].imvec[mask] ) + if PROPERROR: + P_List_t_tm1[t] = Q + np.dot( np.dot( A, P_List_t_t[t-1] ), np.transpose(A) ) + else: + print('no prop error') + P_List_t_tm1[t] = Q + + # main predict step, using Lemma 1 from StarWarps supplementary doc (also see eq 29-30), using interior priors + if interiorPriors: + z_star_List_t_tm1[t].imvec[mask], P_star_List_t_tm1[t] = prodGaussiansLem1( mu_t.imvec[mask], Lambda_t, z_List_t_tm1[t].imvec[mask], P_List_t_tm1[t] ) + else: + z_star_List_t_tm1[t] = z_List_t_tm1[t].copy() + P_star_List_t_tm1[t] = copy.deepcopy(P_List_t_tm1[t]) + + + # update + # either go with user-provided initialization (if given) or take z_star_List_t_tm1 as an initialization + if init_images is None: + init_images_t = z_star_List_t_tm1[t].copy() + elif len(init_images) == 1: + init_images_t = init_images[0] + else: + init_images_t = init_images[t] + z_List_lin[t] = init_images_t.copy() + + # Do the linearized iterations + for k in range(0,numLinIters): + # F is the derivative of the Forward model with respect to the unknown parameters + if lightcurve: + meas, idealmeas, F, measCov, valid = getMeasurementTerms(obs_List[t], z_List_lin[t], measurement=measurement, tot_flux=lightcurve[t], mask=mask, normalize=normalize) + else: + meas, idealmeas, F, measCov, valid = getMeasurementTerms(obs_List[t], z_List_lin[t], measurement=measurement, tot_flux=None, mask=mask, normalize=normalize) + + # main update step, using Lemma 2 from StarWarps supplementary doc (also see eq 30-31) + if valid: + z_List_t_t[t].imvec[mask], P_List_t_t[t] = prodGaussiansLem2(F, measCov, meas, z_star_List_t_tm1[t].imvec[mask], P_star_List_t_tm1[t]) + + if k < numLinIters-1: + z_List_lin[t] = z_List_t_t[t].copy() + else: + z_List_t_t[t] = z_star_List_t_tm1[t].copy() + P_List_t_t[t] = copy.deepcopy(P_star_List_t_tm1[t]) + + # update the prior log likelihood, using interior priors + if t>0 and interiorPriors: + loglikelihood_prior = loglikelihood_prior + evaluateGaussianDist_log( z_List_t_tm1[t].imvec[mask], mu_t.imvec[mask], Lambda_t + P_List_t_tm1[t] ) + + # update the data log likelihood + if valid: + loglikelihood_data = loglikelihood_data + evaluateGaussianDist_log( np.dot(F , z_star_List_t_tm1[t].imvec[mask]), meas, measCov + np.dot( F, np.dot(P_star_List_t_tm1[t], F.T)) ) + + + # compute the log likelihood (equation 27 in StarWarps paper) + loglikelihood = loglikelihood_prior + loglikelihood_data + return ((loglikelihood_data, loglikelihood_prior, loglikelihood), z_List_t_tm1, P_List_t_tm1, z_List_t_t, P_List_t_t, z_List_lin) + + +###################################### EXTENDED MESSAGE PASSING ######################################## + +def backwardUpdates(mu, Lambda_orig, obs_List, A_orig, Q_orig, measurement={'vis':1}, lightcurve=None, apxImgs=False, mask=[], normalize=False): + ''' + Gaussian image prior: + :param mu: (list - len(mu)=num_time_steps or 1): every element is an image object which contains the mean image + at given timestep. If list length is one mean image is duplicated for all time steps + :param Lambda_orig: (list - len(Lambda_orig)=num_time_steps or 1): original unmasked covariance matrix. + Every element is a 2D numpy array which contains the covariance at a given timestep. + If list length is one, the cov image is duplicated for all time steps. + + Observations: + :param obs_List: list of observations, for each time step + + Dynamical Evolution Model + :param A_orig: original unmasked A matrix - time-invariant mean of warp field for dynamical evolution + :param Q_orig: original unmasked Q matrix - time-invariant covariance matrix of dynamical evolution model, + describing the amount of allowed intensity deviation + + Other Parameters + :param measurement: data products used + :param lightcurve: light curve time seres, needed if imposing a flux constraint + :param apxImgs: ??? + :param mask: to select parts of the image to utilize. default is True for all pixels. + This is because getMeasurementTerms doesn't work with a mask yet. + :param normalize: flag for whether to normalize sigma in getMeasurementTerms + ''' + + if len(mask): + A = A_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1]) + Q = Q_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1]) + Lambda = [] + for t in range(0, len(Lambda_orig)): + Lambda.append( Lambda_orig[t][mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1]) ) + else: + Lambda = Lambda_orig + A = A_orig + Q = Q_orig + + # create an image of 0's + zero_im = mu[0].copy() + zero_im.imvec = 0.0*zero_im.imvec + + # intitilize the prediction and update mean and covariances + z_t_t = [] + P_t_t = [] + # update list + for t in range(0,len(obs_List)): + z_t_t.append(zero_im.copy()) + P_t_t.append(np.zeros(Lambda[0].shape)) + # prediction 1 list + z_star_t_tp1 = copy.deepcopy(z_t_t) + P_star_t_tp1 = copy.deepcopy(P_t_t) + + + lastidx = len(obs_List)-1 + for t in range(lastidx,-1,-1): + sys.stdout.write('\rBackward timestep %i of %i total timesteps...' % (t,len(obs_List))) + sys.stdout.flush() + + if len(mu) == 1: + mu_t = mu[0] + Lambda_t = Lambda[0] + else: + mu_t = mu[t] + Lambda_t = Lambda[t] + if lightcurve is not None: + tot = np.sum(mu_t.imvec) + mu_t.imvec = lightcurve[t]/tot * mu_t.imvec + Lambda_t = (lightcurve[t]/tot)**2 * Lambda_t + + # predict + if t==lastidx: + z_star_t_tp1[t].imvec = copy.deepcopy(mu_t.imvec) + P_star_t_tp1[t] = copy.deepcopy(Lambda_t) + else: + if PROPERROR: + z_star_t_tp1[t].imvec[mask], P_star_t_tp1[t] = prodGaussiansLem2( A, Q + P_t_t[t+1], z_t_t[t+1].imvec[mask], mu_t.imvec[mask], Lambda_t) + else: + print('no prop error') + z_star_t_tp1[t].imvec[mask], P_star_t_tp1[t] = prodGaussiansLem2( A, Q, z_t_t[t+1].imvec[mask], mu_t.imvec[mask], Lambda_t) + + # update + if lightcurve: + meas, idealmeas, F, measCov, valid = getMeasurementTerms(obs_List[t], apxImgs[t], measurement=measurement, tot_flux=lightcurve[t], mask=mask, normalize=normalize) + else: + meas, idealmeas, F, measCov, valid = getMeasurementTerms(obs_List[t], apxImgs[t], measurement=measurement, tot_flux=None, mask=mask, normalize=normalize) + + if valid: + z_t_t[t].imvec[mask], P_t_t[t] = prodGaussiansLem2(F, measCov, meas, z_star_t_tp1[t].imvec[mask], P_star_t_tp1[t]) + + else: + z_t_t[t] = z_star_t_tp1[t].copy() + P_t_t[t] = copy.deepcopy(P_star_t_tp1[t]) + + return (z_t_t, P_t_t) + + +def smoothingUpdates(z_t_t, P_t_t, z_t_tm1, P_t_tm1, A_orig, mask=[]): + ''' + Smoothing + ''' + z = copy.deepcopy(z_t_t) + P = copy.deepcopy(P_t_t) + backwardsA = copy.deepcopy(P_t_t) + + if len(mask): + A = A_orig[mask[:,None] & mask[None,:]].reshape([np.sum(mask), -1]) + + lastidx = len(z)-1 + for t in range(lastidx,-1,-1): + + if t < lastidx: + backwardsA[t] = np.dot( np.dot(P_t_t[t], A.T ), np.linalg.inv(P_t_tm1[t+1]) ) + z[t].imvec[mask] = z_t_t[t].imvec[mask] + np.dot( backwardsA[t], z[t+1].imvec[mask] - z_t_tm1[t+1].imvec[mask] ) + P[t] = np.dot( np.dot( backwardsA[t] , P[t+1] - P_t_tm1[t+1]), backwardsA[t].T ) + P_t_t[t] + + return (z, P, backwardsA) + + + +def computeSuffStatistics(mu, Lambda, obs_List, Upsilon, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, init_images=None, method='phase', measurement={'vis':1}, lightcurve=None, interiorPriors=False, numLinIters=1, compute_expVal_tm1_t=True, mask=[], normalize=False): + ''' + Gaussian image prior: + :param mu: (list - len(mu)=num_time_steps or 1): every element is an image object which contains the mean image + at given timestep. If list length is one mean image is duplicated for all time steps + :param Lambda_orig: (list - len(Lambda_orig)=num_time_steps or 1): original unmasked covariance matrix. + Every element is a 2D numpy array which contains the covariance at a given timestep. + If list length is one, the cov image is duplicated for all time steps. + + Observations: + :param obs_List: list of observations, for each time step + + Dynamical Evolution Model + :param A_orig: original unmasked A matrix - time-invariant mean of warp field for dynamical evolution + :param Q_orig: original unmasked Q matrix - time-invariant covariance matrix of dynamical evolution model, + describing the amount of allowed intensity deviation + + Other Parameters: + :param init_images: option to provide initialization for the forward updates. + If none provided, then use the initialization from StarWarps paper + :param measurement: data products used + :param lightcurve: light curve time seres, needed if imposing a flux constraint + :param numLinIters: number of linearized iterations. We have non-linear measurement function f_{t}(x_{t}), + and we linearize the solution around \tilde{x_{t}} by taking the first order Taylor series expansion of f. + To improve the solution of the forward and backward terms, each step in the forward pass can be + iteratively re-solved and \tilde{x}_{t} can be updated at each iteration. + The values of \tilde{x}_{t} are fixed for the backward pass. + Note that if f is linear, only a single iteration will be enough to converge to the optimal solution. + Thus, if the only measurement is visibility, numLinIters = 1 should be set. + :param interiorPriors: Flag for whether to use interior priors. + :param compute_expVal_tm1_t: flag for whether to compute the second sufficient statistic, E[x_{t-1}x_{t}^{T}]. + :param mask: to select parts of the image to utilize. default is True for all pixels. + This is because getMeasurementTerms doesn't work with a mask yet. + :param normalize: flag for whether to normalize sigma in getMeasurementTerms + ''' + + # if mask not provided, create default mask + if not len(mask): + mask = np.ones(mu[0].imvec.shape)>0 + + # check if first mean image is square + if mu[0].xdim != mu[0].ydim: + error('Error: This has only been checked thus far on square images!') + + # lightcurve and flux constraint go together + if lightcurve == None and 'flux' in measurement.keys(): #KATIE ADDED FEB 1 2021 + error('Error: if you are using a flux constraint you must specify a lightcurve') + + # if the visibility is the only measurement, + if list(measurement.keys())==1 and measurement.keys()[0]=='vis': + numLinIters = 1 + + # calculate matrix to represent warp field, from first mean image + warpMtx = calcWarpMtx(mu[0], theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method) + + # Parameterize dynamical evolution model with Guassian + # the time-invariant mean warp field + A = warpMtx + # time-invariant covariance matrix describing amount of allowed intensity variation in the warp field + Q = Upsilon + + # Do forward passes + loglikelihood, z_t_tm1, P_t_tm1, z_t_t, P_t_t, apxImgs = forwardUpdates_apxImgs(mu, Lambda, obs_List, A, Q, init_images=init_images, measurement=measurement, lightcurve=lightcurve, interiorPriors=interiorPriors, numLinIters=numLinIters, mask=mask, normalize=normalize) + + # Extended message passing with backward passes, using interior priors + if interiorPriors: + # Do backward passes + z_backward_t_t, P_backward_t_t = backwardUpdates(mu, Lambda, obs_List, A, Q, measurement=measurement, lightcurve=lightcurve, apxImgs=apxImgs, mask=mask, normalize=normalize) + + z = copy.deepcopy(z_backward_t_t) + P = copy.deepcopy(P_backward_t_t) + for t in range(0,len(obs_List)): + if t==0: + z[t] = z_backward_t_t[t].copy() + P[t] = copy.deepcopy(P_backward_t_t[t]) + else: + z[t].imvec[mask], P[t] = prodGaussiansLem1(z_t_tm1[t].imvec[mask], P_t_tm1[t], z_backward_t_t[t].imvec[mask], P_backward_t_t[t]) + + # Use smoothing updates instead + else: + z, P, backwardsA = smoothingUpdates(z_t_t, P_t_t, z_t_tm1, P_t_tm1, A, mask=mask) + + + + expVal_t = copy.deepcopy(z) + #initilize the lists + expVal_t_t = copy.deepcopy(P) + expVal_tm1_t = copy.deepcopy(P) + for t in range(0,len(obs_List)): + # expected value of xx^T for each x + z_t_hvec = np.array([z[t].imvec[mask]]) + expVal_t_t[t] = np.dot(z_t_hvec.T, z_t_hvec) + P[t] + + # expected value of x_t x_t-1^T for each x except for the first one + if t>0 and interiorPriors==False and compute_expVal_tm1_t: + z_tm1_hvec = np.array([z[t-1].imvec[mask]]) + expVal_tm1_t[t] = np.dot(z_tm1_hvec.T, z_t_hvec) + np.dot(backwardsA[t-1], P[t]) + + # expected value of x_t x_t-1^T, using interior priors + if interiorPriors and compute_expVal_tm1_t: + expVal_tm1_t = JointDist(z, z_t_t, P_t_t, z_backward_t_t, P_backward_t_t, A, Q) + + return (expVal_t, expVal_t_t, expVal_tm1_t, loglikelihood, apxImgs) + + + + +def JointDist(z, z_List_t_t_forward, P_List_t_t_forward, z_List_t_t_backward, P_List_t_t_backward, A, Q): + ''' + Calculate the joint distribution p(x_{t},x_{t-1} | y_{1:N}) + See section 2.2 in StarWarps supplementary doc, starting from eq 60 + ''' + + expVal_tm1_t = [] + expVal_tm1_t.append(0.0) + + # section 2.2 + for t in range(1, len(z_List_t_t_forward) ): + + Sigma = Q + P_List_t_t_backward[t] + Sigma_inv = np.linalg.inv(Sigma) + + # eq 76 + M = np.dot(P_List_t_t_backward[t], np.dot(Sigma_inv, A) ) + + # eq 77, 78 + (m, C) = prodGaussiansLem2(A, Sigma, z_List_t_t_backward[t].imvec, z_List_t_t_forward[t-1].imvec, P_List_t_t_forward[t-1]) + + # eq 79 + D_tmp1 = np.dot(M, np.dot(C, M.T)) + D_tmp2 = np.dot(Q, np.dot( Sigma_inv, P_List_t_t_backward[t] ) ) + D = np.dot(C, np.dot(M.T, np.linalg.inv(D_tmp1 + D_tmp2) ) ) + + # eq 81 + F = C - np.dot(D, np.dot(M, C)) + + # E[x_{t}] + z_t_hvec = np.array([z[t].imvec]) + # E[x_{t-1}] + z_tm1_hvec = np.array([z[t-1].imvec]) + + # eq 88 + expVal_tm1_t.append( np.dot(F, np.linalg.inv(D.T)) + np.dot(z_tm1_hvec.T, z_t_hvec) ) + + return expVal_tm1_t + + + +################################# + + +def maximizeWarpMtx(expVal_t_t, expVal_tm1_t, expVal_t=0, B=0): + + M1 = np.zeros(expVal_tm1_t[1].shape) + M2 = np.zeros(expVal_t_t[1].shape) + + for t in range(1,len(expVal_t_t)): + #M1 = M1 + 0.5*expVal_tm1_t[t] + 0.5*expVal_tm1_t[t].T + M1 = M1 + expVal_tm1_t[t].T + if B !=0: + M1 = M1 + np.dot(B, expVal_t[t]) + M2 = M2 + expVal_t_t[t-1] + + warpMtx = np.dot( M1, np.linalg.inv(M2) ) + return warpMtx + +def maximizeTheta_multiIter(expVal_t_t, expVal_tm1_t, dummy_im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase', nIter=10): + + newTheta = centerTheta + for i in range(0, nIter): + newTheta = maximizeTheta(expVal_t_t, expVal_tm1_t, dummy_im, newTheta, newTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method) + + return newTheta + + +def maximizeTheta(expVal_t_t, expVal_tm1_t, dummy_im, Q, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase'): + + if method == 'phase' or method == 'approx_phase': + dWarp_dTheta = calc_dWarp_dTheta(dummy_im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase') + else: + error('ERROR: WE ONLY HANDLE PHASE WARP MINIMIZATION RIGHT NOW') + + warpMtx = calcWarpMtx(dummy_im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method) + invQ = np.linalg.inv(Q) + + nbasis = len(initTheta) + thetaNew = np.zeros( initTheta.shape ) + + G1 = np.zeros(expVal_tm1_t[1].shape) + for t in range(1,len(expVal_t_t)): + #G1 = G1 + 0.5*expVal_tm1_t[t] + 0.5*expVal_tm1_t[t].T - np.dot( warpMtx, expVal_t_t[t-1] ) + G1 = G1 + expVal_tm1_t[t].T - np.dot( warpMtx, expVal_t_t[t-1] ) + #G1 = G1 + expVal_tm1_t[t] - np.dot( warpMtx, expVal_t_t[t-1] ) + for b in range(0, nbasis): + G1 = G1 + np.dot( dWarp_dTheta[b], expVal_t_t[t-1] )*centerTheta[b] + G1 = np.dot(invQ, G1) + + G2 = [] + for b in range (0,nbasis): + G2.append(np.zeros(expVal_t_t[1].shape)) + for t in range(1,len(expVal_t_t)): + G2[b] = G2[b] + np.dot( dWarp_dTheta[b], expVal_t_t[t-1] ) + G2[b] = np.dot(invQ, G2[b]) + + D1 = np.zeros(initTheta.shape) + for b1 in range(0, nbasis): + for p in range(0,dWarp_dTheta[b1].shape[0]): + for q in range(0,dWarp_dTheta[b1].shape[1]): + D1[b1] = D1[b1] + G1[p,q]*dWarp_dTheta[b1][p,q] + + D2 = np.zeros([nbasis, nbasis]) + for b1 in range(0, nbasis): + for b2 in range(0, nbasis): + for p in range(0,dWarp_dTheta[b1].shape[0]): + for q in range(0,dWarp_dTheta[b1].shape[1]): + D2[b1,b2] = D2[b1,b2] + G2[b2][p,q]*dWarp_dTheta[b1][p,q] + + + thetaNew = np.dot(np.linalg.inv(D2), D1) + + + + secondDeriv = np.zeros((nbasis, nbasis)) + for b in range(0, nbasis): + thetaNew_tmp = copy.deepcopy(thetaNew) + thetaNew_tmp[b] = 1.0 + secondDeriv[:,b] = D1 - np.dot(D2, thetaNew_tmp) + eigvals,_ = np.linalg.eig(secondDeriv) + if all(eigvals>0): + print('local min') + elif all(eigvals<0): + print('local max') + elif any(eigvals==0.0): + print('inconclusive') + else: + print('saddle point: ' + str(np.sum(eigvals<0)) + ' negative eigs of ' + str(len(eigvals))) + + + return (thetaNew, secondDeriv, D1, D2) + + +def negloglikelihood(theta, mu, Lambda, obs_List, Upsilon, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method, measurement, interiorPriors, mask=[]): + + warpMtx = calcWarpMtx(mu, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method) + + A = warpMtx + B = np.zeros(mu.imvec.shape) + Q = Upsilon + + loglike, z_t_tm1, P_t_tm1, z_t_t, P_t_t = forwardUpdates(mu, Lambda, obs_List, A, B, Q, measurement=measurement, interiorPriors=interiorPriors, mask=mask) + + return -loglike[2] + +def expnegloglikelihood_full(theta, expectation_theta, mu, Lambda, obs_List, Q, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method, measurement, interiorPriors, numLinIters, apxImgs): + + expVal_t, expVal_t_t, expVal_tm1_t, _ = computeSuffStatistics(mu, Lambda, obs_List, Q, expectation_theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method, measurement=measurement, interiorPriors=interiorPriors, numLinIters=numLinIters, apxImgs=apxImgs) + neg_expll = expnegloglikelihood(theta, expVal_t, expVal_t_t, expVal_tm1_t, mu, Lambda, obs_List, Q, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method) + print(neg_expll) + + +def expnegloglikelihood(theta, expVal_t, expVal_t_t, expVal_tm1_t, mu, Lambda, obs_List, Upsilon, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method): + + #if interiorPriors: + #TODO: print 'WARNING: not sure if this works with interior priors because of the derivation of the E[xMx] terms may be different' + + warpMtx = calcWarpMtx(mu[0], theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method) + A = warpMtx + B = np.zeros(mu[0].imvec.shape) + Q = Upsilon + invQ = np.linalg.inv(Q) + + value = 0.0 + for t in range(1, len(expVal_t)): + x_t = np.array([expVal_t[t].imvec]).T + x_tm1 = np.array([expVal_t[t-1].imvec]).T + + P_tm1_t = expVal_tm1_t[t] - np.dot(x_tm1, x_t.T) + P_tm1_tm1 = expVal_t_t[t-1] - np.dot(x_tm1, x_tm1.T) + + term1 = exp_xtm1_M_xt(P_tm1_t.T, x_t, x_tm1, np.dot(invQ, A) ) + term2 = exp_xtm1_M_xt(P_tm1_t, x_tm1, x_t, np.dot(A.T, invQ) ) + term3 = exp_xtm1_M_xt(P_tm1_tm1, x_tm1, x_tm1, np.dot(A.T, np.dot(invQ, A) ) ) + term4 = np.dot(B.T, np.dot(invQ, np.dot(A, x_tm1))) + + value = value - 0.5*( -term1 - term2 + term3 + term4 + term4.T ) + + value = -value + return value + +def exp_xtm1_M_xt(P, z1, z2, M): + value = np.trace( np.dot(P, M.T) ) + np.dot(z1.T, np.dot(M, z2 ) ) + return value + + + +def deriv_expnegloglikelihood_full(theta, expectation_theta, mu, Lambda, obs_List, Q, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method, measurement, interiorPriors, numLinIters, apxImgs): + + expVal_t, expVal_t_t, expVal_tm1_t, _ = computeSuffStatistics(mu, Lambda, obs_List, Q, expectation_theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method, measurement=measurement, interiorPriors=interiorPriors, numLinIters=numLinIters, apxImgs=apxImgs) + return deriv_expnegloglikelihood(theta, expVal_t, expVal_t_t, expVal_tm1_t, mu, Lambda, obs_List, Q, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method) + + +def deriv_expnegloglikelihood(theta, expVal_t, expVal_t_t, expVal_tm1_t, mu, Lambda, obs_List, Q, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method): + + if method == 'phase': + dWarp_dTheta = calc_dWarp_dTheta(mu[0], theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method) + else: + print('WARNING: WE ONLY HANDLE PHASE WARP MINIMIZATION RIGHT NOW') + + warpMtx = calcWarpMtx(mu[0], theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method) + + invQ = np.linalg.inv(Q) + M1 = np.zeros(expVal_tm1_t[1].shape) + for t in range(1,len(expVal_t_t)): + #M1 = M1 + 0.5*expVal_tm1_t[t] + 0.5*expVal_tm1_t[t].T - np.dot( warpMtx, expVal_t_t[t-1] ) + M1 = M1 + expVal_tm1_t[t].T - np.dot( warpMtx, expVal_t_t[t-1] ) + M1 = np.dot( invQ , M1) + + deriv = np.zeros( initTheta.shape ) + for b in range(0,len(initTheta)): + for p in range(0,dWarp_dTheta[b].shape[0]): + for q in range(0,dWarp_dTheta[b].shape[1]): + deriv[b] = deriv[b] + M1[p,q]*dWarp_dTheta[b][p,q] + + # the derivative computed is for the ll but we want the derivative of the neg ll + deriv = -deriv + return deriv + + +def maximizeBrightness(expVal_t_t, expVal_tm1_t, dummy_im, Q): + + dWarp_dTheta = np.eye(dummy_im.xdim*dummy_im.ydim) + invQ = np.linalg.inv(Q) + + M1 = np.zeros(expVal_tm1_t[1].shape) + M2 = np.zeros(expVal_t_t[1].shape) + + top = 0.0 + bottom = 0.0 + + for t in range(1,len(expVal_t_t)): + #M1 = M1 + 0.5*expVal_tm1_t[t] + 0.5*expVal_tm1_t[t].T + M1 = M1 + expVal_tm1_t[t].T + M2 = M2 + expVal_t_t[t-1] + M1 = np.dot(invQ, M1) + M2 = np.dot(invQ, M2) + + for p in range(0,dWarp_dTheta.shape[0]): + for q in range(0,dWarp_dTheta.shape[1]): + top = top + M1[p,q]*dWarp_dTheta[p,q] + bottom = bottom + M2[p,q]*dWarp_dTheta[p,q] + + thetaNew = top/bottom + + return thetaNew + + + +def evaluateGaussianDist_log(y, x, Sigma): + + n = len(x) + if len(x) != np.prod(x.shape): + raise AssertionError() + + diff = x - y + (sign, logdet) = np.linalg.slogdet(Sigma) + expval_log = - (n/2.0)*np.log( 2.0*np.pi ) - 0.5*(sign*logdet) - 0.5*np.dot( diff.T, np.dot( np.linalg.inv(Sigma), diff ) ) + + return expval_log + +def evaluateGaussianDist(y, x, Sigma): + + expval_log = evaluateGaussianDist_log(y, x, Sigma) + expval = np.exp( expval_log ) + return expval + +def prodGaussiansLem1(m1, S1, m2, S2): + + K = np.linalg.inv(S1 + S2) + + covariance = np.dot( S1, np.dot( K, S2 ) ) + mean = np.dot(S1, np.dot(K, m2) ) + np.dot(S2, np.dot(K, m1) ) + + return (mean, covariance) + +def prodGaussiansLem2(A, Sigma, y, mu, Q): + + K1 = np.linalg.inv( Sigma + np.dot(A, np.dot(Q, np.transpose(A))) ) + K2 = np.dot( Q, np.dot( A.T, K1 ) ) + + covariance = Q - np.dot( K2, np.dot( A, Q ) ) + mean = mu + np.dot( K2, y - np.dot(A, mu) ) + + return (mean, covariance) + + +def getMeasurementTerms(obs, im, measurement={'vis': 1}, tot_flux=None, mask=[], normalize=False): + if not np.sum(mask)==len(mask): + raise ValueError('The code doenst currently work with a mask!') + + #initilize the concatenated data terms + measdiff_all = [] + ideal_all = [] + F_all = [] + Cov_all = [] + data_all = [] + + count = 0 + # loop through data products we want to constrain + for dname in list(measurement.keys()): + + # ignore data terms that 0 weight + if np.allclose(measurement[dname],0.0): + continue + + # check to see if you have data in the current obs + try: + if dname=='flux': + if tot_flux == None: + error('Error: if you are using a flux constraint you must specify a total flux (via the lightcurve)') + data = np.array([tot_flux]) + sigma = np.array([1]) + else: + data, sigma, A = chisqdata(obs, im, mask, dtype=dname, ttype='direct') + count = count + 1 + except: + continue + + #compute the derivative matrix and the ideal measurements if im was the true image + if dname == 'vis': + F = A + ideal = np.dot(A,im.imvec) + elif dname == 'bs': + F = grad_bs(im.imvec, A) + ideal = bs(im.imvec,A) + elif dname == 'cphase': + F = grad_cphase(im.imvec, A) + ideal = cphase(im.imvec,A) + elif dname == 'amp': + F = grad_amp(im.imvec, A) + ideal = amp(im.imvec,A) + elif dname == 'logcamp': + F = grad_logcamp(im.imvec, A) + ideal = logcamp(im.imvec,A) + elif dname == 'flux': + F = grad_flux(im.imvec) + ideal = flux(im.imvec) + + #turn complex matrices to real + if not np.allclose(data.imag,0): + F = realimagStack(F) + data = realimagStack(data) + ideal = realimagStack(ideal) + sigma = np.concatenate((sigma,sigma), axis=0) + + # change the error bars based upon which elements we want to constrain more + weight = measurement[dname] + if normalize: + sigma = sigma / np.sqrt(np.sum(sigma ** 2)) + sigma = sigma / np.sqrt(weight) + Cov = np.diag(sigma ** 2) + + data_all = np.concatenate((data_all, data.reshape(-1)), axis=0).reshape(-1) + ideal_all = np.concatenate((ideal_all, ideal.reshape(-1)), axis=0).reshape(-1) + measdiff_all = np.concatenate( + (measdiff_all, data.reshape(-1) + np.dot(F, im.imvec[mask]) - ideal.reshape(-1)), axis=0) + F_all = np.concatenate((F_all, F), axis=0) if len(F_all) else F + Cov_all = scipy.linalg.block_diag(Cov_all, Cov) if len(Cov_all) else Cov + + if len(data_all): + return (measdiff_all, ideal_all, F_all, Cov_all, True) + else: + return (-1, -1, -1, -1, False) + +def bs(imvec, Amatrices): + """the bispectrum""" + out = np.dot(Amatrices[0],imvec) * np.dot(Amatrices[1],imvec) * np.dot(Amatrices[2],imvec) + return out + +def grad_bs(imvec, Amatrices): + """The gradient of the bispectrum""" + pt1 = np.dot(Amatrices[1],imvec) * np.dot(Amatrices[2],imvec) + pt2 = np.dot(Amatrices[0],imvec) * np.dot(Amatrices[2],imvec) + pt3 = np.dot(Amatrices[0],imvec) * np.dot(Amatrices[1],imvec) + out = pt1[:,None] * Amatrices[0] + pt2[:,None] * Amatrices[1] + pt3[:,None] * Amatrices[2] + return out + +def flux(imvec): + return np.sum(imvec) + +def grad_flux(imvec): + return np.ones((1, len(imvec))) + +def cphase(imvec, Amatrices): + """the closure phase""" + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + clphase_samples = np.angle(i1 * i2 * i3) + out = np.exp(1j * clphase_samples) + return out + +def grad_cphase(imvec, Amatrices): + """The gradient of the closure phase""" + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + clphase_samples = np.angle(i1 * i2 * i3) + pt1 = 1.0/i1 + pt2 = 1.0/i2 + pt3 = 1.0/i3 + dphi = (pt1[:,None]*Amatrices[0]) + (pt2[:,None] * Amatrices[1]) + \ + (pt3[:,None]* Amatrices[2]) + out = 1j * np.imag(dphi) * np.exp(1j * clphase_samples[:,None]) + return out + +def amp(imvec, A): + """the amplitude""" + i1 = np.dot(A, imvec) + out = np.abs(i1) + return out + +def grad_amp(imvec, A): + """The gradient of the amplitude """ + i1 = np.dot(A, imvec) + pp = np.abs(i1) / i1 + out = np.real(pp[:,None] * A) + return out + +def camp(imvec, Amatrices): + """the closure amplitude""" + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + i4 = np.dot(Amatrices[3], imvec) + out = np.abs((i1 * i2)/(i3 * i4)) + return out + + +def grad_camp(imvec, Amatrices): + """The gradient of the closure amplitude """ + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + i4 = np.dot(Amatrices[3], imvec) + clamp_samples = np.abs((i1 * i2)/(i3 * i4)) + + pt1 = clamp_samples/i1 + pt2 = clamp_samples/i2 + pt3 = -clamp_samples/i3 + pt4 = -clamp_samples/i4 + out = (pt1[:,None]*Amatrices[0]) + (pt2[:,None]*Amatrices[1]) + (pt3[:,None]*Amatrices[2]) + (pt4[:,None]*Amatrices[3]) + return np.real(out) + + +def logcamp(imvec, Amatrices): + """The Log closure amplitude """ + + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + i4 = np.dot(Amatrices[3], imvec) + out = np.log(np.abs(i1)) + np.log(np.abs(i2)) - np.log(np.abs(i3)) - np.log(np.abs(i4)) + return out + +def grad_logcamp(imvec, Amatrices): + """The gradient of the Log closure amplitude """ + + i1 = np.dot(Amatrices[0], imvec) + i2 = np.dot(Amatrices[1], imvec) + i3 = np.dot(Amatrices[2], imvec) + i4 = np.dot(Amatrices[3], imvec) + + pt1 = 1/i1 + pt2 = 1/i2 + pt3 = -1/i3 + pt4 = -1/i4 + out = np.real(pt1[:,None] * Amatrices[0] + pt2[:,None] * Amatrices[1] + \ + pt3[:,None] * Amatrices[2] + pt4[:,None] * Amatrices[3]) + return out + + +def mergeObs(obs_List): + + obs = obs_List[0].copy() + data = obs.data + for t in range(1,len(obs_List)): + data = np.concatenate((data, obs_List[t].data)) + obs.data = data + return obs + + +def splitObs(obs): + """Split single observation into multiple observation files, one per scan + """ + + print("Splitting Observation File into " + str(len(obs.tlist())) + " scans") + + #Note that the tarr of the output includes all sites, even those that don't participate in the scan + obs_List = [ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, tdata, obs.tarr, source=obs.source, + mjd=obs.mjd, ampcal=obs.ampcal, phasecal=obs.phasecal) + for tdata in obs.tlist() + ] + return obs_List + + +def movie(im_List, out='movie.mp4', fps=10, dpi=120): + import matplotlib + matplotlib.use('agg') + import matplotlib.pyplot as plt + import matplotlib.animation as animation + + fig = plt.figure() + frame = im_List[0].imvec #read_auto(filelist[len(filelist)/2]) + fov = im_List[0].psize*im_List[0].xdim + extent = fov * np.array((1,-1,-1,1)) / 2. + maxi = np.max(frame) + im = plt.imshow( np.reshape(frame,[im_List[0].xdim, im_List[0].xdim]) , cmap='hot', extent=extent) #inferno + plt.colorbar() + im.set_clim([0,maxi]) + fig.set_size_inches([5,5]) + plt.tight_layout() + + def update_img(n): + sys.stdout.write('\rprocessing image %i of %i ...' % (n,len(im_List)) ) + sys.stdout.flush() + im.set_data(np.reshape(im_List[n].imvec, [im_List[n].xdim, im_List[n].xdim]) ) + return im + + ani = animation.FuncAnimation(fig,update_img,len(im_List),interval=1e3/fps) + writer = animation.writers['ffmpeg'](fps=max(20, fps), bitrate=1e6) + ani.save(out,writer=writer,dpi=dpi) + + +def dirtyImage(im, obs_List, init_x=[], init_y=[], flowbasis_x=[], flowbasis_y=[], initTheta=[]): + + if len(initTheta)==0: + init_x, init_y, flowbasis_x, flowbasis_y, initTheta = affineMotionBasis(im) + + im_List = []; + for t in range(0,len(obs_List)): + im_List.append(im.copy()) + + for t in range(0,len(obs_List)): + A = genPhaseShiftMtx_obs(obs_List[t],init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, pulse=ehtim.observing.pulses.deltaPulse2D) + #im_List[t].imvec = np.real( np.dot( np.linalg.inv( np.dot(np.transpose(A),A) ), np.dot( np.transpose(A), obs_List[t].data['vis'] ) ) ) + im_List[t].imvec = np.real( np.dot( np.transpose(np.conj(A) ), obs_List[t].data['vis']) ) + + return im_List + +def weinerFiltering(meanImg, covImg, obs_List, mask=[]): + + if type(obs_List) != list: + obs_List = [obs_List] + + im_List = [] + cov_List = [] + exp_tm1_t = [] + + for t in range(0,len(obs_List)): + im_List.append(meanImg.copy()) + cov_List.append(np.zeros(covImg.shape)) + exp_tm1_t.append(np.zeros(covImg.shape)) + + for t in range(0,len(obs_List)): + + meas, idealmeas, A, measCov, valid = getMeasurementTerms(obs_List[t], meanImg, measurement={'vis':1}, mask=mask) + + if valid==False: + im_List[t] = meanImg.copy() + cov_List[t] = copy.deepcopy(covImg) + else: + im_List[t].imvec, cov_List[t] = newDensity(meanImg.imvec, covImg, A, meas, idealmeas, measCov) + + if t>0: + exp_tm1_t[t] = np.dot( np.array([im_List[t-1].imvec]).T , np.array([im_List[t].imvec]) ) + + return (im_List, cov_List, exp_tm1_t) + + + +def newDensity(X, covX, A, Y, idealY, covY): + + measresidual = Y - idealY + residualcov = np.dot(np.dot(A , covX ), np.transpose(A)) + covY + G = np.dot( covX, np.dot( np.transpose(A) , np.linalg.inv ( residualcov ) ) ) + Xnew = X + np.dot( G, measresidual ) + covXnew = covX - np.dot(G, np.dot( A, covX ) ) + + return (Xnew, covXnew) + + +def newDensity_linearize(X, covX, A, Y, idealY, covY, Xlin): + + measresidual = Y - idealY + np.dot(A,Xlin) - np.dot(A,X) + residualcov = np.dot(np.dot(A , covX ), np.transpose(A)) + covY + G = np.dot( covX, np.dot( np.transpose(A) , np.linalg.inv ( residualcov ) ) ) + Xnew = X + np.dot( G, measresidual ) + covXnew = covX - np.dot(G, np.dot( A, covX ) ) + + return (Xnew, covXnew) + + +def newDensity3(X0, covX0, X1, covX1, X2, Y, idealY_X2, A_X2, covY): + + invCovX0 = np.linalg.inv(covX0) + invCovX1 = np.linalg.inv(covX1) + invCovY = np.linalg.inv(covY) + At_CovY = np.dot( np.transpose(A_X2), invCovY ) + At_CovY_A = np.dot( At_CovY , A_X2 ) + + covXnew = np.linalg.inv( At_CovY_A + invCovX0 + invCovX1 ) + Xnew = np.dot( covXnew , ( np.dot(At_CovY, Y - idealY_X2) + np.dot(At_CovY_A, X2) + np.dot(invCovX0,X0) + np.dot(invCovX1, X1) ) ) + + return (Xnew, covXnew) + + +def newDensity2(X, covX, A, Y, covY): + + measresidual = Y - np.dot( A, X ) + residualcov = np.dot(np.dot(A , covX ), np.transpose(A)) + covY + G = np.dot( covX, np.dot( np.transpose(A) , np.linalg.inv ( residualcov ) ) ) + Xnew = X + np.dot( G, measresidual ) + covXnew = covX - np.dot(G, np.dot( A, covX ) ) + + return (Xnew, covXnew) + + + +def gaussImgCovariance_2(im, powerDropoff=1.0, frac=1.): + + eps = 0.001 + + init_x, init_y, flowbasis_x, flowbasis_y, initTheta = affineMotionBasis(im) + ufull, vfull = genFreqComp(im) + shiftMtx = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse) + uvdist = np.reshape( np.sqrt(ufull**2 + vfull**2), (ufull.shape[0]) ) + eps + uvdist = uvdist / np.min(uvdist) + uvdist[0] = 'Inf' + + #shiftMtx = np.dot(shiftMtx, np.diag(im.imvec) ) + shiftMtx_exp = realimagStack(shiftMtx) + uvdist_exp = np.concatenate( (uvdist, uvdist), axis=0) + + imCov = np.dot( np.transpose(shiftMtx_exp) , np.dot( np.diag( 1/(uvdist_exp**powerDropoff) ), shiftMtx_exp ) ) + imCov = frac**2 * np.dot( np.diag(im.imvec).T, np.dot(imCov/imCov[0,0], np.diag(im.imvec) ) ); + return imCov + +def gaussImgCovariance(im, pixelStdev=1.0, powerDropoff=1.0, filter='none', kernsig=3.0): + + eps = 0.001 + + init_x, init_y, flowbasis_x, flowbasis_y, initTheta = affineMotionBasis(im) + ufull, vfull = genFreqComp(im) + shiftMtx = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse) + uvdist = np.reshape( np.sqrt(ufull**2 + vfull**2), (ufull.shape[0]) ) + eps + uvdist = uvdist / np.min(uvdist) + uvdist[0] = 'Inf' + + if filter == 'hamming': + hammingwindow = np.dot(np.reshape(np.hamming(im.xdim), (im.xdim,1)), np.reshape(np.hamming(im.ydim) , (1, im.ydim)) ) + shiftMtx = np.dot(shiftMtx, np.diag(np.reshape(hammingwindow, (im.xdim*im.ydim))) ) + if filter == 'gaussian': + gausswindow = gkern(kernlen=im.xdim, nsig=kernsig) + shiftMtx = np.dot(shiftMtx, np.diag(np.reshape(gausswindow, (im.xdim*im.ydim))) ) + + shiftMtx_exp = realimagStack(shiftMtx) + uvdist_exp = np.concatenate( (uvdist, uvdist), axis=0) + + imCov = np.dot( np.transpose(shiftMtx_exp) , np.dot( np.diag( 1/(uvdist_exp**powerDropoff) ), shiftMtx_exp ) ) + imCov = pixelStdev**2 * (imCov/imCov[0,0]); + return imCov + +###################################### BASIS ######################################## + +def affineMotionBasis(im): + + xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0 + ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0 + + init_x = np.array([[ [0] for i in xlist] for j in ylist]) + init_y = np.array([[ [0] for i in xlist] for j in ylist]) + + flowbasis_x = np.array([[ [i, j ,im.psize, 0, 0, 0] for i in xlist] for j in ylist]) + flowbasis_y = np.array([[ [0, 0, 0, i, j, im.psize] for i in xlist] for j in ylist]) + initTheta = np.array([1, 0, 0, 0, 1, 0]) + + return (init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + +def affineMotionBasis_noTranslation(im): + + xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0 + ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0 + + init_x = np.array([[ [0] for i in xlist] for j in ylist]) + init_y = np.array([[ [0] for i in xlist] for j in ylist]) + + flowbasis_x = np.array([[ [i, j, 0, 0] for i in xlist] for j in ylist]) + flowbasis_y = np.array([[ [0, 0, i, j] for i in xlist] for j in ylist]) + initTheta = np.array([1, 0, 0, 1]) + + return (init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + +def translationBasis(im): + + xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0 + ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0 + + init_x = np.array([[ [i] for i in xlist] for j in ylist]) + init_y = np.array([[ [j] for i in xlist] for j in ylist]) + + flowbasis_x = np.array([[ [im.psize, 0.0] for i in xlist] for j in ylist]) + flowbasis_y = np.array([[ [0.0, im.psize] for i in xlist] for j in ylist]) + initTheta = np.array([0.0, 0.0]) + + return (init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + +def xTranslationBasis(im): + + xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0 + ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0 + + init_x = np.array([[ [i] for i in xlist] for j in ylist]) + init_y = np.array([[ [j] for i in xlist] for j in ylist]) + + flowbasis_x = np.array([[ [im.psize] for i in xlist] for j in ylist]) + flowbasis_y = np.array([[ [0.0] for i in xlist] for j in ylist]) + initTheta = np.array([0.0]) + + return (init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + +def dftMotionBasis(im): + + DFTBASIS_THRESH = 0.1 + print('WARNING SOMETHING ISNT RIGHT WITH THE DFT GEN') + + init_x, init_y, flowbasis_x, flowbasis_y, initTheta = affineMotionBasis(im) + ufull, vfull = genFreqComp(im) + shiftMtx = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse) + uvdist = np.reshape( np.sqrt(ufull**2 + vfull**2), (ufull.shape[0]) ) + uvdist_norm = uvdist / np.max(uvdist) + + halfbasis = shiftMtx[uvdist_norm < DFTBASIS_THRESH,:] + fullbasis = np.concatenate( ( np.real( halfbasis[1:,:] ) , np.imag( halfbasis[1:,:] ) ), axis = 0) + fullbasis = np.reshape(fullbasis.T, [im.xdim, im.ydim, fullbasis.shape[0]]) * im.psize + + flowbasis_x = np.concatenate( (fullbasis, np.zeros(fullbasis.shape)), axis=2 ) + flowbasis_y = np.concatenate( (np.zeros(fullbasis.shape), fullbasis), axis=2 ) + + xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0 + ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0 + + init_x = np.array([[ [i] for i in xlist] for j in ylist]) + init_y = np.array([[ [j] for i in xlist] for j in ylist]) + + initTheta = np.zeros(flowbasis_x.shape[2]) + + return (init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + + +def applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, theta): + + imsize = flowbasis_x.shape[0:2] + nbasis = theta.shape[0] + npixels = np.prod(imsize) + + flow_x = init_x[:,:,0] + np.reshape( np.dot( np.reshape(flowbasis_x, (npixels, nbasis), order ='F'), theta ), imsize, order='F') + flow_y = init_y[:,:,0] + np.reshape( np.dot( np.reshape(flowbasis_y, (npixels, nbasis), order ='F'), theta ), imsize, order='F') + + return (flow_x, flow_y) + + +###################################### FULL WARPING ######################################## + +def applyWarp(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase'): + + if method=='phase': + outim = applyPhaseWarp(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + else: + outim = applyImageWarp(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + return outim + + +def calcWarpMtx(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase', normalize=False): + + npixels = im.xdim*im.ydim + + if method=='phase': + ufull, vfull = genFreqComp(im) + shiftMtx0 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, theta, im.psize, im.pulse) + shiftMtx1 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse) + + #outMtx = np.real( np.dot(np.transpose(np.conj(shiftMtx1)), shiftMtx0 ) / (npixels) ) + outMtx = np.real( np.dot( np.linalg.inv(shiftMtx1), shiftMtx0 ) ) + + elif method=='img': + probeim = im.copy() + outMtx = np.zeros((npixels, npixels)) + for i in range(0,npixels): + probeim.imvec = np.zeros(probeim.imvec.shape) + probeim.imvec[i] = 1.0 + outprobeim = applyImageWarp(probeim, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, pad=True) + outMtx[:,i] = outprobeim.imvec + + outMtx = np.nan_to_num(outMtx) + + if normalize: + for i in range(0,npixels): + outMtx[i,:] = outMtx[i,:]/np.sum(outMtx[i,:]) + + return outMtx + +def applyImageWarp(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, pad=False): + + flow_x_orig, flow_y_orig = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + flow_x_new, flow_y_new = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, theta) + + from_pts = np.concatenate( ( reshapeFlowbasis(flow_x_new), reshapeFlowbasis(flow_y_new) ), axis=1) + to_pts = np.concatenate( ( reshapeFlowbasis(flow_x_orig), reshapeFlowbasis(flow_y_orig) ), axis=1) + im_pts = im.imvec + + # add padding of 0's around the warped image so that the griddata function works properly + if pad: + # add padding on the x axis + for i in range(0,im.xdim): + + vec_x = flow_x_new[1,i] - flow_x_new[0,i] + vec_y = flow_y_new[1,i] - flow_y_new[0,i] + from_pts = np.row_stack( (from_pts, np.array( [ flow_x_new[0,i]-vec_x, flow_y_new[0,i]-vec_y ] ) ) ) + im_pts = np.concatenate( ( im_pts, np.array([0.0]) ), axis = 0 ) + + vec_x = flow_x_new[im.ydim-2,i] - flow_x_new[im.ydim-1,i] + vec_y = flow_y_new[im.ydim-2,i] - flow_y_new[im.ydim-1,i] + from_pts = np.row_stack( (from_pts, np.array( [ flow_x_new[im.ydim-1,i]-vec_x, flow_y_new[im.ydim-1,i]-vec_y ] ) ) ) + im_pts = np.concatenate( ( im_pts, np.array([0.0]) ), axis = 0 ) + # add padding on the y axis + for i in range(0,im.ydim): + + vec_x = flow_x_new[i,1] - flow_x_new[i,0] + vec_y = flow_y_new[i,1] - flow_y_new[i,0] + from_pts = np.row_stack( (from_pts, np.array( [ flow_x_new[i,0]-vec_x, flow_y_new[i,0]-vec_y ] ) ) ) + im_pts = np.concatenate( ( im_pts, np.array([0.0]) ), axis = 0 ) + + vec_x = flow_x_new[i,im.xdim-2] - flow_x_new[i,im.xdim-1] + vec_y = flow_y_new[i,im.xdim-2] - flow_y_new[i,im.xdim-1] + from_pts = np.row_stack( (from_pts, np.array( [ flow_x_new[i,im.xdim-1]-vec_x, flow_y_new[i,im.xdim-1]-vec_y ] ) ) ) + im_pts = np.concatenate( ( im_pts, np.array([0.0]) ), axis = 0 ) + + #npixels = flowbasis_x.shape[0]*flowbasis_x.shape[1] + out = scipy.interpolate.griddata( from_pts , im_pts, to_pts , method='linear', fill_value=0.0 ) + #out = scipy.interpolate.griddata( np.concatenate( ( np.reshape(flow_x_new, (npixels, -1)), np.reshape(flow_y_new, (npixels, -1)) ), axis=1) , im.imvec, np.concatenate( ( np.reshape(flow_x_orig, (npixels, -1)), np.reshape(flow_y_orig, (npixels, -1)) ), axis=1), method='linear', fill_value=0.0 ) + outim = image.Image(np.reshape(out, (im.ydim, im.xdim)), im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse) + return outim + + +def applyPhaseWarp(im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta): + + ufull, vfull = genFreqComp(im) + + shiftMtx0 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, theta, im.psize, im.pulse) + shiftMtx1 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse) + + #out = np.real( np.dot(np.transpose(np.conj(shiftMtx1)), np.dot(shiftMtx0, im.imvec) ) ) / (im.xdim * im.ydim) + out = np.real( np.dot( np.linalg.inv(shiftMtx1) , np.dot(shiftMtx0, im.imvec) ) ) + outim = image.Image(np.reshape(out, (im.ydim, im.xdim)), im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse) + return outim + +def genPhaseShiftMtx(ulist, vlist, init_x, init_y, flowbasis_x, flowbasis_y, theta, pdim, pulse=ehtim.observing.pulses.deltaPulse2D): + + flow_x, flow_y = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, theta) + + imsize = flow_x.shape + npixels = np.prod(imsize) + + flow_x_vec = np.reshape(flow_x, (npixels)) + flow_y_vec = np.reshape(flow_y, (npixels)) + + shiftMtx_y = np.exp( [-1j * 2.0 * np.pi * flow_y_vec * v for v in vlist] ) + shiftMtx_x = np.exp( [-1j * 2.0 * np.pi * flow_x_vec * u for u in ulist] ) + + uvlist = np.transpose(np.squeeze(np.array([ulist, vlist]))) + uvlist = np.reshape(uvlist, (vlist.shape[0], 2)) + pulseVec = [pulse(2*np.pi*uv[0], 2*np.pi*uv[1], pdim, dom="F") for uv in uvlist ] + + shiftMtx = np.dot( np.diag(pulseVec) , np.reshape( np.squeeze(shiftMtx_x * shiftMtx_y), (vlist.shape[0], npixels) ) ) + return shiftMtx + +################################### APPROXIMATE SHIFTING ########################################### + +def calc_dWarp_dTheta(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase'): + + if method == 'phase': + + ufull, vfull = genFreqComp(im) + + derivShiftMtx_x, derivShiftMtx_y = calcDerivShiftMtx_freq(ufull, vfull, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, includeImgFlow=False) + + shiftMtx1 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse) + invShiftMtx1 = np.linalg.inv(shiftMtx1) + + flowbasis = np.concatenate((reshapeFlowbasis(flowbasis_x), reshapeFlowbasis(flowbasis_y)), axis=0) + + reshape_flowbasis_x = reshapeFlowbasis(flowbasis_x) + reshape_flowbasis_y = reshapeFlowbasis(flowbasis_y) + + + dWarp_dTheta = []; + for b in range(0, flowbasis.shape[1]): + K = np.dot( derivShiftMtx_x, np.diag( reshape_flowbasis_x[:,b] ) ) + np.dot( derivShiftMtx_y, np.diag( reshape_flowbasis_y[:,b] ) ) + dWarp_dTheta.append( np.real( np.dot( invShiftMtx1 , K ) ) ) + + + else: + print('WARNING: we do not handle this method yet') + + return dWarp_dTheta + + + +def applyAppxWarp(im, theta, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method1='phase', method2='phase'): + + centerIm, dImg_dTheta = calAppxWarpTerms(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method1=method1, method2=method2) + out = centerIm.imvec + np.dot(dImg_dTheta, theta - centerTheta) + outim = image.Image(np.reshape(out, (im.ydim, im.xdim)), im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse) + return outim + +def calAppxWarpTerms(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method1='phase', method2='phase'): + + centerIm = applyWarp(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method1) + dImg_dTheta = calc_dImage_dTheta(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method=method2) + return (centerIm, dImg_dTheta) + +def calc_dImage_dTheta(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase'): + + if method == 'phase': + ufull, vfull = genFreqComp(im) + thetaDerivShiftMtx = calcDerivShiftMtx_freq(ufull, vfull, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, includeImgFlow=True) + shiftMtx1 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse) + #dImg_dTheta = np.real( np.dot( np.transpose(np.conj(shiftMtx1)) , thetaDerivShiftMtx) / (im.xdim * im.ydim) ) + dImg_dTheta = np.real( np.dot( np.linalg.inv(shiftMtx1) , thetaDerivShiftMtx) ) + else: + dImg_dTheta = calcDerivShiftMtx_image(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + + return dImg_dTheta + +def calcDerivWarpMtx_noimg_noshift(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method='phase'): + + if method == 'phase': + + ufull, vfull = genFreqComp(im) + #npixels = im.xdim*im.ydim + + freqShiftMtx_x, freqShiftMtx_y = calcDerivShiftMtx_freq(ufull, vfull, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, includeImgFlow=False) + shiftMtx1 = genPhaseShiftMtx(ufull, vfull, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, im.psize, im.pulse) + + #derivShiftMtx_y = np.real( np.dot( np.transpose(np.conj(shiftMtx1)) , freqShiftMtx_y) / npixels ) + #derivShiftMtx_x = np.real( np.dot( np.transpose(np.conj(shiftMtx1)) , freqShiftMtx_x) / npixels ) + invShiftMtx1 = np.linalg.inv(shiftMtx1) + derivShiftMtx_y = np.real( np.dot( invShiftMtx1 , freqShiftMtx_y) ) + derivShiftMtx_x = np.real( np.dot( invShiftMtx1 , freqShiftMtx_x) ) + else: + if (centerTheta != initTheta).any(): + raise ValueError('Can only take the optical flow derivative around no shift') + + derivShiftMtx_x = -gradMtx(im.xdim, im.ydim, dir='x')/im.psize + derivShiftMtx_y = -gradMtx(im.xdim, im.ydim, dir='y')/im.psize + + return (derivShiftMtx_x, derivShiftMtx_y) + + +def calcDerivShiftMtx_freq(ulist, vlist, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, includeImgFlow=True): + + shiftMtx = genPhaseShiftMtx(ulist, vlist, init_x, init_y, flowbasis_x, flowbasis_y, centerTheta, im.psize, im.pulse) + + shiftVec_y = np.array( [-1j * 2.0 * np.pi * v * np.ones(im.xdim*im.ydim) for v in vlist] ) + shiftVec_x = np.array( [-1j * 2.0 * np.pi * u * np.ones(im.xdim*im.ydim) for u in ulist] ) + derivShiftMtx_x = shiftVec_x * shiftMtx + derivShiftMtx_y = shiftVec_y * shiftMtx + + + if includeImgFlow: + # TODO: WARNING DOES THIS HANDLE THE INIT_X INIT_Y + print('WARNING: is this handling the init x and y??') + + flowbasis = np.concatenate((reshapeFlowbasis(flowbasis_x), reshapeFlowbasis(flowbasis_y)), axis=0) + derivShiftMtx = np.concatenate( ( np.dot(derivShiftMtx_x, np.diag(im.imvec)) , np.dot(derivShiftMtx_y, np.diag(im.imvec)) ) , axis=1) + thetaDerivShiftMtx = np.dot(derivShiftMtx, flowbasis) + return thetaDerivShiftMtx + else: + return (derivShiftMtx_x, derivShiftMtx_y) + +def calcDerivShiftMtx_image(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta): +# out = im.imvec + np.dot(gradIm_x, theta - initTheta ) + np.dot(gradIm_y, theta - initTheta ) + + centerImg = applyImageWarp(im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + + Gx = -gradMtx(im.xdim, im.ydim, dir='x')/im.psize + Gy = -gradMtx(im.xdim, im.ydim, dir='y')/im.psize + + gradIm_x = np.diag( np.dot(Gx, centerImg.imvec) ) + gradIm_y = np.diag( np.dot(Gy, centerImg.imvec) ) + + # TODO: WARNING DOES THIS HANDLE THE INIT_X INIT_Y + print('WARNING: is this handling the init x and y??') + + flowbasis = np.concatenate((reshapeFlowbasis(flowbasis_x), reshapeFlowbasis(flowbasis_y)), axis=0) + derivShiftMtx = np.concatenate( (gradIm_x, gradIm_y), axis=1) + + thetaDerivShiftMtx = np.dot(derivShiftMtx, flowbasis) + return thetaDerivShiftMtx + +def calcAppxWarpMtx_image(im, init_x, init_y, flowbasis_x, flowbasis_y, theta, initTheta): + + flow_x_orig, flow_y_orig = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + flow_x_new, flow_y_new = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, theta) + + flow_x = flow_x_new - flow_x_orig + flow_y = flow_y_new - flow_y_orig + + Gx = -gradMtx(im.xdim, im.ydim, dir='x') + Gy = -gradMtx(im.xdim, im.ydim, dir='y') + + derivx = np.dot( np.diag( vec(flow_x) ), Gx/im.psize) + derivy = np.dot( np.diag( vec(flow_y) ), Gy/im.psize) + + appxMtx = np.eye(im.xdim*im.ydim) + derivx + derivy + return appxMtx + + +###################################### FREQUENCY SPACE ######################################## + +def shiftVisibilities(obs, shiftX, shiftY): + obs.data['vis'] = obs.data['vis']*np.exp(-1j*2.0*np.pi*( obs.data['u']*shiftX + obs.data['v']*shiftY )) + return obs + +def genAppxShiftMtx(ulist, vlist, npixels, shiftMtx): + + derivShiftMtx_y = np.array( [-1j * 2.0 * np.pi * np.ones((npixels)) * v for v in vlist] ) * shiftMtx + derivShiftMtx_x = np.array( [-1j * 2.0 * np.pi * np.ones((npixels)) * u for u in ulist] ) * shiftMtx + + return (derivShiftMtx_x, derivShiftMtx_y) + +def genFreqComp(im): + + fN2 = int(np.floor(im.xdim/2)) #TODO: !!! THIS DOESNT WORK FOR ODD IMAGE SIZES + fM2 = int(np.floor(im.ydim/2)) + + ulist = (np.array([np.concatenate((np.linspace(0, fN2 - 1, fN2), np.linspace(-fN2, -1, fN2)), axis=0)]) / im.xdim ) / im.psize + vlist = (np.array([np.concatenate((np.linspace(0, fM2 - 1, fM2), np.linspace(-fM2, -1, fM2)), axis=0)]) / im.ydim ) / im.psize + + ufull, vfull = np.meshgrid(ulist, vlist) + + ufull = np.reshape(ufull, (im.xdim*im.ydim, -1), order='F') + vfull = np.reshape(vfull, (im.xdim*im.ydim, -1), order='F') + + return (ufull, vfull) + +def genPhaseShiftMtx_obs(obs, init_x, init_y, flowbasis_x, flowbasis_y, theta, pdim, pulse=ehtim.observing.pulses.deltaPulse2D): + ulist = obs.unpack('u')['u'] + vlist = obs.unpack('v')['v'] + + shiftMtx = genPhaseShiftMtx(ulist, vlist, init_x, init_y, flowbasis_x, flowbasis_y, theta, pdim, pulse) + return shiftMtx + + +def calcDerivShiftMtx_obs(obs, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y): + + ulist = obs.unpack('u')['u'] + vlist = obs.unpack('v')['v'] + + thetaDerivShiftMtx = calcDerivShiftMtx_freq(ulist, vlist, im, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y) + return thetaDerivShiftMtx + +def cmpFreqExtraction_phaseWarp(obs, im_true, im_canonical, theta, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta): + + data = obs.unpack(['u','v','vis','sigma']) + uv = np.hstack((data['u'].reshape(-1,1), data['v'].reshape(-1,1))) + A = ftmatrix(im_true.psize, im_true.xdim, im_true.ydim, uv, pulse=im_true.pulse) + + shiftMtx_true = genPhaseShiftMtx_obs(obs, init_x, init_y, flowbasis_x, flowbasis_y, theta, im_canonical.psize, im_canonical.pulse) + + shiftMtx_center = genPhaseShiftMtx_obs(obs, init_x, init_y, flowbasis_x, flowbasis_y, centerTheta, im_canonical.psize, im_canonical.pulse) + thetaDerivShiftMtx = calcDerivShiftMtx_obs(obs, im_canonical, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y) + + centerim, dImg_dTheta = calAppxWarpTerms(im_canonical, centerTheta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, method1='phase', method2='phase') + + chiSq_shift = 0.5 * np.mean( np.abs( (obs.data['vis'] - np.dot(shiftMtx_true, im_canonical.imvec ) )/ obs.data['sigma'] )**2 ) + chiSq_appxshift = 0.5 * np.mean( np.abs( (obs.data['vis'] - np.dot(shiftMtx_center, im_canonical.imvec) - np.dot(thetaDerivShiftMtx, theta-centerTheta) ) / obs.data['sigma'] )**2 ) + chiSq_true = 0.5 * np.mean( np.abs( (obs.data['vis'] - np.dot(A, im_true.imvec) ) / obs.data['sigma'] )**2 ) + + return (chiSq_true, chiSq_shift, chiSq_appxshift) + + +################################# HELPER FUNCTIONS ############################################# + + +def gkern(kernlen=21, nsig=3): + """Returns a 2D Gaussian kernel array.""" + + interval = (2*nsig+1.)/(kernlen) + x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1) + kern1d = np.diff(st.norm.cdf(x)) + kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) + kernel = kernel_raw/kernel_raw.sum() + return kernel + +def gradMtx(w, h, dir='x'): + + G = np.eye(h*w) + if dir=='x': + G = G - np.diag(np.ones((h*w-1)), k=1) + delrows = (np.linspace(w,h*w,h)-1).astype(int) + G[delrows,:] = 0 + else: + G = G - np.diag(np.ones((h*w-w)), k=h) + delrows = range(h*w-w,h*w) + G[delrows,:] = 0 + return G + + +def realimagStack(mtx): + stack = np.concatenate( ( np.real(mtx), np.imag(mtx) ), axis=0 ) + return stack + +def reshapeFlowbasis(flowbasis): + npixels = flowbasis.shape[0] * flowbasis.shape[1] + #nbasis = flowbasis.shape[2] + flowbasis = np.reshape(flowbasis, (npixels, -1) ) #, order ='F') + return flowbasis + +def vec(x, order='F'): + x = np.reshape(x, (np.prod(x.shape)), order = order) + return x + +def listconcatenate(*lists): + new_list = [] + for i in lists: + new_list.extend(i) + return new_list + +def padNewFOV(im, fov_arcseconds): + + oldfov = im.psize * im.xdim + newfov = fov_arcseconds * ehtim.RADPERUAS + tnpixels = np.ceil(im.xdim * newfov/oldfov).astype('int') + + origimg = np.reshape(im.imvec, [im.xdim, im.xdim]) + padimg = np.pad(origimg, ((0,tnpixels-im.xdim), (0,tnpixels-im.xdim)), 'constant') + padimg = np.roll(padimg, np.floor((tnpixels-im.xdim)/2.).astype('int'), axis=0) + padimg = np.roll(padimg, np.floor((tnpixels-im.xdim)/2.).astype('int'), axis=1) + + return image.Image(padimg.reshape((tnpixels, tnpixels)), im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse) + + +def flipImg(im, flip_lr, flip_ud): + + img = np.reshape(im.imvec, [im.xdim, im.xdim]) + if flip_lr: + img = np.fliplr(img) + if flip_ud: + img = np.flipud(img) + im.imec = img.reshape((im.xdim*im.xdim)) + return image.Image(img, im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse) + + +def rotateImg(im, k): + + img = np.reshape(im.imvec, [im.xdim, im.xdim]) + img = np.rot90(img, k=k) + im.imec = img.reshape((im.xdim*im.xdim)) + return image.Image(img, im.psize, im.ra, im.dec, rf=im.rf, source=im.source, mjd=im.mjd, pulse=im.pulse) + + +####################################### MICHAELS STUFF ####################################### + +def plot_im_List(im_List, title_List=[], ipynb=False): + + plt.title("Test", fontsize=20) + plt.ion() + plt.clf() + + Prior = im_List[0] + + for i in range(len(im_List)): + plt.subplot(1, len(im_List), i+1) + plt.imshow(im_List[i].imvec.reshape(Prior.ydim,Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + plt.axis('off') + xticks = ticks(Prior.xdim, Prior.psize/ehtim.RADPERAS/1e-6) + yticks = ticks(Prior.ydim, Prior.psize/ehtim.RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + if i == 0: + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + else: + plt.xlabel('') + plt.ylabel('') + #plt.title('') + if len(title_List)==len(im_List): + plt.title(title_List[i], fontsize = 5) + + + plt.draw() + + +def plot_Flow(Im, theta, init_x, init_y, flowbasis_x, flowbasis_y, initTheta, step=4, ipynb=False): + + # Get vectors and ratio from current image + x = np.array([[i for i in range(Im.xdim)] for j in range(Im.ydim)]) + y = np.array([[j for i in range(Im.xdim)] for j in range(Im.ydim)]) + + flow_x_new, flow_y_new = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, theta) + flow_x_orig, flow_y_orig = applyMotionBasis(init_x, init_y, flowbasis_x, flowbasis_y, initTheta) + + vx = -(flow_x_new - flow_x_orig) + vy = -(flow_y_new - flow_y_orig) + + # Create figure and title + plt.ion() + plt.clf() + + # Stokes I plot + plt.subplot(111) + plt.imshow(Im.imvec.reshape(Im.ydim, Im.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + plt.quiver(x[::step,::step], y[::step,::step], vx[::step,::step], vy[::step,::step], + headaxislength=3, headwidth=7, headlength=5, minlength=0, minshaft=1, + width=.005*Im.xdim/30., pivot='mid', color='w', angles='xy') + + xticks = ticks(Im.xdim, Im.psize/ehtim.RADPERAS/1e-6) + yticks = ticks(Im.ydim, Im.psize/ehtim.RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + plt.title('Flow Map') + #plt.ylim(plt.ylim()[::-1]) + # Display + plt.draw() + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/io/__init__.py b/io/__init__.py new file mode 100644 index 00000000..b7ba4935 --- /dev/null +++ b/io/__init__.py @@ -0,0 +1,11 @@ +""" +.. module:: ehtim.io + :platform: Unix + :synopsis: EHT Imaging Utilities: I/O functions + +.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu) + +""" +from . import load + +from ..const_def import * diff --git a/io/load.py b/io/load.py new file mode 100644 index 00000000..8e38a4f5 --- /dev/null +++ b/io/load.py @@ -0,0 +1,1837 @@ +# load.py +# functions to load observation & image data from files +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import astropy.io.fits as fits +import datetime +import os +import copy +import sys +import time as ttime +import h5py + +import ehtim.obsdata +import ehtim.image +import ehtim.array +import ehtim.movie +import ehtim.vex +import ehtim.observing + +import ehtim.io.oifits +import ehtim.const_def as ehc + +import warnings +warnings.filterwarnings("ignore", message="Mean of empty slice") +warnings.filterwarnings("ignore", message="invalid value encountered in true_divide") + +################################################################################################## +# Vex IO +################################################################################################## + + +def load_vex(fname): + """Read in .vex files. + Assumes there is only 1 MODE in vex file + Hotaka Shiokawa - 2017 + + Args: + fname (str): path to input .vex file + Returns: + vex (Vex): Vex file object + """ + print("Loading vexfile: ", fname) + return ehtim.vex.Vex(fname) + + +################################################################################################## +# Image IO +################################################################################################## +def load_im_txt(filename, pulse=ehc.PULSE_DEFAULT, polrep='stokes', pol_prim='I', zero_pol=True): + """Read in an image from a text file. + + Args: + filename (str): path to input text file + pulse (function): The function convolved with the pixel values for continuous image. + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + zero_pol (bool): If True, loads any missing polarizations as zeros + + Returns: + (Image): loaded image object + """ + + print("Loading text image: ", filename) + + # Read the header + file = open(filename) + src = ' '.join(file.readline().split()[2:]) + ra = file.readline().split() + ra = float(ra[2]) + float(ra[4]) / 60.0 + float(ra[6]) / 3600.0 + dec = file.readline().split() + dec = np.sign(float(dec[2])) * (abs(float(dec[2])) + + float(dec[4]) / 60.0 + float(dec[6]) / 3600.0) + mjd_float = float(file.readline().split()[2]) + mjd = int(mjd_float) + time = (mjd_float - mjd) * 24 + rf = float(file.readline().split()[2]) * 1e9 + xdim = file.readline().split() + xdim_p = int(xdim[2]) + psize_x = float(xdim[4]) * ehc.RADPERAS / xdim_p + ydim = file.readline().split() + ydim_p = int(ydim[2]) + psize_y = float(ydim[4]) * ehc.RADPERAS / ydim_p + file.close() + + if psize_x != psize_y: + raise Exception("Pixel dimensions in x and y are inconsistent!") + + # Load the data, convert to list format, make object + datatable = np.loadtxt(filename, dtype=float) + image = datatable[:, 2].reshape(ydim_p, xdim_p) + outim = ehtim.image.Image(image, psize_x, ra, dec, + rf=rf, source=src, mjd=mjd, time=time, pulse=pulse, + polrep='stokes', pol_prim='I') + + # Look for Stokes Q and U + qimage = uimage = vimage = np.zeros(image.shape) + if datatable.shape[1] == 6: + qimage = datatable[:, 3].reshape(ydim_p, xdim_p) + uimage = datatable[:, 4].reshape(ydim_p, xdim_p) + vimage = datatable[:, 5].reshape(ydim_p, xdim_p) + elif datatable.shape[1] == 5: + qimage = datatable[:, 3].reshape(ydim_p, xdim_p) + uimage = datatable[:, 4].reshape(ydim_p, xdim_p) + + if np.any((qimage != 0) + (uimage != 0)) and np.any((vimage != 0)): + # print('Loaded Stokes I, Q, U, and V Images') + outim.add_qu(qimage, uimage) + outim.add_v(vimage) + elif np.any((vimage != 0)): + # print('Loaded Stokes I and V Images') + outim.add_v(vimage) + if zero_pol: + outim.add_qu(0 * vimage, 0 * vimage) + elif np.any((qimage != 0) + (uimage != 0)): + # print('Loaded Stokes I, Q, and U Images') + outim.add_qu(qimage, uimage) + if zero_pol: + outim.add_v(0 * qimage) + else: + if zero_pol: + outim.add_qu(0 * image, 0 * image) + outim.add_v(0 * image) + # print('Loaded Stokes I Image Only') + + # Transform to desired pol rep + if not (polrep == 'stokes' and pol_prim == 'I'): + outim = outim.switch_polrep(polrep_out=polrep, pol_prim_out=pol_prim) + + return outim + + +def load_im_hdf5(filename): + """Read in an image from an hdf5 file. + Args: + filename (str): path to input hdf5 file + Returns: + (Image): loaded image object + """ + print("Loading hdf5 image: ", filename) + + # Load information from hdf5 file + + hfp = h5py.File(filename,'r') + dsource = hfp['header']['dsource'][()] # distance to source in cm + jyscale = hfp['header']['scale'][()] # convert cgs intensity -> Jy flux density + rf = hfp['header']['freqcgs'][()] # in cgs + tunit = hfp['header']['units']['T_unit'][()] # in seconds + lunit = hfp['header']['units']['L_unit'][()] # in cm + DX = hfp['header']['camera']['dx'][()] # in GM/c^2 + nx = hfp['header']['camera']['nx'][()] # width in pixels + time = hfp['header']['t'][()] * tunit / 3600. # time in hours + if 'pol' in hfp: + poldat = np.copy(hfp['pol'])[:, :, :4] # NX,NY,{I,Q,U,V} + else: # unpolarized data only + unpoldat = np.copy(hfp['unpol']) # NX,NY + poldat = np.zeros(list(unpoldat.shape)+[4]) + poldat[:,:,0] = unpoldat + hfp.close() + + # Correct image orientation + # unpoldat = np.flip(unpoldat.transpose((1, 0)), axis=0) + poldat = np.flip(poldat.transpose((1, 0, 2)), axis=0) + + # Make a guess at the source based on distance and optionally fall back on mass + src = ehc.SOURCE_DEFAULT + if dsource > 4.e25 and dsource < 6.2e25: + src = "M87" + elif dsource > 2.45e22 and dsource < 2.6e22: + src = "SgrA" + + # Fill in information according to the source + ra = ehc.RA_DEFAULT + dec = ehc.DEC_DEFAULT + if src == "SgrA": + ra = 17.76112247 + dec = -28.992189444 + elif src == "M87": + ra = 187.70593075 + dec = 12.391123306 + + # Process image to set proper dimensions + fovmuas = DX / dsource * lunit * 2.06265e11 + psize_x = ehc.RADPERUAS * fovmuas / nx + + Iim = poldat[:, :, 0] * jyscale + Qim = poldat[:, :, 1] * jyscale + Uim = poldat[:, :, 2] * jyscale + Vim = poldat[:, :, 3] * jyscale + + outim = ehtim.image.Image(Iim, psize_x, ra, dec, rf=rf, source=src, + polrep='stokes', pol_prim='I', time=time) + outim.add_qu(Qim, Uim) + outim.add_v(Vim) + + return outim + + +def load_im_fits(filename, aipscc=False, pulse=ehc.PULSE_DEFAULT, + punit="deg", polrep='stokes', pol_prim=None, zero_pol=True): + """Read in an image from a FITS file. + + Args: + fname (str): path to input fits file + aipscc (bool): if True, then AIPS CC table will be loaded + pulse (function): The function convolved with the pixel values for continuous image. + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U,V for Stokes, RR,LL,LR,RL for circ + zero_pol (bool): If True, loads any missing polarizations as zeros + + Returns: + (Image): loaded image object + """ + + print("Loading fits image: ", filename) + + # Radian or Degree? + if punit == "deg": + pscl = ehc.DEGREE + elif punit == "rad": + pscl = 1.0 + elif punit == "uas": + pscl = ehc.RADPERUAS + elif punit == "mas": + pscl = ehc.RADPERUAS * 1000.0 + + # Open the FITS file + hdulist = fits.open(filename) + + # Assume stokes I is the primary hdu + header = hdulist[0].header + + # Read some header values + xdim_p = header['NAXIS1'] + psize_x = np.abs(header['CDELT1']) * pscl + ydim_p = header['NAXIS2'] + psize_y = np.abs(header['CDELT2']) * pscl + + if 'OBSRA' in list(header.keys()): + ra = header['OBSRA'] * 12 / 180. + elif 'CRVAL1' in list(header.keys()): + ra = header['CRVAL1'] * 12 / 180. + else: + ra = 0. + + if 'OBSDEC' in list(header.keys()): + dec = header['OBSDEC'] + elif 'CRVAL2' in list(header.keys()): + dec = header['CRVAL2'] + else: + dec = 0. + + if 'FREQ' in list(header.keys()): + rf = header['FREQ'] + elif 'CRVAL3' in list(header.keys()): + rf = header['CRVAL3'] + else: + rf = 0. + + if 'MJD' in list(header.keys()): + mjd_float = header['MJD'] + else: + mjd_float = 0. + mjd = int(mjd_float) + time = (mjd_float - mjd) * 24 + + if 'OBJECT' in list(header.keys()): + src = header['OBJECT'] + else: + src = '' + + # Get the image and create the object + data = hdulist[0].data + + # Check for multiple stokes in top hdu + stokes_in_hdu0 = False + if len(data.shape) == 4: + print("reading all stokes images from top HDU -- assuming IQUV order") + stokesdata = data[:4,:,:,:] # ignore fields after the first 4 + data = stokesdata[0, 0] + stokes_in_hdu0 = True + + elif len(data.shape) == 3: + # ANDREW added this for BHAC models 3/22/23 + print("reading all stokes images from top HDU -- assuming IQUV order") + stokesdata = data[:4,:,:] # ignore fields after the first 4 + stokesdata = stokesdata.reshape(4,-1,stokesdata.shape[-2],stokesdata.shape[-1]) + data = stokesdata[0, 0] + stokes_in_hdu0 = True + + #data = data.reshape((data.shape[-2], data.shape[-1])) + data = data.reshape((data.shape[-2], data.shape[-1])) + + # Update the image using the AIPS CC table + if aipscc: + try: + aipscctab = hdulist["AIPS CC"] + except BaseException: + print("Input FITS file does not have an AIPS CC table. Loading image instead.") + aipscc = False + + if aipscc: + + print("loading the AIPS CC table.") + print("force the pulse function to be the delta function.") + pulse = ehtim.observing.pulses.deltaPulse2D + + # get the source brightness brightness ifromation + flux = aipscctab.data["FLUX"] + deltax = aipscctab.data["DELTAX"] + deltay = aipscctab.data["DELTAY"] + + # check to make sure all the source types are point sources and gaussian components + try: + checkmtype = np.abs(np.unique(aipscctab.data["TYPE OBJ"])) < 2.0 + if False in checkmtype.tolist(): + errmsg = "The primary AIPS CC table in the input FITS file has non point-source" + errmsg += " or Gaussian Source CC components, which are not currently supported." + raise ValueError(errmsg) + point_src = aipscctab.data["TYPE OBJ"] == 0 + gaussian_src = aipscctab.data["TYPE OBJ"] == 1 + except(KeyError): + print("Cannot load AIPS CC Table OBJ data -- assuming all CC components are point sources!") + point_src = np.ones(aipscctab.data.shape).astype(bool) + gaussian_src = np.zeros(aipscctab.data.shape).astype(bool) + print("%d CC components are loaded." % (len(flux))) + + # compile the point source aipscc info + flux_ps = flux[point_src] + deltax_ps = deltax[point_src] + deltay_ps = deltay[point_src] + + # compile the gaussian aipscc info, if any + if np.any(gaussian_src): + flux_gs = flux[gaussian_src] + deltax_gs = deltax[gaussian_src] + deltay_gs = deltay[gaussian_src] + maj_gs = aipscctab.data["MAJOR AX"][gaussian_src] + min_gs = aipscctab.data["MINOR AX"][gaussian_src] + pa_gs = aipscctab.data["POSANGLE"][gaussian_src] + else: + flux_gs = [] + + # the map_coordinates delta x / delta y of each delta CC component are + # relative to the reference pixel which is defined by CRPIX1 and CRPIX2. + try: + Nxref = header.get("CRPIX1") + except BaseException: + Nxref = header.get("NAXIS1") // 2 + 1 + try: + Nyref = header.get("CRPIX2") + except BaseException: + Nyref = header.get("NAXIS2") // 2 + 1 + + # compute the corresponding index of pixel for each deltax_ps / deltay_ps + ix = np.array(np.int64(np.round(deltax_ps / header.get("CDELT1") + Nxref - 1))) + iy = np.array(np.int64(np.round(deltay_ps / header.get("CDELT2") + Nyref - 1))) + + # reset the image and input flux information + data[:, :] = 0. + Noutcomp = 0 + for i in range(len(flux_ps)): + try: + data[iy[i], ix[i]] += flux_ps[i] + except BaseException: + Noutcomp += 1 + print("added %d CC delta components." % (len(flux_ps))) + if Noutcomp > 0: + print("%d CC delta components are outside of the FoV and ignored." % (Noutcomp)) + + # flip y-axis! + image = data[::-1, :] + + # normalize the flux + normalizer = 1.0 + if 'BUNIT' in list(header.keys()): + if header['BUNIT'].lower() == 'JY/BEAM'.lower(): + + print("converting Jy/Beam --> Jy/pixel") + bmaj = bmin = 1.0 # default values + + if 'BMAJ' in list(header.keys()): + bmaj = header['BMAJ'] + bmin = header['BMIN'] + + elif 'HISTORY' in list(header.keys()): # Alternate option, to read AIPS fits images + print("No beam info in header; reading from AIPS HISTORY instead...") + for line in header['HISTORY']: + if 'BMAJ' in line and len(line.split())>6: + bmaj = float(line.split()[3]) + bmin = float(line.split()[5]) + + if bmaj==1.0 and bmin==1.0: + print("No beam info found! Assuming nominal values for conversion.") + bmaj = bmin = 1.0 + + beamarea = (2.0 * np.pi * bmaj * bmin / (8.0 * np.log(2))) + normalizer = (header['CDELT2'])**2 / beamarea + + if aipscc: + print("the computed normalizer will not be applied since we are loading the AIPS CC table") + else: + image *= normalizer + + # make image object in Stokes I + outim = ehtim.image.Image(image, psize_x, ra, dec, + rf=rf, source=src, mjd=mjd, time=time, pulse=pulse, + polrep='stokes', pol_prim='I') + + # add gaussian components to the image from the aipscc table + if aipscc and len(flux_gs): + Noutcomp = 0 + for i in range(len(flux_gs)): + # make sure the aipscc table gaussian is within the FOV + if ((deltax_gs[i] * ehc.DEGREE < outim.fovx() / 2.0) and + (deltay_gs[i] * ehc.DEGREE < outim.fovy() / 2.0) and + (maj_gs[i] * ehc.DEGREE * 3 < (np.min([outim.fovx(), outim.fovy()]) / 2.0))): + # add a gaussian component with the specified flux and location + outim = outim.add_gauss(flux_gs[i], (maj_gs[i] * ehc.DEGREE, min_gs[i] * ehc.DEGREE, + pa_gs[i] * ehc.DEGREE, + deltax_gs[i] * ehc.DEGREE, + deltay_gs[i] * ehc.DEGREE)) + else: + Noutcomp += 1 + print("added %d CC gaussian components." % (len(flux_gs))) + if Noutcomp > 0: + print("%d CC gaussian components are outside of the FoV and ignored." % (Noutcomp)) + + # Look for Stokes Q and U and V + qimage = uimage = vimage = np.array([]) + + if stokes_in_hdu0: # stokes in top HDU + try: + qdata = stokesdata[1, 0].reshape((data.shape[-2], data.shape[-1])) + qimage = normalizer * qdata[::-1, :] # flip y-axis! + except IndexError: + pass + try: + udata = stokesdata[2, 0].reshape((data.shape[-2], data.shape[-1])) + uimage = normalizer * udata[::-1, :] # flip y-axis! + except IndexError: + pass + try: + vdata = stokesdata[3, 0].reshape((data.shape[-2], data.shape[-1])) + vimage = normalizer * vdata[::-1, :] # flip y-axis! + except IndexError: + pass + + else: # stokes in different HDUS + for hdu in hdulist[1:]: + header = hdu.header + data = hdu.data + try: + data = data.reshape((data.shape[-2], data.shape[-1])) + except IndexError: + continue + + if 'STOKES' in list(header.keys()) and header['STOKES'] == 'Q': + qimage = normalizer * data[::-1, :] # flip y-axis! + if 'STOKES' in list(header.keys()) and header['STOKES'] == 'U': + uimage = normalizer * data[::-1, :] # flip y-axis! + if 'STOKES' in list(header.keys()) and header['STOKES'] == 'V': + vimage = normalizer * data[::-1, :] # flip y-axis! + + if qimage.shape == uimage.shape == vimage.shape == image.shape: + # print('Loaded Stokes I, Q, U, and V Images') + outim.add_qu(qimage, uimage) + outim.add_v(vimage) + elif vimage.shape == image.shape: + # print('Loaded Stokes I and V Images') + outim.add_v(vimage) + if zero_pol: + outim.add_qu(0 * vimage, 0 * vimage) + elif qimage.shape == uimage.shape == image.shape: + # print('Loaded Stokes I, Q, and U Images') + outim.add_qu(qimage, uimage) + if zero_pol: + outim.add_v(0 * qimage) + else: + if zero_pol: + outim.add_qu(0 * image, 0 * image) + outim.add_v(0 * image) + # print('Loaded Stokes I Image Only') + + # Transform to desired pol rep + if not (polrep == 'stokes' and pol_prim == 'I'): + outim = outim.switch_polrep(polrep_out=polrep, pol_prim_out=pol_prim) + + return outim + +################################################################################################## +# Movie IO +################################################################################################## + +# Old version for arizona datasets +# def load_movie_hdf5(file_name, framedur_sec=1, psize=-1, +# ra=17.761122472222223, dec=-28.992189444444445, rf=230e9, source='SgrA', +# pulse=ehc.PULSE_DEFAULT, polrep='stokes', pol_prim=None, zero_pol=True): + + +# """Read in a movie from an hdf5 file and create a Movie object. + +# Args: +# file_name (str): The name of the hdf5 file. +# framedur_sec (float): The frame duration in seconds +# psize (float): Pixel size in radian +# ra (float): The movie right ascension +# dec (float): The movie declination +# rf (float): The movie frequency +# pulse (function): The function convolved with the pixel values for continuous image +# polrep (str): polarization representation, either 'stokes' or 'circ' +# pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular +# zero_pol (bool): If True, loads any missing polarizations as zeros + +# Returns: +# Movie: a Movie object +# """ + +# # TODO: Currently only supports one polarization! +# file = h5py.File(file_name, 'r') +# name = list(file.keys())[0] +# d = file[str(name)] +# frames = d[:] +# file.close() + +# # TODO: currently no frame times stored in hdf5! +# framedur_hr = framedur/3600. +# mjd0 = ehc.MJD_DEFAULT +# hour0 = 0 +# nframes = len(frames) +# tstart = hour0 +# tstop = hour0 + framedur_hr*nframes +# times = np.linspace(tstart, tstop, nframes) + +# movie = Movie(frames, times, +# psize, ra, dec, rf=rf, +# polrep=polrep, pol_prim=pol_prim, +# source=source, mjd=ehc.MJD_DEFAULT, pulse=pulse) + +# if zero_pol: +# for pol in list(movie._movdict.keys()): +# if pol==movie.pol_prim: continue +# polframes = np.zeros(frames.shape) +# newmov.add_pol_movie(polframes, pol) + +# return movie + + +def load_movie_hdf5(file_name, pulse=ehc.PULSE_DEFAULT, interp=ehc.INTERP_DEFAULT, + bounds_error=ehc.BOUNDS_ERROR): + """Read in a movie from an hdf5 file and create a Movie object. + + Args: + file_name (str): The name of the hdf5 file. + framedur_sec (float): The frame duration in seconds (overwrites internal timestamps) + pulse (function): The function convolved with the pixel values for continuous image + interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword + bounds_error (bool): if False, return nearest frame outside [start_hr, stop_hr] + + Returns: + Movie: a Movie object + """ + + # TODO: Currently only supports one polarization! + with h5py.File(file_name, 'r') as file: + + head = file['header'] + + mjd = int(head.attrs['mjd'].astype(str)) + psize = float(head.attrs['psize'].astype(str)) + source = head.attrs['source'].astype(str) + ra = float(head.attrs['ra'].astype(str)) + dec = float(head.attrs['dec'].astype(str)) + rf = float(head.attrs['rf'].astype(str)) + polrep = head.attrs['polrep'].astype(str) + pol_prim = head.attrs['pol_prim'].astype(str) + + times = file['times'][:] + frames = file[pol_prim][:] + + movie = ehtim.movie.Movie(frames, times, + psize, ra, dec, rf=rf, + interp=interp, bounds_error=bounds_error, + polrep=polrep, pol_prim=pol_prim, + source=source, mjd=mjd, pulse=pulse) + + if polrep == 'stokes': + keys = ['I', 'Q', 'U', 'V'] + elif polrep == 'circ': + keys = ['RR', 'LL', 'RL', 'LR'] + else: + raise Exception("hdf5 polrep is not 'circ' or 'stokes'!") + + for pol in keys: + if pol == movie.pol_prim: + continue + if pol in file.keys(): + polframes = file[pol][:] + if len(polframes): + movie.add_pol_movie(polframes, pol) + file.close() + + return movie + + +def load_movie_txt(basename, nframes, framedur=-1, pulse=ehc.PULSE_DEFAULT, + polrep='stokes', pol_prim=None, zero_pol=True, + interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR): + """Read in a movie from text files and create a Movie object. + + Args: + basename (str): The base name of individual movie frames. + Files should have names basename + 00001, etc. + nframes (int): The total number of frames + framedur (float): The frame duration in seconds + (default = -1, and framedur is taken from file headers) + pulse (function): The function convolved with the pixel values for continuous image + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + zero_pol (bool): If True, loads any missing polarizations as zeros + interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword + bounds_error (bool): if False, return nearest frame outside [start_hr, stop_hr] + + Returns: + Movie: a Movie object + """ + + imlist = [] + + for i in range(nframes): + filename = basename + "%05d" % i + + sys.stdout.write('\rReading Movie Image %i/%i...' % (i, nframes)) + sys.stdout.flush() + + im = load_im_txt(filename, pulse=pulse, polrep=polrep, pol_prim=pol_prim, zero_pol=zero_pol) + imlist.append(im) + + if i == 0: + hour0 = im.time + times = [hour0] + else: + times.append(hour0) + + if framedur != -1: + + framedur_hr = framedur / 3600. + nframes = len(imlist) + tstart = hour0 + tstop = hour0 + framedur_hr * nframes + times = np.linspace(tstart, tstop, nframes) + for kk in range(len(imlist)): + imlist[kk].time = times[kk] + + out_mov = ehtim.movie.merge_im_list(imlist, framedur=framedur, + interp=interp, bounds_error=bounds_error) + + return out_mov + + +def load_movie_fits(basename, nframes, framedur=-1, + interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR, + pulse=ehc.PULSE_DEFAULT, polrep='stokes', pol_prim=None, zero_pol=True): + """Read in a movie from fits files and create a Movie object. + + Args: + basename (str): The base name of individual movie frames. + Files should have names basename + 00001, etc. + nframes (int): The total number of frames + framedur (float): The frame duration in seconds + (default = -1, and framedur is taken from file headers) + pulse (function): The function convolved with the pixel values for continuous image + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + zero_pol (bool): If True, loads any missing polarizations as zeros + interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword + bounds_error (bool): if False, return nearest frame outside [start_hr, stop_hr] + + Returns: + Movie: a Movie object + """ + + imlist = [] + + for i in range(nframes): + sys.stdout.write('\rReading Movie Image %i/%i...' % (i, nframes)) + sys.stdout.flush() + for tag in ["%02d" % i, "%03d" % i, "%04d" % i, "%05d" % i]: + + try: + filename = basename + tag + '.fits' + + im = load_im_fits(filename, pulse=pulse, polrep=polrep, + pol_prim=pol_prim, zero_pol=zero_pol) + imlist.append(im) + break + except BaseException: + continue + + if i == 0: + hour0 = im.time + else: + pass + + if framedur != -1: + + framedur_hr = framedur / 3600. + nframes = len(imlist) + tstart = hour0 + tstop = hour0 + framedur_hr * nframes + times = np.linspace(tstart, tstop, nframes) + for kk in range(len(imlist)): + imlist[kk].time = times[kk] + + out_mov = ehtim.movie.merge_im_list(imlist, framedur=framedur, + interp=interp, bounds_error=bounds_error) + + return out_mov + + +def load_movie_dat(basename, nframes, startframe=0, framedur_sec=1, psize=-1, + interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR, + ra=ehc.RA_DEFAULT, dec=ehc.DEC_DEFAULT, rf=ehc.RF_DEFAULT, + pulse=ehc.PULSE_DEFAULT): + """Read in a movie from dat files and create a Movie object. + + Args: + basename (str): The base name of individual movie frames. + Files should have names basename + 000001, etc. + nframes (int): The total number of frames + startframe (int): The index of the first frame to load + framedur_sec (float): The frame duration in seconds (default = 1) + psize (float): The pixel size in radian + ra (float): the right ascension of the source (default for SgrA*) + dec (float): the declination of the source (default for SgrA*) + rf (float): The refrence frequency of the observation + pulse (function): The function convolved with the pixel values for continuous image + interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword + bounds_error (bool): if False, return nearest frame outside [start_hr, stop_hr] + + Returns: + Movie: a Movie object + """ + + for i in range(startframe, startframe + nframes): + + filename = basename + "%04d" % i + '.dat' + + sys.stdout.write('\rReading Movie Image %i/%i...' % (i - startframe, nframes)) + sys.stdout.flush() + + datatable = np.loadtxt(filename, dtype=np.float64) + + if i == startframe: + sim = np.zeros([nframes, datatable.shape[0]]) + + sim[i - startframe, :] = datatable[:, 2] + + npix = np.sqrt(sim.shape[1]).astype('int') + sim = np.reshape(sim, [sim.shape[0], npix, npix]) + sim = np.array([im.transpose()[::-1, :] for im in sim]) + + # TODO: read frame times from files? + hour0 = 0 + framedur_hr = framedur_sec / 3600. + nframes = len(sim) + tstart = hour0 + tstop = hour0 + framedur_hr * nframes + times = np.linspace(tstart, tstop, nframes) + + return(ehtim.movie.Movie(sim, times, psize, ra, dec, rf, + interp=interp, bounds_error=bounds_error)) + + +################################################################################################### +# Array IO +################################################################################################### +def load_array_txt(filename, ephemdir='ephemeris'): + """Read an array from a text file and return an Array object + Sites with x=y=z=0 are spacecraft - TLE ephemeris loaded from ephemdir + + Args: + filename (str): path to input text file + ephemdir (str): directory with TLE files for spacecraft + + Returns: + arr (Array): Array object loaded from file + """ + + + tdata = np.loadtxt(filename, dtype=bytes, comments='#').astype(str) + if tdata[0][0].lower() == 'site': + tdata = tdata[1:] + + path = os.path.dirname(filename) + + tdataout = [] + if (tdata.shape[1] != 5 and tdata.shape[1] != 13): + raise Exception("Array file should have format: " + + "(name, x, y, z, SEFDR, SEFDL " + + "FR_PAR_ANGLE FR_ELEV_ANGLE FR_OFFSET" + + "DR_RE DR_IM DL_RE DL_IM )") + + elif tdata.shape[1] == 5: + tdataout = [np.array((x[0], float(x[1]), float(x[2]), float(x[3]), float(x[4]), float(x[4]), + 0.0, 0.0, + 0.0, 0.0, 0.0), + dtype=ehc.DTARR) for x in tdata] + elif tdata.shape[1] == 13: + tdataout = [np.array((x[0], float(x[1]), float(x[2]), float(x[3]), float(x[4]), float(x[5]), + float(x[9]) + 1j * float(x[10]), float(x[11]) + 1j * float(x[12]), + float(x[6]), float(x[7]), float(x[8])), + dtype=ehc.DTARR) for x in tdata] + + # load spacecraft + tdataout = np.array(tdataout) + edata = {} + for line in tdataout: + if np.all(np.array([line['x'], line['y'], line['z']]) == (0., 0., 0.)): + sitename = str(line['site']) + + # TODO ephempath shouldn't always start with array file path + + + try: + ephempath = path + '/' + ephemdir + '/' + sitename + '.tle' + edata[sitename] = np.loadtxt(ephempath, dtype=bytes, + comments='#', delimiter='/').astype(str) + print('loaded spacecraft ephemeris %s' % ephempath) + except IOError: + pass + try: + ephempath = path + '/' + ephemdir + '/' + sitename + edata[sitename] = np.loadtxt(ephempath, dtype=bytes, + comments='#', delimiter='/').astype(str) + print('loaded spacecraft ephemeris %s' % ephempath) + except IOError: + raise Exception('no ephemeris file %s !' % ephempath) + + return ehtim.array.Array(tdataout, ephem=edata) + +################################################################################################## +# Observation IO +################################################################################################## + + +def load_obs_txt(filename, polrep='stokes'): + """Read an observation from a text file. + Args: + fname (str): path to input text file + polrep (str): load data as either 'stokes' or 'circ' + Returns: + obs (Obsdata): Obsdata object loaded from file + """ + + if not(polrep in ['stokes', 'circ']): + raise Exception("polrep should be 'stokes' or 'circ' in load_uvfits") + print("Loading text observation: ", filename) + + # Read the header parameters + file = open(filename) + src = ' '.join(file.readline().split()[2:]) + ra = file.readline().split() + ra = float(ra[2]) + float(ra[4]) / 60.0 + float(ra[6]) / 3600.0 + dec = file.readline().split() + dec = np.sign(float(dec[2])) * (abs(float(dec[2])) + + float(dec[4]) / 60.0 + float(dec[6]) / 3600.0) + mjd = float(file.readline().split()[2]) + rf = float(file.readline().split()[2]) * 1e9 + bw = float(file.readline().split()[2]) * 1e9 + phasecal = bool(file.readline().split()[2]) + ampcal = bool(file.readline().split()[2]) + + # New Header Parameters + x = file.readline().split() + if x[1] == 'OPACITYCAL:': + opacitycal = bool(x[2]) + dcal = bool(file.readline().split()[2]) + frcal = bool(file.readline().split()[2]) + file.readline() + else: + opacitycal = True + dcal = True + frcal = True + file.readline() + + # read the tarr + line = file.readline().split() + tarr = [] + while line[1][0] != "-": + if len(line) == 6: + tarr.append(np.array((line[1], line[2], line[3], line[4], line[5], line[5], + 0, 0, 0, 0, 0), dtype=ehc.DTARR)) + elif len(line) == 14: + tarr.append(np.array((line[1], line[2], line[3], line[4], line[5], line[6], + float(line[10]) + 1j * float(line[11]), + float(line[12]) + 1j * float(line[13]), + line[7], line[8], line[9]), dtype=ehc.DTARR)) + else: + raise Exception("Telescope header doesn't have the right number of fields!") + line = file.readline().split() + tarr = np.array(tarr, dtype=ehc.DTARR) + + # read the polrep + line = file.readline().split() + if line[12] == 'RRamp': + polrep_orig = 'circ' + elif line[12] == 'Iamp': + polrep_orig = 'stokes' + else: + raise Exception("cannot determine original polrep from observation text file!") + file.close() + + # Load the data, convert to list format, return object + datatable = np.loadtxt(filename, dtype=bytes).astype(str) + datatable2 = [] + for row in datatable: + time = float(row[0]) + tint = float(row[1]) + t1 = row[2] + t2 = row[3] + + # Old datatable formats + if datatable.shape[1] < 20: + tau1 = float(row[6]) + tau2 = float(row[7]) + u = float(row[8]) + v = float(row[9]) + vis1 = float(row[10]) * np.exp(1j * float(row[11]) * ehc.DEGREE) + if datatable.shape[1] == 19: + vis2 = float(row[12]) * np.exp(1j * float(row[13]) * ehc.DEGREE) + vis3 = float(row[14]) * np.exp(1j * float(row[15]) * ehc.DEGREE) + vis4 = float(row[16]) * np.exp(1j * float(row[17]) * ehc.DEGREE) + sigma1 = sigma2 = sigma3 = sigma4 = float(row[18]) + elif datatable.shape[1] == 17: + vis2 = float(row[12]) * np.exp(1j * float(row[13]) * ehc.DEGREE) + vis3 = float(row[14]) * np.exp(1j * float(row[15]) * ehc.DEGREE) + vis4 = 0 + 0j + sigma1 = sigma2 = sigma3 = sigma4 = float(row[16]) + elif datatable.shape[1] == 15: + vis2 = 0 + 0j + vis3 = 0 + 0j + vis4 = 0 + 0j + sigma1 = sigma2 = sigma3 = sigma4 = float(row[12]) + else: + raise Exception('Text file does not have the right number of fields!') + + # Current datatable format + elif datatable.shape[1] == 20: + tau1 = float(row[4]) + tau2 = float(row[5]) + u = float(row[6]) + v = float(row[7]) + vis1 = float(row[8]) * np.exp(1j * float(row[9]) * ehc.DEGREE) + vis2 = float(row[10]) * np.exp(1j * float(row[11]) * ehc.DEGREE) + vis3 = float(row[12]) * np.exp(1j * float(row[13]) * ehc.DEGREE) + vis4 = float(row[14]) * np.exp(1j * float(row[15]) * ehc.DEGREE) + sigma1 = float(row[16]) + sigma2 = float(row[17]) + sigma3 = float(row[18]) + sigma4 = float(row[19]) + + else: + raise Exception('Text file does not have the right number of fields!') + + if polrep_orig == 'stokes': + datatable2.append(np.array((time, tint, t1, t2, tau1, tau2, + u, v, vis1, vis2, vis3, vis4, + sigma1, sigma2, sigma3, sigma4), dtype=ehc.DTPOL_STOKES)) + elif polrep_orig == 'circ': + datatable2.append(np.array((time, tint, t1, t2, tau1, tau2, + u, v, vis1, vis2, vis3, vis4, + sigma1, sigma2, sigma3, sigma4), dtype=ehc.DTPOL_CIRC)) + + # Return the data object + datatable2 = np.array(datatable2) + out = ehtim.obsdata.Obsdata(ra, dec, rf, bw, datatable2, tarr, polrep=polrep_orig, + source=src, mjd=mjd, + ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal, + dcal=dcal, frcal=frcal) + out = out.switch_polrep(polrep_out=polrep) + return out + +# TODO can we save new telescope array terms and flags to uvfits and load them? +# TODO uv coordinates, multiply by IF freqs and not header FREQ? +def load_obs_uvfits(filename, polrep='stokes', flipbl=False, + allow_singlepol=True, force_singlepol=None, + channel=all, IF=all, remove_nan=False, + ignore_pzero_date=True, + trial_speedups=False): + """Load observation data from a uvfits file. + + Args: + filename (str or HDUList): path to either an input text file or an HDUList object + polrep (str): load data as either 'stokes' or 'circ' + flipbl (bool): flip baseline phases if True. + allow_singlepol (bool): If True and polrep='stokes', + treat single-polarization data as Stokes I + force_singlepol (str): 'R' or 'L' to load only 1 polarization and treat as Stokes I + channel (list): list of channels to average in the import. channel=all averages all + IF (list): list of IFs to average in the import. IF=all averages all + remove_nan (bool): whether or not to remove entries with nan data + + ignore_pzero_date (bool): if True, ignore the offset parameters in DATE field + TODO: what is the correct behavior per AIPS memo 117? + Returns: + obs (Obsdata): Obsdata object loaded from file + """ + + if not(polrep in ['stokes', 'circ']): + raise Exception("polrep should be 'stokes' or 'circ' in load_uvfits") + if not(force_singlepol is None or force_singlepol is False) and polrep != 'stokes': + raise Exception( + "force_singlepol is incompatible with polrep!='stokes' in load_uvfits") + + # Load the uvfits file + if isinstance(filename, fits.HDUList): + hdulist = filename.copy() + else: + print("Loading uvfits: ", filename) + hdulist = fits.open(filename) + header = hdulist[0].header + data = hdulist[0].data + + # Load the array data + tnames = hdulist['AIPS AN'].data['ANNAME'] + tnums = hdulist['AIPS AN'].data['NOSTA'] - 1 + xyz = np.real(hdulist['AIPS AN'].data['STABXYZ']) + try: + sefdr = np.real(hdulist['AIPS AN'].data['SEFD']) + sefdl = np.real(hdulist['AIPS AN'].data['SEFD']) # TODO add sefdl to uvfits? + except KeyError: + sefdr = np.zeros(len(tnames)) + sefdl = np.zeros(len(tnames)) + + # TODO - get the *actual* values of these telescope parameters from the uvfits file? + fr_par = np.zeros(len(tnames)) + fr_el = np.zeros(len(tnames)) + fr_off = np.zeros(len(tnames)) + dr = np.zeros(len(tnames)) + 1j * np.zeros(len(tnames)) + dl = np.zeros(len(tnames)) + 1j * np.zeros(len(tnames)) + + tarr = [np.array(( + str(tnames[i]), xyz[i][0], xyz[i][1], xyz[i][2], + sefdr[i], sefdl[i], dr[i], dl[i], + fr_par[i], fr_el[i], fr_off[i]), + dtype=ehc.DTARR) for i in range(len(tnames))] + + tarr = np.array(tarr) + + # Various header parameters + try: + ra = header['OBSRA'] * 12. / 180. + dec = header['OBSDEC'] + except KeyError: + if header['CTYPE6'] == 'RA': + ra = header['CRVAL6'] * 12. / 180. + else: + raise Exception('Cannot find RA!') + if header['CTYPE7'] == 'DEC': + dec = header['CRVAL7'] + else: + raise Exception('Cannot find DEC!') + + src = header['OBJECT'] + rf = hdulist['AIPS AN'].header['FREQ'] + + if header['CTYPE4'] == 'FREQ': + ch1_freq = header['CRVAL4'] + ch_bw = header['CDELT4'] + nchan = header['NAXIS4'] + + else: + raise Exception('Cannot find observing frequencies!') + + nif = 1 + try: + if header['CTYPE5'] == 'IF': + nif = header['NAXIS5'] + except KeyError: + print("no IF in uvfits header!") + + try: + if header['CTYPE3'] == 'STOKES': + if header['CRVAL3'] == 1: + polrep_uvfits = 'stokes' + elif header['CRVAL3'] == -1: + polrep_uvfits = 'circ' + else: + raise Exception("header[CRVAL3] not a recognized polarization basis!") + except BaseException: + raise Exception("STOKES field not in expected header position 'CTYPE3'!") + print('POLREP_UVFITS:', polrep_uvfits) + + if polrep_uvfits == 'stokes' and not(force_singlepol is None): + raise Exception("force_singlepole not implemented on native Stokes uvfits files!") + + # determine the bandwidth + bw = ch_bw * nchan * nif + + # Determine the number of correlation products in the data + num_corr = data['DATA'].shape[5] + print("Number of uvfits Correlation Products:", num_corr) + if num_corr == 1 and force_singlepol is not None: + print("Cannot force single polarization when file is not full polarization.") + force_singlepol = None + + # If the user selects force_singlepol, then we must allow_singlepol for stokes conversion + if force_singlepol is not None and polrep == 'stokes': + allow_singlepol = True + + # Mask to screen bad data + # Reducing to single frequency + + # prepare the arrays of if and channels that will be extracted from the data. + nvis = data['DATA'].shape[0] + full_nchannels = data['DATA'].shape[4] + full_nifs = data['DATA'].shape[3] + if channel == all: + channel = np.arange(0, full_nchannels, 1) + nchannels = full_nchannels + else: + try: + nchannels = len(np.array(channel)) + channel = np.array(channel).reshape(-1) + except TypeError: + channel = np.array([channel]).reshape(-1) + nchannels = len(np.array(channel)) + + if IF == all: + IF = np.arange(0, full_nifs, 1) + nifs = full_nifs + else: + try: + nifs = len(IF) + IF = np.array(IF).reshape(-1) + except TypeError: + IF = np.array([IF]).reshape(-1) + nifs = len(np.array(IF)) + + if (np.max(channel) >= full_nchannels) or (np.min(channel) < 0): + raise Exception('The specified channel does not exist') + if (np.max(IF) >= full_nifs) or (np.min(IF) < 0): + raise Exception('The specified IF does not exist') + + # NOTE: here we are assuming data is in RR, LL, RL, LR basis with the variable names + # BUT: polrep_uvfits will correctly interpret these data as IQUV if necessary + # TODO: change the variable names! + rrweight = data['DATA'][:, 0, 0, IF, channel, 0, 2].reshape(nvis, nifs, nchannels) + if num_corr >= 2: + llweight = data['DATA'][:, 0, 0, IF, channel, 1, 2].reshape(nvis, nifs, nchannels) + else: + llweight = rrweight * 0.0 + if num_corr >= 3: + rlweight = data['DATA'][:, 0, 0, IF, channel, 2, 2].reshape(nvis, nifs, nchannels) + else: + rlweight = rrweight * 0.0 + if num_corr >= 4: + lrweight = data['DATA'][:, 0, 0, IF, channel, 3, 2].reshape(nvis, nifs, nchannels) + else: + lrweight = rrweight * 0.0 + + # If necessary, enforce single polarization + if polrep_uvfits == 'circ': + if force_singlepol in ['L' or 'LL']: + rrweight = rrweight * 0.0 + rlweight = rlweight * 0.0 + lrweight = lrweight * 0.0 + elif force_singlepol in ['R' or 'RR']: + llweight = llweight * 0.0 + rlweight = rlweight * 0.0 + lrweight = lrweight * 0.0 + elif force_singlepol == 'LR': + print('WARNING: Putting LR data in Stokes I') + rrweight = copy.deepcopy(lrweight) + llweight = llweight * 0.0 + rlweight = rlweight * 0.0 + lrweight = lrweight * 0.0 + elif force_singlepol == 'RL': + print('WARNING: Putting RL data in Stokes I') + rrweight = copy.deepcopy(rlweight) + llweight = llweight * 0.0 + rlweight = rlweight * 0.0 + lrweight = lrweight * 0.0 + + # first, catch nans + rrnanmask_2d = (np.isnan(rrweight)) + llnanmask_2d = (np.isnan(llweight)) + rlnanmask_2d = (np.isnan(rlweight)) + lrnanmask_2d = (np.isnan(lrweight)) + + rrweight[rrnanmask_2d] = 0. + llweight[llnanmask_2d] = 0. + rlweight[rlnanmask_2d] = 0. + lrweight[lrnanmask_2d] = 0. + + # look for weights < 0 + rrmask_2d = (rrweight > 0.) + llmask_2d = (llweight > 0.) + rlmask_2d = (rlweight > 0.) + lrmask_2d = (lrweight > 0.) + + # if there is any unmasked data in the frequency column, use it + rrmask = np.any(np.any(rrmask_2d, axis=2), axis=1) + llmask = np.any(np.any(llmask_2d, axis=2), axis=1) + rlmask = np.any(np.any(rlmask_2d, axis=2), axis=1) + lrmask = np.any(np.any(lrmask_2d, axis=2), axis=1) + + # Total intensity mask + if polrep_uvfits == 'circ': + mask = rrmask + llmask + elif polrep_uvfits == 'stokes': + mask = rrmask # remember rr is really I when polrep_uvfits=='stokes'! + + if not np.any(mask): + raise Exception("No unflagged RR or LL data in uvfits file!") + if np.any(~(rrmask * llmask)): + print("Warning: removing flagged data present!") + + # Obs Times + paridx = data.parnames.index("DATE")+1 + if "PSCAL%d"%(paridx) in header.keys(): + jd1scal = header["PSCAL%d"%(paridx)] + else: + jd1scal = 1.0 + if "PZERO%d"%(paridx) in header.keys(): + jd1zero = header["PZERO%d"%(paridx)] + else: + jd1zero = 0.0 + if "PSCAL%d"%(paridx+1) in header.keys(): + jd2scal = header["PSCAL%d"%(paridx+1)] + else: + jd2scal = 1.0 + if "PZERO%d"%(paridx+1) in header.keys(): + jd2zero = header["PZERO%d"%(paridx+1)] + else: + jd2zero = 0.0 + + if ignore_pzero_date: + if jd1zero!=0. or jd2zero!=0.: + print("Warning! ignoring nonzero header PZERO values for DATE. Check your observation mjd/times!") + jd1zero = 0. + jd2zero = 0. + + jds = jd1scal * data['DATE'][mask].astype('d') + jd1zero + jds += jd2scal * data['_DATE'][mask].astype('d') + jd2zero + + mjd = int(np.min(jds) - 2400000.5) + times = (jds - 2400000.5 - mjd) * 24.0 + + try: + scantable = [] + nxtable = hdulist['AIPS NX'] + for scan in nxtable.data: + scan_start = scan['TIME'] # in days since reference date + scan_dur = scan['TIME INTERVAL'] + startvis = scan['START VIS'] - 1 + endvis = scan['END VIS'] - 1 + scantable.append([scan_start - 0.5 * scan_dur, + scan_start + 0.5 * scan_dur]) + scantable = np.array(scantable) * 24 + + except BaseException: + print("No NX table in uvfits!") + scantable = None + + # Integration times + try: + tints = data['INTTIM'][mask] + except KeyError: + tints = np.zeros(len(mask)) + + # Sites - add names + t1c = data['BASELINE'][mask].astype(int) // 256 + t2c = data['BASELINE'][mask].astype(int) - t1c * 256 + t1c = t1c - 1 + t2c = t2c - 1 + + # TODO make site identificantion faster + if trial_speedups and (not np.any(tnums!=np.arange(len(tnums)))): + sites = tarr['site'] + t1 = sites[t1c] + t2 = sites[t2c] + else: # original, slow code + t1 = np.array([tarr[np.where(tnums==i)[0][0]]['site'] for i in t1c]) + t2 = np.array([tarr[np.where(tnums==i)[0][0]]['site'] for i in t2c]) + + # Opacities (not in standard files) + try: + tau1 = data['TAU1'][mask] + tau2 = data['TAU2'][mask] + except KeyError: + tau1 = tau2 = np.zeros(len(t1)) + + # Convert uv in lightsec to lambda by multiplying by rf + try: + u = data['UU---SIN'][mask] * rf + v = data['VV---SIN'][mask] * rf + except KeyError: + try: + u = data['UU'][mask] * rf + v = data['VV'][mask] * rf + except KeyError: + try: + u = data['UU--'][mask] * rf + v = data['VV--'][mask] * rf + except KeyError: + raise Exception("Cant figure out column label for UV coords") + + # Get and coherently average visibility data in frequency + # replace masked vis with nans so they don't mess up the average + rr_2d = data['DATA'][:, 0, 0, IF, channel, 0, 0] + \ + 1j * data['DATA'][:, 0, 0, IF, channel, 0, 1] + rr_2d = rr_2d.reshape(nvis, nifs, nchannels) + if num_corr >= 2: + ll_2d = data['DATA'][:, 0, 0, IF, channel, 1, 0] + \ + 1j * data['DATA'][:, 0, 0, IF, channel, 1, 1] + ll_2d = ll_2d.reshape(nvis, nifs, nchannels) + else: + ll_2d = rr_2d * 0.0 + if num_corr >= 3: + rl_2d = data['DATA'][:, 0, 0, IF, channel, 2, 0] + \ + 1j * data['DATA'][:, 0, 0, IF, channel, 2, 1] + rl_2d = rl_2d.reshape(nvis, nifs, nchannels) + else: + rl_2d = rr_2d * 0.0 + if num_corr >= 4: + lr_2d = data['DATA'][:, 0, 0, IF, channel, 3, 0] + \ + 1j * data['DATA'][:, 0, 0, IF, channel, 3, 1] + lr_2d = lr_2d.reshape(nvis, nifs, nchannels) + else: + lr_2d = rr_2d * 0.0 + + if polrep_uvfits == 'circ': + if force_singlepol == 'LR': + rr_2d = copy.deepcopy(lr_2d) + elif force_singlepol == 'RL': + rr_2d = copy.deepcopy(rl_2d) + + rr_2d[~rrmask_2d] = np.nan + ll_2d[~llmask_2d] = np.nan + rl_2d[~rlmask_2d] = np.nan + lr_2d[~lrmask_2d] = np.nan + + rr = np.nanmean(np.nanmean(rr_2d, axis=2), axis=1)[mask] + ll = np.nanmean(np.nanmean(ll_2d, axis=2), axis=1)[mask] + rl = np.nanmean(np.nanmean(rl_2d, axis=2), axis=1)[mask] + lr = np.nanmean(np.nanmean(lr_2d, axis=2), axis=1)[mask] + + # average the weights + # variances are mean / N , or sum / N^2 + # then replace masked weights with nans so they don't mess up the average + rrweight[~rrmask_2d] = np.nan + llweight[~llmask_2d] = np.nan + rlweight[~rlmask_2d] = np.nan + lrweight[~lrmask_2d] = np.nan + + nsig_rr = np.sum(np.sum(rrmask_2d, axis=2), axis=1).astype(float) + nsig_rr[~rrmask] = np.nan + rrsig = np.sqrt(np.nansum(np.nansum(1. / rrweight, axis=2), axis=1)) / nsig_rr + rrsig = rrsig[mask] + + nsig_ll = np.sum(np.sum(llmask_2d, axis=2), axis=1).astype(float) + nsig_ll[~llmask] = np.nan + llsig = np.sqrt(np.nansum(np.nansum(1. / llweight, axis=2), axis=1)) / nsig_ll + llsig = llsig[mask] + + nsig_rl = np.sum(np.sum(rlmask_2d, axis=2), axis=1).astype(float) + nsig_rl[~rlmask] = np.nan + rlsig = np.sqrt(np.nansum(np.nansum(1. / rlweight, axis=2), axis=1)) / nsig_rl + rlsig = rlsig[mask] + + nsig_lr = np.sum(np.sum(lrmask_2d, axis=2), axis=1).astype(float) + nsig_lr[~lrmask] = np.nan + lrsig = np.sqrt(np.nansum(np.nansum(1. / lrweight, axis=2), axis=1)) / nsig_lr + lrsig = lrsig[mask] + + # Reverse sign of baselines for correct imaging if asked + if flipbl: + u = -u + v = -v + + # determine correct data type: + # TODO add linear! + if polrep_uvfits == 'circ': + dtpol_out = ehc.DTPOL_CIRC + poldict_out = ehc.POLDICT_CIRC + elif polrep_uvfits == 'stokes': + dtpol_out = ehc.DTPOL_STOKES + poldict_out = ehc.POLDICT_STOKES + + #TODO new, faster, + if trial_speedups: + datatable = np.empty((len(times)),dtype=dtpol_out) + datatable['time'] = times + datatable['tint'] = tints + datatable['t1'] = t1 + datatable['t2'] = t2 + datatable['tau1'] = tau1 + datatable['tau2'] = tau2 + datatable['u'] = u + datatable['v'] = v + datatable[poldict_out['vis1']] = rr + datatable[poldict_out['vis2']] = ll + datatable[poldict_out['vis3']] = rl + datatable[poldict_out['vis4']] = lr + datatable[poldict_out['sigma1']] = rrsig + datatable[poldict_out['sigma2']] = llsig + datatable[poldict_out['sigma3']] = rlsig + datatable[poldict_out['sigma4']] = lrsig + else: # original, slower code + datatable = [] + for i in range(len(times)): + datatable.append(np.array + (( + times[i], tints[i], + t1[i], t2[i], tau1[i], tau2[i], + u[i], v[i], + rr[i], ll[i], rl[i], lr[i], + rrsig[i], llsig[i], rlsig[i], lrsig[i] + ), dtype=dtpol_out + )) + datatable = np.array(datatable) + + obs = ehtim.obsdata.Obsdata(ra, dec, rf, bw, datatable, tarr, polrep=polrep_uvfits, + source=src, mjd=mjd, scantable=scantable, + trial_speedups=trial_speedups) + + # TODO -- this is bad and slow, use masks! + if remove_nan: + if polrep_uvfits == 'circ': + for j in range(len(obs.data)): + if np.isnan(obs.data[j]['rrsigma']): + obs.data[j]['rrsigma'] = obs.data[j]['llsigma'] + if np.isnan(obs.data[j]['llsigma']): + obs.data[j]['llsigma'] = obs.data[j]['rrsigma'] + if np.isnan(obs.data[j]['rlsigma']): + obs.data[j]['rlsigma'] = obs.data[j]['rrsigma'] + if np.isnan(obs.data[j]['lrsigma']): + obs.data[j]['lrsigma'] = obs.data[j]['rrsigma'] + else: + print("WARNING: remove_nan not implemented with stokes uvfits files!") + + obs = obs.switch_polrep(polrep, allow_singlepol=allow_singlepol) + + # TODO get calibration flags from uvfits? + return obs + + +def load_obs_oifits(filename, flux=1.0): + """Load data from an oifits file. Does NOT currently support polarization. + Args: + fname (str): path to input text file + flux (float): normalization total flux + Returns: + obs (Obsdata): Obsdata object loaded from file + """ + + print('Warning: load_obs_oifits does NOT currently support polarimetric data!') + + # open oifits file and get visibilities + oidata = ehtim.io.oifits.open(filename) + vis_data = oidata.vis + + # get source info + src = oidata.target[0].target + ra = oidata.target[0].raep0.angle + dec = oidata.target[0].decep0.angle + + # get annena info + nAntennas = len(oidata.array[list(oidata.array.keys())[0]].station) + sites = np.array([oidata.array[list(oidata.array.keys())[0] + ].station[i].sta_name for i in range(nAntennas)]) + arrayX = oidata.array[list(oidata.array.keys())[0]].arrxyz[0] + arrayY = oidata.array[list(oidata.array.keys())[0]].arrxyz[1] + arrayZ = oidata.array[list(oidata.array.keys())[0]].arrxyz[2] + x = np.array([arrayX + oidata.array[list(oidata.array.keys())[0]].station[i].staxyz[0] + for i in range(nAntennas)]) + y = np.array([arrayY + oidata.array[list(oidata.array.keys())[0]].station[i].staxyz[1] + for i in range(nAntennas)]) + z = np.array([arrayZ + oidata.array[list(oidata.array.keys())[0]].station[i].staxyz[2] + for i in range(nAntennas)]) + + # get wavelength and corresponding frequencies + wavelength = oidata.wavelength[list(oidata.wavelength.keys())[0]].eff_wave + nWavelengths = wavelength.shape[0] + bandpass = oidata.wavelength[list(oidata.wavelength.keys())[0]].eff_band + frequency = ehc.C / wavelength + + # TODO: this result seems wrong... + bw = np.mean(2 * (np.sqrt(bandpass**2 * frequency**2 + ehc.C**2) - ehc.C) / bandpass) + rf = np.mean(frequency) + + # get the u-v point for each visibility + u = np.array([vis_data[i].ucoord / wavelength for i in range(len(vis_data))]) + v = np.array([vis_data[i].vcoord / wavelength for i in range(len(vis_data))]) + + # get visibility info - currently the phase error is not being used properly + amp = np.array([vis_data[i]._visamp for i in range(len(vis_data))]) + phase = np.array([vis_data[i]._visphi for i in range(len(vis_data))]) + amperr = np.array([vis_data[i]._visamperr for i in range(len(vis_data))]) + visphierr = np.array([vis_data[i]._visphierr for i in range(len(vis_data))]) + timeobs = np.array([vis_data[i].timeobs for i in range(len(vis_data))] + ) # convert to single number + + # return timeobs + time = np.transpose(np.tile(np.array([(ttime.mktime((timeobs[i] + + datetime.timedelta(days=1)).timetuple()) + ) / (60.0 * 60.0) + for i in range(len(timeobs))]), [nWavelengths, 1])) + + # integration time + tint = np.array([vis_data[i].int_time for i in range(len(vis_data))]) + # if not all(tint[0] == item for item in np.reshape(tint, (-1)) ): + # raise TypeError("The time integrations for each visibility are different") + tint = tint[0] + tint = tint * np.ones(amp.shape) + + # get telescope names for each visibility + t1 = np.transpose(np.tile(np.array([vis_data[i].station[0].sta_name + for i in range(len(vis_data))]), [nWavelengths, 1])) + t2 = np.transpose(np.tile(np.array([vis_data[i].station[1].sta_name + for i in range(len(vis_data))]), [nWavelengths, 1])) + + # dummy variables + tau1 = np.zeros(amp.shape) + tau2 = np.zeros(amp.shape) + qvis = np.zeros(amp.shape) + uvis = np.zeros(amp.shape) + vvis = np.zeros(amp.shape) + sefdr = np.zeros(x.shape) + sefdl = np.zeros(x.shape) + fr_par = np.zeros(x.shape) + fr_el = np.zeros(x.shape) + fr_off = np.zeros(x.shape) + dr = np.zeros(x.shape) + 1j * np.zeros(x.shape) + dl = np.zeros(x.shape) + 1j * np.zeros(x.shape) + + # vectorize + time = time.ravel() + tint = tint.ravel() + t1 = t1.ravel() + t2 = t2.ravel() + + tau1 = tau1.ravel() + tau2 = tau2.ravel() + u = u.ravel() + v = v.ravel() + vis = amp.ravel() * np.exp(-1j * phase.ravel() * np.pi / 180.0) + qvis = qvis.ravel() + uvis = uvis.ravel() + vvis = vvis.ravel() + amperr = amperr.ravel() + + # TODO - check that we are properly using the error from the amplitude and phase + # create data tables + datatable = np.array([(time[i], tint[i], t1[i], t2[i], tau1[i], tau2[i], u[i], v[i], + flux * vis[i], qvis[i], uvis[i], vvis[i], + flux * amperr[i], flux * amperr[i], flux * amperr[i], flux * amperr[i] + ) for i in range(len(vis)) + ], dtype=ehc.DTPOL_STOKES) + + tarr = np.array([(sites[i], x[i], y[i], z[i], + sefdr[i], sefdl[i], dr[i], dl[i], + fr_par[i], fr_el[i], fr_off[i], + ) for i in range(nAntennas) + ], dtype=ehc.DTARR) + + # return object + return ehtim.obsdata.Obsdata(ra, dec, rf, bw, datatable, tarr, + polrep='stokes', source=src, mjd=time[0]) + +def load_obs_maps(arrfile, obsspec, ifile, qfile=0, ufile=0, vfile=0, + src=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, ampcal=False, phasecal=False): + """Read an observation from a maps text file and return an Obsdata object. + + Args: + arrfile (str): path to input array file + obsspec (str): path to input obs spec file + ifile (str): path to input Stokes I data file + qfile (str): path to input Stokes Q data file + ufile (str): path to input Stokes U data file + vfile (str): path to input Stokes V data file + src (str): source name + mjd (int): integer observation MJD + ampcal (bool): True if amplitude calibrated + phasecal (bool): True if phase calibrated + + Returns: + obs (Obsdata): Obsdata object loaded from file + """ + # Read telescope parameters from the array file + tdata = np.loadtxt(arrfile, dtype=bytes).astype(str) + tdata = [np.array((x[0], float(x[1]), float(x[2]), float(x[3]), + float(x[-1]), float(x[-1]), 0., 0., 0., 0., 0.), + dtype=ehc.DTARR) for x in tdata] + tdata = np.array(tdata) + + # Read parameters from the obs_spec + f = open(obsspec) + stop = False + while not stop: + line = f.readline().split() + if line == [] or line[0] == '\\': + continue + elif line[0] == 'FOV_center_RA': + x = line[2].split(':') + ra = float(x[0]) + float(x[1]) / 60.0 + float(x[2]) / 3600.0 + elif line[0] == 'FOV_center_Dec': + x = line[2].split(':') + dec = np.sign(float(x[0])) * (abs(float(x[0])) + + float(x[1]) / 60.0 + float(x[2]) / 3600.0) + elif line[0] == 'Corr_int_time': + tint = float(line[2]) + elif line[0] == 'Corr_chan_bw': # TODO what if multiple channels? + bw = float(line[2]) * 1e6 # in MHz + elif line[0] == 'Channel': # TODO what if multiple scans with different params? + rf = float(line[2].split(':')[0]) * 1e6 + elif line[0] == 'Scan_start': + x = line[2].split(':') # TODO properly compute MJD! + elif line[0] == 'Endscan': + stop = True + f.close() + + # Load the data, convert to list format, return object + datatable = [] + f = open(ifile) + + for line in f: + line = line.split() + if not (line[0] in ['UV', 'Scan', '\n']): + time = line[0].split(':') + time = float(time[2]) + float(time[3]) / 60.0 + float(time[4]) / 3600.0 + u = float(line[1]) * 1000 + v = float(line[2]) * 1000 + bl = line[4].split('-') + t1 = tdata[int(bl[0]) - 1]['site'] + t2 = tdata[int(bl[1]) - 1]['site'] + tau1 = 0. + tau2 = 0. + vis = float(line[7][:-1]) * np.exp(1j * float(line[8][:-1]) * ehc.DEGREE) + sigma = float(line[10]) + datatable.append(np.array((time, tint, t1, t2, tau1, tau2, + u, v, vis, 0.0, 0.0, 0.0, + sigma, 0.0, 0.0, 0.0), dtype=ehc.DTPOL_STOKES)) + + datatable = np.array(datatable) + + # TODO qfile ufile and vfile must have exactly the same format as ifile! + # add some consistency check + if not qfile == 0: + f = open(qfile) + i = 0 + for line in f: + line = line.split() + if not (line[0] in ['UV', 'Scan', '\n']): + datatable[i]['qvis'] = float(line[7][:-1]) * \ + np.exp(1j * float(line[8][:-1]) * ehc.DEGREE) + datatable[i]['qsigma'] = float(line[10]) + i += 1 + + if not ufile == 0: + f = open(ufile) + i = 0 + for line in f: + line = line.split() + if not (line[0] in ['UV', 'Scan', '\n']): + datatable[i]['uvis'] = float(line[7][:-1]) * \ + np.exp(1j * float(line[8][:-1]) * ehc.DEGREE) + datatable[i]['usigma'] = float(line[10]) + i += 1 + + if not vfile == 0: + f = open(vfile) + i = 0 + for line in f: + line = line.split() + if not (line[0] in ['UV', 'Scan', '\n']): + datatable[i]['vvis'] = float(line[7][:-1]) * \ + np.exp(1j * float(line[8][:-1]) * ehc.DEGREE) + datatable[i]['vsigma'] = float(line[10]) + i += 1 + + # Return the data object + return ehtim.obsdata.Obsdata(ra, dec, rf, bw, datatable, tdata, + source=src, mjd=mjd, polrep='stokes') + +def load_dtype_txt(obs, filename, dtype='cphase'): + + """Load the dtype data in a text file and put it in the already-created obs object + Args: + obs (Obsdata): obsdata object + filename (str): path to output text file + dtype (str): desired data type + Returns: + """ + + print("Loading text observation: ", filename) + + # Read the header parameters + file = open(filename) + src = ' '.join(file.readline().split()[2:]) + ra = file.readline().split() + ra = float(ra[2]) + float(ra[4]) / 60.0 + float(ra[6]) / 3600.0 + dec = file.readline().split() + dec = np.sign(float(dec[2])) * (abs(float(dec[2])) + + float(dec[4]) / 60.0 + float(dec[6]) / 3600.0) + mjd = float(file.readline().split()[2]) + rf = float(file.readline().split()[2]) * 1e9 + bw = float(file.readline().split()[2]) * 1e9 + phasecal = bool(file.readline().split()[2]) + ampcal = bool(file.readline().split()[2]) + + # Load the data, convert to list format, return object + datatable = np.loadtxt(filename, dtype=bytes).astype(str) + + if dtype == 'cphase': + datatable2 = [] + for row in datatable: + time = float(row[0]) + t1 = row[1] + t2 = row[2] + t3 = row[3] + u1 = float(row[4]) + v1 = float(row[5]) + u2 = float(row[6]) + v2 = float(row[7]) + u3 = float(row[8]) + v3 = float(row[9]) + cphase = float(row[10]) + sigmacp = float(row[11]) + datatable2.append(np.array((time, t1, t2, t3, u1, v1, u2, v2, + u3, v3, cphase, sigmacp), dtype=ehc.DTCPHASE)) + obs.cphase = np.array(datatable2) + + elif dtype == 'logcamp': + datatable2 = [] + for row in datatable: + time = float(row[0]) + t1 = row[1] + t2 = row[2] + t3 = row[3] + t4 = row[4] + u1 = float(row[5]) + v1 = float(row[6]) + u2 = float(row[7]) + v2 = float(row[8]) + u3 = float(row[9]) + v3 = float(row[10]) + u4 = float(row[11]) + v4 = float(row[12]) + logcamp = float(row[13]) + sigmalogcamp = float(row[14]) + datatable2.append(np.array((time, t1, t2, t3, t4, u1, v1, u2, v2, u3, + v3, u4, v4, logcamp, sigmalogcamp), dtype=ehc.DTCAMP)) + obs.logcamp = np.array(datatable2) + + elif dtype == 'camp': + datatable2 = [] + for row in datatable: + time = float(row[0]) + t1 = row[1] + t2 = row[2] + t3 = row[3] + t4 = row[4] + u1 = float(row[5]) + v1 = float(row[6]) + u2 = float(row[7]) + v2 = float(row[8]) + u3 = float(row[9]) + v3 = float(row[10]) + u4 = float(row[11]) + v4 = float(row[12]) + camp = float(row[13]) + sigmacamp = float(row[14]) + datatable2.append(np.array((time, t1, t2, t3, t4, u1, v1, u2, v2, + u3, v3, u4, v4, camp, sigmacamp), dtype=ehc.DTCAMP)) + obs.camp = np.array(datatable2) + + elif dtype == 'bs': + datatable2 = [] + for row in datatable: + time = float(row[0]) + t1 = row[1] + t2 = row[2] + t3 = row[3] + u1 = float(row[4]) + v1 = float(row[5]) + u2 = float(row[6]) + v2 = float(row[7]) + u3 = float(row[8]) + v3 = float(row[9]) + bispec = float(row[10]) + sigmab = float(row[11]) + datatable2.append(np.array((time, t1, t2, t3, u1, v1, u2, + v2, u3, v3, bispec, sigmab), dtype=ehc.DTBIS)) + obs.bispec = np.array(datatable2) + + elif dtype == 'amp': + datatable2 = [] + for row in datatable: + time = float(row[0]) + tint = float(row[1]) + t1 = row[2] + t2 = row[3] + u = float(row[4]) + v = float(row[5]) + amp = float(row[6]) + sigmaamp = float(row[7]) + datatable2.append(np.array((time, tint, t1, t2, u, v, amp, sigmaamp), dtype=ehc.DTAMP)) + obs.amp = np.array(datatable2) + + else: + raise Exception(dtype + ' is not a possible data type!') + + return diff --git a/io/oifits.py b/io/oifits.py new file mode 100644 index 00000000..4583dc63 --- /dev/null +++ b/io/oifits.py @@ -0,0 +1,1358 @@ +""" +A module for reading/writing OIFITS files + +This module is NOT related to the OIFITS Python module provided at +http://www.mrao.cam.ac.uk/research/OAS/oi_data/oifits.html +It is a (better) alternative. + +To open an existing OIFITS file, use the oifits.open(filename) +function. This will return an oifits object with the following +members (any of which can be empty dictionaries or numpy arrays): + + array: a dictionary of interferometric arrays, as defined by the + OI_ARRAY tables. The dictionary key is the name of the array + (ARRNAME). + + target: a numpy array of targets, as defined by the rows of the + OI_TARGET table. + + wavelength: a dictionary of wavelength tables (OI_WAVELENGTH). The + dictionary key is the name of the instrument/settings (INSNAME). + + vis, vis2 and t3: numpy arrays of objects containing all the + measurement information. Each list member corresponds to a row in + an OI_VIS/OI_VIS2/OI_T3 table. + +This module makes an ad-hoc, backwards-compatible change to the OIFITS +revision 1 standard originally described by Pauls et al., 2005, PASP, +117, 1255. The OI_VIS and OI_VIS2 tables in OIFITS files produced by +this file contain two additional columns for the correlated flux, +CFLUX and CFLUXERR , which are arrays with a length corresponding to +the number of wavelength elements (just as VISAMP/VIS2DATA). + +The main purpose of this module is to allow easy access to your OIFITS +data within Python, where you can then analyze it in any way you want. +As of version 0.3, the module can now be used to create OIFITS files +from scratch without serious pain. Be warned, creating an array table +from scratch is probably like nailing jelly to a tree. In a future +verison this will become easier. + +The module also provides a simple mechanism for combining multiple +oifits objects, achieved by using the '+' operator on two oifits +objects: result = a + b. The result can then be written to a file +using result.save(filename). + +Many of the parameters and their meanings are not specifically +documented here. However, the nomenclature mirrors that of the OIFITS +standard, so it is recommended to use this module with the PASP +reference above in hand. + +Beginning with version 0.3, the OI_VIS/OI_VIS2/OI_T3 classes now use +masked arrays for convenience, where the mask is defined via the +'flag' member of these classes. Beware of the following subtlety: as +before, the array data are accessed via (for example) OI_VIS.visamp; +however, OI_VIS.visamp is just a method which constructs (on the fly) +a masked array from OI_VIS._visamp, which is where the data are +actually stored. This is done transparently, and the data can be +accessed and modified transparently via the "visamp" hidden attribute. +The same goes for correlated fluxes, differential/closure phases, +triple products, etc. See the notes on the individual classes for a +list of all the "hidden" attributes. + +For further information, contact Paul Boley (boley@mpia-hd.mpg.de). + +""" +from __future__ import division +from __future__ import print_function +from builtins import str +from builtins import object + +import numpy as np +from numpy import double, ma +from astropy.io import fits as pyfits +import datetime +import copy + +__author__ = "Paul Boley" +__email__ = "boley@mpia-hd.mpg.de" +__date__ ='1 October 2012' +__version__ = '0.3.1' +_mjdzero = datetime.datetime(1858, 11, 17) + +matchtargetbyname = False +matchstationbyname = False +refdate = datetime.datetime(2000, 1, 1) + +def _plurals(count): + if count != 1: return 's' + return '' + +def _array_eq(a, b): + "Test whether all the elements of two arrays are equal." + + try: + return not (a != b).any() + except: + return not (a != b) + +class _angpoint(float): + "Convenience object for representing angles." + + def __init__(self, angle): + self.angle = angle + + def __repr__(self): + return '_angpoint(%s)'%self.angle.__repr__() + + def __str__(self): + return "%g degrees"%(self.angle) + + def __eq__(self, other): + return self.angle == other.angle + + def __ne__(self, other): + return not self.__eq__(other) + + def asdms(self): + """Return the value as a string in dms format, + e.g. +25:30:22.55. Useful for declination.""" + angle = self.angle + if angle < 0: + negative = True + angle *= -1.0 + else: + negative = False + degrees = np.floor(angle) + minutes = np.floor((angle - degrees)*60.0) + seconds = (angle - degrees - minutes/60.0)*3600.0 + try: + if negative: + return "-%02d:%02d:%05.2f"%(degrees,minutes,seconds) + else: + return "+%02d:%02d:%05.2f"%(degrees,minutes,seconds) + except TypeError: + return self.__repr__() + + def ashms(self): + """Return the value as a string in hms format, + e.g. 5:12:17.21. Useful for right ascension.""" + angle = self.angle*24.0/360.0 + + hours = np.floor(angle) + minutes = np.floor((angle - hours)*60.0) + seconds = (angle - hours - minutes/60.0)*3600.0 + try: + return "%02d:%02d:%05.2f"%(hours,minutes,seconds) + except TypeError: + return self.__repr__() + +class OI_TARGET(object): + + def __init__(self, target, raep0, decep0, equinox=2000.0, ra_err=0.0, dec_err=0.0, + sysvel=0.0, veltyp='TOPCENT', veldef='OPTICAL', pmra=0.0, pmdec=0.0, + pmra_err=0.0, pmdec_err=0.0, parallax=0.0, para_err=0.0, spectyp='UNKNOWN'): + self.target = target + self.raep0 = _angpoint(raep0) + self.decep0 = _angpoint(decep0) + self.equinox = equinox + self.ra_err = ra_err + self.dec_err = dec_err + self.sysvel = sysvel + self.veltyp = veltyp + self.veldef = veldef + self.pmra = pmra + self.pmdec = pmdec + self.pmra_err = pmra_err + self.pmdec_err = pmdec_err + self.parallax = parallax + self.para_err = para_err + self.spectyp = spectyp + + def __eq__(self, other): + + if type(self) != type(other): return False + + return not ( + (self.target != other.target) or + (self.raep0 != other.raep0) or + (self.decep0 != other.decep0) or + (self.equinox != other.equinox) or + (self.ra_err != other.ra_err) or + (self.dec_err != other.dec_err) or + (self.sysvel != other.sysvel) or + (self.veltyp != other.veltyp) or + (self.veldef != other.veldef) or + (self.pmra != other.pmra) or + (self.pmdec != other.pmdec) or + (self.pmra_err != other.pmra_err) or + (self.pmdec_err != other.pmdec_err) or + (self.parallax != other.parallax) or + (self.para_err != other.para_err) or + (self.spectyp != other.spectyp)) + + def __ne__(self, other): + return not self.__eq__(other) + + def __str__(self): + return "%s: %s %s (%g)"%(self.target, self.raep0.ashms(), self.decep0.asdms(), self.equinox) + + def info(self): + print(str(self)) + +class OI_WAVELENGTH(object): + + def __init__(self, eff_wave, eff_band=None): + self.eff_wave = np.array(eff_wave, dtype=double).reshape(-1) + if eff_band == None: + eff_band = np.zeros_like(eff_wave) + self.eff_band = np.array(eff_band, dtype=double).reshape(-1) + + def __eq__(self, other): + + if type(self) != type(other): return False + + return not ( + (not _array_eq(self.eff_wave, other.eff_wave)) or + (not _array_eq(self.eff_band, other.eff_band))) + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return "%d wavelength%s (%.3g-%.3g um)"%(len(self.eff_wave), _plurals(len(self.eff_wave)), 1e6*np.min(self.eff_wave),1e6*np.max(self.eff_wave)) + + def info(self): + print(str(self)) + + +class OI_VIS(object): + """ + Class for storing visibility amplitude and differential phase data. + To access the data, use the following hidden attributes: + + visamp, visamperr, visphi, visphierr, flag; + and possibly cflux, cfluxerr. + + """ + + def __init__(self, timeobs, int_time, visamp, visamperr, visphi, visphierr, flag, ucoord, + vcoord, wavelength, target, array=None, station=(None,None), cflux=None, cfluxerr=None): + self.timeobs = timeobs + self.array = array + self.wavelength = wavelength + self.target = target + self.int_time = int_time + self._visamp = np.array(visamp, dtype=double).reshape(-1) + self._visamperr = np.array(visamperr, dtype=double).reshape(-1) + self._visphi = np.array(visphi, dtype=double).reshape(-1) + self._visphierr = np.array(visphierr, dtype=double).reshape(-1) + if cflux != None: self._cflux = np.array(cflux, dtype=double).reshape(-1) + else: self._cflux = None + if cfluxerr != None: self._cfluxerr = np.array(cfluxerr, dtype=double).reshape(-1) + else: self._cfluxerr = None + self.flag = np.array(flag, dtype=bool).reshape(-1) + self.ucoord = ucoord + self.vcoord = vcoord + self.station = station + + def __eq__(self, other): + + if type(self) != type(other): return False + + return not ( + (self.timeobs != other.timeobs) or + (self.array != other.array) or + (self.wavelength != other.wavelength) or + (self.target != other.target) or + (self.int_time != other.int_time) or + (self.ucoord != other.ucoord) or + (self.vcoord != other.vcoord) or + (self.array != other.array) or + (self.station != other.station) or + (not _array_eq(self.visamp, other.visamp)) or + (not _array_eq(self.visamperr, other.visamperr)) or + (not _array_eq(self.visphi, other.visphi)) or + (not _array_eq(self.visphierr, other.visphierr)) or + (not _array_eq(self.flag, other.flag))) + + def __ne__(self, other): + return not self.__eq__(other) + + def __getattr__(self, attrname): + if attrname in ('visamp', 'visamperr', 'visphi', 'visphierr'): + return ma.masked_array(self.__dict__['_' + attrname], mask=self.flag) + elif attrname in ('cflux', 'cfluxerr'): + if (self.__dict__['_' + attrname] != None): + return ma.masked_array(self.__dict__['_' + attrname], mask=self.flag) + else: + return None + else: + raise AttributeError(attrname) + + def __setattr__(self, attrname, value): + if attrname in ('visamp', 'visamperr', 'visphi', 'visphierr', 'cflux', 'cfluxerr'): + self.__dict__['_' + attrname] = value + else: + self.__dict__[attrname] = value + + def __repr__(self): + meanvis = ma.mean(self.visamp) + if self.station[0] and self.station[1]: + baselinename = ' (' + self.station[0].sta_name + self.station[1].sta_name + ')' + else: + baselinename = '' + return '%s %s%s: %d point%s (%d masked), B = %5.1f m, PA = %5.1f deg, = %4.2g'%(self.target.target, self.timeobs.strftime('%F %T'), baselinename, len(self.visamp), _plurals(len(self.visamp)), np.sum(self.flag), np.sqrt(self.ucoord**2 + self.vcoord**2), np.arctan(self.ucoord/self.vcoord) * 180.0 / np.pi % 180.0, meanvis) + + def info(self): + print(str(self)) + +class OI_VIS2(object): + """ + Class for storing squared visibility amplitude data. + To access the data, use the following hidden attributes: + + vis2data, vis2err + + """ + def __init__(self, timeobs, int_time, vis2data, vis2err, flag, ucoord, vcoord, wavelength, + target, array=None, station=(None, None)): + self.timeobs = timeobs + self.array = array + self.wavelength = wavelength + self.target = target + self.int_time = int_time + self._vis2data = np.array(vis2data, dtype=double).reshape(-1) + self._vis2err = np.array(vis2err, dtype=double).reshape(-1) + self.flag = np.array(flag, dtype=bool).reshape(-1) + self.ucoord = ucoord + self.vcoord = vcoord + self.station = station + + def __eq__(self, other): + + if type(self) != type(other): return False + + return not ( + (self.timeobs != other.timeobs) or + (self.array != other.array) or + (self.wavelength != other.wavelength) or + (self.target != other.target) or + (self.int_time != other.int_time) or + (self.ucoord != other.ucoord) or + (self.vcoord != other.vcoord) or + (self.array != other.array) or + (self.station != other.station) or + (not _array_eq(self.vis2data, other.vis2data)) or + (not _array_eq(self.vis2err, other.vis2err)) or + (not _array_eq(self.flag, other.flag))) + + def __ne__(self, other): + return not self.__eq__(other) + + def __getattr__(self, attrname): + if attrname in ('vis2data', 'vis2err'): + return ma.masked_array(self.__dict__['_' + attrname], mask=self.flag) + else: + raise AttributeError(attrname) + + def __setattr__(self, attrname, value): + if attrname in ('vis2data', 'vis2err'): + self.__dict__['_' + attrname] = value + else: + self.__dict__[attrname] = value + + def __repr__(self): + meanvis = ma.mean(self.vis2data) + if self.station[0] and self.station[1]: + baselinename = ' (' + self.station[0].sta_name + self.station[1].sta_name + ')' + else: + baselinename = '' + return "%s %s%s: %d point%s (%d masked), B = %5.1f m, PA = %5.1f deg, = %4.2g"%(self.target.target, self.timeobs.strftime('%F %T'), baselinename, len(self.vis2data), _plurals(len(self.vis2data)), np.sum(self.flag), np.sqrt(self.ucoord**2 + self.vcoord**2), np.arctan(self.ucoord/self.vcoord) * 180.0 / np.pi % 180.0, meanvis) + + def info(self): + print(str(self)) + + +class OI_T3(object): + """ + Class for storing triple product and closure phase data. + To access the data, use the following hidden attributes: + + t3amp, t3amperr, t3phi, t3phierr + + """ + + def __init__(self, timeobs, int_time, t3amp, t3amperr, t3phi, t3phierr, flag, u1coord, + v1coord, u2coord, v2coord, wavelength, target, array=None, station=(None,None,None)): + self.timeobs = timeobs + self.array = array + self.wavelength = wavelength + self.target = target + self.int_time = int_time + self._t3amp = np.array(t3amp, dtype=double).reshape(-1) + self._t3amperr = np.array(t3amperr, dtype=double).reshape(-1) + self._t3phi = np.array(t3phi, dtype=double).reshape(-1) + self._t3phierr = np.array(t3phierr, dtype=double).reshape(-1) + self.flag = np.array(flag, dtype=bool).reshape(-1) + self.u1coord = u1coord + self.v1coord = v1coord + self.u2coord = u2coord + self.v2coord = v2coord + self.station = station + + def __eq__(self, other): + + if type(self) != type(other): return False + + return not ( + (self.timeobs != other.timeobs) or + (self.array != other.array) or + (self.wavelength != other.wavelength) or + (self.target != other.target) or + (self.int_time != other.int_time) or + (self.u1coord != other.u1coord) or + (self.v1coord != other.v1coord) or + (self.u2coord != other.u2coord) or + (self.v2coord != other.v2coord) or + (self.array != other.array) or + (self.station != other.station) or + (not _array_eq(self.t3amp, other.t3amp)) or + (not _array_eq(self.t3amperr, other.t3amperr)) or + (not _array_eq(self.t3phi, other.t3phi)) or + (not _array_eq(self.t3phierr, other.t3phierr)) or + (not _array_eq(self.flag, other.flag))) + + def __ne__(self, other): + return not self.__eq__(other) + + def __getattr__(self, attrname): + if attrname in ('t3amp', 't3amperr', 't3phi', 't3phierr'): + return ma.masked_array(self.__dict__['_' + attrname], mask=self.flag) + else: + raise AttributeError(attrname) + + def __setattr__(self, attrname, value): + if attrname in ('vis2data', 'vis2err'): + self.__dict__['_' + attrname] = value + else: + self.__dict__[attrname] = value + + def __repr__(self): + meant3 = np.mean(self.t3amp[np.where(self.flag == False)]) + if self.station[0] and self.station[1] and self.station[2]: + baselinename = ' (' + self.station[0].sta_name + self.station[1].sta_name + self.station[2].sta_name + ')' + else: + baselinename = '' + return "%s %s%s: %d point%s (%d masked), B = %5.1fm, %5.1fm, = %4.2g"%(self.target.target, self.timeobs.strftime('%F %T'), baselinename, len(self.t3amp), _plurals(len(self.t3amp)), np.sum(self.flag), np.sqrt(self.u1coord**2 + self.v1coord**2), np.sqrt(self.u2coord**2 + self.v2coord**2), meant3) + + def info(self): + print(str(self)) + +class OI_STATION(object): + """ This class corresponds to a single row (i.e. single + station/telescope) of an OI_ARRAY table.""" + + def __init__(self, tel_name=None, sta_name=None, diameter=None, staxyz=[None, None, None]): + self.tel_name = tel_name + self.sta_name = sta_name + self.diameter = diameter + self.staxyz = staxyz + + def __eq__(self, other): + + if type(self) != type(other): return False + + return not ( + (self.tel_name != other.tel_name) or + (self.sta_name != other.sta_name) or + (self.diameter != other.diameter) or + (not _array_eq(self.staxyz, other.staxyz))) + + def __ne__(self, other): + return not self.__eq__(other) + + def __repr__(self): + return '%s/%s (%g m)'%(self.sta_name, self.tel_name, self.diameter) + +class OI_ARRAY(object): + """Contains all the data for a single OI_ARRAY table. Note the + hidden convenience attributes latitude, longitude, and altitude.""" + + def __init__(self, frame, arrxyz, stations=()): + self.frame = frame + self.arrxyz = arrxyz + #self.station = stations; + self.station = np.empty(0) + for station in stations: + tel_name, sta_name, sta_index, diameter, staxyz = station + self.station = np.append(self.station, OI_STATION(tel_name=tel_name, sta_name=sta_name, diameter=diameter, staxyz=staxyz)) + + def __eq__(self, other): + + if type(self) != type(other): return False + + equal = not ( + (self.frame != other.frame) or + (not _array_eq(self.arrxyz, other.arrxyz))) + + if not equal: return False + + # If position appears to be the same, check that the stations + # (and ordering) are also the same + if (self.station != other.station).any(): + return False + + return True + + def __ne__(self, other): + return not self.__eq__(other) + + def __getattr__(self, attrname): + if attrname == 'latitude': + radius = np.sqrt((self.arrxyz**2).sum()) + return _angpoint(np.arcsin(self.arrxyz[2]/radius)*180.0/np.pi) + elif attrname == 'longitude': + radius = np.sqrt((self.arrxyz**2).sum()) + xylen = np.sqrt(self.arrxyz[0]**2+self.arrxyz[1]**2) + return _angpoint(np.arcsin(self.arrxyz[1]/xylen)*180.0/np.pi) + elif attrname == 'altitude': + radius = np.sqrt((self.arrxyz**2).sum()) + return radius - 6378100.0 + else: + raise AttributeError(attrname) + + def __repr__(self): + return '%s %s %g m, %d station%s'%(self.latitude.asdms(), self.longitude.asdms(), self.altitude, len(self.station), _plurals(len(self.station))) + + def info(self, verbose=0): + """Print the array's center coordinates. If verbosity >= 1, + print information about each station.""" + print(str(self)) + if verbose >= 1: + for station in self.station: + print(" %s"%str(station)) + + def get_station_by_name(self, name): + + for station in self.station: + if station.sta_name == name: + return station + + raise LookupError('No such station %s'%name) + +class oifits(object): + + def __init__(self): + + self.wavelength = {} + self.target = np.empty(0) + self.array = {} + self.vis = np.empty(0) + self.vis2 = np.empty(0) + self.t3 = np.empty(0) + + def __add__(self, other): + """Consistently combine two separate oifits objects. Note + that targets can be matched by name only (e.g. if coordinates + differ) by setting oifits.matchtargetbyname to True. The same + goes for stations of the array (controlled by + oifits.matchstationbyname)""" + # Don't do anything if the two oifits objects are not CONSISTENT! + if self.isconsistent() == False or other.isconsistent() == False: + print('oifits objects are not consistent, bailing.') + return + + new = copy.deepcopy(self) + if len(other.wavelength): + wavelengthmap = {} + for key in list(other.wavelength.keys()): + if key not in list(new.wavelength.keys()): + new.wavelength[key] = copy.deepcopy(other.wavelength[key]) + elif new.wavelength[key] != other.wavelength[key]: + raise ValueError('Wavelength tables have the same key but differing contents.') + wavelengthmap[id(other.wavelength[key])] = new.wavelength[key] + + if len(other.target): + targetmap = {} + for otarget in other.target: + for ntarget in new.target: + if matchtargetbyname and ntarget.target == otarget.target: + targetmap[id(otarget)] = ntarget + break + elif ntarget == otarget: + targetmap[id(otarget)] = ntarget + break + elif ntarget.target == otarget.target: + print('Found a target with a matching name, but some differences in the target specification. Creating a new target. Set oifits.matchtargetbyname to True to override this behavior.') + # If 'id(otarget)' is not in targetmap, then this is a new + # target and should be added to the array of targets + if id(otarget) not in list(targetmap.keys()): + try: + newkey = list(new.target.keys())[-1]+1 + except: + newkey = 1 + target = copy.deepcopy(otarget) + new.target = np.append(new.target, target) + targetmap[id(otarget)] = target + + if len(other.array): + stationmap = {} + arraymap = {} + for key, otharray in other.array.items(): + arraymap[id(otharray)] = key + if key not in list(new.array.keys()): + new.array[key] = copy.deepcopy(other.array[key]) + # If arrays have the same name but seem to differ, try + # to combine the two (by including the union of both + # sets of stations) + for othsta in other.array[key].station: + for newsta in new.array[key].station: + if newsta == othsta: + stationmap[id(othsta)] = newsta + break + elif matchstationbyname and newsta.sta_name == othsta.sta_name: + stationmap[id(othsta)] = newsta + break + elif newsta.sta_name == othsta.sta_name and matchstationbyname == False: + raise ValueError('Stations have matching names but conflicting data.') + # If 'id(othsta)' is not in the stationmap + # dictionary, then this is a new station and + # should be added to the current array + if id(othsta) not in list(stationmap.keys()): + newsta = copy.deepcopy(othsta) + new.array[key].station = np.append(new.array[key].station, newsta) + stationmap[id(othsta)] = newsta + # Make sure that staxyz of the new station is relative to the new array center + newsta.staxyz = othsta.staxyz - other.array[key].arrxyz + new.array[key].arrxyz + + for vis in other.vis: + if vis not in new.vis: + newvis = copy.copy(vis) + # The wavelength, target, array and station objects + # should point to the appropriate objects inside the + # 'new' structure + newvis.wavelength = wavelengthmap[id(vis.wavelength)] + newvis.target = targetmap[id(vis.target)] + if (vis.array): + newvis.array = new.array[arraymap[id(vis.array)]] + newvis.station = [None, None] + newvis.station[0] = stationmap[id(vis.station[0])] + newvis.station[1] = stationmap[id(vis.station[1])] + new.vis = np.append(new.vis, newvis) + + for vis2 in other.vis2: + if vis2 not in new.vis2: + newvis2 = copy.copy(vis2) + # The wavelength, target, array and station objects + # should point to the appropriate objects inside the + # 'new' structure + newvis2.wavelength = wavelengthmap[id(vis2.wavelength)] + newvis2.target = targetmap[id(vis2.target)] + if (vis2.array): + newvis2.array = new.array[arraymap[id(vis2.array)]] + newvis2.station = [None, None] + newvis2.station[0] = stationmap[id(vis2.station[0])] + newvis2.station[1] = stationmap[id(vis2.station[1])] + new.vis2 = np.append(new.vis2, newvis2) + + for t3 in other.t3: + if t3 not in new.t3: + newt3 = copy.copy(t3) + # The wavelength, target, array and station objects + # should point to the appropriate objects inside the + # 'new' structure + newt3.wavelength = wavelengthmap[id(t3.wavelength)] + newt3.target = targetmap[id(t3.target)] + if (t3.array): + newt3.array = new.array[arraymap[id(t3.array)]] + newt3.station = [None, None, None] + newt3.station[0] = stationmap[id(t3.station[0])] + newt3.station[1] = stationmap[id(t3.station[1])] + newt3.station[2] = stationmap[id(t3.station[2])] + new.t3 = np.append(new.t3, newt3) + + return(new) + + + def __eq__(self, other): + + if type(self) != type(other): return False + + return not ( + (self.wavelength != other.wavelength) or + (self.target != other.target).any() or + (self.array != other.array) or + (self.vis != other.vis).any() or + (self.vis2 != other.vis2).any() or + (self.t3 != other.t3).any()) + + def __ne__(self, other): + return not self.__eq__(other) + + def isvalid(self): + """Returns True of the oifits object is both consistent (as + determined by isconsistent()) and conforms to the OIFITS + standard (according to Pauls et al., 2005, PASP, 117, 1255).""" + + warnings = [] + errors = [] + if not self.isconsistent(): + errors.append('oifits object is not consistent') + if not self.target.size: + errors.append('No OI_TARGET data') + if not self.wavelength: + errors.append('No OI_WAVELENGTH data') + else: + for wavelength in list(self.wavelength.values()): + if len(wavelength.eff_wave) != len(wavelength.eff_band): + errors.append("eff_wave and eff_band are of different lengths for wavelength table '%s'"%key) + if (self.vis.size + self.vis2.size + self.t3.size == 0): + errors.append('Need to have atleast one measurement table (vis, vis2 or t3)') + for vis in self.vis: + nwave = len(vis.wavelength.eff_band) + if (len(vis.visamp) != nwave) or (len(vis.visamperr) != nwave) or (len(vis.visphi) != nwave) or (len(vis.visphierr) != nwave) or (len(vis.flag) != nwave): + errors.append("Data size mismatch for visibility measurement 0x%x (wavelength table has a length of %d)"%(id(vis), nwave)) + for vis2 in self.vis2: + nwave = len(vis2.wavelength.eff_band) + if (len(vis2.vis2data) != nwave) or (len(vis2.vis2err) != nwave) or (len(vis2.flag) != nwave): + errors.append("Data size mismatch for visibility^2 measurement 0x%x (wavelength table has a length of %d)"%(id(vis), nwave)) + for t3 in self.t3: + nwave = len(t3.wavelength.eff_band) + if (len(t3.t3amp) != nwave) or (len(t3.t3amperr) != nwave) or (len(t3.t3phi) != nwave) or (len(t3.t3phierr) != nwave) or (len(t3.flag) != nwave): + errors.append("Data size mismatch for visibility measurement 0x%x (wavelength table has a length of %d)"%(id(vis), nwave)) + + if warnings: + print("*** %d warning%s:"%(len(warnings), _plurals(len(warnings)))) + for warning in warnings: + print(' ' + warning) + if errors: + print("*** %d ERROR%s:"%(len(errors), _plurals(len(errors)).upper())) + for error in errors: + print(' ' + error) + + return not (len(warnings) or len(errors)) + + def isconsistent(self): + """Returns True if the object is entirely self-contained, + i.e. all cross-references to wavelength tables, arrays, + stations etc. in the measurements refer to elements which are + stored in the oifits object. Note that an oifits object can + be 'consistent' in this sense without being 'valid' as checked + by isvalid().""" + + for vis in self.vis: + if vis.array and (vis.array not in list(self.array.values())): + print('A visibility measurement (0x%x) refers to an array which is not inside the main oifits object.'%id(vis)) + return False + if ((vis.station[0] and (vis.station[0] not in vis.array.station)) or + (vis.station[1] and (vis.station[1] not in vis.array.station))): + print('A visibility measurement (0x%x) refers to a station which is not inside the main oifits object.'%id(vis)) + return False + if vis.wavelength not in list(self.wavelength.values()): + print('A visibility measurement (0x%x) refers to a wavelength table which is not inside the main oifits object.'%id(vis)) + return False + if vis.target not in self.target: + print('A visibility measurement (0x%x) refers to a target which is not inside the main oifits object.'%id(vis)) + return False + + for vis2 in self.vis2: + if vis2.array and (vis2.array not in list(self.array.values())): + print('A visibility^2 measurement (0x%x) refers to an array which is not inside the main oifits object.'%id(vis2)) + return False + if ((vis2.station[0] and (vis2.station[0] not in vis2.array.station)) or + (vis2.station[1] and (vis2.station[1] not in vis2.array.station))): + print('A visibility^2 measurement (0x%x) refers to a station which is not inside the main oifits object.'%id(vis)) + return False + if vis2.wavelength not in list(self.wavelength.values()): + print('A visibility^2 measurement (0x%x) refers to a wavelength table which is not inside the main oifits object.'%id(vis2)) + return False + if vis2.target not in self.target: + print('A visibility^2 measurement (0x%x) refers to a target which is not inside the main oifits object.'%id(vis2)) + return False + + for t3 in self.t3: + if t3.array and (t3.array not in list(self.array.values())): + print('A closure phase measurement (0x%x) refers to an array which is not inside the main oifits object.'%id(t3)) + return False + if ((t3.station[0] and (t3.station[0] not in t3.array.station)) or + (t3.station[1] and (t3.station[1] not in t3.array.station)) or + (t3.station[2] and (t3.station[2] not in t3.array.station))): + print('A closure phase measurement (0x%x) refers to a station which is not inside the main oifits object.'%id(t3)) + return False + if t3.wavelength not in list(self.wavelength.values()): + print('A closure phase measurement (0x%x) refers to a wavelength table which is not inside the main oifits object.'%id(t3)) + return False + if t3.target not in self.target: + print('A closure phase measurement (0x%x) refers to a target which is not inside the main oifits object.'%id(t3)) + return False + + return True + + def info(self, recursive=True, verbose=0): + """Print out a summary of the contents of the oifits object. + Set recursive=True to obtain more specific information about + each of the individual components, and verbose to an integer + to increase the verbosity level.""" + + if self.wavelength: + wavelengths = 0 + if recursive: + print("====================================================================") + print("SUMMARY OF WAVELENGTH TABLES") + print("====================================================================") + for key in list(self.wavelength.keys()): + wavelengths += len(self.wavelength[key].eff_wave) + if recursive: print("'%s': %s"%(key, str(self.wavelength[key]))) + print("%d wavelength table%s with %d wavelength%s in total"%(len(self.wavelength), _plurals(len(self.wavelength)), wavelengths, _plurals(wavelengths))) + if self.target.size: + if recursive: + print("====================================================================") + print("SUMMARY OF TARGET TABLES") + print("====================================================================") + for target in self.target: + target.info() + print("%d target%s"%(len(self.target), _plurals(len(self.target)))) + if self.array: + stations = 0 + if recursive: + print("====================================================================") + print("SUMMARY OF ARRAY TABLES") + print("====================================================================") + for key in list(self.array.keys()): + if recursive: + print(key + ':') + self.array[key].info(verbose=verbose) + stations += len(self.array[key].station) + print("%d array%s with %d station%s"%(len(self.array), _plurals(len(self.array)), stations, _plurals(stations))) + if self.vis.size: + if recursive: + print("====================================================================") + print("SUMMARY OF VISIBILITY MEASUREMENTS") + print("====================================================================") + for vis in self.vis: + vis.info() + print("%d visibility measurement%s"%(len(self.vis), _plurals(len(self.vis)))) + if self.vis2.size: + if recursive: + print("====================================================================") + print("SUMMARY OF VISIBILITY^2 MEASUREMENTS") + print("====================================================================") + for vis2 in self.vis2: + vis2.info() + print("%d visibility^2 measurement%s"%(len(self.vis2), _plurals(len(self.vis2)))) + if self.t3.size: + if recursive: + print("====================================================================") + print("SUMMARY OF T3 MEASUREMENTS") + print("====================================================================") + for t3 in self.t3: + t3.info() + print("%d closure phase measurement%s"%(len(self.t3), _plurals(len(self.t3)))) + + def save(self, filename): + """Write the contents of the oifits object to a file in OIFITS + format.""" + + if not self.isconsistent(): + print('oifits object is not consistent, refusing to go further') + return + + hdulist = pyfits.HDUList() + hdu = pyfits.PrimaryHDU() + hdu.header.update('DATE', datetime.datetime.now().strftime(format='%F'), comment='Creation date') + hdu.header.add_comment('Written by OIFITS Python module version %s'%__version__) + hdu.header.add_comment('http://www.mpia-hd.mpg.de/homes/boley/oifits/') + + wavelengthmap = {} + hdulist.append(hdu) + for insname, wavelength in self.wavelength.items(): + wavelengthmap[id(wavelength)] = insname + hdu = pyfits.new_table(pyfits.ColDefs(( + pyfits.Column(name='EFF_WAVE', format='1E', unit='METERS', array=wavelength.eff_wave), + pyfits.Column(name='EFF_BAND', format='1E', unit='METERS', array=wavelength.eff_band) + ))) + hdu.header.update('EXTNAME', 'OI_WAVELENGTH') + hdu.header.update('OI_REVN', 1, 'Revision number of the table definition') + hdu.header.update('INSNAME', insname, 'Name of detector, for cross-referencing') + hdulist.append(hdu) + + targetmap = {} + if self.target.size: + target_id = [] + target = [] + raep0 = [] + decep0 = [] + equinox = [] + ra_err = [] + dec_err = [] + sysvel = [] + veltyp = [] + veldef = [] + pmra = [] + pmdec = [] + pmra_err = [] + pmdec_err = [] + parallax = [] + para_err = [] + spectyp = [] + for i, targ in enumerate(self.target): + key = i+1 + targetmap[id(targ)] = key + target_id.append(key) + target.append(targ.target) + raep0.append(targ.raep0) + decep0.append(targ.decep0) + equinox.append(targ.equinox) + ra_err.append(targ.ra_err) + dec_err.append(targ.dec_err) + sysvel.append(targ.sysvel) + veltyp.append(targ.veltyp) + veldef.append(targ.veldef) + pmra.append(targ.pmra) + pmdec.append(targ.pmdec) + pmra_err.append(targ.pmra_err) + pmdec_err.append(targ.pmdec_err) + parallax.append(targ.parallax) + para_err.append(targ.para_err) + spectyp.append(targ.spectyp) + + hdu = pyfits.new_table(pyfits.ColDefs(( + pyfits.Column(name='TARGET_ID', format='1I', array=target_id), + pyfits.Column(name='TARGET', format='16A', array=target), + pyfits.Column(name='RAEP0', format='D1', unit='DEGREES', array=raep0), + pyfits.Column(name='DECEP0', format='D1', unit='DEGREES', array=decep0), + pyfits.Column(name='EQUINOX', format='E1', unit='YEARS', array=equinox), + pyfits.Column(name='RA_ERR', format='D1', unit='DEGREES', array=ra_err), + pyfits.Column(name='DEC_ERR', format='D1', unit='DEGREES', array=dec_err), + pyfits.Column(name='SYSVEL', format='D1', unit='M/S', array=sysvel), + pyfits.Column(name='VELTYP', format='A8', array=veltyp), + pyfits.Column(name='VELDEF', format='A8', array=veldef), + pyfits.Column(name='PMRA', format='D1', unit='DEG/YR', array=pmra), + pyfits.Column(name='PMDEC', format='D1', unit='DEG/YR', array=pmdec), + pyfits.Column(name='PMRA_ERR', format='D1', unit='DEG/YR', array=pmra_err), + pyfits.Column(name='PMDEC_ERR', format='D1', unit='DEG/YR', array=pmdec_err), + pyfits.Column(name='PARALLAX', format='E1', unit='DEGREES', array=parallax), + pyfits.Column(name='PARA_ERR', format='E1', unit='DEGREES', array=para_err), + pyfits.Column(name='SPECTYP', format='A16', array=spectyp) + ))) + hdu.header.update('EXTNAME', 'OI_TARGET') + hdu.header.update('OI_REVN', 1, 'Revision number of the table definition') + hdulist.append(hdu) + + arraymap = {} + stationmap = {} + for arrname, array in self.array.items(): + arraymap[id(array)] = arrname + tel_name = [] + sta_name = [] + sta_index = [] + diameter = [] + staxyz = [] + if array.station.size: + for i, station in enumerate(array.station, 1): + stationmap[id(station)] = i + tel_name.append(station.tel_name) + sta_name.append(station.sta_name) + sta_index.append(i) + diameter.append(station.diameter) + staxyz.append(station.staxyz) + hdu = pyfits.new_table(pyfits.ColDefs(( + pyfits.Column(name='TEL_NAME', format='16A', array=tel_name), + pyfits.Column(name='STA_NAME', format='16A', array=sta_name), + pyfits.Column(name='STA_INDEX', format='1I', array=sta_index), + pyfits.Column(name='DIAMETER', unit='METERS', format='1E', array=diameter), + pyfits.Column(name='STAXYZ', unit='METERS', format='3D', array=staxyz) + ))) + hdu.header.update('EXTNAME', 'OI_ARRAY') + hdu.header.update('OI_REVN', 1, 'Revision number of the table definition') + hdu.header.update('ARRNAME', arrname, comment='Array name, for cross-referencing') + hdu.header.update('FRAME', array.frame, comment='Coordinate frame') + hdu.header.update('ARRAYX', array.arrxyz[0], comment='Array center x coordinate (m)') + hdu.header.update('ARRAYY', array.arrxyz[1], comment='Array center y coordinate (m)') + hdu.header.update('ARRAYZ', array.arrxyz[2], comment='Array center z coordinate (m)') + hdulist.append(hdu) + + if self.vis.size: + # The tables are grouped by ARRNAME and INSNAME -- all + # observations which have the same ARRNAME and INSNAME are + # put into a single FITS binary table. + tables = {} + for vis in self.vis: + nwave = vis.wavelength.eff_wave.size + if vis.array: + key = (arraymap[id(vis.array)], wavelengthmap[id(vis.wavelength)]) + else: + key = (None, wavelengthmap[id(vis.wavelength)]) + if key in list(tables.keys()): + data = tables[key] + else: + data = tables[key] = {'target_id':[], 'time':[], 'mjd':[], 'int_time':[], + 'visamp':[], 'visamperr':[], 'visphi':[], 'visphierr':[], + 'cflux':[], 'cfluxerr':[], 'ucoord':[], 'vcoord':[], + 'sta_index':[], 'flag':[]} + data['target_id'].append(targetmap[id(vis.target)]) + if vis.timeobs: + time = vis.timeobs - refdate + data['time'].append(time.days * 24.0 * 3600.0 + time.seconds) + mjd = (vis.timeobs - _mjdzero).days + (vis.timeobs - _mjdzero).seconds / 3600.0 / 24.0 + data['mjd'].append(mjd) + else: + data['time'].append(None) + data['mjd'].append(None) + data['int_time'].append(vis.int_time) + if nwave == 1: + data['visamp'].append(vis.visamp[0]) + data['visamperr'].append(vis.visamperr[0]) + data['visphi'].append(vis.visphi[0]) + data['visphierr'].append(vis.visphierr[0]) + data['flag'].append(vis.flag[0]) + if vis.cflux != None: + data['cflux'].append(vis.cflux[0]) + else: + data['cflux'].append(None) + if vis.cfluxerr != None: + data['cfluxerr'].append(vis.cfluxerr[0]) + else: + data['cfluxerr'].append(None) + else: + data['visamp'].append(vis.visamp) + data['visamperr'].append(vis.visamperr) + data['visphi'].append(vis.visphi) + data['visphierr'].append(vis.visphierr) + data['flag'].append(vis.flag) + if vis.cflux != None: + data['cflux'].append(vis.cflux) + else: + cflux=np.empty(nwave) + cflux[:]=None + data['cflux'].append(cflux) + if vis.cfluxerr != None: + data['cfluxerr'].append(vis.cfluxerr) + else: + cfluxerr=np.empty(nwave) + cfluxerr[:]=None + data['cfluxerr'].append(cfluxerr) + data['ucoord'].append(vis.ucoord) + data['vcoord'].append(vis.vcoord) + if vis.station[0] and vis.station[1]: + data['sta_index'].append([stationmap[id(vis.station[0])], stationmap[id(vis.station[1])]]) + else: + data['sta_index'].append([-1, -1]) + for key in list(tables.keys()): + data = tables[key] + nwave = self.wavelength[key[1]].eff_wave.size + + hdu = pyfits.new_table(pyfits.ColDefs([ + pyfits.Column(name='TARGET_ID', format='1I', array=data['target_id']), + pyfits.Column(name='TIME', format='1D', unit='SECONDS', array=data['time']), + pyfits.Column(name='MJD', unit='DAY', format='1D', array=data['mjd']), + pyfits.Column(name='INT_TIME', format='1D', unit='SECONDS', array=data['int_time']), + pyfits.Column(name='VISAMP', format='%dD'%nwave, array=data['visamp']), + pyfits.Column(name='VISAMPERR', format='%dD'%nwave, array=data['visamperr']), + pyfits.Column(name='VISPHI', unit='DEGREES', format='%dD'%nwave, array=data['visphi']), + pyfits.Column(name='VISPHIERR', unit='DEGREES', format='%dD'%nwave, array=data['visphierr']), + pyfits.Column(name='CFLUX', format='%dD'%nwave, array=data['cflux']), + pyfits.Column(name='CFLUXERR', format='%dD'%nwave, array=data['cfluxerr']), + pyfits.Column(name='UCOORD', format='1D', unit='METERS', array=data['ucoord']), + pyfits.Column(name='VCOORD', format='1D', unit='METERS', array=data['vcoord']), + pyfits.Column(name='STA_INDEX', format='2I', array=data['sta_index'], null=-1), + pyfits.Column(name='FLAG', format='%dL'%nwave) + ])) + + # Setting the data of logical field via the + # pyfits.Column call above with length > 1 (eg + # format='171L' above) seems to be broken, atleast as + # of PyFITS 2.2.2 + hdu.data.field('FLAG').setfield(data['flag'], bool) + hdu.header.update('EXTNAME', 'OI_VIS') + hdu.header.update('OI_REVN', 1, 'Revision number of the table definition') + hdu.header.update('DATE-OBS', refdate.strftime('%F'), comment='Zero-point for table (UTC)') + if key[0]: hdu.header.update('ARRNAME', key[0], 'Identifies corresponding OI_ARRAY') + hdu.header.update('INSNAME', key[1], 'Identifies corresponding OI_WAVELENGTH table') + hdulist.append(hdu) + + if self.vis2.size: + tables = {} + for vis in self.vis2: + nwave = vis.wavelength.eff_wave.size + if vis.array: + key = (arraymap[id(vis.array)], wavelengthmap[id(vis.wavelength)]) + else: + key = (None, wavelengthmap[id(vis.wavelength)]) + if key in list(tables.keys()): + data = tables[key] + else: + data = tables[key] = {'target_id':[], 'time':[], 'mjd':[], 'int_time':[], + 'vis2data':[], 'vis2err':[], 'ucoord':[], 'vcoord':[], + 'sta_index':[], 'flag':[]} + data['target_id'].append(targetmap[id(vis.target)]) + if vis.timeobs: + time = vis.timeobs - refdate + data['time'].append(time.days * 24.0 * 3600.0 + time.seconds) + mjd = (vis.timeobs - _mjdzero).days + (vis.timeobs - _mjdzero).seconds / 3600.0 / 24.0 + data['mjd'].append(mjd) + else: + data['time'].append(None) + data['mjd'].append(None) + data['int_time'].append(vis.int_time) + if nwave == 1: + data['vis2data'].append(vis.vis2data[0]) + data['vis2err'].append(vis.vis2err[0]) + data['flag'].append(vis.flag[0]) + else: + data['vis2data'].append(vis.vis2data) + data['vis2err'].append(vis.vis2err) + data['flag'].append(vis.flag) + data['ucoord'].append(vis.ucoord) + data['vcoord'].append(vis.vcoord) + if vis.station[0] and vis.station[1]: + data['sta_index'].append([stationmap[id(vis.station[0])], stationmap[id(vis.station[1])]]) + else: + data['sta_index'].append([-1, -1]) + for key in list(tables.keys()): + data = tables[key] + nwave = self.wavelength[key[1]].eff_wave.size + + hdu = pyfits.new_table(pyfits.ColDefs([ + pyfits.Column(name='TARGET_ID', format='1I', array=data['target_id']), + pyfits.Column(name='TIME', format='1D', unit='SECONDS', array=data['time']), + pyfits.Column(name='MJD', format='1D', unit='DAY', array=data['mjd']), + pyfits.Column(name='INT_TIME', format='1D', unit='SECONDS', array=data['int_time']), + pyfits.Column(name='VIS2DATA', format='%dD'%nwave, array=data['vis2data']), + pyfits.Column(name='VIS2ERR', format='%dD'%nwave, array=data['vis2err']), + pyfits.Column(name='UCOORD', format='1D', unit='METERS', array=data['ucoord']), + pyfits.Column(name='VCOORD', format='1D', unit='METERS', array=data['vcoord']), + pyfits.Column(name='STA_INDEX', format='2I', array=data['sta_index'], null=-1), + pyfits.Column(name='FLAG', format='%dL'%nwave, array=data['flag']) + ])) + # Setting the data of logical field via the + # pyfits.Column call above with length > 1 (eg + # format='171L' above) seems to be broken, atleast as + # of PyFITS 2.2.2 + hdu.data.field('FLAG').setfield(data['flag'], bool) + hdu.header.update('EXTNAME', 'OI_VIS2') + hdu.header.update('OI_REVN', 1, 'Revision number of the table definition') + hdu.header.update('DATE-OBS', refdate.strftime('%F'), comment='Zero-point for table (UTC)') + if key[0]: hdu.header.update('ARRNAME', key[0], 'Identifies corresponding OI_ARRAY') + hdu.header.update('INSNAME', key[1], 'Identifies corresponding OI_WAVELENGTH table') + hdulist.append(hdu) + + if self.t3.size: + tables = {} + for t3 in self.t3: + nwave = t3.wavelength.eff_wave.size + if t3.array: + key = (arraymap[id(t3.array)], wavelengthmap[id(t3.wavelength)]) + else: + key = (None, wavelengthmap[id(t3.wavelength)]) + if key in list(tables.keys()): + data = tables[key] + else: + data = tables[key] = {'target_id':[], 'time':[], 'mjd':[], 'int_time':[], + 't3amp':[], 't3amperr':[], 't3phi':[], 't3phierr':[], + 'u1coord':[], 'v1coord':[], 'u2coord':[], 'v2coord':[], + 'sta_index':[], 'flag':[]} + data['target_id'].append(targetmap[id(t3.target)]) + if t3.timeobs: + time = t3.timeobs - refdate + data['time'].append(time.days * 24.0 * 3600.0 + time.seconds) + mjd = (t3.timeobs - _mjdzero).days + (t3.timeobs - _mjdzero).seconds / 3600.0 / 24.0 + data['mjd'].append(mjd) + else: + data['time'].append(None) + data['mjd'].append(None) + data['int_time'].append(t3.int_time) + if nwave == 1: + data['t3amp'].append(t3.t3amp[0]) + data['t3amperr'].append(t3.t3amperr[0]) + data['t3phi'].append(t3.t3phi[0]) + data['t3phierr'].append(t3.t3phierr[0]) + data['flag'].append(t3.flag[0]) + else: + data['t3amp'].append(t3.t3amp) + data['t3amperr'].append(t3.t3amperr) + data['t3phi'].append(t3.t3phi) + data['t3phierr'].append(t3.t3phierr) + data['flag'].append(t3.flag) + data['u1coord'].append(t3.u1coord) + data['v1coord'].append(t3.v1coord) + data['u2coord'].append(t3.u2coord) + data['v2coord'].append(t3.v2coord) + if t3.station[0] and t3.station[1] and t3.station[2]: + data['sta_index'].append([stationmap[id(t3.station[0])], stationmap[id(t3.station[1])], stationmap[id(t3.station[2])]]) + else: + data['sta_index'].append([-1, -1, -1]) + for key in list(tables.keys()): + data = tables[key] + nwave = self.wavelength[key[1]].eff_wave.size + + hdu = pyfits.new_table(pyfits.ColDefs(( + pyfits.Column(name='TARGET_ID', format='1I', array=data['target_id']), + pyfits.Column(name='TIME', format='1D', unit='SECONDS', array=data['time']), + pyfits.Column(name='MJD', format='1D', unit='DAY', array=data['mjd']), + pyfits.Column(name='INT_TIME', format='1D', unit='SECONDS', array=data['int_time']), + pyfits.Column(name='T3AMP', format='%dD'%nwave, array=data['t3amp']), + pyfits.Column(name='T3AMPERR', format='%dD'%nwave, array=data['t3amperr']), + pyfits.Column(name='T3PHI', format='%dD'%nwave, unit='DEGREES', array=data['t3phi']), + pyfits.Column(name='T3PHIERR', format='%dD'%nwave, unit='DEGREES', array=data['t3phierr']), + pyfits.Column(name='U1COORD', format='1D', unit='METERS', array=data['u1coord']), + pyfits.Column(name='V1COORD', format='1D', unit='METERS', array=data['v1coord']), + pyfits.Column(name='U2COORD', format='1D', unit='METERS', array=data['u2coord']), + pyfits.Column(name='V2COORD', format='1D', unit='METERS', array=data['v2coord']), + pyfits.Column(name='STA_INDEX', format='3I', array=data['sta_index'], null=-1), + pyfits.Column(name='FLAG', format='%dL'%nwave, array=data['flag']) + ))) + # Setting the data of logical field via the + # pyfits.Column call above with length > 1 (eg + # format='171L' above) seems to be broken, atleast as + # of PyFITS 2.2.2 + hdu.data.field('FLAG').setfield(data['flag'], bool) + hdu.header.update('EXTNAME', 'OI_T3') + hdu.header.update('OI_REVN', 1, 'Revision number of the table definition') + hdu.header.update('DATE-OBS', refdate.strftime('%F'), 'Zero-point for table (UTC)') + if key[0]: hdu.header.update('ARRNAME', key[0], 'Identifies corresponding OI_ARRAY') + hdu.header.update('INSNAME', key[1], 'Identifies corresponding OI_WAVELENGTH table') + hdulist.append(hdu) + + hdulist.writeto(filename, clobber=True) + + + +def open(filename, quiet=False): + """Open an OIFITS file.""" + + newobj = oifits() + targetmap = {} + sta_indices = {} + + if not quiet: + print("Opening %s"%filename) + hdulist = pyfits.open(filename) + # First get all the OI_TARGET, OI_WAVELENGTH and OI_ARRAY tables + for hdu in hdulist: + header = hdu.header + data = hdu.data + if hdu.name == 'OI_WAVELENGTH': + if newobj.wavelength == None: newobj.wavelength = {} + insname = header['INSNAME'] + newobj.wavelength[insname] = OI_WAVELENGTH(data.field('EFF_WAVE'), data.field('EFF_BAND')) + elif hdu.name == 'OI_TARGET': + for row in data: + target_id = row['TARGET_ID'] + target = OI_TARGET(target=row['TARGET'], raep0=row['RAEP0'], decep0=row['DECEP0'], + equinox=row['EQUINOX'], ra_err=row['RA_ERR'], dec_err=row['DEC_ERR'], + sysvel=row['SYSVEL'], veltyp=row['VELTYP'], veldef=row['VELDEF'], + pmra=row['PMRA'], pmdec=row['PMDEC'], pmra_err=row['PMRA_ERR'], + pmdec_err=row['PMDEC_ERR'], parallax=row['PARALLAX'], + para_err=row['PARA_ERR'], spectyp=row['SPECTYP']) + newobj.target = np.append(newobj.target, target) + targetmap[target_id] = target + elif hdu.name == 'OI_ARRAY': + if newobj.array == None: newobj.array = {} + arrname = header['ARRNAME'] + frame = header['FRAME'] + arrxyz = np.array([header['ARRAYX'], header['ARRAYY'], header['ARRAYZ']]) + newobj.array[arrname] = OI_ARRAY(frame, arrxyz, stations=data) + # Save the sta_index for each array, as we will need it + # later to match measurements to stations + sta_indices[arrname] = data.field('sta_index') + + # Then get any science measurements + for hdu in hdulist: + header = hdu.header + data = hdu.data + if hdu.name in ('OI_VIS', 'OI_VIS2', 'OI_T3'): + if 'ARRNAME' in list(header.keys()): + arrname = header['ARRNAME'] + else: + arrname = None + if arrname and newobj.array: + array = newobj.array[arrname] + else: + array = None + wavelength = newobj.wavelength[header['INSNAME']] + if hdu.name == 'OI_VIS': + for row in data: + date = header['DATE-OBS'].split('-') + timeobs = datetime.datetime(int(date[0]), int(date[1]), int(date[2])) + datetime.timedelta(seconds=np.around(row.field('TIME'), 2)) + int_time = row.field('INT_TIME') + visamp = np.reshape(row.field('VISAMP'), -1) + visamperr = np.reshape(row.field('VISAMPERR'), -1) + visphi = np.reshape(row.field('VISPHI'), -1) + visphierr = np.reshape(row.field('VISPHIERR'), -1) + if 'CFLUX' in row.array.names: cflux = np.reshape(row.field('CFLUX'), -1) + else: cflux = None + if 'CFLUXERR' in row.array.names: cfluxerr = np.reshape(row.field('CFLUXERR'), -1) + else: cfluxerr = None + flag = np.reshape(row.field('FLAG'), -1) + ucoord = row.field('UCOORD') + vcoord = row.field('VCOORD') + target = targetmap[row.field('TARGET_ID')] + if array: + sta_index = row.field('STA_INDEX') + s1 = array.station[sta_indices[arrname] == sta_index[0]][0] + s2 = array.station[sta_indices[arrname] == sta_index[1]][0] + station = [s1, s2] + else: + station = [None, None] + newobj.vis = np.append(newobj.vis, OI_VIS(timeobs=timeobs, int_time=int_time, visamp=visamp, + visamperr=visamperr, visphi=visphi, visphierr=visphierr, + flag=flag, ucoord=ucoord, vcoord=vcoord, wavelength=wavelength, + target=target, array=array, station=station, cflux=cflux, + cfluxerr=cfluxerr)) + elif hdu.name == 'OI_VIS2': + for row in data: + date = header['DATE-OBS'].split('-') + timeobs = datetime.datetime(int(date[0]), int(date[1]), int(date[2])) + datetime.timedelta(seconds=np.around(row.field('TIME'), 2)) + int_time = row.field('INT_TIME') + vis2data = np.reshape(row.field('VIS2DATA'), -1) + vis2err = np.reshape(row.field('VIS2ERR'), -1) + flag = np.reshape(row.field('FLAG'), -1) + ucoord = row.field('UCOORD') + vcoord = row.field('VCOORD') + target = targetmap[row.field('TARGET_ID')] + if array: + sta_index = row.field('STA_INDEX') + s1 = array.station[sta_indices[arrname] == sta_index[0]][0] + s2 = array.station[sta_indices[arrname] == sta_index[1]][0] + station = [s1, s2] + else: + station = [None, None] + newobj.vis2 = np.append(newobj.vis2, OI_VIS2(timeobs=timeobs, int_time=int_time, vis2data=vis2data, + vis2err=vis2err, flag=flag, ucoord=ucoord, vcoord=vcoord, + wavelength=wavelength, target=target, array=array, + station=station)) + elif hdu.name == 'OI_T3': + for row in data: + date = header['DATE-OBS'].split('-') + timeobs = datetime.datetime(int(date[0]), int(date[1]), int(date[2])) + datetime.timedelta(seconds=np.around(row.field('TIME'), 2)) + int_time = row.field('INT_TIME') + t3amp = np.reshape(row.field('T3AMP'), -1) + t3amperr = np.reshape(row.field('T3AMPERR'), -1) + t3phi = np.reshape(row.field('T3PHI'), -1) + t3phierr = np.reshape(row.field('T3PHIERR'), -1) + flag = np.reshape(row.field('FLAG'), -1) + u1coord = row.field('U1COORD') + v1coord = row.field('V1COORD') + u2coord = row.field('U2COORD') + v2coord = row.field('V2COORD') + target = targetmap[row.field('TARGET_ID')] + if array: + sta_index = row.field('STA_INDEX') + s1 = array.station[sta_indices[arrname] == sta_index[0]][0] + s2 = array.station[sta_indices[arrname] == sta_index[1]][0] + s3 = array.station[sta_indices[arrname] == sta_index[2]][0] + station = [s1, s2, s3] + else: + station = [None, None, None] + newobj.t3 = np.append(newobj.t3, OI_T3(timeobs=timeobs, int_time=int_time, t3amp=t3amp, + t3amperr=t3amperr, t3phi=t3phi, t3phierr=t3phierr, + flag=flag, u1coord=u1coord, v1coord=v1coord, u2coord=u2coord, + v2coord=v2coord, wavelength=wavelength, target=target, + array=array, station=station)) + + hdulist.close() + if not quiet: + newobj.info(recursive=False) + + return newobj diff --git a/io/save.py b/io/save.py new file mode 100644 index 00000000..da7d11e7 --- /dev/null +++ b/io/save.py @@ -0,0 +1,1017 @@ +# save.py +# functions to save observation & image data from files +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import astropy.io.fits as fits +import datetime +import h5py + +import ehtim.io.writeoifits +import ehtim.io.oifits +from astropy.time import Time + +import ehtim.const_def as ehc +import ehtim.observing.obs_helpers as obsh + +################################################################################################## +# Image IO +################################################################################################## + + +def save_im_txt(im, fname, mjd=False, time=False): + """Save image data to text file. + + Args: + im (Image): image object + fname (str): path to output text file + mjd (int): MJD of saved image + time (float): UTC time of saved image + + Returns: + """ + + # Transform to Stokes parameters: + if im.polrep != 'stokes' or im.pol_prim != 'I': + im = im.switch_polrep(polrep_out='stokes', pol_prim_out=None) + + # Coordinate values + pdimas = im.psize/ehc.RADPERAS + xs = np.array([[j for j in range(im.xdim)] for i in range(im.ydim)]).reshape(im.xdim*im.ydim, 1) + xs = pdimas * (xs[::-1] - im.xdim/2.0) + ys = np.array([[i for j in range(im.xdim)] for i in range(im.ydim)]).reshape(im.xdim*im.ydim, 1) + ys = pdimas * (ys[::-1] - im.xdim/2.0) + + # If V values but no Q/U values, make Q/U zero + if len(im.vvec) and not len(im.qvec): + im.qvec = 0*im.vvec + im.uvec = 0*im.vvec + + # Format Data + if len(im.qvec) and len(im.vvec): + outdata = np.hstack((xs, ys, (im.imvec).reshape(im.xdim*im.ydim, 1), + (im.qvec).reshape(im.xdim*im.ydim, 1), + (im.uvec).reshape(im.xdim*im.ydim, 1), + (im.vvec).reshape(im.xdim*im.ydim, 1))) + hf = "x (as) y (as) I (Jy/pixel) Q (Jy/pixel) U (Jy/pixel) V (Jy/pixel)" + + fmts = "%10.10f %10.10f %10.10f %10.10f %10.10f %10.10f" + + elif len(im.qvec): + outdata = np.hstack((xs, ys, (im.imvec).reshape(im.xdim*im.ydim, 1), + (im.qvec).reshape(im.xdim*im.ydim, 1), + (im.uvec).reshape(im.xdim*im.ydim, 1))) + hf = "x (as) y (as) I (Jy/pixel) Q (Jy/pixel) U (Jy/pixel)" + + fmts = "%10.10f %10.10f %10.10f %10.10f %10.10f" + + else: + outdata = np.hstack((xs, ys, (im.imvec).reshape(im.xdim*im.ydim, 1))) + hf = "x (as) y (as) I (Jy/pixel)" + fmts = "%10.10f %10.10f %10.10f" + + # Header + if not mjd: + mjd = float(im.mjd) + if not time: + time = im.time + mjd += (time/24.) + + head = ("SRC: %s \n" % im.source + + "RA: " + obsh.rastring(im.ra) + "\n" + "DEC: " + obsh.decstring(im.dec) + "\n" + + "MJD: %.6f \n" % (float(mjd)) + + "RF: %.4f GHz \n" % (im.rf/1e9) + + "FOVX: %i pix %f as \n" % (im.xdim, pdimas * im.xdim) + + "FOVY: %i pix %f as \n" % (im.ydim, pdimas * im.ydim) + + "------------------------------------\n" + hf) + + # Save + np.savetxt(fname, outdata, header=head, fmt=fmts) + return + +# TODO save image in circular basis? +def save_im_fits(im, fname, mjd=False, time=False): + """Save image data to a fits file. + + Args: + im (Image): image object + fname (str): path to output fits file + mjd (int): MJD of saved image + time (float): UTC time of saved image + + Returns: + """ + + # Transform to Stokes parameters: + if (im.polrep != 'stokes') or (im.pol_prim != 'I'): + im = im.switch_polrep(polrep_out='stokes', pol_prim_out=None) + + # Create header and fill in some values + header = fits.Header() + header['OBJECT'] = im.source + header['CTYPE1'] = 'RA---SIN' + header['CTYPE2'] = 'DEC--SIN' + header['CDELT1'] = -im.psize/ehc.DEGREE + header['CDELT2'] = im.psize/ehc.DEGREE + header['OBSRA'] = im.ra * 180/12. + header['OBSDEC'] = im.dec + header['FREQ'] = im.rf + + # TODO these are the default values for centered images + # TODO support for arbitrary CRPIX? + header['CRPIX1'] = im.xdim/2. + .5 + header['CRPIX2'] = im.ydim/2. + .5 + + if not mjd: + mjd = float(im.mjd) + if not time: + time = im.time + mjd += (time/24.) + + header['MJD'] = float(mjd) + header['TELESCOP'] = 'VLBI' + header['BUNIT'] = 'JY/PIXEL' + header['STOKES'] = 'I' + + # Create the fits image + image = np.reshape(im.imvec, (im.ydim, im.xdim))[::-1, :] # flip y axis! + hdu = fits.PrimaryHDU(image, header=header) + hdulist = [hdu] + if len(im.qvec): + qimage = np.reshape(im.qvec, (im.ydim, im.xdim))[::-1, :] + uimage = np.reshape(im.uvec, (im.ydim, im.xdim))[::-1, :] + header['STOKES'] = 'Q' + hduq = fits.ImageHDU(qimage, name='Q', header=header) + header['STOKES'] = 'U' + hduu = fits.ImageHDU(uimage, name='U', header=header) + hdulist = [hdu, hduq, hduu] + if len(im.vvec): + vimage = np.reshape(im.vvec, (im.ydim, im.xdim))[::-1, :] + header['STOKES'] = 'V' + hduv = fits.ImageHDU(vimage, name='V', header=header) + hdulist.append(hduv) + + hdulist = fits.HDUList(hdulist) + + # Save fits + hdulist.writeto(fname, overwrite=True) + + return + +################################################################################################## +# Movie IO +################################################################################################## + + +def save_mov_hdf5(mov, fname, mjd=False): + """Save movie data to an hdf5 file. + + Args: + mov (Movie): movie object + fname (str): basename of output fits file + mjd (int): MJD of saved movie + + Returns: + """ + + # TODO: Currently only supports one polarization! + with h5py.File(fname, 'w') as file: + + head = file.create_dataset('header', (0,), dtype="S10") + + if mjd is False: + mjd = mov.mjd + + head.attrs['mjd'] = np.string_(str(mjd)) + head.attrs['psize'] = np.string_(str(mov.psize)) + head.attrs['source'] = np.string_(str(mov.source)) + head.attrs['ra'] = np.string_(str(mov.ra)) + head.attrs['dec'] = np.string_(str(mov.dec)) + head.attrs['rf'] = np.string_(str(mov.rf)) + head.attrs['polrep'] = np.string_(str(mov.polrep)) + head.attrs['pol_prim'] = np.string_(str(mov.pol_prim)) + + name = 'times' + times = mov.times + dset = file.create_dataset(name, data=times, dtype='f8') + + name = mov.pol_prim + frames = mov.frames.reshape((mov.nframes, mov.ydim, mov.xdim)) + dset = file.create_dataset(name, data=frames, dtype='f8') + + for pol in list(mov._movdict.keys()): + if pol == mov.pol_prim: + continue + polframes = mov._movdict[pol] + if len(polframes): + polframes = polframes.reshape((mov.nframes, mov.ydim, mov.xdim)) + dset = file.create_dataset(pol, data=polframes, dtype='f8') + return + + +def save_mov_fits(mov, fname, mjd=False): + """Save movie data to series of fits files. + + Args: + mov (Movie): movie object + fname (str): basename of output fits file + mjd (int): MJD of saved movie + + Returns: + """ + + if mjd is False: + mjd = mov.mjd + + for i in range(mov.nframes): + time_frame = mov.times[i] + fname_frame = fname + "%05d" % i + print('saving file '+fname_frame) + frame_im = mov.get_frame(i) + save_im_fits(frame_im, fname_frame, mjd=mjd, time=time_frame) + + return + + +def save_mov_txt(mov, fname, mjd=False): + """Save movie data to series of text files. + + Args: + mov (Movie): movie object + fname (str): basename of output text file + mjd (int): MJD of saved movie + + Returns: + """ + + if mjd is False: + mjd = mov.mjd + + for i in range(mov.nframes): + time_frame = mov.times[i] + fname_frame = fname + "%05d" % i + print('saving file '+fname_frame) + frame_im = mov.get_frame(i) + save_im_txt(frame_im, fname_frame, mjd=mjd, time=time_frame) + + return + + +################################################################################################## +# Array IO +################################################################################################## + +def save_array_txt(arr, fname): + """Save the array data in a text file. + + Args: + arr (Array): array object + fname (str): name of output text file + + Returns: + """ + + if type(arr) == np.ndarray: + tarr = arr + else: + try: + tarr = arr.tarr + except: + print("Array format not recognized!") + + out = ("#Site X(m) Y(m) Z(m) " + + "SEFDR SEFDL FR_PAR FR_EL FR_OFF " + + "DR_RE DR_IM DL_RE DL_IM \n") + for scope in range(len(tarr)): + dat = (tarr[scope]['site'], + tarr[scope]['x'], tarr[scope]['y'], tarr[scope]['z'], + tarr[scope]['sefdr'], tarr[scope]['sefdl'], + tarr[scope]['fr_par'], tarr[scope]['fr_elev'], tarr[scope]['fr_off'], + tarr[scope]['dr'].real, tarr[scope]['dr'].imag, + tarr[scope]['dl'].real, tarr[scope]['dl'].imag + ) + out += "%-8s %15.5f %15.5f %15.5f %8.2f %8.2f %5.2f %5.2f %5.2f %8.4f %8.4f %8.4f %8.4f \n" % dat + f = open(fname, 'w') + f.write(out) + f.close() + return + + +################################################################################################## +# Observation IO +################################################################################################## +def save_obs_txt(obs, fname): + """Save the observation data in a text file. + + Args: + obs (Obsdata): obsdata object + fname (str): name of output text file + + Returns: + """ + + # output times must be in utc + obs = obs.switch_timetype(timetype_out='UTC') + + # Get the necessary data and the header + if obs.polrep == 'stokes': + outdata = obs.unpack(['time', 'tint', 't1', 't2', 'tau1', 'tau2', + 'u', 'v', 'amp', 'phase', 'qamp', 'qphase', 'uamp', 'uphase', + 'vamp', 'vphase', + 'sigma', 'qsigma', 'usigma', 'vsigma']) + elif obs.polrep == 'circ': + outdata = obs.unpack(['time', 'tint', 't1', 't2', 'tau1', 'tau2', + 'u', 'v', 'rramp', 'rrphase', 'llamp', 'llphase', 'rlamp', 'rlphase', + 'lramp', 'lrphase', + 'rrsigma', 'llsigma', 'rlsigma', 'lrsigma']) + + else: + raise Exception("obs.polrep not 'stokes' or 'circ'!") + + head = ("SRC: %s \n" % obs.source + + "RA: " + obsh.rastring(obs.ra) + "\n" + "DEC: " + obsh.decstring(obs.dec) + "\n" + + "MJD: %i \n" % obs.mjd + + "RF: %.4f GHz \n" % (obs.rf/1e9) + + "BW: %.4f GHz \n" % (obs.bw/1e9) + + "PHASECAL: %i \n" % obs.phasecal + + "AMPCAL: %i \n" % obs.ampcal + + "OPACITYCAL: %i \n" % obs.opacitycal + + "DCAL: %i \n" % obs.dcal + + "FRCAL: %i \n" % obs.frcal + + "----------------------------------------------------------------------" + + "------------------------------------------------------------------\n" + + "Site X(m) Y(m) Z(m) " + + "SEFDR SEFDL FR_PAR FR_EL FR_OFF " + + "DR_RE DR_IM DL_RE DL_IM \n" + ) + + for i in range(len(obs.tarr)): + head += ("%-8s %15.5f %15.5f %15.5f %8.2f %8.2f %5.2f %5.2f %5.2f %8.4f %8.4f %8.4f %8.4f \n" % + (obs.tarr[i]['site'], + obs.tarr[i]['x'], obs.tarr[i]['y'], obs.tarr[i]['z'], + obs.tarr[i]['sefdr'], obs.tarr[i]['sefdl'], + obs.tarr[i]['fr_par'], obs.tarr[i]['fr_elev'], obs.tarr[i]['fr_off'], + (obs.tarr[i]['dr']).real, (obs.tarr[i]['dr']).imag, + (obs.tarr[i]['dl']).real, (obs.tarr[i]['dl']).imag + )) + + if obs.polrep == 'stokes': + head += ( + "----------------------------------------------------------------------" + + "------------------------------------------------------------------\n" + + "time (hr) tint T1 T2 Tau1 Tau2 U (lambda) V (lambda) " + + "Iamp (Jy) Iphase(d) Qamp (Jy) Qphase(d) Uamp (Jy) Uphase(d) Vamp (Jy) Vphase(d) " + + "Isigma (Jy) Qsigma (Jy) Usigma (Jy) Vsigma (Jy)" + ) + elif obs.polrep == 'circ': + head += ( + "----------------------------------------------------------------------" + + "------------------------------------------------------------------\n" + + "time (hr) tint T1 T2 Tau1 Tau2 U (lambda) V (lambda) " + + "RRamp (Jy) RRphase(d) LLamp (Jy) LLphase(d) RLamp (Jy) RLphase(d) LRamp (Jy) LRphase(d) " + + "RRsigma (Jy) LLsigma (Jy) RLsigma (Jy) LRsigma (Jy)" + ) + + # Format and save the data + fmts = ("%011.8f %4.2f %6s %6s %4.2f %4.2f %16.4f %16.4f " + + "%10.8f %10.4f %10.8f %10.4f %10.8f %10.4f %10.8f %10.4f " + + "%10.8f %10.8f %10.8f %10.8f") + np.savetxt(fname, outdata, header=head, fmt=fmts) + return + + +def save_obs_uvfits(obs, fname=None, force_singlepol=None, polrep_out='circ'): + """Save observation data to uvfits. + + Args: + obs (Obsdata): obsdata object + fname (str): path to output fits file, or None to return HDUList only + force_singlepol (str): if 'R' or 'L', will interpret stokes I field as 'RR' or 'LL' + polrep_out (str): 'circ' or 'stokes': how data should be stored in the uvfits file + Returns: + hdulist (astropy.io.fits.HDUList) + + """ + + # output times must be in utc + obs = obs.switch_timetype(timetype_out='UTC') + + if polrep_out == 'circ': + obs = obs.switch_polrep('circ') + elif polrep_out == 'stokes': + obs = obs.switch_polrep('stokes') + else: + raise Exception("'polrep_out' in 'save_obs_uvfits' must be 'circ' or 'stokes'!") + + hdulist_new = fits.HDUList() + hdulist_new.append(fits.GroupsHDU()) + + ##################### + # AIPS Data TABLE + ##################### + + # Data header (based on the BU format) + MJD_0 = 2400000.5 + header = hdulist_new['PRIMARY'].header + header['OBSRA'] = obs.ra * 180./12. + header['OBSDEC'] = obs.dec + header['OBJECT'] = obs.source + header['MJD'] = float(obs.mjd) + header['DATE-OBS'] = Time(obs.mjd + MJD_0, format='jd', scale='utc').iso[0:10] + header['BSCALE'] = 1.0 + header['BZERO'] = 0.0 + header['BUNIT'] = 'JY' + header['VELREF'] = 3 # TODO ?? + header['EQUINOX'] = 'J2000' + header['ALTRPIX'] = 1.e0 + header['ALTRVAL'] = 0.e0 + header['TELESCOP'] = 'VLBA' # TODO Can we change this field? + header['INSTRUME'] = 'VLBA' + header['OBSERVER'] = 'EHT' + + header['CTYPE2'] = 'COMPLEX' + header['CRVAL2'] = 1.e0 + header['CDELT2'] = 1.e0 + header['CRPIX2'] = 1.e0 + header['CROTA2'] = 0.e0 + header['CTYPE3'] = 'STOKES' + if polrep_out == 'circ': + header['CRVAL3'] = -1.e0 + header['CDELT3'] = -1.e0 + elif polrep_out == 'stokes': + header['CRVAL3'] = 1.e0 + header['CDELT3'] = 1.e0 + header['CRPIX3'] = 1.e0 + header['CROTA3'] = 0.e0 + header['CTYPE4'] = 'FREQ' + header['CRVAL4'] = obs.rf + header['CDELT4'] = obs.bw + header['CRPIX4'] = 1.e0 + header['CROTA4'] = 0.e0 + header['CTYPE6'] = 'RA' + header['CRVAL6'] = header['OBSRA'] + header['CDELT6'] = 1.e0 + header['CRPIX6'] = 1.e0 + header['CROTA6'] = 0.e0 + header['CTYPE7'] = 'DEC' + header['CRVAL7'] = header['OBSDEC'] + header['CDELT7'] = 1.e0 + header['CRPIX7'] = 1.e0 + header['CROTA7'] = 0.e0 + header['PTYPE1'] = 'UU---SIN' + header['PSCAL1'] = 1.0 + header['PZERO1'] = 0.e0 + header['PTYPE2'] = 'VV---SIN' + header['PSCAL2'] = 1.0 + header['PZERO2'] = 0.e0 + header['PTYPE3'] = 'WW---SIN' + header['PSCAL3'] = 1.0 + header['PZERO3'] = 0.e0 + header['PTYPE4'] = 'BASELINE' + header['PSCAL4'] = 1.e0 + header['PZERO4'] = 0.e0 + header['PTYPE5'] = 'DATE' + header['PSCAL5'] = 1.e0 + header['PZERO5'] = 0.e0 + header['PTYPE6'] = 'DATE' + header['PSCAL6'] = 1.e0 + header['PZERO6'] = 0.0 + header['PTYPE7'] = 'INTTIM' + header['PSCAL7'] = 1.e0 + header['PZERO7'] = 0.e0 + header['PTYPE8'] = 'TAU1' + header['PSCAL8'] = 1.e0 + header['PZERO8'] = 0.e0 + header['PTYPE9'] = 'TAU2' + header['PSCAL9'] = 1.e0 + header['PZERO9'] = 0.e0 + header['history'] = "AIPS SORT ORDER='TB'" + + # Get data + + if polrep_out == 'circ': + obsdata = obs.unpack(['time', 'tint', 'u', 'v', + 'rrvis', 'llvis', 'rlvis', 'lrvis', + 'rrsigma', 'llsigma', 'rlsigma', 'lrsigma', + 't1', 't2', 'tau1', 'tau2']) + elif polrep_out == 'stokes': + obsdata = obs.unpack(['time', 'tint', 'u', 'v', 'vis', 'qvis', 'uvis', 'vvis', + 'sigma', 'qsigma', 'usigma', 'vsigma', 't1', 't2', 'tau1', 'tau2']) + + ndat = len(obsdata['time']) + + # times and tints + jds = (2400000.5 + obs.mjd) * np.ones(len(obsdata)) + fractimes = (obsdata['time']/24.0) + tints = obsdata['tint'] + + # Baselines + t1 = [obs.tkey[scope] + 1 for scope in obsdata['t1']] + t2 = [obs.tkey[scope] + 1 for scope in obsdata['t2']] + bl = 256*np.array(t1) + np.array(t2) + + # opacities + tau1 = obsdata['tau1'] + tau2 = obsdata['tau2'] + + # uv are in lightseconds + u = obsdata['u']/obs.rf + v = obsdata['v']/obs.rf + + # rr, ll, lr, rl, weights + + if polrep_out == 'circ': + rr = obsdata['rrvis'] + ll = obsdata['llvis'] + rl = obsdata['rlvis'] + lr = obsdata['lrvis'] + weightrr = 1.0/(obsdata['rrsigma']**2) + weightll = 1.0/(obsdata['llsigma']**2) + weightrl = 1.0/(obsdata['rlsigma']**2) + weightlr = 1.0/(obsdata['lrsigma']**2) + + # If necessary, enforce single polarization + if force_singlepol == 'L': + if obs.polrep == 'stokes': + raise Exception("force_singlepol only works with obs.polrep=='stokes'!") + print("force_singlepol='L': treating Stokes 'I' as LL and ignoring Q,U,V!!") + ll = obsdata['vis'] + rr = rr * 0.0 + rl = rl * 0.0 + lr = lr * 0.0 + weightrr = weightrr * 0.0 + weightrl = weightrl * 0.0 + weightlr = weightlr * 0.0 + elif force_singlepol == 'R': + if obs.polrep == 'stokes': + raise Exception("force_singlepol only works with obs.polrep=='stokes'!") + print("force_singlepol='R': treating Stokes 'I' as RR and ignoring Q,U,V!!") + rr = obsdata['vis'] + ll = rr * 0.0 + rl = rl * 0.0 + lr = lr * 0.0 + weightll = weightll * 0.0 + weightrl = weightrl * 0.0 + weightlr = weightlr * 0.0 + + dat1 = rr + dat2 = ll + dat3 = rl + dat4 = lr + weight1 = weightrr + weight2 = weightll + weight3 = weightrl + weight4 = weightlr + + elif polrep_out == 'stokes': + dat1 = obsdata['vis'] + dat2 = obsdata['qvis'] + dat3 = obsdata['uvis'] + dat4 = obsdata['vvis'] + weight1 = 1.0/(obsdata['sigma']**2) + weight2 = 1.0/(obsdata['qsigma']**2) + weight3 = 1.0/(obsdata['usigma']**2) + weight4 = 1.0/(obsdata['vsigma']**2) + + # Replace nans by zeros (including zero weights) + dat1 = np.nan_to_num(dat1) + dat2 = np.nan_to_num(dat2) + dat3 = np.nan_to_num(dat3) + dat4 = np.nan_to_num(dat4) + weight1 = np.nan_to_num(weight1) + weight2 = np.nan_to_num(weight2) + weight3 = np.nan_to_num(weight3) + weight4 = np.nan_to_num(weight4) + + # Data array + outdat = np.zeros((ndat, 1, 1, 1, 1, 4, 3)) + outdat[:, 0, 0, 0, 0, 0, 0] = np.real(dat1) + outdat[:, 0, 0, 0, 0, 0, 1] = np.imag(dat1) + outdat[:, 0, 0, 0, 0, 0, 2] = weight1 + outdat[:, 0, 0, 0, 0, 1, 0] = np.real(dat2) + outdat[:, 0, 0, 0, 0, 1, 1] = np.imag(dat2) + outdat[:, 0, 0, 0, 0, 1, 2] = weight2 + outdat[:, 0, 0, 0, 0, 2, 0] = np.real(dat3) + outdat[:, 0, 0, 0, 0, 2, 1] = np.imag(dat3) + outdat[:, 0, 0, 0, 0, 2, 2] = weight3 + outdat[:, 0, 0, 0, 0, 3, 0] = np.real(dat4) + outdat[:, 0, 0, 0, 0, 3, 1] = np.imag(dat4) + outdat[:, 0, 0, 0, 0, 3, 2] = weight4 + + # Save data + pars = ['UU---SIN', 'VV---SIN', 'WW---SIN', 'BASELINE', 'DATE', 'DATE', + 'INTTIM', 'TAU1', 'TAU2'] + x = fits.GroupData(outdat, parnames=pars, + pardata=[u, v, np.zeros(ndat), bl, jds, fractimes, tints, tau1, tau2], + bitpix=-32) + + hdulist_new['PRIMARY'].data = x + hdulist_new['PRIMARY'].header = header # TODO necessary ?? + + ##################### + # AIPS AN TABLE + ##################### + + # Load the array data + tarr = obs.tarr + tnames = tarr['site'] + tnums = np.arange(1, len(tarr)+1) + xyz = np.array([[tarr[i]['x'], tarr[i]['y'], tarr[i]['z']] for i in np.arange(len(tarr))]) + sefd = tarr['sefdr'] + + nsta = len(tnames) + col1 = fits.Column(name='ANNAME', format='8A', array=tnames) + col2 = fits.Column(name='STABXYZ', format='3D', unit='METERS', array=xyz) + col3 = fits.Column(name='NOSTA', format='1J', array=tnums) + colfin = fits.Column(name='SEFD', format='1D', array=sefd) + + # TODO these antenna fields+header are questionable - look into them + col4 = fits.Column(name='MNTSTA', format='1J', + array=np.zeros(nsta)) + col5 = fits.Column(name='STAXOF', format='1E', unit='METERS', + array=np.zeros(nsta)) + col6 = fits.Column(name='POLTYA', format='1A', + array=np.array(['R' for i in range(nsta)], dtype='|S1')) + col7 = fits.Column(name='POLAA', format='1E', unit='DEGREES', + array=np.zeros(nsta)) + col8 = fits.Column(name='POLCALA', format='3E', + array=np.zeros((nsta, 3))) + col9 = fits.Column(name='POLTYB', format='1A', + array=np.array(['L' for i in range(nsta)], dtype='|S1')) + col10 = fits.Column(name='POLAB', format='1E', unit='DEGREES', + array=(90.*np.ones(nsta))) + col11 = fits.Column(name='POLCALB', format='3E', + array=np.zeros((nsta, 3))) + col25 = fits.Column(name='ORBPARM', format='1E', + array=np.zeros(0)) + + # Antenna Header params + # TODO do we need to change more of these?? + collist = [col1, col2, col25, col3, col4, col5, col6, col7, col8, col9, col10, col11, colfin] + tbhdu = fits.BinTableHDU.from_columns(fits.ColDefs(collist), name='AIPS AN') + hdulist_new.append(tbhdu) + + head = hdulist_new['AIPS AN'].header + + head['EXTVER'] = 1 + head['ARRAYX'] = 0.e0 + head['ARRAYY'] = 0.e0 + head['ARRAYZ'] = 0.e0 + + # TODO change the reference date + #rdate_tt_new = Time(obs.mjd + MJD_0, format='jd', scale='utc', out_subfmt='date') + #rdate_out = rdate_tt_new.iso + + rdate_tt_new = Time(obs.mjd + MJD_0, format='jd', scale='utc') + rdate_out = rdate_tt_new.iso[0:10] + + rdate_tt_new.out_subfmt = 'float' # TODO -- needed to fix subformat issue in astropy 4.0 + rdate_jd_out = rdate_tt_new.jd + rdate_gstiao_out = rdate_tt_new.sidereal_time('apparent', 'greenwich').degree + rdate_offset_out = (rdate_tt_new.ut1.datetime.second - rdate_tt_new.utc.datetime.second) + rdate_offset_out += 1.e-6*(rdate_tt_new.ut1.datetime.microsecond - + rdate_tt_new.utc.datetime.microsecond) + + head['RDATE'] = rdate_out + head['GSTIA0'] = rdate_gstiao_out + head['DEGPDY'] = 360.9856 + head['UT1UTC'] = rdate_offset_out # difference between UT1 and UTC ? + head['DATUTC'] = 0.e0 + head['TIMESYS'] = 'UTC' + + head['FREQ'] = obs.rf + head['POLARX'] = 0.e0 + head['POLARY'] = 0.e0 + + head['ARRNAM'] = 'VLBA' # TODO must be recognized by aips/casa + head['XYZHAND'] = 'RIGHT' + head['FRAME'] = '????' + head['NUMORB'] = 0 + head['NO_IF'] = 1 # TODO nchan + head['NOPCAL'] = 0 # TODO add pol cal information + head['POLTYPE'] = 'VLBI' + head['FREQID'] = 1 + + hdulist_new['AIPS AN'].header = head # TODO necessary, or is it a pointer? + + ##################### + # AIPS FQ TABLE + ##################### + # Convert types & columns + + nif = 1 + col1 = np.array(1, dtype=np.int32).reshape([nif]) # frqsel + col2 = np.array(0.0, dtype=np.float64).reshape([nif]) # iffreq + col3 = np.array([obs.bw], dtype=np.float32).reshape([nif]) # chwidth + col4 = np.array([obs.bw], dtype=np.float32).reshape([nif]) # bw + col5 = np.array([1], dtype=np.int32).reshape([nif]) # sideband + + col1 = fits.Column(name="FRQSEL", format="1J", array=col1) + col2 = fits.Column(name="IF FREQ", format="%dD" % (nif), array=col2) + col3 = fits.Column(name="CH WIDTH", format="%dE" % (nif), array=col3) + col4 = fits.Column(name="TOTAL BANDWIDTH", format="%dE" % (nif), array=col4) + col5 = fits.Column(name="SIDEBAND", format="%dJ" % (nif), array=col5) + cols = fits.ColDefs([col1, col2, col3, col4, col5]) + + # create table + tbhdu = fits.BinTableHDU.from_columns(cols) + + # add header information + tbhdu.header.append(("NO_IF", nif, "Number IFs")) + tbhdu.header.append(("EXTNAME", "AIPS FQ")) + tbhdu.header.append(("EXTVER", 1)) + hdulist_new.append(tbhdu) + + ##################### + # AIPS NX TABLE + ##################### + + scan_times = [] + scan_time_ints = [] + start_vis = [] + stop_vis = [] + + # TODO make sure jds AND scan_info MUST be time sorted!! + jj = 0 + + ROUND_SCAN_INT = 5 + comp_fac = 3600*24*100 # compare to 100th of a second + scan_arr = obs.scans + print('Building NX table') + if (scan_arr is None or len(scan_arr) == 0): + print("No NX table in saved uvfits") + else: + try: + scan_arr = scan_arr/24. + for scan in scan_arr: + scan_start = round(scan[0], ROUND_SCAN_INT) + scan_stop = round(scan[1], ROUND_SCAN_INT) + scan_dur = (scan_stop - scan_start) + + if jj >= len(fractimes): + # print start_vis, stop_vis + break + + # print ("%.12f %.12f %.12f" % (fractimes[jj], scan_start, scan_stop)) + jd = round(fractimes[jj], ROUND_SCAN_INT)*comp_fac # TODO precision?? + + if ((np.floor(jd) >= np.floor(scan_start*comp_fac)) and + (np.ceil(jd) <= np.ceil(comp_fac*scan_stop))): + start_vis.append(jj) + + # TODO AIPS MEMO 117 says scan_times should be midpoint! + # but AIPS data looks likes it's at the start? + scan_times.append(scan_start + 0.5*scan_dur) # - rdate_jd_out) + scan_time_ints.append(scan_dur) + ceilcut = np.ceil(comp_fac*scan_stop) + while ((jj < len(fractimes) and + np.floor(round(fractimes[jj], ROUND_SCAN_INT)*comp_fac) <= ceilcut)): + jj += 1 + stop_vis.append(jj-1) + else: + continue + + if jj < len(fractimes): + print(scan_arr[-1]) + print(round(scan_arr[-1][0], ROUND_SCAN_INT), + round(scan_arr[-1][1], ROUND_SCAN_INT)) + print(jj, len(jds), round(jds[jj], ROUND_SCAN_INT)) + print("WARNING!!!: in save_uvfits NX table, " + + "didn't get to all entries when computing scan start/stop!") + print(scan_times) + time_nx = fits.Column(name="TIME", format="1D", unit='DAYS', array=np.array(scan_times)) + timeint_nx = fits.Column(name="TIME INTERVAL", format="1E", + unit='DAYS', array=np.array(scan_time_ints)) + sourceid_nx = fits.Column(name="SOURCE ID", format="1J", + unit='', array=np.ones(len(scan_times))) + subarr_nx = fits.Column(name="SUBARRAY", format="1J", unit='', + array=np.ones(len(scan_times))) + freqid_nx = fits.Column(name="FREQ ID", format="1J", unit='', + array=np.ones(len(scan_times))) + startvis_nx = fits.Column(name="START VIS", format="1J", + unit='', array=np.array(start_vis)+1) + endvis_nx = fits.Column(name="END VIS", format="1J", + unit='', array=np.array(stop_vis)+1) + cols = fits.ColDefs([time_nx, timeint_nx, sourceid_nx, subarr_nx, + freqid_nx, startvis_nx, endvis_nx]) + + tbhdu = fits.BinTableHDU.from_columns(cols) + + # header information + tbhdu.header.append(("EXTNAME", "AIPS NX")) + tbhdu.header.append(("EXTVER", 1)) + + hdulist_new.append(tbhdu) + except TypeError: + print("No NX table in saved uvfits") + + # Write final HDUList to file + if fname is not None: + hdulist_new.writeto(fname, overwrite=True) + + return hdulist_new.copy() + + + +def save_obs_oifits(obs, fname, flux=1.0): + """Save visibility data to oifits file. + Polarization data is NOT saved + NOTE: as of 2021, this function is very out-of-date and should be updated + Args: + obs (Obsdata): obsdata object + fname (str): path to output uvfits file. + flux (float): Flux density normalization + Returns: + """ + + # TODO: Add polarization to oifits?? + print('Warning: save_oifits does NOT save polarimetric visibility data!') + + # output times must be in utc + obs = obs.switch_timetype(timetype_out='UTC') + + if (obs.polrep != 'stokes'): + raise Exception("save_obs_oifits only works with polrep 'stokes'!") + + # Normalizing by the total flux passed in - note this is changing the data inside the obs structure + obs.data['vis'] /= flux + obs.data['sigma'] /= flux + + data = obs.unpack(['u', 'v', 'amp', 'phase', 'sigma', 'time', 't1', 't2', 'tint']) + biarr = obs.bispectra(mode="all", count="min") + + # extract the telescope names and parameters + antennaNames = obs.tarr['site'] + sefd = obs.tarr['sefdr'] + antennaX = obs.tarr['x'] + antennaY = obs.tarr['y'] + antennaZ = obs.tarr['z'] + + # TODO: this is incorrect and there is just a dummy variable here + # antennaDiam = -np.ones(antennaX.shape) + antennaDiam = sefd # replace antennaDiam with SEFD for radio observtions + + # create dictionary + union = {} + union = ehtim.io.writeoifits.arrayUnion(antennaNames, union) + + # extract the integration time + intTime = data['tint'][0] + if not all(data['tint'][0] == item for item in np.reshape(data['tint'], (-1))): + raise TypeError("The time integrations for each visibility are different") + + # get visibility information + amp = data['amp'] + phase = data['phase'] + viserror = data['sigma'] + u = data['u'] + v = data['v'] + + # convert antenna name strings to number identifiers + ant1 = ehtim.io.writeoifits.convertStrings(data['t1'], union) + ant2 = ehtim.io.writeoifits.convertStrings(data['t2'], union) + + # convert times to datetime objects + # TODO: these do not correspond to the acutal times + time = data['time'] + dttime = np.array([datetime.datetime.utcfromtimestamp(x*60.0*60.0) + for x in time]) + + # get the bispectrum information + bi = biarr['bispec'] + t3amp = np.abs(bi) + t3phi = np.angle(bi, deg=1) + t3amperr = biarr['sigmab'] + t3phierr = 180.0/np.pi * (1.0/t3amp) * t3amperr + uClosure = np.transpose(np.array([np.array(biarr['u1']), np.array(biarr['u2'])])) + vClosure = np.transpose(np.array([np.array(biarr['v1']), np.array(biarr['v2'])])) + + # convert times to datetime objects + # TODO: these do not correspond to the acutal times + timeClosure = biarr['time'] + dttimeClosure = np.array([datetime.datetime.utcfromtimestamp(x*60.0*60.0) + for x in timeClosure]) + + # convert antenna name strings to number identifiers + biarr_ant1 = ehtim.io.writeoifits.convertStrings(biarr['t1'], union) + biarr_ant2 = ehtim.io.writeoifits.convertStrings(biarr['t2'], union) + biarr_ant3 = ehtim.io.writeoifits.convertStrings(biarr['t3'], union) + antOrder = np.transpose(np.array([biarr_ant1, biarr_ant2, biarr_ant3])) + + # todo: check that putting the negatives on the phase and t3phi is correct + ehtim.io.writeoifits.writeOIFITS(fname, obs.ra, obs.dec, obs.rf, obs.bw, intTime, + amp, viserror, phase, viserror, u, v, ant1, ant2, dttime, + t3amp, t3amperr, t3phi, t3phierr, uClosure, vClosure, antOrder, + dttimeClosure, antennaNames, antennaDiam, + antennaX, antennaY, antennaZ) + + # Un-Normalizing by the total flux passed in + # NOTE this is changing the data inside the obs structure back to what it originally was + obs.data['vis'] *= flux + obs.data['sigma'] *= flux + + return + + +def save_dtype_txt(obs, fname, dtype='cphase'): + """Save the data product of type 'dtype' in a text file. + Args: + obs (Obsdata): obsdata object + fname (str): path to output text file + dtype (str): desired data type + Returns: + """ + + head = ("SRC: %s \n" % obs.source + + "RA: " + obsh.rastring(obs.ra) + "\n" + "DEC: " + obsh.decstring(obs.dec) + "\n" + + "MJD: %i \n" % obs.mjd + + "RF: %.4f GHz \n" % (obs.rf/1e9) + + "BW: %.4f GHz \n" % (obs.bw/1e9) + + "PHASECAL: %i \n" % obs.phasecal + + "AMPCAL: %i \n" % obs.ampcal + + "OPACITYCAL: %i \n" % obs.opacitycal + + "DCAL: %i \n" % obs.dcal + + "FRCAL: %i \n" % obs.frcal + + "----------------------------------------------------------------------" + + "------------------------------------------------------------------\n" + + "Site X(m) Y(m) Z(m) " + + "SEFDR SEFDL FR_PAR FR_EL FR_OFF " + + "DR_RE DR_IM DL_RE DL_IM \n" + ) + + for i in range(len(obs.tarr)): + head += ("%-8s %15.5f %15.5f %15.5f %8.2f %8.2f %5.2f %5.2f %5.2f %8.4f %8.4f %8.4f %8.4f \n" % + (obs.tarr[i]['site'], + obs.tarr[i]['x'], obs.tarr[i]['y'], obs.tarr[i]['z'], + obs.tarr[i]['sefdr'], obs.tarr[i]['sefdl'], + obs.tarr[i]['fr_par'], obs.tarr[i]['fr_elev'], obs.tarr[i]['fr_off'], + (obs.tarr[i]['dr']).real, (obs.tarr[i]['dr']).imag, + (obs.tarr[i]['dl']).real, (obs.tarr[i]['dl']).imag + )) + + if dtype == 'cphase': + outdata = obs.cphase + head += ( + "----------------------------------------------------------------------" + + "------------------------------------------------------------------\n" + + "time (hr) T1 T2 T3 U1 (lambda) V1 (lambda) U2 (lambda) V2 (lambda) U3 (lambda) V3 (lambda) Cphase (d) Sigmacp") + fmts = ("%011.8f %6s %6s %6s %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %10.4f %10.8f") + + elif dtype == 'logcamp': + outdata = obs.logcamp + head += ( + "----------------------------------------------------------------------" + + "------------------------------------------------------------------\n" + + "time (hr) T1 T2 T3 T4 U1 (lambda) V1 (lambda) U2 (lambda) V2 (lambda) U3 (lambda) V3 (lambda) U4 (lambda) V4 (lambda) Logcamp Sigmalogca") + fmts = ("%011.8f %6s %6s %6s %6s %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %10.4f %10.8f") + + elif dtype == 'camp': + outdata = obs.camp + head += ( + "----------------------------------------------------------------------" + + "------------------------------------------------------------------\n" + + "time (hr) T1 T2 T3 T4 U1 (lambda) V1 (lambda) U2 (lambda) V2 (lambda) U3 (lambda) V3 (lambda) U4 (lambda) V4 (lambda) Camp Sigmaca") + fmts = ("%011.8f %6s %6s %6s %6s %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %10.4f %10.8f") + + elif dtype == 'bs': + outdata = obs.bispec + head += ( + "----------------------------------------------------------------------" + + "------------------------------------------------------------------\n" + + "time (hr) T1 T2 T3 U1 (lambda) V1 (lambda) U2 (lambda) V2 (lambda) U3 (lambda) V3 (lambda) Bispec Sigmab") + fmts = ("%011.8f %6s %6s %6s %16.4f %16.4f %16.4f %16.4f %16.4f %16.4f %10.4f %10.8f") + + elif dtype == 'amp': + outdata = obs.amp + head += ( + "----------------------------------------------------------------------" + + "------------------------------------------------------------------\n" + + "time (hr) tint T1 T2 U (lambda) V (lambda) Amp (Jy) Ampsigma") + fmts = ("%011.8f %4.2f %6s %6s %16.4f %16.4f %10.8f %10.8f") + + else: + raise Exception(dtype + ' is not a possible data type!') + + np.savetxt(fname, outdata, header=head, fmt=fmts) + return diff --git a/io/writeoifits.py b/io/writeoifits.py new file mode 100644 index 00000000..0d5eb250 --- /dev/null +++ b/io/writeoifits.py @@ -0,0 +1,102 @@ +# writeData.py +# functionto save observation to OIFITS +# author: Katie Bouman + +from __future__ import division +from __future__ import print_function +from builtins import range + +import numpy as np +import ehtim.io.oifits +import ehtim.const_def as ehc + + +def writeOIFITS(filename, RA, DEC, frequency, bandWidth, intTime, + visamp, visamperr, visphi, visphierr, u, v, ant1, ant2, timeobs, + t3amp, t3amperr, t3phi, t3phierr, uClosure, vClosure, antOrder, timeClosure, + antennaNames, antennaDiam, antennaX, antennaY, antennaZ): + + speedoflight = ehc.C + flagVis = False # do not flag any data + + # open a new oifits file + data = ehtim.io.oifits.oifits() + + # put in the target information - RA and DEC should be in degrees + name = 'TARGET_NAME' + data.target = np.append(data.target, ehtim.io.oifits.OI_TARGET(name, RA, DEC, veltyp='LSR')) + + # calulate wavelength and bandpass + wavelength = speedoflight/frequency + bandlow = speedoflight/(frequency+(0.5*bandWidth)) + bandhigh = speedoflight/(frequency-(0.5*bandWidth)) + bandpass = bandhigh-bandlow + +# put in the wavelength information - only using a single frequency + data.wavelength['WAVELENGTH_NAME'] = ehtim.io.oifits.OI_WAVELENGTH( + wavelength, eff_band=bandpass) + + # put in information about the telescope stations in the array + stations = [] + for i in range(0, len(antennaNames)): + stations.append((antennaNames[i], antennaNames[i], i+1, + antennaDiam[i], [antennaX[i], antennaY[i], antennaZ[i]])) + data.array['ARRAY_NAME'] = ehtim.io.oifits.OI_ARRAY('GEOCENTRIC', [0, 0, 0], stations) + + print('Warning: set cflux and cfluxerr = False ' + + 'because otherwise problems were being generated ' + + '...are they the total flux density?') + print('Warning: are there any true flags?') + + # put in the visibility information - note this does not include phase errors! + for i in range(0, len(u)): + station_curr = (data.array['ARRAY_NAME'].station[int(ant1[i] - 1)], + data.array['ARRAY_NAME'].station[int(ant2[i] - 1)]) + currVis = ehtim.io.oifits.OI_VIS(timeobs[i], intTime, visamp[i], visamperr[i], + visphi[i], visphierr[i], flagVis, + u[i]*wavelength, v[i]*wavelength, + data.wavelength['WAVELENGTH_NAME'], data.target[0], + array=data.array['ARRAY_NAME'], station=station_curr, + cflux=False, cfluxerr=False) + data.vis = np.append(data.vis, currVis) + + # put in bispectrum information + for j in range(0, len(uClosure)): + station_curr = (data.array['ARRAY_NAME'].station[int(antOrder[j][0] - 1)], + data.array['ARRAY_NAME'].station[int(antOrder[j][1] - 1)], + data.array['ARRAY_NAME'].station[int(antOrder[j][2] - 1)]) + currT3 = ehtim.io.oifits.OI_T3(timeClosure[j], intTime, t3amp[j], t3amperr[j], + t3phi[j], t3phierr[j], flagVis, + uClosure[j][0]*wavelength, vClosure[j][0]*wavelength, + uClosure[j][1]*wavelength, vClosure[j][1]*wavelength, + data.wavelength['WAVELENGTH_NAME'], data.target[0], + array=data.array['ARRAY_NAME'], station=station_curr) + data.t3 = np.append(data.t3, currT3) + + # put in visibility squared information + for k in range(0, len(u)): + station_curr = (data.array['ARRAY_NAME'].station[int(ant1[k] - 1)], + data.array['ARRAY_NAME'].station[int(ant2[k] - 1)]) + currVis2 = ehtim.io.oifits.OI_VIS2(timeobs[k], intTime, visamp[k]**2, + 2.0*visamp[k]*visamperr[k], flagVis, + u[k]*wavelength, v[k]*wavelength, + data.wavelength['WAVELENGTH_NAME'], data.target[0], + array=data.array['ARRAY_NAME'], station=station_curr) + data.vis2 = np.append(data.vis2, currVis2) + +# save oifits file + data.save(filename) + + +def arrayUnion(array, union): + for item in array: + if not (item in list(union.keys())): + union[item] = len(union)+1 + return union + + +def convertStrings(array, union): + returnarray = np.zeros(array.shape) + for i in range(len(array)): + returnarray[i] = union[array[i]] + return returnarray diff --git a/model.py b/model.py new file mode 100644 index 00000000..839dcda4 --- /dev/null +++ b/model.py @@ -0,0 +1,2412 @@ +# model.py +# an interferometric model class + +from __future__ import division +from __future__ import print_function +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import scipy.special as sps +import scipy.integrate as integrate +import scipy.interpolate as interpolate +import copy + +import ehtim.observing.obs_simulate as simobs +import ehtim.observing.pulses + +from ehtim.const_def import * +from ehtim.observing.obs_helpers import * +#from ehtim.modeling.modeling_utils import * + +import ehtim.image as image + +from ehtim.const_def import * + +LINE_THICKNESS = 2 # Thickness of 1D models on the image, in pixels +FOV_DEFAULT = 100.*RADPERUAS +NPIX_DEFAULT = 256 +COMPLEX_BASIS = 'abs-arg' # Basis for representing (most) complex quantities: 'abs-arg' or 're-im' + +########################################################################################################################################### +#Model object +########################################################################################################################################### + +def model_params(model_type, model_params=None, fit_pol=False, fit_cpol=False): + """Return the ordered list of model parameters for a specified model type. This order must match that of the gradient function, sample_1model_grad_uv. + """ + + if COMPLEX_BASIS == 're-im': + complex_labels = ['_re','_im'] + elif COMPLEX_BASIS == 'abs-arg': + complex_labels = ['_abs','_arg'] + else: + raise Exception('COMPLEX_BASIS ' + COMPLEX_BASIS + ' not recognized!') + + params = [] + + # Function to add polarimetric parameters; these must be added before stretch parameters + def add_pol(): + if fit_pol: + if model_type.find('mring') == -1: + params.append('pol_frac') + params.append('pol_evpa') + else: + for j in range(-(len(model_params['beta_list_pol'])-1)//2,(len(model_params['beta_list_pol'])+1)//2): + params.append('betapol' + str(j) + complex_labels[0]) + params.append('betapol' + str(j) + complex_labels[1]) + if fit_cpol: + if model_type.find('mring') == -1: + params.append('cpol_frac') + else: + for j in range(len(model_params['beta_list_cpol'])): + if j==0: + params.append('betacpol0') + else: + params.append('betacpol' + str(j) + complex_labels[0]) + params.append('betacpol' + str(j) + complex_labels[1]) + + if model_type == 'point': + params = ['F0','x0','y0'] + add_pol() + elif model_type == 'circ_gauss': + params = ['F0','FWHM','x0','y0'] + add_pol() + elif model_type == 'gauss': + params = ['F0','FWHM_maj','FWHM_min','PA','x0','y0'] + add_pol() + elif model_type == 'disk': + params = ['F0','d','x0','y0'] + add_pol() + elif model_type == 'blurred_disk': + params = ['F0','d','alpha','x0','y0'] + add_pol() + elif model_type == 'crescent': + params = ['F0','d', 'fr', 'fo', 'ff', 'phi','x0','y0'] + add_pol() + elif model_type == 'blurred_crescent': + params = ['F0','d','alpha','fr', 'fo', 'ff', 'phi','x0','y0'] + add_pol() + elif model_type == 'ring': + params = ['F0','d','x0','y0'] + add_pol() + elif model_type == 'stretched_ring': + params = ['F0','d','x0','y0','stretch','stretch_PA'] + add_pol() + elif model_type == 'thick_ring': + params = ['F0','d','alpha','x0','y0'] + add_pol() + elif model_type == 'stretched_thick_ring': + params = ['F0','d','alpha','x0','y0','stretch','stretch_PA'] + add_pol() + elif model_type == 'mring': + params = ['F0','d','x0','y0'] + for j in range(len(model_params['beta_list'])): + params.append('beta' + str(j+1) + complex_labels[0]) + params.append('beta' + str(j+1) + complex_labels[1]) + add_pol() + elif model_type == 'stretched_mring': + params = ['F0','d','x0','y0'] + for j in range(len(model_params['beta_list'])): + params.append('beta' + str(j+1) + complex_labels[0]) + params.append('beta' + str(j+1) + complex_labels[1]) + add_pol() + params.append('stretch') + params.append('stretch_PA') + elif model_type == 'thick_mring': + params = ['F0','d','alpha','x0','y0'] + for j in range(len(model_params['beta_list'])): + params.append('beta' + str(j+1) + complex_labels[0]) + params.append('beta' + str(j+1) + complex_labels[1]) + add_pol() + elif model_type == 'thick_mring_floor': + params = ['F0','d','alpha','ff','x0','y0'] + for j in range(len(model_params['beta_list'])): + params.append('beta' + str(j+1) + complex_labels[0]) + params.append('beta' + str(j+1) + complex_labels[1]) + add_pol() + elif model_type == 'thick_mring_Gfloor': + params = ['F0','d','alpha','ff','FWHM','x0','y0'] + for j in range(len(model_params['beta_list'])): + params.append('beta' + str(j+1) + complex_labels[0]) + params.append('beta' + str(j+1) + complex_labels[1]) + add_pol() + + elif model_type == 'stretched_thick_mring': + params = ['F0','d','alpha', 'x0','y0'] + for j in range(len(model_params['beta_list'])): + params.append('beta' + str(j+1) + complex_labels[0]) + params.append('beta' + str(j+1) + complex_labels[1]) + add_pol() + params.append('stretch') + params.append('stretch_PA') + elif model_type == 'stretched_thick_mring_floor': + params = ['F0','d','alpha','ff', 'x0','y0'] + for j in range(len(model_params['beta_list'])): + params.append('beta' + str(j+1) + complex_labels[0]) + params.append('beta' + str(j+1) + complex_labels[1]) + add_pol() + params.append('stretch') + params.append('stretch_PA') + else: + print('Model ' + model_init.models[j] + ' not recognized.') + params = [] + + return params + +def default_prior(model_type,model_params=None,fit_pol=False,fit_cpol=False): + """Return the default model prior and transformation for a specified model type + """ + + if COMPLEX_BASIS == 're-im': + complex_labels = ['_re','_im'] + complex_priors = [{'prior_type':'flat','min':-0.5,'max':0.5}, {'prior_type':'flat','min':-0.5,'max':0.5}] + complex_priors2 = [{'prior_type':'flat','min':-1,'max':1}, {'prior_type':'flat','min':-1,'max':1}] + elif COMPLEX_BASIS == 'abs-arg': + complex_labels = ['_abs','_arg'] + # Note: angle range here must match np.angle(). Need to properly define wrapped distributions + complex_priors = [{'prior_type':'flat','min':0.0,'max':0.5}, {'prior_type':'flat','min':-np.pi, 'max':np.pi}] + complex_priors2 = [{'prior_type':'flat','min':0.0,'max':1.0}, {'prior_type':'flat','min':-np.pi, 'max':np.pi}] + else: + raise Exception('COMPLEX_BASIS ' + COMPLEX_BASIS + ' not recognized!') + + prior = {'F0':{'prior_type':'none','transform':'log'}, + 'x0':{'prior_type':'none'}, + 'y0':{'prior_type':'none'}} + if model_type == 'point': + pass + elif model_type == 'circ_gauss': + prior['FWHM'] = {'prior_type':'none','transform':'log'} + elif model_type == 'gauss': + prior['FWHM_maj'] = {'prior_type':'positive','transform':'log'} + prior['FWHM_min'] = {'prior_type':'positive','transform':'log'} + prior['PA'] = {'prior_type':'none'} + elif model_type == 'disk': + prior['d'] = {'prior_type':'positive','transform':'log'} + elif model_type == 'blurred_disk': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['alpha'] = {'prior_type':'positive','transform':'log'} + elif model_type == 'crescent': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['fr'] = {'prior_type':'flat','min':0,'max':1} + prior['fo'] = {'prior_type':'flat','min':0,'max':1} + prior['ff'] = {'prior_type':'flat','min':0,'max':1} + prior['phi'] = {'prior_type':'flat','min':0,'max':2.*np.pi} + elif model_type == 'blurred_crescent': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['alpha'] = {'prior_type':'positive','transform':'log'} + prior['fr'] = {'prior_type':'flat','min':0,'max':1} + prior['fo'] = {'prior_type':'flat','min':0,'max':1} + prior['ff'] = {'prior_type':'flat','min':0,'max':1} + prior['phi'] = {'prior_type':'flat','min':0,'max':2.*np.pi} + elif model_type == 'ring': + prior['d'] = {'prior_type':'positive','transform':'log'} + elif model_type == 'stretched_ring': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['stretch'] = {'prior_type':'positive','transform':'log'} + prior['stretch_PA'] = {'prior_type':'none'} + elif model_type == 'thick_ring': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['alpha'] = {'prior_type':'positive','transform':'log'} + elif model_type == 'stretched_thick_ring': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['alpha'] = {'prior_type':'positive','transform':'log'} + prior['stretch'] = {'prior_type':'positive','transform':'log'} + prior['stretch_PA'] = {'prior_type':'none'} + elif model_type == 'mring': + prior['d'] = {'prior_type':'positive','transform':'log'} + for j in range(len(model_params['beta_list'])): + prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0] + prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1] + elif model_type == 'stretched_mring': + prior['d'] = {'prior_type':'positive','transform':'log'} + for j in range(len(model_params['beta_list'])): + prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0] + prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1] + prior['stretch'] = {'prior_type':'positive','transform':'log'} + prior['stretch_PA'] = {'prior_type':'none'} + elif model_type == 'thick_mring': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['alpha'] = {'prior_type':'positive','transform':'log'} + for j in range(len(model_params['beta_list'])): + prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0] + prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1] + elif model_type == 'thick_mring_floor': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['alpha'] = {'prior_type':'positive','transform':'log'} + prior['ff'] = {'prior_type':'flat','min':0,'max':1} + for j in range(len(model_params['beta_list'])): + prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0] + prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1] + elif model_type == 'thick_mring_Gfloor': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['alpha'] = {'prior_type':'positive','transform':'log'} + prior['ff'] = {'prior_type':'flat','min':0,'max':1} + prior['FWHM'] = {'prior_type':'positive','transform':'log'} + for j in range(len(model_params['beta_list'])): + prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0] + prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1] + elif model_type == 'stretched_thick_mring': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['alpha'] = {'prior_type':'positive','transform':'log'} + for j in range(len(model_params['beta_list'])): + prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0] + prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1] + prior['stretch'] = {'prior_type':'positive','transform':'log'} + prior['stretch_PA'] = {'prior_type':'none'} + elif model_type == 'stretched_thick_mring_floor': + prior['d'] = {'prior_type':'positive','transform':'log'} + prior['alpha'] = {'prior_type':'positive','transform':'log'} + prior['ff'] = {'prior_type':'flat','min':0,'max':1} + for j in range(len(model_params['beta_list'])): + prior['beta' + str(j+1) + complex_labels[0]] = complex_priors[0] + prior['beta' + str(j+1) + complex_labels[1]] = complex_priors[1] + prior['stretch'] = {'prior_type':'positive','transform':'log'} + prior['stretch_PA'] = {'prior_type':'none'} + else: + print('Model not recognized!') + + if fit_pol: + if model_type.find('mring') == -1: + prior['pol_frac'] = {'prior_type':'flat','min':0.0,'max':1.0} + prior['pol_evpa'] = {'prior_type':'flat','min':0.0,'max':np.pi} + else: + for j in range(-(len(model_params['beta_list_pol'])-1)//2,(len(model_params['beta_list_pol'])+1)//2): + prior['betapol' + str(j) + complex_labels[0]] = complex_priors2[0] + prior['betapol' + str(j) + complex_labels[1]] = complex_priors2[1] + + if fit_cpol: + if model_type.find('mring') == -1: + prior['cpol_frac'] = {'prior_type':'flat','min':-1.0,'max':1.0} + else: + for j in range(len(model_params['beta_list_cpol'])): + if j > 0: + prior['betacpol' + str(j) + complex_labels[0]] = complex_priors2[0] + prior['betacpol' + str(j) + complex_labels[1]] = complex_priors2[1] + else: + prior['betacpol0'] = {'prior_type':'flat','min':-1.0,'max':1.0} + + return prior + +def stretch_xy(x, y, params): + x_stretch = ((x - params['x0']) * (np.cos(params['stretch_PA'])**2 + np.sin(params['stretch_PA'])**2 / params['stretch']) + + (y - params['y0']) * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA']) * (1.0/params['stretch'] - 1.0)) + y_stretch = ((y - params['y0']) * (np.cos(params['stretch_PA'])**2 / params['stretch'] + np.sin(params['stretch_PA'])**2) + + (x - params['x0']) * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA']) * (1.0/params['stretch'] - 1.0)) + return (params['x0'] + x_stretch,params['y0'] + y_stretch) + +def stretch_uv(u, v, params): + u_stretch = (u * (np.cos(params['stretch_PA'])**2 + np.sin(params['stretch_PA'])**2 * params['stretch']) + + v * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA']) * (params['stretch'] - 1.0)) + v_stretch = (v * (np.cos(params['stretch_PA'])**2 * params['stretch'] + np.sin(params['stretch_PA'])**2) + + u * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA']) * (params['stretch'] - 1.0)) + return (u_stretch,v_stretch) + +def get_const_polfac(model_type, params, pol): + # Return the scaling factor for models with constant fractional polarization + + if model_type.find('mring') != -1: + # mring models have polarization information specified differently than a constant scaling factor + return 1.0 + + try: + if pol == 'I': + return 1.0 + elif pol == 'Q': + return params['pol_frac'] * np.cos(2.0 * params['pol_evpa']) + elif pol == 'U': + return params['pol_frac'] * np.sin(2.0 * params['pol_evpa']) + elif pol == 'V': + return params['cpol_frac'] + elif pol == 'P': + return params['pol_frac'] * np.exp(1j * 2.0 * params['pol_evpa']) + elif pol == 'RR': + return get_const_polfac(model_type, params, 'I') + get_const_polfac(model_type, params, 'V') + elif pol == 'RL': + return get_const_polfac(model_type, params, 'Q') + 1j*get_const_polfac(model_type, params, 'U') + elif pol == 'LR': + return get_const_polfac(model_type, params, 'Q') - 1j*get_const_polfac(model_type, params, 'U') + elif pol == 'LL': + return get_const_polfac(model_type, params, 'I') - get_const_polfac(model_type, params, 'V') + except Exception: + pass + + return 0.0 + +def sample_1model_xy(x, y, model_type, params, psize=1.*RADPERUAS, pol='I'): + if pol == 'Q': + return np.real(sample_1model_xy(x, y, model_type, params, psize=psize, pol='P')) + elif pol == 'U': + return np.imag(sample_1model_xy(x, y, model_type, params, psize=psize, pol='P')) + elif pol in ['I','V','P']: + pass + else: + raise Exception('Polarization ' + pol + ' not implemented!') + + if model_type == 'point': + val = params['F0'] * (np.abs( x - params['x0']) < psize/2.0) * (np.abs( y - params['y0']) < psize/2.0) + elif model_type == 'circ_gauss': + sigma = params['FWHM'] / (2. * np.sqrt(2. * np.log(2.))) + val = (params['F0']*psize**2 * 4.0 * np.log(2.)/(np.pi * params['FWHM']**2) * + np.exp(-((x - params['x0'])**2 + (y - params['y0'])**2)/(2*sigma**2))) + elif model_type == 'gauss': + sigma_maj = params['FWHM_maj'] / (2. * np.sqrt(2. * np.log(2.))) + sigma_min = params['FWHM_min'] / (2. * np.sqrt(2. * np.log(2.))) + cth = np.cos(params['PA']) + sth = np.sin(params['PA']) + val = (params['F0']*psize**2 * 4.0 * np.log(2.)/(np.pi * params['FWHM_maj'] * params['FWHM_min']) * + np.exp(-((y - params['y0'])*np.cos(params['PA']) + (x - params['x0'])*np.sin(params['PA']))**2/(2*sigma_maj**2) + + -((x - params['x0'])*np.cos(params['PA']) - (y - params['y0'])*np.sin(params['PA']))**2/(2*sigma_min**2))) + elif model_type == 'disk': + val = params['F0']*psize**2/(np.pi*params['d']**2/4.) * (np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2) < params['d']/2.0) + elif model_type == 'blurred_disk': + # Note: the exact form of a blurred disk requires numeric integration + + # This is the peak brightness of the blurred disk + I_peak = 4.0/(np.pi*params['d']**2) * (1.0 - 2.0**(-params['d']**2/params['alpha']**2)) + + # Constant prefactor + prefac = 32.0 * np.log(2.0)/(np.pi * params['alpha']**2 * params['d']**2) + + def f(r): + return integrate.quad(lambda rp: + prefac * rp * np.exp( -4.0 * np.log(2.0)/params['alpha']**2 * (r**2 + rp**2 - 2.0*r * rp) ) + * sps.ive(0, 8.0*np.log(2.0) * r * rp/params['alpha']**2), + 0, params['d']/2.0, limit=1000, epsabs=I_peak/1e9, epsrel=1.0e-6)[0] + f=np.vectorize(f) + r = np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2) + + # For images, it's much quicker to do the 1-D problem and interpolate + if np.ndim(r) > 0: + r_min = np.min(r) + r_max = np.max(r) + r_list = np.linspace(r_min, r_max, int((r_max-r_min)/(params['alpha']) * 20)) + if len(r_list) < len(np.ravel(r))/2 and len(r) > 100: + f = interpolate.interp1d(r_list, f(r_list), kind='cubic') + val = params['F0'] * psize**2 * f(r) + elif model_type == 'crescent': + phi = params['phi'] + fr = params['fr'] + fo = params['fo'] + ff = params['ff'] + r = params['d'] / 2. + params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'x0': params['x0'], 'y0': params['y0']} + params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)} + val = sample_1model_xy(x, y, 'disk', params0, psize=psize, pol=pol) + val -= sample_1model_xy(x, y, 'disk', params1, psize=psize, pol=pol) + elif model_type == 'blurred_crescent': + phi = params['phi'] + fr = params['fr'] + fo = params['fo'] + ff = params['ff'] + r = params['d'] / 2. + params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'alpha':params['alpha'], 'x0': params['x0'], 'y0': params['y0']} + params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'alpha':params['alpha'], 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)} + val = sample_1model_xy(x, y, 'blurred_disk', params0, psize=psize, pol=pol) + val -= sample_1model_xy(x, y, 'blurred_disk', params1, psize=psize, pol=pol) + elif model_type == 'ring': + val = (params['F0']*psize**2/(np.pi*params['d']*psize*LINE_THICKNESS) + * (params['d']/2.0 - psize*LINE_THICKNESS/2 < np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2)) + * (params['d']/2.0 + psize*LINE_THICKNESS/2 > np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2))) + elif model_type == 'thick_ring': + r = np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2) + z = 4.*np.log(2.) * r * params['d']/params['alpha']**2 + val = (params['F0']*psize**2 * 4.0 * np.log(2.)/(np.pi * params['alpha']**2) + * np.exp(-4.*np.log(2.)/params['alpha']**2*(r**2 + params['d']**2/4.) + z) + * sps.ive(0, z)) + elif model_type == 'mring': + phi = np.angle((y - params['y0']) + 1j*(x - params['x0'])) + if pol == 'I': + beta_factor = (1.0 + np.sum([2.*np.real(params['beta_list'][m-1] * np.exp(1j * m * phi)) for m in range(1,len(params['beta_list'])+1)],axis=0)) + elif pol == 'V' and len(params['beta_list_cpol']) > 0: + beta_factor = np.real(params['beta_list_cpol'][0]) + np.sum([2.*np.real(params['beta_list_cpol'][m] * np.exp(1j * m * phi)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + elif pol == 'P' and len(params['beta_list_pol']) > 0: + num_coeff = len(params['beta_list_pol']) + beta_factor = np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * np.exp(1j * m * phi) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0) + else: + beta_factor = 0.0 + + val = (params['F0']*psize**2/(np.pi*params['d']*psize*LINE_THICKNESS) + * beta_factor + * (params['d']/2.0 - psize*LINE_THICKNESS/2 < np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2)) + * (params['d']/2.0 + psize*LINE_THICKNESS/2 > np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2))) + elif model_type == 'thick_mring': + phi = np.angle((y - params['y0']) + 1j*(x - params['x0'])) + r = np.sqrt((x - params['x0'])**2 + (y - params['y0'])**2) + z = 4.*np.log(2.) * r * params['d']/params['alpha']**2 + if pol == 'I': + beta_factor = (sps.ive(0, z) + np.sum([2.*np.real(sps.ive(m, z) * params['beta_list'][m-1] * np.exp(1j * m * phi)) for m in range(1,len(params['beta_list'])+1)],axis=0)) + elif pol == 'V' and len(params['beta_list_cpol']) > 0: + beta_factor = (sps.ive(0, z) * np.real(params['beta_list_cpol'][0]) + np.sum([2.*np.real(sps.ive(m, z) * params['beta_list_cpol'][m] * np.exp(1j * m * phi)) for m in range(1,len(params['beta_list_cpol']))],axis=0)) + elif pol == 'P' and len(params['beta_list_pol']) > 0: + num_coeff = len(params['beta_list_pol']) + beta_factor = np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.ive(m, z) * np.exp(1j * m * phi) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0) + else: + # Note: not all polarizations accounted for yet (need RR, RL, LR, LL; do these by calling for linear combinations of I, Q, U, V)! + beta_factor = 0.0 + + val = (params['F0']*psize**2 * 4.0 * np.log(2.)/(np.pi * params['alpha']**2) + * np.exp(-4.*np.log(2.)/params['alpha']**2*(r**2 + params['d']**2/4.) + z) + * beta_factor) + elif model_type == 'thick_mring_floor': + val = (1.0 - params['ff']) * sample_1model_xy(x, y, 'thick_mring', params, psize=psize, pol=pol) + val += params['ff'] * sample_1model_xy(x, y, 'blurred_disk', params, psize=psize, pol=pol) + elif model_type == 'thick_mring_Gfloor': + val = (1.0 - params['ff']) * sample_1model_xy(x, y, 'thick_mring', params, psize=psize, pol=pol) + val += params['ff'] * sample_1model_xy(x, y, 'circ_gauss', params, psize=psize, pol=pol) + elif model_type[:9] == 'stretched': + params_stretch = params.copy() + params_stretch['F0'] /= params['stretch'] + val = sample_1model_xy(*stretch_xy(x, y, params), model_type[10:], params_stretch, psize, pol=pol) + else: + print('Model ' + model_type + ' not recognized!') + val = 0.0 + return val * get_const_polfac(model_type, params, pol) + +def sample_1model_uv(u, v, model_type, params, pol='I', jonesdict=None): + if jonesdict is not None: + # Define the various lists + fr1 = jonesdict['fr1'] # Field rotation of site 1 + fr2 = jonesdict['fr2'] # Field rotation of site 2 + DR1 = jonesdict['DR1'] # Right leakage term of site 1 + DL1 = jonesdict['DL1'] # Left leakage term of site 1 + DR2 = np.conj(jonesdict['DR2']) # Right leakage term of site 2 + DL2 = np.conj(jonesdict['DL2']) # Left leakage term of site 2 + # Sample the model without leakage + RR = sample_1model_uv(u, v, model_type, params, pol='RR') + RL = sample_1model_uv(u, v, model_type, params, pol='RL') + LR = sample_1model_uv(u, v, model_type, params, pol='LR') + LL = sample_1model_uv(u, v, model_type, params, pol='LL') + # Apply the Jones matrices + RRp = RR + LR * DR1 * np.exp( 2j*fr1) + RL * DR2 * np.exp(-2j*fr2) + LL * DR1 * DR2 * np.exp( 2j*(fr1-fr2)) + RLp = RL + LL * DR1 * np.exp( 2j*fr1) + RR * DL2 * np.exp( 2j*fr2) + LR * DR1 * DL2 * np.exp( 2j*(fr1+fr2)) + LRp = LR + RR * DL1 * np.exp(-2j*fr1) + LL * DR2 * np.exp(-2j*fr2) + RL * DL1 * DR2 * np.exp(-2j*(fr1+fr2)) + LLp = LL + LR * DL2 * np.exp( 2j*fr2) + RL * DL1 * np.exp(-2j*fr1) + RR * DL1 * DL2 * np.exp(-2j*(fr1-fr2)) + # Return the specified polarization + if pol == 'RR': return RRp + elif pol == 'RL': return RLp + elif pol == 'LR': return LRp + elif pol == 'LL': return LLp + elif pol == 'I': return 0.5 * (RRp + LLp) + elif pol == 'Q': return 0.5 * (LRp + RLp) + elif pol == 'U': return 0.5j* (LRp - RLp) + elif pol == 'V': return 0.5 * (RRp - LLp) + elif pol == 'P': return RLp + else: + raise Exception('Polarization ' + pol + ' not recognized!') + + if pol == 'Q': + return 0.5 * (sample_1model_uv(u, v, model_type, params, pol='P') + np.conj(sample_1model_uv(-u, -v, model_type, params, pol='P'))) + elif pol == 'U': + return -0.5j * (sample_1model_uv(u, v, model_type, params, pol='P') - np.conj(sample_1model_uv(-u, -v, model_type, params, pol='P'))) + elif pol in ['I','V','P']: + pass + elif pol == 'RR': + return sample_1model_uv(u, v, model_type, params, pol='I') + sample_1model_uv(u, v, model_type, params, pol='V') + elif pol == 'LL': + return sample_1model_uv(u, v, model_type, params, pol='I') - sample_1model_uv(u, v, model_type, params, pol='V') + elif pol == 'RL': + return sample_1model_uv(u, v, model_type, params, pol='Q') + 1j*sample_1model_uv(u, v, model_type, params, pol='U') + elif pol == 'LR': + return sample_1model_uv(u, v, model_type, params, pol='Q') - 1j*sample_1model_uv(u, v, model_type, params, pol='U') + else: + raise Exception('Polarization ' + pol + ' not implemented!') + + if model_type == 'point': + val = params['F0'] * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + elif model_type == 'circ_gauss': + val = (params['F0'] + * np.exp(-np.pi**2/(4.*np.log(2.)) * (u**2 + v**2) * params['FWHM']**2) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + elif model_type == 'gauss': + u_maj = u*np.sin(params['PA']) + v*np.cos(params['PA']) + u_min = u*np.cos(params['PA']) - v*np.sin(params['PA']) + val = (params['F0'] + * np.exp(-np.pi**2/(4.*np.log(2.)) * ((u_maj * params['FWHM_maj'])**2 + (u_min * params['FWHM_min'])**2)) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + elif model_type == 'disk': + z = np.pi * params['d'] * (u**2 + v**2)**0.5 + #Add a small offset to avoid issues with division by zero + z += (z == 0.0) * 1e-10 + val = (params['F0'] * 2.0/z * sps.jv(1, z) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + elif model_type == 'blurred_disk': + z = np.pi * params['d'] * (u**2 + v**2)**0.5 + #Add a small offset to avoid issues with division by zero + z += (z == 0.0) * 1e-10 + val = (params['F0'] * 2.0/z * sps.jv(1, z) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.)))) + elif model_type == 'crescent': + phi = params['phi'] + fr = params['fr'] + fo = params['fo'] + ff = params['ff'] + r = params['d'] / 2. + params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'x0': params['x0'], 'y0': params['y0']} + params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)} + val = sample_1model_uv(u, v, 'disk', params0, pol=pol, jonesdict=jonesdict) + val -= sample_1model_uv(u, v, 'disk', params1, pol=pol, jonesdict=jonesdict) + elif model_type == 'blurred_crescent': + phi = params['phi'] + fr = params['fr'] + fo = params['fo'] + ff = params['ff'] + r = params['d'] / 2. + params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'alpha':params['alpha'], 'x0': params['x0'], 'y0': params['y0']} + params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'alpha':params['alpha'], 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)} + val = sample_1model_uv(u, v, 'blurred_disk', params0, pol=pol, jonesdict=jonesdict) + val -= sample_1model_uv(u, v, 'blurred_disk', params1, pol=pol, jonesdict=jonesdict) + elif model_type == 'ring': + z = np.pi * params['d'] * (u**2 + v**2)**0.5 + val = (params['F0'] * sps.jv(0, z) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + elif model_type == 'thick_ring': + z = np.pi * params['d'] * (u**2 + v**2)**0.5 + val = (params['F0'] * sps.jv(0, z) + * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + elif model_type == 'mring': + phi = np.angle(v + 1j*u) + # Flip the baseline sign to match eht-imaging conventions + phi += np.pi + z = np.pi * params['d'] * (u**2 + v**2)**0.5 + + if pol == 'I': + beta_factor = (sps.jv(0, z) + + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)) + elif pol == 'V' and len(params['beta_list_cpol']) > 0: + beta_factor = (np.real(params['beta_list_cpol'][0]) * sps.jv(0, z) + + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)) + elif pol == 'P' and len(params['beta_list_pol']) > 0: + num_coeff = len(params['beta_list_pol']) + beta_factor = np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0) + else: + beta_factor = 0.0 + + val = params['F0'] * beta_factor * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + elif model_type == 'thick_mring': + phi = np.angle(v + 1j*u) + # Flip the baseline sign to match eht-imaging conventions + phi += np.pi + z = np.pi * params['d'] * (u**2 + v**2)**0.5 + + if pol == 'I': + beta_factor = (sps.jv(0, z) + + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)) + elif pol == 'V' and len(params['beta_list_cpol']) > 0: + beta_factor = (np.real(params['beta_list_cpol'][0]) * sps.jv(0, z) + + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)) + elif pol == 'P' and len(params['beta_list_pol']) > 0: + num_coeff = len(params['beta_list_pol']) + beta_factor = np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0) + else: + beta_factor = 0.0 + + val = (params['F0'] * beta_factor + * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + elif model_type == 'thick_mring_floor': + val = (1.0 - params['ff']) * sample_1model_uv(u, v, 'thick_mring', params, pol=pol, jonesdict=jonesdict) + val += params['ff'] * sample_1model_uv(u, v, 'blurred_disk', params, pol=pol, jonesdict=jonesdict) + elif model_type == 'thick_mring_Gfloor': + val = (1.0 - params['ff']) * sample_1model_uv(u, v, 'thick_mring', params, pol=pol, jonesdict=jonesdict) + val += params['ff'] * sample_1model_uv(u, v, 'circ_gauss', params, pol=pol, jonesdict=jonesdict) + elif model_type[:9] == 'stretched': + params_stretch = params.copy() + params_stretch['x0'] = 0.0 + params_stretch['y0'] = 0.0 + val = sample_1model_uv(*stretch_uv(u,v,params), model_type[10:], params_stretch, pol=pol) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + else: + print('Model ' + model_type + ' not recognized!') + val = 0.0 + return val * get_const_polfac(model_type, params, pol) + +def sample_1model_graduv_uv(u, v, model_type, params, pol='I', jonesdict=None): + # Gradient of the visibility function, (dV/du, dV/dv) + # This function makes it convenient to, e.g., compute gradients of stretched images and to compute the model centroid + + if jonesdict is not None: + # Define the various lists + fr1 = jonesdict['fr1'] # Field rotation of site 1 + fr2 = jonesdict['fr2'] # Field rotation of site 2 + DR1 = jonesdict['DR1'] # Right leakage term of site 1 + DL1 = jonesdict['DL1'] # Left leakage term of site 1 + DR2 = np.conj(jonesdict['DR2']) # Right leakage term of site 2 + DL2 = np.conj(jonesdict['DL2']) # Left leakage term of site 2 + # Sample the model without leakage + RR = sample_1model_graduv_uv(u, v, model_type, params, pol='RR').reshape(2,len(u)) + RL = sample_1model_graduv_uv(u, v, model_type, params, pol='RL').reshape(2,len(u)) + LR = sample_1model_graduv_uv(u, v, model_type, params, pol='LR').reshape(2,len(u)) + LL = sample_1model_graduv_uv(u, v, model_type, params, pol='LL').reshape(2,len(u)) + # Apply the Jones matrices + RRp = (RR + LR * DR1 * np.exp( 2j*fr1) + RL * DR2 * np.exp(-2j*fr2) + LL * DR1 * DR2 * np.exp( 2j*(fr1-fr2))) + RLp = (RL + LL * DR1 * np.exp( 2j*fr1) + RR * DL2 * np.exp( 2j*fr2) + LR * DR1 * DL2 * np.exp( 2j*(fr1+fr2))) + LRp = (LR + RR * DL1 * np.exp(-2j*fr1) + LL * DR2 * np.exp(-2j*fr2) + RL * DL1 * DR2 * np.exp(-2j*(fr1+fr2))) + LLp = (LL + LR * DL2 * np.exp( 2j*fr2) + RL * DL1 * np.exp(-2j*fr1) + RR * DL1 * DL2 * np.exp(-2j*(fr1-fr2))) + # Return the specified polarization + if pol == 'RR': return RRp + elif pol == 'RL': return RLp + elif pol == 'LR': return LRp + elif pol == 'LL': return LLp + elif pol == 'I': return 0.5 * (RRp + LLp) + elif pol == 'Q': return 0.5 * (LRp + RLp) + elif pol == 'U': return 0.5j* (LRp - RLp) + elif pol == 'V': return 0.5 * (RRp - LLp) + elif pol == 'P': return RLp + else: + raise Exception('Polarization ' + pol + ' not recognized!') + + if pol == 'Q': + return 0.5 * (sample_1model_graduv_uv(u, v, model_type, params, pol='P') + np.conj(sample_1model_graduv_uv(-u, -v, model_type, params, pol='P'))) + elif pol == 'U': + return -0.5j * (sample_1model_graduv_uv(u, v, model_type, params, pol='P') - np.conj(sample_1model_graduv_uv(-u, -v, model_type, params, pol='P'))) + elif pol in ['I','V','P']: + pass + elif pol == 'RR': + return sample_1model_graduv_uv(u, v, model_type, params, pol='I') + sample_1model_graduv_uv(u, v, model_type, params, pol='V') + elif pol == 'LL': + return sample_1model_graduv_uv(u, v, model_type, params, pol='I') - sample_1model_graduv_uv(u, v, model_type, params, pol='V') + elif pol == 'RL': + return sample_1model_graduv_uv(u, v, model_type, params, pol='Q') + 1j*sample_1model_graduv_uv(u, v, model_type, params, pol='U') + elif pol == 'LR': + return sample_1model_graduv_uv(u, v, model_type, params, pol='Q') - 1j*sample_1model_graduv_uv(u, v, model_type, params, pol='U') + else: + raise Exception('Polarization ' + pol + ' not implemented!') + + vis = sample_1model_uv(u, v, model_type, params, jonesdict=jonesdict) + if model_type == 'point': + val = np.array([ 1j * 2.0 * np.pi * params['x0'] * vis, + 1j * 2.0 * np.pi * params['y0'] * vis]) + elif model_type == 'circ_gauss': + val = np.array([ (1j * 2.0 * np.pi * params['x0'] - params['FWHM']**2 * np.pi**2 * u/(2. * np.log(2.))) * vis, + (1j * 2.0 * np.pi * params['y0'] - params['FWHM']**2 * np.pi**2 * v/(2. * np.log(2.))) * vis]) + elif model_type == 'gauss': + u_maj = u*np.sin(params['PA']) + v*np.cos(params['PA']) + u_min = u*np.cos(params['PA']) - v*np.sin(params['PA']) + val = np.array([ (1j * 2.0 * np.pi * params['x0'] - params['FWHM_maj']**2 * np.pi**2 * u_maj/(2. * np.log(2.)) * np.sin(params['PA']) - params['FWHM_min']**2 * np.pi**2 * u_min/(2. * np.log(2.)) * np.cos(params['PA'])) * vis, + (1j * 2.0 * np.pi * params['y0'] - params['FWHM_maj']**2 * np.pi**2 * u_maj/(2. * np.log(2.)) * np.cos(params['PA']) + params['FWHM_min']**2 * np.pi**2 * u_min/(2. * np.log(2.)) * np.sin(params['PA'])) * vis]) + elif model_type == 'disk': + # Take care of the degenerate origin point by a small offset + #v += (u==0.)*(v==0.)*1e-10 + uvdist = (u**2 + v**2 + (u==0.)*(v==0.)*1e-10)**0.5 + z = np.pi * params['d'] * uvdist + bessel_deriv = 0.5 * (sps.jv( 0, z) - sps.jv( 2, z)) + val = np.array([ (1j * 2.0 * np.pi * params['x0'] - u/uvdist**2) * vis + + params['F0'] * 2./z * np.pi * params['d'] * u/uvdist * bessel_deriv * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + (1j * 2.0 * np.pi * params['y0'] - v/uvdist**2) * vis + + params['F0'] * 2./z * np.pi * params['d'] * v/uvdist * bessel_deriv * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ]) + elif model_type == 'blurred_disk': + # Take care of the degenerate origin point by a small offset + #u += (u==0.)*(v==0.)*1e-10 + uvdist = (u**2 + v**2 + (u==0.)*(v==0.)*1e-10)**0.5 + z = np.pi * params['d'] * uvdist + blur = np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) + bessel_deriv = 0.5 * (sps.jv( 0, z) - sps.jv( 2, z)) + val = np.array([ (1j * 2.0 * np.pi * params['x0'] - u/uvdist**2 - params['alpha']**2 * np.pi**2 * u/(2. * np.log(2.))) * vis + + params['F0'] * 2./z * np.pi * params['d'] * u/uvdist * bessel_deriv * blur * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + (1j * 2.0 * np.pi * params['y0'] - v/uvdist**2 - params['alpha']**2 * np.pi**2 * v/(2. * np.log(2.))) * vis + + params['F0'] * 2./z * np.pi * params['d'] * v/uvdist * bessel_deriv * blur * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ]) + elif model_type == 'crescent': + phi = params['phi'] + fr = params['fr'] + fo = params['fo'] + ff = params['ff'] + r = params['d'] / 2. + params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'x0': params['x0'], 'y0': params['y0']} + params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)} + val = sample_1model_graduv_uv(u, v, 'disk', params0, pol=pol, jonesdict=jonesdict) + val -= sample_1model_graduv_uv(u, v, 'disk', params1, pol=pol, jonesdict=jonesdict) + elif model_type == 'blurred_crescent': + phi = params['phi'] + fr = params['fr'] + fo = params['fo'] + ff = params['ff'] + r = params['d'] / 2. + params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'alpha':params['alpha'], 'x0': params['x0'], 'y0': params['y0']} + params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'alpha':params['alpha'], 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)} + val = sample_1model_graduv_uv(u, v, 'blurred_disk', params0, pol=pol, jonesdict=jonesdict) + val -= sample_1model_graduv_uv(u, v, 'blurred_disk', params1, pol=pol, jonesdict=jonesdict) + elif model_type == 'ring': + # Take care of the degenerate origin point by a small offset + u += (u==0.)*(v==0.)*1e-10 + uvdist = (u**2 + v**2 + (u==0.)*(v==0.)*1e-10)**0.5 + z = np.pi * params['d'] * uvdist + val = np.array([ 1j * 2.0 * np.pi * params['x0'] * vis + - params['F0'] * np.pi*params['d']*u/uvdist * sps.jv(1, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + 1j * 2.0 * np.pi * params['y0'] * vis + - params['F0'] * np.pi*params['d']*v/uvdist * sps.jv(1, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ]) + elif model_type == 'thick_ring': + uvdist = (u**2 + v**2)**0.5 + #Add a small offset to avoid issues with division by zero + uvdist += (uvdist == 0.0) * 1e-10 + z = np.pi * params['d'] * uvdist + val = np.array([ (1j * 2.0 * np.pi * params['x0'] - params['alpha']**2 * np.pi**2 * u/(2. * np.log(2.))) * vis + - params['F0'] * np.pi*params['d']*u/uvdist * sps.jv(1, z) + * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + (1j * 2.0 * np.pi * params['y0'] - params['alpha']**2 * np.pi**2 * v/(2. * np.log(2.))) * vis + - params['F0'] * np.pi*params['d']*v/uvdist * sps.jv(1, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ]) + elif model_type == 'mring': + # Take care of the degenerate origin point by a small offset + u += (u==0.)*(v==0.)*1e-10 + phi = np.angle(v + 1j*u) + # Flip the baseline sign to match eht-imaging conventions + phi += np.pi + uvdist = (u**2 + v**2 + (u==0.)*(v==0.)*1e-10)**0.5 + dphidu = v/uvdist**2 + dphidv = -u/uvdist**2 + z = np.pi * params['d'] * uvdist + + if pol == 'I': + beta_factor_u = (-np.pi * params['d'] * u/uvdist * sps.jv(1, z) + + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * (-1j * m * dphidu) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([params['beta_list'][m-1] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)) + beta_factor_v = (-np.pi * params['d'] * v/uvdist * sps.jv(1, z) + + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * (-1j * m * dphidv) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([params['beta_list'][m-1] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)) + + elif pol == 'V' and len(params['beta_list_cpol']) > 0: + beta_factor_u = (-np.pi * params['d'] * u/uvdist * sps.jv(1, z) * np.real(params['beta_list_cpol'][0]) + + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * (-1j * m * dphidu) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([params['beta_list_cpol'][m] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)) + beta_factor_v = (-np.pi * params['d'] * v/uvdist * sps.jv(1, z) + + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * (-1j * m * dphidv) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([params['beta_list_cpol'][m] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)) + elif pol == 'P' and len(params['beta_list_pol']) > 0: + num_coeff = len(params['beta_list_pol']) + beta_factor_u = ( + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0) + + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)) + beta_factor_v = ( + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0) + + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)) + else: + beta_factor_u = beta_factor_v = 0.0 + + val = np.array([ + 1j * 2.0 * np.pi * params['x0'] * vis + + params['F0'] * beta_factor_u + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + 1j * 2.0 * np.pi * params['y0'] * vis + + params['F0'] * beta_factor_v + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ]) + elif model_type == 'thick_mring': + # Take care of the degenerate origin point by a small offset + u += (u==0.)*(v==0.)*1e-10 + phi = np.angle(v + 1j*u) + # Flip the baseline sign to match eht-imaging conventions + phi += np.pi + uvdist = (u**2 + v**2 + (u==0.)*(v==0.)*1e-10)**0.5 + dphidu = v/uvdist**2 + dphidv = -u/uvdist**2 + z = np.pi * params['d'] * uvdist + + if pol == 'I': + beta_factor_u = (-np.pi * params['d'] * u/uvdist * sps.jv(1, z) + + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * (-1j * m * dphidu) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([params['beta_list'][m-1] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)) + beta_factor_v = (-np.pi * params['d'] * v/uvdist * sps.jv(1, z) + + np.sum([params['beta_list'][m-1] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * sps.jv(-m, z) * (-1j * m * dphidv) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([params['beta_list'][m-1] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)) + + elif pol == 'V' and len(params['beta_list_cpol']) > 0: + beta_factor_u = (-np.pi * params['d'] * u/uvdist * sps.jv(1, z) * np.real(params['beta_list_cpol'][0]) + + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * (-1j * m * dphidu) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([params['beta_list_cpol'][m] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)) + beta_factor_v = (-np.pi * params['d'] * v/uvdist * sps.jv(1, z) + + np.sum([params['beta_list_cpol'][m] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * sps.jv(-m, z) * (-1j * m * dphidv) * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([params['beta_list_cpol'][m] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * 0.5 * (sps.jv(-m-1, z) - sps.jv(-m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)) + elif pol == 'P' and len(params['beta_list_pol']) > 0: + num_coeff = len(params['beta_list_pol']) + beta_factor_u = (0.0 + + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * ( 1j * m * dphidu) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0) + + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * u/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)) + beta_factor_v = (0.0 + + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * sps.jv( m, z) * ( 1j * m * dphidv) * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0) + + np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * params['d'] * v/uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0)) + else: + beta_factor_u = beta_factor_v = 0.0 + + val = np.array([ + (1j * 2.0 * np.pi * params['x0'] - params['alpha']**2 * np.pi**2 * u/(2. * np.log(2.))) * vis + + params['F0'] * beta_factor_u + * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + (1j * 2.0 * np.pi * params['y0'] - params['alpha']**2 * np.pi**2 * v/(2. * np.log(2.))) * vis + + params['F0'] * beta_factor_v + * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) ]) + elif model_type == 'thick_mring_floor': + val = (1.0 - params['ff']) * sample_1model_graduv_uv(u, v, 'thick_mring', params, pol=pol, jonesdict=jonesdict) + val += params['ff'] * sample_1model_graduv_uv(u, v, 'blurred_disk', params, pol=pol, jonesdict=jonesdict) + elif model_type == 'thick_mring_Gfloor': + val = (1.0 - params['ff']) * sample_1model_graduv_uv(u, v, 'thick_mring', params, pol=pol, jonesdict=jonesdict) + val += params['ff'] * sample_1model_graduv_uv(u, v, 'circ_gauss', params, pol=pol, jonesdict=jonesdict) + elif model_type[:9] == 'stretched': + # Take care of the degenerate origin point by a small offset + u += (u==0.)*(v==0.)*1e-10 + params_stretch = params.copy() + params_stretch['x0'] = 0.0 + params_stretch['y0'] = 0.0 + (u_stretch, v_stretch) = stretch_uv(u,v,params) + + # First calculate the gradient of the unshifted but stretched image + grad0 = sample_1model_graduv_uv(u_stretch, v_stretch, model_type[10:], params_stretch, pol=pol) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + grad = grad0.copy() * 0.0 + grad[0] = ( grad0[0] * (np.cos(params['stretch_PA'])**2 + np.sin(params['stretch_PA'])**2*params['stretch']) + + grad0[1] * ((params['stretch'] - 1.0) * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA']))) + grad[1] = ( grad0[1] * (np.cos(params['stretch_PA'])**2*params['stretch'] + np.sin(params['stretch_PA'])**2) + + grad0[0] * ((params['stretch'] - 1.0) * np.cos(params['stretch_PA']) * np.sin(params['stretch_PA']))) + + # Add the gradient term from the shift + vis = sample_1model_uv(u_stretch, v_stretch, model_type[10:], params_stretch, jonesdict=jonesdict) + grad[0] += vis * 1j * 2.0 * np.pi * params['x0'] * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + grad[1] += vis * 1j * 2.0 * np.pi * params['y0'] * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + + val = grad + else: + print('Model ' + model_type + ' not recognized!') + val = 0.0 + return val * get_const_polfac(model_type, params, pol) + +def sample_1model_grad_leakage_uv_re(u, v, model_type, params, pol, site, hand, jonesdict): + # Convenience function to calculate the gradient with respect to the real part of a specified site/hand leakage + + # Define the various lists + fr1 = jonesdict['fr1'] # Field rotation of site 1 + fr2 = jonesdict['fr2'] # Field rotation of site 2 + DR1 = jonesdict['DR1'] # Right leakage term of site 1 + DL1 = jonesdict['DL1'] # Left leakage term of site 1 + DR2 = np.conj(jonesdict['DR2']) # Right leakage term of site 2 + DL2 = np.conj(jonesdict['DL2']) # Left leakage term of site 2 + # Sample the model without leakage + RR = sample_1model_uv(u, v, model_type, params, pol='RR') + RL = sample_1model_uv(u, v, model_type, params, pol='RL') + LR = sample_1model_uv(u, v, model_type, params, pol='LR') + LL = sample_1model_uv(u, v, model_type, params, pol='LL') + + # Figure out which terms to include in the gradient + DR1mask = 0.0 + (hand == 'R') * (jonesdict['t1'] == site) + DR2mask = 0.0 + (hand == 'R') * (jonesdict['t2'] == site) + DL1mask = 0.0 + (hand == 'L') * (jonesdict['t1'] == site) + DL2mask = 0.0 + (hand == 'L') * (jonesdict['t2'] == site) + + # These are the leakage gradient terms + RRp = LR * DR1mask * np.exp( 2j*fr1) + RL * DR2mask * np.exp(-2j*fr2) + LL * DR1mask * DR2 * np.exp( 2j*(fr1-fr2)) + LL * DR1 * DR2mask * np.exp( 2j*(fr1-fr2)) + RLp = LL * DR1mask * np.exp( 2j*fr1) + RR * DL2mask * np.exp( 2j*fr2) + LR * DR1mask * DL2 * np.exp( 2j*(fr1+fr2)) + LR * DR1 * DL2mask * np.exp( 2j*(fr1+fr2)) + LRp = RR * DL1mask * np.exp(-2j*fr1) + LL * DR2mask * np.exp(-2j*fr2) + RL * DL1mask * DR2 * np.exp(-2j*(fr1+fr2)) + RL * DL1 * DR2mask * np.exp(-2j*(fr1+fr2)) + LLp = LR * DL2mask * np.exp( 2j*fr2) + RL * DL1mask * np.exp(-2j*fr1) + RR * DL1mask * DL2 * np.exp(-2j*(fr1-fr2)) + RR * DL1 * DL2mask * np.exp(-2j*(fr1-fr2)) + + # Return the specified polarization + if pol == 'RR': return RRp + elif pol == 'RL': return RLp + elif pol == 'LR': return LRp + elif pol == 'LL': return LLp + elif pol == 'I': return 0.5 * (RRp + LLp) + elif pol == 'Q': return 0.5 * (LRp + RLp) + elif pol == 'U': return 0.5j* (LRp - RLp) + elif pol == 'V': return 0.5 * (RRp - LLp) + elif pol == 'P': return RLp + else: + raise Exception('Polarization ' + pol + ' not recognized!') + +def sample_1model_grad_leakage_uv_im(u, v, model_type, params, pol, site, hand, jonesdict): + # Convenience function to calculate the gradient with respect to the imaginary part of a specified site/hand leakage + # The tricky thing here is the conjugation of the second leakage site, flipping the sign of the gradient + + # Define the various lists + fr1 = jonesdict['fr1'] # Field rotation of site 1 + fr2 = jonesdict['fr2'] # Field rotation of site 2 + DR1 = jonesdict['DR1'] # Right leakage term of site 1 + DL1 = jonesdict['DL1'] # Left leakage term of site 1 + DR2 = np.conj(jonesdict['DR2']) # Right leakage term of site 2 + DL2 = np.conj(jonesdict['DL2']) # Left leakage term of site 2 + # Sample the model without leakage + RR = sample_1model_uv(u, v, model_type, params, pol='RR') + RL = sample_1model_uv(u, v, model_type, params, pol='RL') + LR = sample_1model_uv(u, v, model_type, params, pol='LR') + LL = sample_1model_uv(u, v, model_type, params, pol='LL') + + # Figure out which terms to include in the gradient + DR1mask = 0.0 + (hand == 'R') * (jonesdict['t1'] == site) + DR2mask = 0.0 + (hand == 'R') * (jonesdict['t2'] == site) + DL1mask = 0.0 + (hand == 'L') * (jonesdict['t1'] == site) + DL2mask = 0.0 + (hand == 'L') * (jonesdict['t2'] == site) + + # These are the leakage gradient terms + RRp = 1j*( LR * DR1mask * np.exp( 2j*fr1) - RL * DR2mask * np.exp(-2j*fr2) + LL * DR1mask * DR2 * np.exp( 2j*(fr1-fr2)) - LL * DR1 * DR2mask * np.exp( 2j*(fr1-fr2))) + RLp = 1j*( LL * DR1mask * np.exp( 2j*fr1) - RR * DL2mask * np.exp( 2j*fr2) + LR * DR1mask * DL2 * np.exp( 2j*(fr1+fr2)) - LR * DR1 * DL2mask * np.exp( 2j*(fr1+fr2))) + LRp = 1j*( RR * DL1mask * np.exp(-2j*fr1) - LL * DR2mask * np.exp(-2j*fr2) + RL * DL1mask * DR2 * np.exp(-2j*(fr1+fr2)) - RL * DL1 * DR2mask * np.exp(-2j*(fr1+fr2))) + LLp = 1j*(-LR * DL2mask * np.exp( 2j*fr2) + RL * DL1mask * np.exp(-2j*fr1) + RR * DL1mask * DL2 * np.exp(-2j*(fr1-fr2)) - RR * DL1 * DL2mask * np.exp(-2j*(fr1-fr2))) + + # Return the specified polarization + if pol == 'RR': return RRp + elif pol == 'RL': return RLp + elif pol == 'LR': return LRp + elif pol == 'LL': return LLp + elif pol == 'I': return 0.5 * (RRp + LLp) + elif pol == 'Q': return 0.5 * (LRp + RLp) + elif pol == 'U': return 0.5j* (LRp - RLp) + elif pol == 'V': return 0.5 * (RRp - LLp) + elif pol == 'P': return RLp + else: + raise Exception('Polarization ' + pol + ' not recognized!') + +def sample_1model_grad_uv(u, v, model_type, params, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + # Gradient of the model for each model parameter + + if jonesdict is not None: + # Define the various lists + fr1 = jonesdict['fr1'] # Field rotation of site 1 + fr2 = jonesdict['fr2'] # Field rotation of site 2 + DR1 = jonesdict['DR1'] # Right leakage term of site 1 + DL1 = jonesdict['DL1'] # Left leakage term of site 1 + DR2 = np.conj(jonesdict['DR2']) # Right leakage term of site 2 + DL2 = np.conj(jonesdict['DL2']) # Left leakage term of site 2 + # Sample the gradients without leakage + RR = sample_1model_grad_uv(u, v, model_type, params, pol='RR', fit_pol=fit_pol, fit_cpol=fit_cpol) + RL = sample_1model_grad_uv(u, v, model_type, params, pol='RL', fit_pol=fit_pol, fit_cpol=fit_cpol) + LR = sample_1model_grad_uv(u, v, model_type, params, pol='LR', fit_pol=fit_pol, fit_cpol=fit_cpol) + LL = sample_1model_grad_uv(u, v, model_type, params, pol='LL', fit_pol=fit_pol, fit_cpol=fit_cpol) + # Apply the Jones matrices + RRp = (RR + LR * DR1 * np.exp( 2j*fr1) + RL * DR2 * np.exp(-2j*fr2) + LL * DR1 * DR2 * np.exp( 2j*(fr1-fr2))) + RLp = (RL + LL * DR1 * np.exp( 2j*fr1) + RR * DL2 * np.exp( 2j*fr2) + LR * DR1 * DL2 * np.exp( 2j*(fr1+fr2))) + LRp = (LR + RR * DL1 * np.exp(-2j*fr1) + LL * DR2 * np.exp(-2j*fr2) + RL * DL1 * DR2 * np.exp(-2j*(fr1+fr2))) + LLp = (LL + LR * DL2 * np.exp( 2j*fr2) + RL * DL1 * np.exp(-2j*fr1) + RR * DL1 * DL2 * np.exp(-2j*(fr1-fr2))) + # Return the specified polarization + if pol == 'RR': grad = RRp + elif pol == 'RL': grad = RLp + elif pol == 'LR': grad = LRp + elif pol == 'LL': grad = LLp + elif pol == 'I': grad = 0.5 * (RRp + LLp) + elif pol == 'Q': grad = 0.5 * (LRp + RLp) + elif pol == 'U': grad = 0.5j* (LRp - RLp) + elif pol == 'V': grad = 0.5 * (RRp - LLp) + elif pol == 'P': grad = RLp + else: + raise Exception('Polarization ' + pol + ' not recognized!') + # If necessary, add the gradient components from the leakage terms + # Each leakage term has two corresponding gradient terms: d/dRe and d/dIm. + if fit_leakage: + # 'leakage_fit' is a list of tuples [site, 'R' or 'L'] denoting the fitted leakage terms + for (site, hand) in jonesdict['leakage_fit']: + grad = np.vstack([grad, sample_1model_grad_leakage_uv_re(u, v, model_type, params, pol, site, hand, jonesdict), sample_1model_grad_leakage_uv_im(u, v, model_type, params, pol, site, hand, jonesdict)]) + + return grad + + if pol == 'Q': + return 0.5 * (sample_1model_grad_uv(u, v, model_type, params, pol='P', fit_pol=fit_pol, fit_cpol=fit_cpol) + np.conj(sample_1model_grad_uv(-u, -v, model_type, params, pol='P', fit_pol=fit_pol, fit_cpol=fit_cpol))) + elif pol == 'U': + return -0.5j * (sample_1model_grad_uv(u, v, model_type, params, pol='P', fit_pol=fit_pol, fit_cpol=fit_cpol) - np.conj(sample_1model_grad_uv(-u, -v, model_type, params, pol='P', fit_pol=fit_pol, fit_cpol=fit_cpol))) + elif pol in ['I','V','P']: + pass + elif pol == 'RR': + return sample_1model_grad_uv(u, v, model_type, params, pol='I', fit_pol=fit_pol, fit_cpol=fit_cpol) + sample_1model_grad_uv(u, v, model_type, params, pol='V', fit_pol=fit_pol, fit_cpol=fit_cpol) + elif pol == 'LL': + return sample_1model_grad_uv(u, v, model_type, params, pol='I', fit_pol=fit_pol, fit_cpol=fit_cpol) - sample_1model_grad_uv(u, v, model_type, params, pol='V', fit_pol=fit_pol, fit_cpol=fit_cpol) + elif pol == 'RL': + return sample_1model_grad_uv(u, v, model_type, params, pol='Q', fit_pol=fit_pol, fit_cpol=fit_cpol) + 1j*sample_1model_grad_uv(u, v, model_type, params, pol='U', fit_pol=fit_pol, fit_cpol=fit_cpol) + elif pol == 'LR': + return sample_1model_grad_uv(u, v, model_type, params, pol='Q', fit_pol=fit_pol, fit_cpol=fit_cpol) - 1j*sample_1model_grad_uv(u, v, model_type, params, pol='U', fit_pol=fit_pol, fit_cpol=fit_cpol) + else: + raise Exception('Polarization ' + pol + ' not implemented!') + + if model_type == 'point': # F0, x0, y0 + val = np.array([ np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + 1j * 2.0 * np.pi * u * params['F0'] * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + 1j * 2.0 * np.pi * v * params['F0'] * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))]) + elif model_type == 'circ_gauss': # F0, FWHM, x0, y0 + gauss = (params['F0'] * np.exp(-np.pi**2/(4.*np.log(2.)) * (u**2 + v**2) * params['FWHM']**2) + *np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + val = np.array([ 1.0/params['F0'] * gauss, + -np.pi**2/(2.*np.log(2.)) * (u**2 + v**2) * params['FWHM'] * gauss, + 1j * 2.0 * np.pi * u * gauss, + 1j * 2.0 * np.pi * v * gauss]) + elif model_type == 'gauss': # F0, FWHM_maj, FWHM_min, PA, x0, y0 + u_maj = u*np.sin(params['PA']) + v*np.cos(params['PA']) + u_min = u*np.cos(params['PA']) - v*np.sin(params['PA']) + vis = (params['F0'] + * np.exp(-np.pi**2/(4.*np.log(2.)) * ((u_maj * params['FWHM_maj'])**2 + (u_min * params['FWHM_min'])**2)) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + val = np.array([ 1.0/params['F0'] * vis, + -np.pi**2/(2.*np.log(2.)) * params['FWHM_maj'] * u_maj**2 * vis, + -np.pi**2/(2.*np.log(2.)) * params['FWHM_min'] * u_min**2 * vis, + -np.pi**2/(2.*np.log(2.)) * (params['FWHM_maj']**2 - params['FWHM_min']**2) * u_maj * u_min * vis, + 1j * 2.0 * np.pi * u * vis, + 1j * 2.0 * np.pi * v * vis]) + elif model_type == 'disk': # F0, d, x0, y0 + z = np.pi * params['d'] * (u**2 + v**2)**0.5 + #Add a small offset to avoid issues with division by zero + z += (z == 0.0) * 1e-10 + vis = (params['F0'] * 2.0/z * sps.jv(1, z) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + val = np.array([ 1.0/params['F0'] * vis, + -(params['F0'] * 2.0/z * sps.jv(2, z) * np.pi * (u**2 + v**2)**0.5 * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) , + 1j * 2.0 * np.pi * u * vis, + 1j * 2.0 * np.pi * v * vis]) + elif model_type == 'blurred_disk': # F0, d, alpha, x0, y0 + z = np.pi * params['d'] * (u**2 + v**2)**0.5 + #Add a small offset to avoid issues with division by zero + z += (z == 0.0) * 1e-10 + blur = np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) + vis = (params['F0'] * 2.0/z * sps.jv(1, z) + * blur + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + val = np.array([ 1.0/params['F0'] * vis, + -params['F0'] * 2.0/z * sps.jv(2, z) * np.pi * (u**2 + v**2)**0.5 * blur * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + -np.pi**2 * (u**2 + v**2) * params['alpha']/(2.*np.log(2.)) * vis, + 1j * 2.0 * np.pi * u * vis, + 1j * 2.0 * np.pi * v * vis]) + elif model_type == 'crescent': #['F0','d', 'fr', 'fo', 'phi','x0','y0'] + phi = params['phi'] + fr = params['fr'] + fo = params['fo'] + ff = params['ff'] + r = params['d'] / 2. + params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'x0': params['x0'], 'y0': params['y0']} + params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)} + + grad0 = sample_1model_grad_uv(u, v, 'disk', params0, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + grad1 = sample_1model_grad_uv(u, v, 'disk', params1, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + + # Add the derivatives one by one + grad = [] + + # F0 + grad.append( 1.0/(1.0-(1-ff)*fr**2)*grad0[0] - (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*grad1[0] ) + + # d + grad.append( grad0[1] - fr*grad1[1] - 0.5 * (1.0 - fr) * fo * (np.sin(phi) * grad1[2] + np.cos(phi) * grad1[3]) ) + + # fr + grad.append( 2.0*params['F0']*(1-ff)*fr/(1.0 - (1-ff)*fr**2)**2 * (grad0[0] - grad1[0]) - params['d']*grad1[1] + r * fo * (np.sin(phi) * grad1[2] + np.cos(phi) * grad1[3]) ) + + # fo + grad.append( -r * (1-fr) * (np.sin(phi) * grad1[2] + np.cos(phi) * grad1[3]) ) + + # ff + grad.append( -params['F0']*fr**2/(1.0 - (1-ff)*fr**2)**2 * (grad0[0] - grad1[0]) ) + + # phi + grad.append( -r*(1-fr)*fo* (np.cos(phi) * grad1[2] - np.sin(phi) * grad1[3]) ) + + # x0, y0 + grad.append( grad0[2] - grad1[2] ) + grad.append( grad0[3] - grad1[3] ) + + val = np.array(grad) + elif model_type == 'blurred_crescent': #['F0','d','alpha','fr', 'fo', 'phi','x0','y0'] + phi = params['phi'] + fr = params['fr'] + fo = params['fo'] + ff = params['ff'] + r = params['d'] / 2. + params0 = {'F0': 1.0/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d'], 'alpha':params['alpha'], 'x0': params['x0'], 'y0': params['y0']} + params1 = {'F0': (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*params['F0'], 'd':params['d']*fr, 'alpha':params['alpha'], 'x0': params['x0'] + r*(1-fr)*fo*np.sin(phi), 'y0': params['y0'] + r*(1-fr)*fo*np.cos(phi)} + + grad0 = sample_1model_grad_uv(u, v, 'blurred_disk', params0, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + grad1 = sample_1model_grad_uv(u, v, 'blurred_disk', params1, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + + # Add the derivatives one by one + grad = [] + + # F0 + grad.append( 1.0/(1.0-(1-ff)*fr**2)*grad0[0] - (1-ff)*fr**2/(1.0-(1-ff)*fr**2)*grad1[0] ) + + # d + grad.append( grad0[1] - fr*grad1[1] - 0.5 * (1.0 - fr) * fo * (np.sin(phi) * grad1[3] + np.cos(phi) * grad1[4]) ) + + # alpha + grad.append( grad0[2] - grad1[2] ) + + # fr + grad.append( 2.0*params['F0']*(1-ff)*fr/(1.0 - (1-ff)*fr**2)**2 * (grad0[0] - grad1[0]) - params['d']*grad1[1] + r * fo * (np.sin(phi) * grad1[3] + np.cos(phi) * grad1[4]) ) + + # fo + grad.append( -r * (1-fr) * (np.sin(phi) * grad1[3] + np.cos(phi) * grad1[4]) ) + + # ff + grad.append( -params['F0']*fr**2/(1.0 - (1-ff)*fr**2)**2 * (grad0[0] - grad1[0]) ) + + # phi + grad.append( -r*(1-fr)*fo* (np.cos(phi) * grad1[3] - np.sin(phi) * grad1[4]) ) + + # x0, y0 + grad.append( grad0[3] - grad1[3] ) + grad.append( grad0[4] - grad1[4] ) + + val = np.array(grad) + + elif model_type == 'ring': # F0, d, x0, y0 + z = np.pi * params['d'] * (u**2 + v**2)**0.5 + val = np.array([ sps.jv(0, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + -np.pi * (u**2 + v**2)**0.5 * params['F0'] * sps.jv(1, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + 2.0 * np.pi * 1j * u * params['F0'] * sps.jv(0, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])), + 2.0 * np.pi * 1j * v * params['F0'] * sps.jv(0, z) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))]) + elif model_type == 'thick_ring': # F0, d, alpha, x0, y0 + z = np.pi * params['d'] * (u**2 + v**2)**0.5 + vis = (params['F0'] * sps.jv(0, z) + * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + val = np.array([ 1.0/params['F0'] * vis, + -(params['F0'] * np.pi * (u**2 + v**2)**0.5 * sps.jv(1, z) + * np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))), + -np.pi**2 * (u**2 + v**2) * params['alpha']/(2.*np.log(2.)) * vis, + 1j * 2.0 * np.pi * u * vis, + 1j * 2.0 * np.pi * v * vis]) + elif model_type in ['mring','thick_mring']: # F0, d, [alpha], x0, y0, beta1_re/abs, beta1_im/arg, beta2_re/abs, beta2_im/arg, ... + phi = np.angle(v + 1j*u) + # Flip the baseline sign to match eht-imaging conventions + phi += np.pi + uvdist = (u**2 + v**2)**0.5 + z = np.pi * params['d'] * uvdist + if model_type == 'thick_mring': + alpha_factor = np.exp(-(np.pi * params['alpha'] * (u**2 + v**2)**0.5)**2/(4. * np.log(2.))) + else: + alpha_factor = 1 + + # Only one of the beta_lists will affect the measurement and have non-zero gradients. Figure out which: + # These are for the derivatives wrt diameter + if pol == 'I': + beta_factor = (-np.pi * uvdist * sps.jv(1, z) + + np.sum([params['beta_list'][m-1] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0) + + np.sum([np.conj(params['beta_list'][m-1]) * 0.5 * (sps.jv( -m-1, z) - sps.jv( -m+1, z)) * np.pi * uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list'])+1)],axis=0)) + elif pol == 'P' and len(params['beta_list_pol']) > 0: + num_coeff = len(params['beta_list_pol']) + beta_factor = np.sum([params['beta_list_pol'][m + (num_coeff-1)//2] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(-(num_coeff-1)//2,(num_coeff+1)//2)],axis=0) + elif pol == 'V' and len(params['beta_list_cpol']) > 0: + beta_factor = (-np.pi * uvdist * sps.jv(1, z) * np.real(params['beta_list_cpol'][0]) + + np.sum([params['beta_list_cpol'][m] * 0.5 * (sps.jv( m-1, z) - sps.jv( m+1, z)) * np.pi * uvdist * np.exp( 1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0) + + np.sum([np.conj(params['beta_list_cpol'][m]) * 0.5 * (sps.jv( -m-1, z) - sps.jv( -m+1, z)) * np.pi * uvdist * np.exp(-1j * m * (phi - np.pi/2.)) for m in range(1,len(params['beta_list_cpol']))],axis=0)) + else: + beta_factor = 0.0 + + vis = sample_1model_uv(u, v, model_type, params, pol=pol, jonesdict=jonesdict) + grad = [ 1.0/params['F0'] * vis, + (params['F0'] * alpha_factor * beta_factor * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])))] + if model_type == 'thick_mring': + grad.append(-np.pi**2/(2.*np.log(2)) * uvdist**2 * params['alpha'] * vis) + grad.append(1j * 2.0 * np.pi * u * vis) + grad.append(1j * 2.0 * np.pi * v * vis) + + if pol=='I': + # Add derivatives of the beta terms + for m in range(1,len(params['beta_list'])+1): + beta_grad_re = params['F0'] * alpha_factor * ( + sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) + sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + beta_grad_im = 1j * params['F0'] * alpha_factor * ( + sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) - sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + if COMPLEX_BASIS == 're-im': + grad.append(beta_grad_re) + grad.append(beta_grad_im) + elif COMPLEX_BASIS == 'abs-arg': + beta_abs = np.abs(params['beta_list'][m-1]) + beta_arg = np.angle(params['beta_list'][m-1]) + grad.append(beta_grad_re * np.cos(beta_arg) + beta_grad_im * np.sin(beta_arg)) + grad.append(-beta_abs * np.sin(beta_arg) * beta_grad_re + beta_abs * np.cos(beta_arg) * beta_grad_im) + else: + raise Exception('COMPLEX_BASIS ' + COMPLEX_BASIS + ' not recognized!') + else: + [grad.append(np.zeros_like(grad[0])) for _ in range(2*len(params['beta_list']))] + + if pol=='P' and fit_pol: + # Add derivatives of the beta_pol terms + num_coeff = len(params['beta_list_pol']) + for m in range(-(num_coeff-1)//2,(num_coeff+1)//2): + beta_grad_re = params['F0'] * alpha_factor * sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + beta_grad_im = 1j * beta_grad_re + if COMPLEX_BASIS == 're-im': + grad.append(beta_grad_re) + grad.append(beta_grad_im) + elif COMPLEX_BASIS == 'abs-arg': + beta_abs = np.abs(params['beta_list_pol'][m+(num_coeff-1)//2]) + beta_arg = np.angle(params['beta_list_pol'][m+(num_coeff-1)//2]) + grad.append(beta_grad_re * np.cos(beta_arg) + beta_grad_im * np.sin(beta_arg)) + grad.append(-beta_abs * np.sin(beta_arg) * beta_grad_re + beta_abs * np.cos(beta_arg) * beta_grad_im) + else: + raise Exception('COMPLEX_BASIS ' + COMPLEX_BASIS + ' not recognized!') + elif pol!='P' and fit_pol: + [grad.append(np.zeros_like(grad[0])) for _ in range(2*len(params['beta_list_pol']))] + + if pol=='V' and fit_cpol: + # Add derivatives of the beta_cpol terms + num_coeff = len(params['beta_list_cpol']) - 1 + + # First do the beta0 mode (real) + beta_grad_re = params['F0'] * alpha_factor * sps.jv( 0, z) + grad.append(beta_grad_re) + + # Now do the remaining modes (complex) + for m in range(1,num_coeff+1): + beta_grad_re = params['F0'] * alpha_factor * ( + sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) + sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + beta_grad_im = 1j * params['F0'] * alpha_factor * ( + sps.jv( m, z) * np.exp( 1j * m * (phi - np.pi/2.)) - sps.jv(-m, z) * np.exp(-1j * m * (phi - np.pi/2.)) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + if COMPLEX_BASIS == 're-im': + grad.append(beta_grad_re) + grad.append(beta_grad_im) + elif COMPLEX_BASIS == 'abs-arg': + beta_abs = np.abs(params['beta_list_cpol'][m]) + beta_arg = np.angle(params['beta_list_cpol'][m]) + grad.append(beta_grad_re * np.cos(beta_arg) + beta_grad_im * np.sin(beta_arg)) + grad.append(-beta_abs * np.sin(beta_arg) * beta_grad_re + beta_abs * np.cos(beta_arg) * beta_grad_im) + else: + raise Exception('COMPLEX_BASIS ' + COMPLEX_BASIS + ' not recognized!') + elif pol!='V' and fit_cpol: + [grad.append(np.zeros_like(grad[0])) for _ in range(2*len(params['beta_list_cpol'])-1)] + + val = np.array(grad) + elif model_type == 'thick_mring_floor': # F0, d, [alpha], ff, x0, y0, beta1_re/abs, beta1_im/arg, beta2_re/abs, beta2_im/arg, ... + # We need to stich together the two gradients for the mring and the disk; we also need to add the gradient for the floor fraction ff + grad_mring = (1.0 - params['ff']) * sample_1model_grad_uv(u, v, 'thick_mring', params, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol) + grad_disk = params['ff'] * sample_1model_grad_uv(u, v, 'blurred_disk', params, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol) + + # mring: F0, d, alpha, x0, y0, beta1_re/abs, beta1_im/arg, beta2_re/abs, beta2_im/arg, ... + # disk: F0, d, alpha, x0, y0 + + # Here are derivatives for F0, d, and alpha + grad = [] + for j in range(3): + grad.append(grad_mring[j] + grad_disk[j]) + + # Here is the derivative for ff + grad.append( params['F0'] * (grad_disk[0]/params['ff'] - grad_mring[0]/(1.0 - params['ff'])) ) + + # Now the derivatives for x0 and y0 + grad.append(grad_mring[3] + grad_disk[3]) + grad.append(grad_mring[4] + grad_disk[4]) + + # Add remaining gradients for the mring + for j in range(5,len(grad_mring)): + grad.append(grad_mring[j]) + + val = np.array(grad) + elif model_type == 'thick_mring_Gfloor': # F0, d, [alpha], ff, FWHM, x0, y0, beta1_re/abs, beta1_im/arg, beta2_re/abs, beta2_im/arg, ... + # We need to stich together the two gradients for the mring and the gaussian; we also need to add the gradient for the floor fraction ff + grad_mring = (1.0 - params['ff']) * sample_1model_grad_uv(u, v, 'thick_mring', params, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol) + grad_gauss = params['ff'] * sample_1model_grad_uv(u, v, 'circ_gauss', params, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol) + + # mring: F0, d, alpha, x0, y0, beta1_re/abs, beta1_im/arg, beta2_re/abs, beta2_im/arg, ... + # gauss: F0, [d, alpha] FWHM, x0, y0 + + grad = [] + grad.append(grad_mring[0] + grad_gauss[0]) # Here are derivatives for F0 + + # Here are derivatives for d, and alpha + grad.append(grad_mring[1]) + grad.append(grad_mring[2]) + + # Here is the derivative for ff + grad.append( params['F0'] * (grad_gauss[0]/params['ff'] - grad_mring[0]/(1.0 - params['ff'])) ) + + # Now the derivatives for FWHM + grad.append(grad_gauss[1]) + + # Now the derivatives for x0 and y0 + grad.append(grad_mring[3] + grad_gauss[2]) + grad.append(grad_mring[4] + grad_gauss[3]) + + # Add remaining gradients for the mring + for j in range(5,len(grad_mring)): + grad.append(grad_mring[j]) + + val = np.array(grad) + elif model_type[:9] == 'stretched': + # Start with the model visibility + vis = sample_1model_uv(u, v, model_type, params, pol=pol, jonesdict=jonesdict) + + # Next, calculate the gradient wrt model parameters other than stretch and stretch_PA + # These are the same as the gradient of the unstretched model on stretched baselines + params_stretch = params.copy() + params_stretch['x0'] = 0.0 + params_stretch['y0'] = 0.0 + (u_stretch, v_stretch) = stretch_uv(u,v,params) + grad = (sample_1model_grad_uv(u_stretch, v_stretch, model_type[10:], params_stretch, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol) + * np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0']))) + + # Add the gradient terms for the centroid + grad[model_params(model_type, params).index('x0')] = 1j * 2.0 * np.pi * u * vis + grad[model_params(model_type, params).index('y0')] = 1j * 2.0 * np.pi * v * vis + + # Now calculate the gradient with respect to stretch and stretch PA + grad_uv = sample_1model_graduv_uv(u_stretch, v_stretch, model_type[10:], params_stretch, pol=pol) + grad_stretch = grad_uv.copy() * 0.0 + grad_stretch[0] = ( grad_uv[0] * (u * np.sin(params['stretch_PA'])**2 + v * np.sin(params['stretch_PA']) * np.cos(params['stretch_PA'])) + + grad_uv[1] * (v * np.cos(params['stretch_PA'])**2 + u * np.sin(params['stretch_PA']) * np.cos(params['stretch_PA']))) + grad_stretch[0] *= np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + + grad_stretch[1] = ( grad_uv[0] * (params['stretch'] - 1.0) * ( u * np.sin(2.0 * params['stretch_PA']) + v * np.cos(2.0 * params['stretch_PA'])) + + grad_uv[1] * (params['stretch'] - 1.0) * (-v * np.sin(2.0 * params['stretch_PA']) + u * np.cos(2.0 * params['stretch_PA']))) + grad_stretch[1] *= np.exp(1j * 2.0 * np.pi * (u * params['x0'] + v * params['y0'])) + + val = np.concatenate([grad, grad_stretch]) + else: + print('Model ' + model_type + ' not recognized!') + val = 0.0 + + grad = val * get_const_polfac(model_type, params, pol) + + if (fit_pol or fit_cpol) and model_type.find('mring') == -1: + # Add gradient contributions for models that have constant polarization + if fit_pol: + # Add gradient wrt pol_frac if the polarization is P, otherwise ignore + grad_params = copy.deepcopy(params) + grad_params['pol_frac'] = 1.0 + grad = np.vstack([grad, (pol == 'P') * sample_1model_uv(u, v, model_type, grad_params, pol=pol, jonesdict=jonesdict)]) + + # Add gradient wrt pol_evpa if the polarization is P, otherwise ignore + grad_params = copy.deepcopy(params) + grad_params['pol_frac'] *= 2j + grad = np.vstack([grad, (pol == 'P') * sample_1model_uv(u, v, model_type, grad_params, pol=pol, jonesdict=jonesdict)]) + if fit_cpol: + # Add gradient wrt cpol_frac + grad_params = copy.deepcopy(params) + grad_params['cpol_frac'] = 1.0 + grad = np.vstack([grad, (pol == 'V') * sample_1model_uv(u, v, model_type, grad_params, pol=pol, jonesdict=jonesdict)]) + + return grad + +def sample_model_xy(models, params, x, y, psize=1.*RADPERUAS, pol='I'): + return np.sum(sample_1model_xy(x, y, models[j], params[j], psize=psize,pol=pol) for j in range(len(models))) + +def sample_model_uv(models, params, u, v, pol='I', jonesdict=None): + return np.sum(sample_1model_uv(u, v, models[j], params[j], pol=pol, jonesdict=jonesdict) for j in range(len(models))) + +def sample_model_graduv_uv(models, params, u, v, pol='I', jonesdict=None): + # Gradient of a sum of models wrt (u,v) + return np.sum([sample_1model_graduv_uv(u, v, models[j], params[j], pol=pol, jonesdict=jonesdict) for j in range(len(models))],axis=0) + +def sample_model_grad_uv(models, params, u, v, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + # Gradient of a sum of models for each parameter + if fit_leakage == False: + return np.concatenate([sample_1model_grad_uv(u, v, models[j], params[j], pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) for j in range(len(models))]) + else: + # Need to sum the leakage contributions + allgrad = [sample_1model_grad_uv(u, v, models[j], params[j], pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) for j in range(len(models))] + n_leakage = len(jonesdict['leakage_fit'])*2 + grad = np.concatenate([allgrad[j][:-n_leakage] for j in range(len(models))]) + grad_leakage = np.sum([allgrad[j][-n_leakage:] for j in range(len(models))],axis=0) + return np.concatenate([grad, grad_leakage]) + +def blur_circ_1model(model_type, params, fwhm): + """Blur a single model, returning new model type and associated parameters + + Args: + fwhm (float) : Full width at half maximum of the kernel (radians) + + Returns: + (dict) : Dictionary with new 'model_type' and new 'params' + """ + + model_type_blur = model_type + params_blur = params.copy() + + if model_type == 'point': + model_type_blur = 'circ_gauss' + params_blur['FWHM'] = fwhm + elif model_type == 'circ_gauss': + params_blur['FWHM'] = (params_blur['FWHM']**2 + fwhm**2)**0.5 + elif model_type == 'gauss': + params_blur['FWHM_maj'] = (params_blur['FWHM_maj']**2 + fwhm**2)**0.5 + params_blur['FWHM_min'] = (params_blur['FWHM_min']**2 + fwhm**2)**0.5 + elif 'thick' in model_type or 'blurred' in model_type: + params_blur['alpha'] = (params_blur['alpha']**2 + fwhm**2)**0.5 + elif model_type == 'disk': + model_type_blur = 'blurred_' + model_type + params_blur['alpha'] = fwhm + elif model_type == 'crescent': + model_type_blur = 'blurred_' + model_type + params_blur['alpha'] = fwhm + elif model_type == 'ring' or model_type == 'mring': + model_type_blur = 'thick_' + model_type + params_blur['alpha'] = fwhm + elif model_type == 'stretched_ring' or model_type == 'stretched_mring': + model_type_blur = 'stretched_thick_' + model_type[10:] + params_blur['alpha'] = fwhm + else: + raise Exception("A blurred " + model_type + " is not yet a supported model!") + + return {'model_type':model_type_blur, 'params':params_blur} + +class Model(object): + """A model with analytic representations in the image and visibility domains. + + Attributes: + """ + + def __init__(self, ra=RA_DEFAULT, dec=DEC_DEFAULT, pa=0.0, + polrep='stokes', pol_prim=None, + rf=RF_DEFAULT, source=SOURCE_DEFAULT, + mjd=MJD_DEFAULT, time=0.): + + """A model with analytic representations in the image and visibility domains. + + Args: + + Returns: + """ + + # The model is a sum of component models, each defined by a tag and associated parameters + self.pol_prim = 'I' + self.polrep = 'stokes' + self._imdict = {'I':{'models':[],'params':[]},'Q':{'models':[],'params':[]},'U':{'models':[],'params':[]},'V':{'models':[],'params':[]}} + + # Save the image metadata + self.ra = float(ra) + self.dec = float(dec) + self.pa = float(pa) + self.rf = float(rf) + self.source = str(source) + self.mjd = int(mjd) + if time > 24: + self.mjd += int((time - time % 24)/24) + self.time = float(time % 24) + else: + self.time = time + + @property + def models(self): + return self._imdict[self.pol_prim]['models'] + + @models.setter + def models(self, model_list): + self._imdict[self.pol_prim]['models'] = model_list + + @property + def params(self): + return self._imdict[self.pol_prim]['params'] + + @params.setter + def params(self, param_list): + self._imdict[self.pol_prim]['params'] = param_list + + def copy(self): + """Return a copy of the Model object. + + Args: + + Returns: + (Model): copy of the Model. + """ + out = Model(ra=self.ra, dec=self.dec, pa=self.pa, polrep=self.polrep, pol_prim=self.pol_prim,rf=self.rf,source=self.source,mjd=self.mjd,time=self.time) + out.models = copy.deepcopy(self.models) + out.params = copy.deepcopy(self.params.copy()) + return out + + def switch_polrep(self, polrep_out='stokes', pol_prim_out=None): + + """Return a new model with the polarization representation changed + Args: + polrep_out (str): the polrep of the output data + pol_prim_out (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for circ + + Returns: + (Model): new Model object with potentially different polrep + """ + + # Note: this currently does nothing, but it is put here for compatibility with functions such as selfcal + if polrep_out not in ['stokes','circ']: + raise Exception("polrep_out must be either 'stokes' or 'circ'") + if pol_prim_out is None: + if polrep_out=='stokes': pol_prim_out = 'I' + elif polrep_out=='circ': pol_prim_out = 'RR' + + return self.copy() + + def N_models(self): + """Return the number of model components + + Args: + + Returns: + (int): number of model components + """ + return len(self.models) + + def total_flux(self): + """Return the total flux of the model in Jy. + + Args: + + Returns: + (float) : model total flux (Jy) + """ + return np.real(self.sample_uv(0,0)) + + def blur_circ(self, fwhm): + """Return a new model, equal to the current one convolved with a circular Gaussian kernel + + Args: + fwhm (float) : Full width at half maximum of the kernel (radians) + + Returns: + (Model) : Blurred model + """ + + out = self.copy() + + for j in range(len(out.models)): + blur_model = blur_circ_1model(out.models[j], out.params[j], fwhm) + out.models[j] = blur_model['model_type'] + out.params[j] = blur_model['params'] + + return out + + def add_point(self, F0 = 1.0, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + """Add a point source model. + + Args: + F0 (float): The total flux of the point source (Jy) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + + Returns: + (Model): Updated Model + """ + + out = self.copy() + out.models.append('point') + out.params.append({'F0':F0,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac}) + return out + + def add_circ_gauss(self, F0 = 1.0, FWHM = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + """Add a circular Gaussian model. + + Args: + F0 (float): The total flux of the Gaussian (Jy) + FWHM (float): The FWHM of the Gaussian (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + + Returns: + (Model): Updated Model + """ + out = self.copy() + out.models.append('circ_gauss') + out.params.append({'F0':F0,'FWHM':FWHM,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac}) + return out + + def add_gauss(self, F0 = 1.0, FWHM_maj = 50.*RADPERUAS, FWHM_min = 50.*RADPERUAS, PA = 0.0, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + """Add an anisotropic Gaussian model. + + Args: + F0 (float): The total flux of the Gaussian (Jy) + FWHM_maj (float): The FWHM of the Gaussian major axis (radians) + FWHM_min (float): The FWHM of the Gaussian minor axis (radians) + PA (float): Position angle of the major axis, east of north (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + + Returns: + (Model): Updated Model + """ + out = self.copy() + out.models.append('gauss') + out.params.append({'F0':F0,'FWHM_maj':FWHM_maj,'FWHM_min':FWHM_min,'PA':PA,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac}) + return out + + def add_disk(self, F0 = 1.0, d = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + """Add a circular disk model. + + Args: + F0 (float): The total flux of the disk (Jy) + d (float): The diameter (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + + Returns: + (Model): Updated Model + """ + out = self.copy() + out.models.append('disk') + out.params.append({'F0':F0,'d':d,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac}) + return out + + def add_blurred_disk(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + """Add a circular disk model that is blurred with a circular Gaussian kernel. + + Args: + F0 (float): The total flux of the disk (Jy) + d (float): The diameter (radians) + alpha (float): The blurring (FWHM of Gaussian convolution) (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + + Returns: + (Model): Updated Model + """ + out = self.copy() + out.models.append('blurred_disk') + out.params.append({'F0':F0,'d':d,'alpha':alpha,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac}) + return out + + def add_crescent(self, F0 = 1.0, d = 50.*RADPERUAS, fr = 0.0, fo = 0.0, ff = 0.0, phi = 0.0, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + """Add a crescent model. + + Args: + F0 (float): The total flux of the disk (Jy) + d (float): The diameter (radians) + fr (float): Fractional radius of the inner subtracted disk with respect to the radius of the outer disk + fo (float): Fractional offset of the inner disk from the center of the outer disk + ff (float): Fractional brightness of the inner disk + phi (float): angle of offset of the inner disk + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + + Returns: + (Model): Updated Model + """ + out = self.copy() + out.models.append('crescent') + out.params.append({'F0':F0,'d':d,'fr':fr, 'fo':fo, 'ff':ff, 'phi':phi, 'x0':x0, 'y0':y0, 'pol_frac':pol_frac, 'pol_evpa':pol_evpa, 'cpol_frac':cpol_frac}) + return out + + def add_blurred_crescent(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, fr = 0.0, fo = 0.0, ff = 0.0, phi = 0.0, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + + """Add a circular disk model that is blurred with a circular Gaussian kernel. + + Args: + F0 (float): The total flux of the disk (Jy) + d (float): The diameter (radians) + alpha (float) :The blurring (FWHM of Gaussian convolution) (radians) + fr (float): Fractional radius of the inner subtracted disk with respect to the radius of the outer disk + fo (float): Fractional offset of the inner disk from the center of the outer disk + ff (float): Fractional brightness of the inner disk + phi (float): angle of offset of the inner disk + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + + Returns: + (Model): Updated Model + """ + out = self.copy() + out.models.append('blurred_crescent') + out.params.append({'F0':F0,'d':d,'alpha':alpha, 'fr':fr, 'fo':fo, 'ff':ff, 'phi':phi, 'x0':x0, 'y0':y0, 'pol_frac':pol_frac, 'pol_evpa':pol_evpa, 'cpol_frac':cpol_frac}) + return out + + def add_ring(self, F0 = 1.0, d = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + """Add a ring model with infinitesimal thickness. + + Args: + F0 (float): The total flux of the ring (Jy) + d (float): The diameter (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + + Returns: + (Model): Updated Model + """ + out = self.copy() + out.models.append('ring') + out.params.append({'F0':F0,'d':d,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac}) + return out + + def add_stretched_ring(self, F0 = 1.0, d = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, stretch = 1.0, stretch_PA = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + """Add a stretched ring model with infinitesimal thickness. + + Args: + F0 (float): The total flux of the ring (Jy) + d (float): The diameter (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + stretch (float): The stretch to apply (1.0 = no stretch) + stretch_PA (float): Position angle of the stretch, east of north (radians) + + Returns: + (Model): Updated Model + """ + out = self.copy() + out.models.append('stretched_ring') + out.params.append({'F0':F0,'d':d,'x0':x0,'y0':y0,'stretch':stretch,'stretch_PA':stretch_PA,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac}) + return out + + def add_thick_ring(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, x0 = 0.0, y0 = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + """Add a ring model with finite thickness, determined by circular Gaussian convolution of a thin ring. + For details, see Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf + + Args: + F0 (float): The total flux of the ring (Jy) + d (float): The ring diameter (radians) + alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + + Returns: + (Model): Updated Model + """ + out = self.copy() + out.models.append('thick_ring') + out.params.append({'F0':F0,'d':d,'alpha':alpha,'x0':x0,'y0':y0,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac}) + return out + + def add_stretched_thick_ring(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, x0 = 0.0, y0 = 0.0, stretch = 1.0, stretch_PA = 0.0, pol_frac = 0.0, pol_evpa = 0.0, cpol_frac = 0.0): + """Add a ring model with finite thickness, determined by circular Gaussian convolution of a thin ring. + For details, see Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf + + Args: + F0 (float): The total flux of the ring (Jy) + d (float): The ring diameter (radians) + alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + stretch (float): The stretch to apply (1.0 = no stretch) + stretch_PA (float): Position angle of the stretch, east of north (radians) + + Returns: + (Model): Updated Model + """ + out = self.copy() + out.models.append('stretched_thick_ring') + out.params.append({'F0':F0,'d':d,'alpha':alpha,'x0':x0,'y0':y0,'stretch':stretch,'stretch_PA':stretch_PA,'pol_frac':pol_frac,'pol_evpa':pol_evpa,'cpol_frac':cpol_frac}) + return out + + def add_mring(self, F0 = 1.0, d = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None): + """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion. + For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329 + + Args: + F0 (float): The total flux of the ring (Jy), which is also beta_0. + d (float): The diameter (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...]. + Negative indices are determined by the condition beta_{-m} = beta_m*. + Indices are all scaled by F0 = beta_0, so they are dimensionless. + Returns: + (Model): Updated Model + """ + if beta_list is None: beta_list = [] + if beta_list_pol is None: beta_list_pol = [] + if beta_list_cpol is None: beta_list_cpol = [] + + out = self.copy() + if beta_list is None: + beta_list = [0.0] + out.models.append('mring') + out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'x0':x0,'y0':y0}) + return out + + def add_stretched_mring(self, F0 = 1.0, d = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None, stretch = 1.0, stretch_PA = 0.0): + """Add a stretched ring model with azimuthal brightness variations determined by a Fourier mode expansion. + For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329 + + Args: + F0 (float): The total flux of the ring (Jy), which is also beta_0. + d (float): The diameter (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...]. + Negative indices are determined by the condition beta_{-m} = beta_m*. + Indices are all scaled by F0 = beta_0, so they are dimensionless. + stretch (float): The stretch to apply (1.0 = no stretch) + stretch_PA (float): Position angle of the stretch, east of north (radians) + Returns: + (Model): Updated Model + """ + if beta_list is None: beta_list = [] + if beta_list_pol is None: beta_list_pol = [] + if beta_list_cpol is None: beta_list_cpol = [] + + out = self.copy() + if beta_list is None: + beta_list = [0.0] + out.models.append('stretched_mring') + out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'x0':x0,'y0':y0,'stretch':stretch,'stretch_PA':stretch_PA}) + return out + + def add_thick_mring(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None): + """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion and thickness determined by circular Gaussian convolution. + For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329 + The Gaussian convolution calculation is a trivial generalization of Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf + + Args: + F0 (float): The total flux of the ring (Jy), which is also beta_0. + d (float): The ring diameter (radians) + alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...]. + Negative indices are determined by the condition beta_{-m} = beta_m*. + Indices are all scaled by F0 = beta_0, so they are dimensionless. + Returns: + (Model): Updated Model + """ + if beta_list is None: beta_list = [] + if beta_list_pol is None: beta_list_pol = [] + if beta_list_cpol is None: beta_list_cpol = [] + + out = self.copy() + if beta_list is None: + beta_list = [0.0] + out.models.append('thick_mring') + out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'alpha':alpha,'x0':x0,'y0':y0}) + return out + + def add_thick_mring_floor(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, ff=0.0, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None): + """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion, thickness determined by circular Gaussian convolution, and a floor + The floor is a blurred disk, with diameter d and blurred FWHM alpha + For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329 + The Gaussian convolution calculation is a trivial generalization of Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf + + Args: + F0 (float): The total flux of the ring (Jy), which is also beta_0. + d (float): The ring diameter (radians) + alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + ff (float): The fraction of the total flux in the floor + beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...]. + Negative indices are determined by the condition beta_{-m} = beta_m*. + Indices are all scaled by F0 = beta_0, so they are dimensionless. + Returns: + (Model): Updated Model + """ + if beta_list is None: beta_list = [] + if beta_list_pol is None: beta_list_pol = [] + if beta_list_cpol is None: beta_list_cpol = [] + + out = self.copy() + if beta_list is None: + beta_list = [0.0] + out.models.append('thick_mring_floor') + out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'alpha':alpha,'x0':x0,'y0':y0,'ff':ff}) + return out + + def add_thick_mring_Gfloor(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, ff=0.0, FWHM = 50.*RADPERUAS, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None): + """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion, thickness determined by circular Gaussian convolution, and a floor + The floor is a circular Gaussian, with size FWHM + For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329 + The Gaussian convolution calculation is a trivial generalization of Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf + + Args: + F0 (float): The total flux of the model + d (float): The ring diameter (radians) + alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians) + FWHM (float): The Gaussian FWHM + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + ff (float): The fraction of the total flux in the floor + beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...]. + Negative indices are determined by the condition beta_{-m} = beta_m*. + Indices are all scaled by F0 = beta_0, so they are dimensionless. + Returns: + (Model): Updated Model + """ + if beta_list is None: beta_list = [] + if beta_list_pol is None: beta_list_pol = [] + if beta_list_cpol is None: beta_list_cpol = [] + + out = self.copy() + if beta_list is None: + beta_list = [0.0] + out.models.append('thick_mring_Gfloor') + out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'alpha':alpha,'x0':x0,'y0':y0,'ff':ff,'FWHM':FWHM}) + return out + + def add_stretched_thick_mring(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None, stretch = 1.0, stretch_PA = 0.0): + """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion and thickness determined by circular Gaussian convolution. + For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329 + The Gaussian convolution calculation is a trivial generalization of Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf + + Args: + F0 (float): The total flux of the ring (Jy), which is also beta_0. + d (float): The ring diameter (radians) + alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...]. + Negative indices are determined by the condition beta_{-m} = beta_m*. + Indices are all scaled by F0 = beta_0, so they are dimensionless. + stretch (float): The stretch to apply (1.0 = no stretch) + stretch_PA (float): Position angle of the stretch, east of north (radians) + Returns: + (Model): Updated Model + """ + if beta_list is None: beta_list = [] + if beta_list_pol is None: beta_list_pol = [] + if beta_list_cpol is None: beta_list_cpol = [] + + out = self.copy() + if beta_list is None: + beta_list = [0.0] + out.models.append('stretched_thick_mring') + out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'alpha':alpha,'x0':x0,'y0':y0,'stretch':stretch,'stretch_PA':stretch_PA}) + return out + + def add_stretched_thick_mring_floor(self, F0 = 1.0, d = 50.*RADPERUAS, alpha = 10.*RADPERUAS, ff=0.0, x0 = 0.0, y0 = 0.0, beta_list = None, beta_list_pol = None, beta_list_cpol = None, stretch = 1.0, stretch_PA = 0.0): + """Add a ring model with azimuthal brightness variations determined by a Fourier mode expansion and thickness determined by circular Gaussian convolution. + For details, see Eq. 18-20 of https://arxiv.org/abs/1907.04329 + The Gaussian convolution calculation is a trivial generalization of Appendix G of https://iopscience.iop.org/article/10.3847/2041-8213/ab0e85/pdf + + Args: + F0 (float): The total flux of the ring (Jy), which is also beta_0. + d (float): The ring diameter (radians) + alpha (float): The ring thickness (FWHM of Gaussian convolution) (radians) + x0 (float): The x-coordinate (radians) + y0 (float): The y-coordinate (radians) + beta_list (list): List of complex Fourier coefficients, [beta_1, beta_2, ...]. + Negative indices are determined by the condition beta_{-m} = beta_m*. + Indices are all scaled by F0 = beta_0, so they are dimensionless. + stretch (float): The stretch to apply (1.0 = no stretch) + stretch_PA (float): Position angle of the stretch, east of north (radians) + Returns: + (Model): Updated Model + """ + if beta_list is None: beta_list = [] + if beta_list_pol is None: beta_list_pol = [] + if beta_list_cpol is None: beta_list_cpol = [] + + out = self.copy() + if beta_list is None: + beta_list = [0.0] + out.models.append('stretched_thick_mring_floor') + out.params.append({'F0':F0,'d':d,'beta_list':np.array(beta_list, dtype=np.complex_),'beta_list_pol':np.array(beta_list_pol, dtype=np.complex_),'beta_list_cpol':np.array(beta_list_cpol, dtype=np.complex_),'alpha':alpha,'x0':x0,'y0':y0,'stretch':stretch,'stretch_PA':stretch_PA,'ff':ff}) + return out + + def sample_xy(self, x, y, psize=1.*RADPERUAS, pol='I'): + """Sample model image on the specified x and y coordinates + + Args: + x (float): x coordinate (dimensionless) + y (float): y coordinate (dimensionless) + + Returns: + (float): Image brightness (Jy/radian^2) + """ + return sample_model_xy(self.models, self.params, x, y, psize=psize, pol=pol) + + def sample_uv(self, u, v, polrep_obs='Stokes', pol='I', jonesdict=None): + """Sample model visibility on the specified u and v coordinates + + Args: + u (float): u coordinate (dimensionless) + v (float): v coordinate (dimensionless) + + Returns: + (complex): complex visibility (Jy) + """ + return sample_model_uv(self.models, self.params, u, v, pol=pol, jonesdict=jonesdict) + + def sample_graduv_uv(self, u, v, pol='I', jonesdict=None): + """Sample model visibility gradient on the specified u and v coordinates wrt (u,v) + + Args: + u (float): u coordinate (dimensionless) + v (float): v coordinate (dimensionless) + + Returns: + (complex): complex visibility (Jy) + """ + return sample_model_graduv_uv(self.models, self.params, u, v, pol=pol, jonesdict=jonesdict) + + def sample_grad_uv(self, u, v, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """Sample model visibility gradient on the specified u and v coordinates wrt all model parameters + + Args: + u (float): u coordinate (dimensionless) + v (float): v coordinate (dimensionless) + + Returns: + (complex): complex visibility (Jy) + """ + return sample_model_grad_uv(self.models, self.params, u, v, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + + def centroid(self, pol=None): + """Compute the location of the image centroid (corresponding to the polarization pol) + Note: This quantity is only well defined for total intensity + + Args: + pol (str): The polarization for which to find the image centroid + + Returns: + (np.array): centroid positions (x0,y0) in radians + """ + + if pol is None: pol=self.pol_prim + if not (pol in list(self._imdict.keys())): + raise Exception("for polrep==%s, pol must be in " % + self.polrep + ",".join(list(self._imdict.keys()))) + + return np.real(self.sample_graduv_uv(0,0)/(2.*np.pi*1j))/self.total_flux() + + def default_prior(self,fit_pol=False,fit_cpol=False): + return [default_prior(self.models[j],self.params[j],fit_pol=fit_pol,fit_cpol=fit_cpol) for j in range(self.N_models())] + + def display(self, fov=FOV_DEFAULT, npix=NPIX_DEFAULT, polrep='stokes', pol_prim=None, pulse=PULSE_DEFAULT, time=0., **kwargs): + return self.make_image(fov, npix, polrep, pol_prim, pulse, time).display(**kwargs) + + def make_image(self, fov, npix, polrep='stokes', pol_prim=None, pulse=PULSE_DEFAULT, time=0.): + """Sample the model onto a square image. + + Args: + fov (float): the field of view of each axis in radians + npix (int): the number of pixels on each axis + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + rf (float): The image frequency in Hz + + source (str): The source name + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + pulse (function): The function convolved with the pixel values for continuous image. + + mjd (int): The integer MJD of the image + time (float): The observing time of the image (UTC hours) + + Returns: + (Image): an image object + """ + + pdim = fov/float(npix) + npix = int(npix) + imarr = np.zeros((npix,npix)) + outim = image.Image(imarr, pdim, self.ra, self.dec, + polrep=polrep, pol_prim=pol_prim, + rf=self.rf, source=self.source, mjd=self.mjd, time=time, pulse=pulse) + + return self.image_same(outim) + + def image_same(self, im): + """Create an image of the model with parameters equal to a reference image. + + Args: + im (Image): the reference image + + Returns: + (Image): image of the model + """ + out = im.copy() + xlist = np.arange(0,-im.xdim,-1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0 + ylist = np.arange(0,-im.ydim,-1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0 + + x_grid, y_grid = np.meshgrid(xlist, ylist) + imarr = self.sample_xy(x_grid, y_grid, im.psize) + out.imvec = imarr.flatten() # Change this to init with image_args + + # Add the remaining polarizations + for pol in ['Q','U','V']: + out.add_pol_image(self.sample_xy(x_grid, y_grid, im.psize, pol=pol), pol) + + return out + + def observe_same_nonoise(self, obs, **kwargs): + """Observe the model on the same baselines as an existing observation, without noise. + + Args: + obs (Obsdata): the existing observation + + Returns: + (Obsdata): an observation object with no noise + """ + + # Copy data to be safe + obsdata = copy.deepcopy(obs.data) + + # Load optional parameters + jonesdict = kwargs.get('jonesdict',None) + + # Compute visibilities and put them into the obsdata + if obs.polrep=='stokes': + obsdata['vis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='I', jonesdict=jonesdict) + obsdata['qvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='Q', jonesdict=jonesdict) + obsdata['uvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='U', jonesdict=jonesdict) + obsdata['vvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='V', jonesdict=jonesdict) + elif obs.polrep=='circ': + obsdata['rrvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='RR', jonesdict=jonesdict) + obsdata['rlvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='RL', jonesdict=jonesdict) + obsdata['lrvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='LR', jonesdict=jonesdict) + obsdata['llvis'] = self.sample_uv(obs.data['u'], obs.data['v'], pol='LL', jonesdict=jonesdict) + + obs_no_noise = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=obs.source, mjd=obs.mjd, polrep=obs.polrep, + ampcal=True, phasecal=True, opacitycal=True, + dcal=True, frcal=True, + timetype=obs.timetype, scantable=obs.scans) + + return obs_no_noise + + def observe_same(self, obs_in, add_th_noise=True, sgrscat=False, ttype=False, # Note: sgrscat and ttype are kept for consistency with comp_plots + opacitycal=True, ampcal=True, phasecal=True, + dcal=True, frcal=True, rlgaincal=True, + stabilize_scan_phase=False, stabilize_scan_amp=False, neggains=False, + jones=False, inv_jones=False, + tau=TAUDEF, taup=GAINPDEF, + gain_offset=GAINPDEF, gainp=GAINPDEF, + dterm_offset=DTERMPDEF, + rlratio_std=0.,rlphase_std=0., + caltable_path=None, seed=False, **kwargs): + + """Observe the image on the same baselines as an existing observation object and add noise. + + Args: + obs_in (Obsdata): the existing observation + + add_th_noise (bool): if True, baseline-dependent thermal noise is added + opacitycal (bool): if False, time-dependent gaussian errors are added to opacities + ampcal (bool): if False, time-dependent gaussian errors are added to station gains + phasecal (bool): if False, time-dependent station-based random phases are added + frcal (bool): if False, feed rotation angle terms are added to Jones matrices. + dcal (bool): if False, time-dependent gaussian errors added to D-terms. + + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + neggains (bool): if True, force the applied gains to be <1 + meaning that you have overestimated your telescope's performance + + + jones (bool): if True, uses Jones matrix to apply mis-calibration effects + inv_jones (bool): if True, applies estimated inverse Jones matrix + (not including random terms) to a priori calibrate data + + tau (float): the base opacity at all sites, + or a dict giving one opacity per site + taup (float): the fractional std. dev. of the random error on the opacities + gainp (float): the fractional std. dev. of the random error on the gains + or a dict giving one std. dev. per site + + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + dterm_offset (float): the base std. dev. of random additive error at all sites, + or a dict giving one std. dev. per site + + rlratio_std (float): the fractional std. dev. of the R/L gain offset + or a dict giving one std. dev. per site + rlphase_std (float): std. dev. of R/L phase offset, + or a dict giving one std. dev. per site + a negative value samples from uniform + + caltable_path (string): The path and prefix of a saved caltable + + seed (int): seeds the random component of the noise terms. DO NOT set to 0! + + Returns: + (Obsdata): an observation object + """ + + if seed!=False: + np.random.seed(seed=seed) + + obs = self.observe_same_nonoise(obs_in, **kwargs) + + # Jones Matrix Corruption & Calibration + if jones: + obsdata = simobs.add_jones_and_noise(obs, add_th_noise=add_th_noise, + opacitycal=opacitycal, ampcal=ampcal, + phasecal=phasecal, dcal=dcal, frcal=frcal, + rlgaincal=rlgaincal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + neggains=neggains, + gainp=gainp, taup=taup, gain_offset=gain_offset, + dterm_offset=dterm_offset, + rlratio_std=rlratio_std,rlphase_std=rlphase_std, + caltable_path=caltable_path, seed=seed) + + obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, + ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal, + dcal=dcal, frcal=frcal, + timetype=obs.timetype, scantable=obs.scans) + + if inv_jones: + obsdata = simobs.apply_jones_inverse(obs, + opacitycal=opacitycal, dcal=dcal, frcal=frcal) + + obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, + ampcal=ampcal, phasecal=phasecal, + opacitycal=True, dcal=True, frcal=True, + timetype=obs.timetype, scantable=obs.scans) + #these are always set to True after inverse jones call + + + # No Jones Matrices, Add noise the old way + # NOTE There is an asymmetry here - in the old way, we don't offer the ability to *not* + # unscale estimated noise. + else: + + if caltable_path: + print('WARNING: the caltable is only saved if you apply noise with a Jones Matrix') + + obsdata = simobs.add_noise(obs, add_th_noise=add_th_noise, + ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + gainp=gainp, taup=taup, gain_offset=gain_offset, seed=seed) + + obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, + ampcal=ampcal, phasecal=phasecal, + opacitycal=True, dcal=True, frcal=True, + timetype=obs.timetype, scantable=obs.scans) + #these are always set to True after inverse jones call + + return obs + + def observe(self, array, tint, tadv, tstart, tstop, bw, + mjd=None, timetype='UTC', polrep_obs=None, + elevmin=ELEV_LOW, elevmax=ELEV_HIGH, + no_elevcut_space=False, + fix_theta_GMST=False, add_th_noise=True, + opacitycal=True, ampcal=True, phasecal=True, + dcal=True, frcal=True, rlgaincal=True, + stabilize_scan_phase=False, stabilize_scan_amp=False, + jones=False, inv_jones=False, + tau=TAUDEF, taup=GAINPDEF, + gainp=GAINPDEF, gain_offset=GAINPDEF, + dterm_offset=DTERMPDEF, rlratio_std=0.,rlphase_std=0., + seed=False, **kwargs): + + """Generate baselines from an array object and observe the image. + + Args: + array (Array): an array object containing sites with which to generate baselines + tint (float): the scan integration time in seconds + tadv (float): the uniform cadence between scans in seconds + tstart (float): the start time of the observation in hours + tstop (float): the end time of the observation in hours + bw (float): the observing bandwidth in Hz + + mjd (int): the mjd of the observation, if set as different from the image mjd + timetype (str): how to interpret tstart and tstop; either 'GMST' or 'UTC' + elevmin (float): station minimum elevation in degrees + elevmax (float): station maximum elevation in degrees + no_elevcut_space (bool): if True, do not apply elevation cut to orbiters + + polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation + + fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel + add_th_noise (bool): if True, baseline-dependent thermal noise is added + opacitycal (bool): if False, time-dependent gaussian errors are added to opacities + ampcal (bool): if False, time-dependent gaussian errors are added to station gains + phasecal (bool): if False, time-dependent station-based random phases are added + frcal (bool): if False, feed rotation angle terms are added to Jones matrices. + dcal (bool): if False, time-dependent gaussian errors added to Jones matrices D-terms. + + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + jones (bool): if True, uses Jones matrix to apply mis-calibration effects + otherwise uses old formalism without D-terms + inv_jones (bool): if True, applies estimated inverse Jones matrix + (not including random terms) to calibrate data + + tau (float): the base opacity at all sites, + or a dict giving one opacity per site + taup (float): the fractional std. dev. of the random error on the opacities + gainp (float): the fractional std. dev. of the random error on the gains + or a dict giving one std. dev. per site + + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + dterm_offset (float): the base std. dev. of random additive error at all sites, + or a dict giving one std. dev. per site + + rlratio_std (float): the fractional std. dev. of the R/L gain offset + or a dict giving one std. dev. per site + rlphase_std (float): std. dev. of R/L phase offset, + or a dict giving one std. dev. per site + a negative value samples from uniform + + seed (int): seeds the random component of noise added. DO NOT set to 0! + + Returns: + (Obsdata): an observation object + """ + + # Generate empty observation + print("Generating empty observation file . . . ") + + if mjd == None: + mjd = self.mjd + if polrep_obs is None: + polrep_obs=self.polrep + + obs = array.obsdata(self.ra, self.dec, self.rf, bw, tint, tadv, tstart, tstop, mjd=mjd, + polrep=polrep_obs, tau=tau, timetype=timetype, + elevmin=elevmin, elevmax=elevmax, + no_elevcut_space=no_elevcut_space, fix_theta_GMST=fix_theta_GMST) + + # Observe on the same baselines as the empty observation and add noise + obs = self.observe_same(obs, add_th_noise=add_th_noise, + opacitycal=opacitycal,ampcal=ampcal, + phasecal=phasecal,dcal=dcal, + frcal=frcal, rlgaincal=rlgaincal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + gainp=gainp,gain_offset=gain_offset, + tau=tau, taup=taup, + dterm_offset=dterm_offset, + rlratio_std=rlratio_std,rlphase_std=rlphase_std, + jones=jones, inv_jones=inv_jones, seed=seed, **kwargs) + + obs.mjd = mjd + + return obs + + def save_txt(self,filename): + # Header + import ehtim.observing.obs_helpers as obshelp + mjd = float(self.mjd) + time = self.time + mjd += (time/24.) + + head = ("SRC: %s \n" % self.source + + "RA: " + obshelp.rastring(self.ra) + "\n" + + "DEC: " + obshelp.decstring(self.dec) + "\n" + + "MJD: %.6f \n" % (float(mjd)) + + "RF: %.4f GHz" % (self.rf/1e9)) + # Models + out = [] + for j in range(self.N_models()): + out.append(self.models[j]) + out.append(str(self.params[j]).replace('\n','').replace('complex128','np.complex128').replace('array','np.array')) + np.savetxt(filename, out, header=head, fmt="%s") + + def load_txt(self,filename): + lines = open(filename).read().splitlines() + + src = ' '.join(lines[0].split()[2:]) + ra = lines[1].split() + self.ra = float(ra[2]) + float(ra[4])/60.0 + float(ra[6])/3600.0 + dec = lines[2].split() + self.dec = np.sign(float(dec[2])) * (abs(float(dec[2])) + float(dec[4])/60.0 + float(dec[6])/3600.0) + mjd_float = float(lines[3].split()[2]) + self.mjd = int(mjd_float) + self.time = (mjd_float - self.mjd) * 24 + self.rf = float(lines[4].split()[2]) * 1e9 + + self.models = lines[5::2] + self.params = [eval(x) for x in lines[6::2]] + +def load_txt(filename): + out = Model() + out.load_txt(filename) + return out diff --git a/modeling/__init__.py b/modeling/__init__.py new file mode 100644 index 00000000..2d00f4e5 --- /dev/null +++ b/modeling/__init__.py @@ -0,0 +1,9 @@ +""" +.. module:: ehtim.modeling + :platform: Unix + :synopsis: EHT Modeling Utilities: modeling functions + +.. moduleauthor:: Michael Johnson (mjohnson@cfa.harvard.edu) + +""" +from ..const_def import * diff --git a/modeling/modeling_utils.py b/modeling/modeling_utils.py new file mode 100644 index 00000000..d58c637c --- /dev/null +++ b/modeling/modeling_utils.py @@ -0,0 +1,2700 @@ +# modeling_utils.py +# General modeling functions for total intensity VLBI data +# +# Copyright (C) 2020 Michael Johnson +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +## TODO ## +# >> return jonesdict for all data types <- requires significant modification to eht-imaging +# >> Deal with nans in fitting (mask chisqdata) <- mostly done +# >> Add optional transform for leakage and gains + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import +from builtins import range + +import string +import time +import numpy as np +import scipy.optimize as opt +import scipy.ndimage as nd +import scipy.ndimage.filters as filt +import matplotlib.pyplot as plt +import scipy.special as sps +import scipy.stats as stats +import copy + +import ehtim.obsdata as obsdata +import ehtim.image as image +import ehtim.model as model +import ehtim.caltable as caltable + +from ehtim.const_def import * +from ehtim.observing.obs_helpers import * +from ehtim.statistics.dataframes import * + +#from IPython import display + +################################################################################################## +# Constants & Definitions +################################################################################################## + +MAXLS = 100 # maximum number of line searches in L-BFGS-B +NHIST = 100 # number of steps to store for hessian approx +MAXIT = 100 # maximum number of iterations +STOP = 1.e-8 # convergence criterion + +BOUNDS_MIN = -1e4 +BOUNDS_MAX = 1e4 +BOUNDS_GAUSS_NSIGMA = 10. +BOUNDS_EXP_NSIGMA = 10. + +PRIOR_MIN = 1e-200 # to avoid problems with log-prior + +DATATERMS = ['vis', 'bs', 'amp', 'cphase', 'cphase_diag', 'camp', 'logcamp', 'logcamp_diag', 'logamp', 'pvis', 'm', 'rlrr', 'rlll', 'lrrr', 'lrll','rrll','llrr','polclosure'] + +nit = 0 # global variable to track the iteration number in the plotting callback +globdict = {} # global dictionary with all parameters related to the model fitting (mainly for efficient parallelization, but also very useful for debugging) + +# Details on each fitted parameter (convenience rescaling factor and associated unit) +PARAM_DETAILS = {'F0':[1.,'Jy'], 'FWHM':[RADPERUAS,'uas'], 'FWHM_maj':[RADPERUAS,'uas'], 'FWHM_min':[RADPERUAS,'uas'], + 'd':[RADPERUAS,'uas'], 'PA':[np.pi/180.,'deg'], 'alpha':[RADPERUAS,'uas'], 'ff':[1.,''], + 'x0':[RADPERUAS,'uas'], 'y0':[RADPERUAS,'uas'], 'stretch':[1.,''], 'stretch_PA':[np.pi/180.,'deg'], + 'arg':[np.pi/180.,'deg'], 'evpa':[np.pi/180.,'deg'], 'phi':[np.pi/180.,'deg']} + +GAIN_PRIOR_DEFAULT = {'prior_type':'lognormal','sigma':0.1,'mu':0.0,'shift':-1.0} +LEAKAGE_PRIOR_DEFAULT = {'prior_type':'flat','min':-0.5,'max':0.5} +N_POSTERIOR_SAMPLES = 100 + +################################################################################################## +# Priors +################################################################################################## + +def cdf(x, prior_params): + """Compute the cumulative distribution function CDF(x) of a given prior at a given point x + + Args: + x (float): Value at which to compute the CDF + prior_params (dict): Dictionary with information about the prior + + Returns: + float: CDF(x) + """ + if prior_params['prior_type'] == 'flat': + return ( (x > prior_params['max']) * 1.0 + + (x > prior_params['min']) * (x < prior_params['max']) * (x - prior_params['min'])/(prior_params['max'] - prior_params['min'])) + elif prior_params['prior_type'] == 'gauss': + return 0.5 * (1.0 + sps.erf( (x - prior_params['mean'])/(prior_params['std'] * np.sqrt(2.0)) )) + elif prior_params['prior_type'] == 'exponential': + return (1.0 - np.exp(-x/prior_params['std'])) * (x >= 0.0) + elif prior_params['prior_type'] == 'lognormal': + return (x > prior_params['shift']) * (0.5 * sps.erfc( (prior_params['mu'] - np.log(x - prior_params['shift']))/(np.sqrt(2.0) * prior_params['sigma']))) + elif prior_params['prior_type'] == 'positive': + raise Exception('CDF is not defined for prior type "positive"') + elif prior_params['prior_type'] == 'none': + raise Exception('CDF is not defined for prior type "none"') + elif prior_params['prior_type'] == 'fixed': + raise Exception('CDF is not defined for prior type "fixed"') + else: + raise Exception('Prior type ' + prior_params['prior_type'] + ' not recognized!') + +def cdf_inverse(x, prior_params): + """Compute the inverse cumulative distribution function of a given prior at a given point 0 <= x <= 1 + + Args: + x (float): Value at which to compute the inverse CDF + prior_params (dict): Dictionary with information about the prior + + Returns: + float: Inverse CDF at x + """ + if prior_params['prior_type'] == 'flat': + return prior_params['min'] * (1.0 - x) + prior_params['max'] * x + elif prior_params['prior_type'] == 'gauss': + return prior_params['mean'] - np.sqrt(2.0) * prior_params['std'] * sps.erfcinv(2.0 * x) + elif prior_params['prior_type'] == 'exponential': + return prior_params['std'] * np.log(1.0/(1.0 - x)) + elif prior_params['prior_type'] == 'lognormal': + return np.exp( prior_params['mu'] - np.sqrt(2.0) * prior_params['sigma'] * sps.erfcinv(2.0 * x)) + prior_params['shift'] + elif prior_params['prior_type'] == 'positive': + raise Exception('CDF is not defined for prior type "positive"') + elif prior_params['prior_type'] == 'none': + raise Exception('CDF is not defined for prior type "none"') + elif prior_params['prior_type'] == 'fixed': + raise Exception('CDF is not defined for prior type "fixed"') + else: + raise Exception('Prior type ' + prior_params['prior_type'] + ' not recognized!') + +def param_bounds(prior_params): + """Compute the parameter boundaries associated with a given prior + + Args: + prior_params (dict): Dictionary with information about the prior + + Returns: + list: 2-element list specifying the allowed parameter range: [min,max] + """ + if prior_params.get('transform','') == 'cdf': + bounds = [0.0, 1.0] + elif prior_params['prior_type'] == 'flat': + bounds = [prior_params['min'],prior_params['max']] + elif prior_params['prior_type'] == 'gauss': + bounds = [prior_params['mean'] - prior_params['std'] * BOUNDS_GAUSS_NSIGMA, prior_params['mean'] + prior_params['std'] * BOUNDS_GAUSS_NSIGMA] + elif prior_params['prior_type'] == 'exponential': + bounds = [PRIOR_MIN, BOUNDS_EXP_NSIGMA * prior_params['std']] + elif prior_params['prior_type'] == 'lognormal': + bounds = [prior_params['shift'], prior_params['shift'] + np.exp(prior_params['mu'] + BOUNDS_GAUSS_NSIGMA * prior_params['sigma'])] + elif prior_params['prior_type'] == 'positive': + bounds = [PRIOR_MIN, BOUNDS_MAX] + elif prior_params['prior_type'] == 'none': + bounds = [BOUNDS_MIN,BOUNDS_MAX] + elif prior_params['prior_type'] == 'fixed': + bounds = [1.0, 1.0] + else: + print('Prior type not recognized!') + bounds = [BOUNDS_MIN,BOUNDS_MAX] + + return bounds + +def prior_func(x, prior_params): + """Compute the value of a 1-D prior P(x) at a specified value x. + + Args: + x (float): Value at which to compute the prior + prior_params (dict): Dictionary with information about the prior + + Returns: + float: Prior value P(x) + """ + + if prior_params['prior_type'] == 'flat': + return (x >= prior_params['min']) * (x <= prior_params['max']) * 1.0/(prior_params['max'] - prior_params['min']) + PRIOR_MIN + elif prior_params['prior_type'] == 'gauss': + return 1./((2.*np.pi)**0.5 * prior_params['std']) * np.exp(-(x - prior_params['mean'])**2/(2.*prior_params['std']**2)) + elif prior_params['prior_type'] == 'exponential': + return (1./prior_params['std'] * np.exp(-x/prior_params['std'])) * (x >= 0.0) + PRIOR_MIN + elif prior_params['prior_type'] == 'lognormal': + return (x > prior_params['shift']) * ( + 1.0/((2.0*np.pi)**0.5 * prior_params['sigma'] * (x - prior_params['shift'])) + * np.exp( -(np.log(x - prior_params['shift']) - prior_params['mu'])**2/(2.0 * prior_params['sigma']**2) ) ) + elif prior_params['prior_type'] == 'positive': + return (x >= 0.0) * 1.0 + PRIOR_MIN + elif prior_params['prior_type'] == 'none': + return 1.0 + elif prior_params['prior_type'] == 'fixed': + return 1.0 + else: + print('Prior not recognized!') + return 1.0 + +def prior_grad_func(x, prior_params): + """Compute the value of the derivative of a 1-D prior, dP/dx at a specified value x. + + Args: + x (float): Value at which to compute the prior derivative + prior_params (dict): Dictionary with information about the prior + + Returns: + float: Prior derivative value dP/dx(x) + """ + + if prior_params['prior_type'] == 'flat': + return 0.0 + elif prior_params['prior_type'] == 'gauss': + return -(x - prior_params['mean'])/((2.*np.pi)**0.5 * prior_params['std']**3) * np.exp(-(x - prior_params['mean'])**2/(2.*prior_params['std']**2)) + elif prior_params['prior_type'] == 'exponential': + return (-1./prior_params['std']**2 * np.exp(-x/prior_params['std'])) * (x >= 0.0) + elif prior_params['prior_type'] == 'lognormal': + return (x > prior_params['shift']) * ( + (prior_params['mu'] - prior_params['sigma']**2 - np.log(x - prior_params['shift'])) + / ((2.0*np.pi)**0.5 * prior_params['sigma']**3 * (x - prior_params['shift'])**2) + * np.exp( -(np.log(x - prior_params['shift']) - prior_params['mu'])**2/(2.0 * prior_params['sigma']**2) ) ) + elif prior_params['prior_type'] == 'positive': + return 0.0 + elif prior_params['prior_type'] == 'none': + return 0.0 + elif prior_params['prior_type'] == 'fixed': + return 0.0 + else: + print('Prior not recognized!') + return 0.0 + +def transform_param(x, x_prior, inverse=True): + """Compute a specified coordinate transformation T(x) of a parameter value x + + Args: + x (float): Untransformed value + x_prior (dict): Dictionary with information about the transformation + inverse (bool): Whether to compute the forward or inverse transform. + + Returns: + float: Transformed parameter value + """ + + try: + transform = x_prior['transform'] + except: + transform = 'none' + pass + + if transform == 'log': + if inverse: + return np.exp(x) + else: + return np.log(x) + elif transform == 'cdf': + if inverse: + return cdf_inverse(x, x_prior) + else: + return cdf(x, x_prior) + else: + return x + +def transform_grad_param(x, x_prior): + """Compute the gradient of a specified coordinate transformation T(x) of a parameter value x + + Args: + x (float): Untransformed value + x_prior (dict): Dictionary with information about the transformation + + Returns: + float: Gradient of transformation, dT/dx(x) + """ + + try: + transform = x_prior['transform'] + except: + transform = 'none' + pass + + if transform == 'log': + return np.exp(x) + elif transform == 'cdf': + return 1.0/prior_func(transform_param(x,x_prior),x_prior) + else: + return 1.0 + +################################################################################################## +# Helper functions +################################################################################################## +def shrink_prior(prior, model, shrink=0.1): + """Shrink a specified prior volume by centering on a specified fitted model + + Args: + prior (list): Model prior (list of dictionaries, one per model component) + model (Model): Model to draw central values from + shrink (float): Factor to shrink each prior width by + + Returns: + prior (list): Model prior with restricted volume + """ + + prior_shrunk = copy.deepcopy(prior) + f = 1.0 + + #TODO: this doesn't work for beta lists yet! + + for j in range(len(prior_shrunk)): + for key in prior_shrunk[j].keys(): + if prior_shrunk[j][key]['prior_type'] == 'flat': + x = model.params[j][key] + w = prior_shrunk[j][key]['max'] - prior_shrunk[j][key]['min'] + prior_shrunk[j][key]['min'] = x - w/2 + prior_shrunk[j][key]['max'] = x + w/2 + if prior_shrunk[j][key]['min'] < prior[j][key]['min']: prior_shrunk[j][key]['min'] = prior[j][key]['min'] + if prior_shrunk[j][key]['max'] > prior[j][key]['max']: prior_shrunk[j][key]['max'] = prior[j][key]['max'] + f *= (prior_shrunk[j][key]['max'] - prior_shrunk[j][key]['min'])/w + else: + pass + + print('(New Prior Volume)/(Original Prior Volume:',f) + + return prior_shrunk + +def selfcal(Obsdata, model, + gain_init=None, gain_prior=None, + minimizer_func='scipy.optimize.minimize', minimizer_kwargs=None, + bounds=None, use_bounds=True, + processes=-1, msgtype='bar', quiet=True, **kwargs): + """Self-calibrate a specified observation to a given model, accounting for gain priors + + Args: + + Returns: + """ + + # This is just a convenience function. It will call modeler_func() scan-by-scan fitting only gains. + # This function differs from ehtim.calibrating.self_cal in the inclusion of gain priors + tlist = Obsdata.tlist() + res_list = [] + for j in range(len(tlist)): + if msgtype not in ['none','']: + prog_msg(j, len(tlist), msgtype, j-1) + obs = Obsdata.copy() + obs.data = tlist[j] + res_list.append(modeler_func(obs, model, model_prior=None, d1='amp', + fit_model=False, fit_gains=True,gain_init=gain_init,gain_prior=gain_prior, + minimizer_func=minimizer_func, minimizer_kwargs=minimizer_kwargs, + bounds=bounds, use_bounds=use_bounds, processes=-1, quiet=quiet, **kwargs)) + + # Assemble a single caltable to return + allsites = Obsdata.tarr['site'] + caldict = res_list[0]['caltable'].data + for j in range(1,len(tlist)): + row = res_list[j]['caltable'].data + for site in allsites: + try: dat = row[site] + except KeyError: continue + + try: caldict[site] = np.append(caldict[site], row[site]) + except KeyError: caldict[site] = dat + + ct = caltable.Caltable(obs.ra, obs.dec, obs.rf, obs.bw, caldict, obs.tarr, + source=obs.source, mjd=obs.mjd, timetype=obs.timetype) + + return ct + +def make_param_map(model_init, model_prior, minimizer_func, fit_model, fit_pol=False, fit_cpol=False): + # Define the mapping between solved parameters and the model + # Each fitted model parameter can be rescaled to give values closer to order unity + + param_map = [] # Define mapping for every fitted parameter: model component #, parameter name, rescale multiplier internal, unit, rescale multiplier external + param_mask = [] # True or False for whether to fit each model parameter (because the gradient is computed for all model parameters) + for j in range(model_init.N_models()): + params = model.model_params(model_init.models[j],model_init.params[j], fit_pol=fit_pol, fit_cpol=fit_cpol) + for param in params: + if fit_model == False: + param_mask.append(False) + elif model_prior[j][param]['prior_type'] != 'fixed': + param_mask.append(True) + param_type = param + if len(param_type.split('_')) == 2 and param_type not in PARAM_DETAILS: + param_type = param_type.split('_')[1] + try: + if model_prior[j][param].get('transform','') == 'cdf' or minimizer_func in ['dynesty_static','dynesty_dynamic','pymc3']: + param_map.append([j,param,1,PARAM_DETAILS[param_type][1],PARAM_DETAILS[param_type][0]]) + else: + param_map.append([j,param,PARAM_DETAILS[param_type][0],PARAM_DETAILS[param_type][1],PARAM_DETAILS[param_type][0]]) + except: + param_map.append([j,param,1,'',1]) + pass + else: + param_mask.append(False) + return (param_map, param_mask) + +def compute_likelihood_constants(d1, d2, d3, d4, sigma1, sigma2, sigma3, sigma4): + # Compute the correct data weights (hyperparameters) and the correct extra constant for the log-likelihood + alpha_d1 = alpha_d2 = alpha_d3 = alpha_d4 = ln_norm1 = ln_norm2 = ln_norm3 = ln_norm4 = 0.0 + + try: + alpha_d1 = 0.5 * len(sigma1) + ln_norm1 = -np.sum(np.log((2.0*np.pi)**0.5 * sigma1)) + except: pass + try: + alpha_d2 = 0.5 * len(sigma2) + ln_norm2 = -np.sum(np.log((2.0*np.pi)**0.5 * sigma2)) + except: pass + try: + alpha_d3 = 0.5 * len(sigma3) + ln_norm3 = -np.sum(np.log((2.0*np.pi)**0.5 * sigma3)) + except: pass + try: + alpha_d4 = 0.5 * len(sigma4) + ln_norm4 = -np.sum(np.log((2.0*np.pi)**0.5 * sigma4)) + except: pass + + # If using closure phase, the sigma is given in degrees, not radians! + # Use the correct von Mises normalization if using closure phase + if d1 in ['cphase','cphase_diag']: + ln_norm1 = -np.sum(np.log(2.0*np.pi*sps.ive(0, 1.0/(sigma1 * DEGREE)**2))) + if d2 in ['cphase','cphase_diag']: + ln_norm2 = -np.sum(np.log(2.0*np.pi*sps.ive(0, 1.0/(sigma2 * DEGREE)**2))) + if d3 in ['cphase','cphase_diag']: + ln_norm3 = -np.sum(np.log(2.0*np.pi*sps.ive(0, 1.0/(sigma3 * DEGREE)**2))) + if d4 in ['cphase','cphase_diag']: + ln_norm4 = -np.sum(np.log(2.0*np.pi*sps.ive(0, 1.0/(sigma4 * DEGREE)**2))) + + if d1 in ['vis','bs','m','pvis','rrll','llrr','lrll','rlll','lrrr','rlrr','polclosure']: + alpha_d1 *= 2 + ln_norm1 *= 2 + if d2 in ['vis','bs','m','pvis','rrll','llrr','lrll','rlll','lrrr','rlrr','polclosure']: + alpha_d2 *= 2 + ln_norm2 *= 2 + if d3 in ['vis','bs','m','pvis','rrll','llrr','lrll','rlll','lrrr','rlrr','polclosure']: + alpha_d3 *= 2 + ln_norm3 *= 2 + if d4 in ['vis','bs','m','pvis','rrll','llrr','lrll','rlll','lrrr','rlrr','polclosure']: + alpha_d4 *= 2 + ln_norm4 *= 2 + ln_norm = ln_norm1 + ln_norm2 + ln_norm3 + ln_norm4 + + return (alpha_d1, alpha_d2, alpha_d3, alpha_d4, ln_norm) + +def default_gain_prior(sites): + print('No gain prior specified. Defaulting to ' + str(GAIN_PRIOR_DEFAULT) + ' for all sites.') + gain_prior = {} + for site in sites: + gain_prior[site] = GAIN_PRIOR_DEFAULT + return gain_prior + +def caltable_to_gains(caltab, gain_list): + # Generate an ordered list of gains from a caltable + # gain_list is a set of tuples (time, site) + gains = [np.abs(caltab.data[site]['rscale'][caltab.data[site]['time'] == time][0]) - 1.0 for (time, site) in gain_list] + return gains + +def make_gain_map(Obsdata, gain_prior): + # gain_list gives all unique (time,site) pairs + # gains_t1 gives the gain index for the first site in each measurement + # gains_t2 gives the gain index for the second site in each measurement + gain_list = [] + for j in range(len(Obsdata.data)): + if ([Obsdata.data[j]['time'],Obsdata.data[j]['t1']] not in gain_list) and (gain_prior[Obsdata.data[j]['t1']]['prior_type'] != 'fixed'): + gain_list.append([Obsdata.data[j]['time'],Obsdata.data[j]['t1']]) + if ([Obsdata.data[j]['time'],Obsdata.data[j]['t2']] not in gain_list) and (gain_prior[Obsdata.data[j]['t2']]['prior_type'] != 'fixed'): + gain_list.append([Obsdata.data[j]['time'],Obsdata.data[j]['t2']]) + + # Now determine the appropriate mapping; use the final index for all ignored gains, which default to 1 + def gain_index(j, tnum): + try: + return gain_list.index([Obsdata.data[j]['time'],Obsdata.data[j][tnum]]) + except: + return len(gain_list) + + gains_t1 = [gain_index(j, 't1') for j in range(len(Obsdata.data))] + gains_t2 = [gain_index(j, 't2') for j in range(len(Obsdata.data))] + + return (gain_list, gains_t1, gains_t2) + +def make_bounds(model_prior, param_map, gain_prior, gain_list, n_gains, leakage_fit, leakage_prior): + bounds = [] + for j in range(len(param_map)): + pm = param_map[j] + pb = param_bounds(model_prior[pm[0]][pm[1]]) + if (model_prior[pm[0]][pm[1]]['prior_type'] not in ['positive','none','fixed']) and (model_prior[pm[0]][pm[1]].get('transform','') != 'cdf'): + pb[0] = transform_param(pb[0]/pm[2], model_prior[pm[0]][pm[1]], inverse=False) + pb[1] = transform_param(pb[1]/pm[2], model_prior[pm[0]][pm[1]], inverse=False) + bounds.append(pb) + for j in range(n_gains): + pb = param_bounds(gain_prior[gain_list[j][1]]) + if (gain_prior[gain_list[j][1]]['prior_type'] not in ['positive','none','fixed']) and (gain_prior[gain_list[j][1]].get('transform','') != 'cdf'): + pb[0] = transform_param(pb[0], gain_prior[gain_list[j][1]], inverse=False) + pb[1] = transform_param(pb[1], gain_prior[gain_list[j][1]], inverse=False) + bounds.append(pb) + for j in range(len(leakage_fit)): + for cpart in ['re','im']: + prior = leakage_prior[leakage_fit[j][0]][leakage_fit[j][1]][cpart] + pb = param_bounds(prior) + if (prior['prior_type'] not in ['positive','none','fixed']) and (prior.get('transform','') != 'cdf'): + pb[0] = transform_param(pb[0], prior, inverse=False) + pb[1] = transform_param(pb[1], prior, inverse=False) + bounds.append(pb) + + return np.array(bounds) + +# Determine multiplicative factor for the gains (amplitude only) +def gain_factor(dtype,gains,gains_t1,gains_t2, fit_or_marginalize_gains): + global globdict + + if not fit_or_marginalize_gains: + if globdict['gain_init'] == None: + return 1 + else: + gains = globdict['gain_init'] + + if globdict['marginalize_gains']: + gains = globdict['gain_init'] + + if dtype in ['amp','vis']: + gains_wzero = np.append(gains,0.0) + return (1.0 + gains_wzero[gains_t1])*(1.0 + gains_wzero[gains_t2]) + else: + return 1 + +def gain_factor_separate(dtype,gains,gains_t1,gains_t2, fit_or_marginalize_gains): + # Determine the pair of multiplicative factors for the gains (amplitude only) + # Note: these are not displaced by unity! + global globdict + + if not fit_or_marginalize_gains: + if globdict['gain_init'] == None: + return (0., 0.) + else: + gains = globdict['gain_init'] + + if globdict['marginalize_gains']: + gains = globdict['gain_init'] + + if dtype in ['amp','vis']: + gains_wzero = np.append(gains,0.0) + return (gains_wzero[gains_t1], gains_wzero[gains_t2]) + else: + return (0, 0) + +def prior_leakage(leakage, leakage_fit, leakage_prior, fit_leakage): + # Compute the log-prior contribution from the fitted leakage terms + if fit_leakage: + cparts = ['re','im'] + return np.sum([np.log(prior_func(leakage[j], leakage_prior[leakage_fit[j//2][0]][leakage_fit[j//2][1]][cparts[j%2]])) for j in range(len(leakage))]) + else: + return 0.0 + +def prior_leakage_grad(leakage, leakage_fit, leakage_prior, fit_leakage): + # Compute the log-prior contribution to the gradient from the leakages + if fit_leakage: + cparts = ['re','im'] + f = np.array([prior_func(leakage[j], leakage_prior[leakage_fit[j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage))]) + df = np.array([prior_grad_func(leakage[j], leakage_prior[leakage_fit[j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage))]) + return df/f + else: + return [] + +def prior_gain(gains, gain_list, gain_prior, fit_gains): + # Compute the log-prior contribution from the gains + if fit_gains: + return np.sum([np.log(prior_func(gains[j], gain_prior[gain_list[j][1]])) for j in range(len(gains))]) + else: + return 0.0 + +def prior_gain_grad(gains, gain_list, gain_prior, fit_gains): + # Compute the log-prior contribution to the gradient from the gains + if fit_gains: + f = np.array([prior_func(gains[j], gain_prior[gain_list[j][1]]) for j in range(len(gains))]) + df = np.array([prior_grad_func(gains[j], gain_prior[gain_list[j][1]]) for j in range(len(gains))]) + return df/f + else: + return [] + +def transform_params(params, param_map, minimizer_func, model_prior, inverse=True): + if minimizer_func not in ['dynesty_static','dynesty_dynamic','pymc3']: + return [transform_param(params[j], model_prior[param_map[j][0]][param_map[j][1]], inverse=inverse) for j in range(len(params))] + else: + # For dynesty or pymc3, over-ride all specified parameter transformations to assume CDF mapping to the hypercube + # However, the passed parameters to the objective function and gradient are *not* transformed (i.e., they are not in the hypercube), thus the transformation does not need to be inverted + return params + +def set_params(params, trial_model, param_map, minimizer_func, model_prior): + tparams = transform_params(params, param_map, minimizer_func, model_prior) + + for j in range(len(params)): + if param_map[j][1] in trial_model.params[param_map[j][0]].keys(): + trial_model.params[param_map[j][0]][param_map[j][1]] = tparams[j] * param_map[j][2] + else: # In this case, the parameter is a list of complex numbers, so the real/imaginary or abs/arg components need to be assigned + if param_map[j][1].find('cpol') != -1: + param_type = 'beta_list_cpol' + idx = int(param_map[j][1].split('_')[0][8:]) + elif param_map[j][1].find('pol') != -1: + param_type = 'beta_list_pol' + idx = int(param_map[j][1].split('_')[0][7:]) + (len(trial_model.params[param_map[j][0]][param_type])-1)//2 + elif param_map[j][1].find('beta') != -1: + param_type = 'beta_list' + idx = int(param_map[j][1].split('_')[0][4:]) - 1 + else: + raise Exception('Unsure how to interpret ' + param_map[j][1]) + + curval = trial_model.params[param_map[j][0]][param_type][idx] + if '_' not in param_map[j][1]: # This is for beta0 of cpol + trial_model.params[param_map[j][0]][param_type][idx] = tparams[j] * param_map[j][2] + elif param_map[j][1][-2:] == 're': + trial_model.params[param_map[j][0]][param_type][idx] = tparams[j] * param_map[j][2] + np.imag(curval)*1j + elif param_map[j][1][-2:] == 'im': + trial_model.params[param_map[j][0]][param_type][idx] = tparams[j] * param_map[j][2] * 1j + np.real(curval) + elif param_map[j][1][-3:] == 'abs': + trial_model.params[param_map[j][0]][param_type][idx] = tparams[j] * param_map[j][2] * np.exp(1j * np.angle(curval)) + elif param_map[j][1][-3:] == 'arg': + trial_model.params[param_map[j][0]][param_type][idx] = np.abs(curval) * np.exp(1j * tparams[j] * param_map[j][2]) + else: + print('Parameter ' + param_map[j][1] + ' not understood!') + +# Define prior +def prior(params, param_map, model_prior, minimizer_func): + tparams = transform_params(params, param_map, minimizer_func, model_prior) + return np.sum([np.log(prior_func(tparams[j]*param_map[j][2], model_prior[param_map[j][0]][param_map[j][1]])) for j in range(len(params))]) + +def prior_grad(params, param_map, model_prior, minimizer_func): + tparams = transform_params(params, param_map, minimizer_func, model_prior) + f = np.array([prior_func(tparams[j]*param_map[j][2], model_prior[param_map[j][0]][param_map[j][1]]) for j in range(len(params))]) + df = np.array([prior_grad_func(tparams[j]*param_map[j][2], model_prior[param_map[j][0]][param_map[j][1]]) for j in range(len(params))]) + return df/f + +# Define constraint functions +def flux_constraint(trial_model, alpha_flux, flux): + if alpha_flux == 0.0: + return 0.0 + + return ((trial_model.total_flux() - flux)/flux)**2 + +def flux_constraint_grad(trial_model, alpha_flux, flux, params, param_map): + if alpha_flux == 0.0: + return 0.0 + + fluxmask = np.zeros_like(params) + for j in range(len(param_map)): + if param_map[j][1] == 'F0': + fluxmask[j] = 1.0 + + return 2.0 * (trial_model.total_flux() - flux)/flux * fluxmask + +################################################################################################## +# Define the chi^2 and chi^2 gradient functions +################################################################################################## +def laplace_approximation(trial_model, dtype, data, uv, sigma, gains_t1, gains_t2): + # Compute the approximate contribution to the log-likelihood by marginalizing over gains + global globdict + + if globdict['marginalize_gains'] == True and dtype == 'amp': + # Add the log-likelihood term from analytic gain marginalization + # Create the Hessian matrix for the argument of the exponential + gain_hess = np.zeros((len(globdict['gain_list']), len(globdict['gain_list']))) + + # Add the terms from the likelihood + gain = gain_factor(dtype,None,gains_t1,gains_t2,True) + amp_model = np.abs(trial_model.sample_uv(uv[:,0],uv[:,1])) # TODO: Add polarization! + amp_bar = gain*data + sigma_bar = gain*sigma + (g1, g2) = gain_factor_separate(dtype,None,gains_t1,gains_t2,True) + + # Each amplitude *measurement* (not fitted gain parameter!) contributes to the hessian in four places; two diagonal and two off-diagonal + for j in range(len(gain)): + gain_hess[gains_t1[j],gains_t1[j]] += amp_model[j] * (3.0 * amp_model[j] - 2.0 * amp_bar[j])/((1.0 + g1[j])**2 * sigma_bar[j]**2) + gain_hess[gains_t2[j],gains_t2[j]] += amp_model[j] * (3.0 * amp_model[j] - 2.0 * amp_bar[j])/((1.0 + g2[j])**2 * sigma_bar[j]**2) + gain_hess[gains_t1[j],gains_t2[j]] += amp_model[j] * (2.0 * amp_model[j] - amp_bar[j])/((1.0 + g1[j])*(1.0 + g2[j]) * sigma_bar[j]**2) + gain_hess[gains_t2[j],gains_t1[j]] += amp_model[j] * (2.0 * amp_model[j] - amp_bar[j])/((1.0 + g1[j])*(1.0 + g2[j]) * sigma_bar[j]**2) + + # Add contributions from the prior to the diagonal. This ranges over the fitted gain parameters. + # Note: for the Laplace approximation, only Gaussian gain priors have any effect! + for j in range(len(globdict['gain_list'])): + t = globdict['gain_list'][j][1] + if globdict['gain_prior'][t]['prior_type'] == 'gauss': + gain_hess[j,j] += 1.0/globdict['gain_prior'][t]['std'] + elif globdict['gain_prior'][t]['prior_type'] == 'flat': + gain_hess[j,j] += 0.0 + elif globdict['gain_prior'][t]['prior_type'] == 'exponential': + gain_hess[j,j] += 0.0 + elif globdict['gain_prior'][t]['prior_type'] == 'fixed': + gain_hess[j,j] += 0.0 + else: + raise Exception('Gain prior not implemented!') + return np.log((2.0 * np.pi)**(len(gain)/2.0) * np.abs(np.linalg.det(gain_hess))**-0.5) + else: + return 0.0 + +def laplace_list(): + global globdict + l1 = laplace_approximation(globdict['trial_model'], globdict['d1'], globdict['data1'], globdict['uv1'], globdict['sigma1'], globdict['gains_t1'], globdict['gains_t2']) + l2 = laplace_approximation(globdict['trial_model'], globdict['d2'], globdict['data2'], globdict['uv2'], globdict['sigma2'], globdict['gains_t1'], globdict['gains_t2']) + l3 = laplace_approximation(globdict['trial_model'], globdict['d3'], globdict['data3'], globdict['uv3'], globdict['sigma3'], globdict['gains_t1'], globdict['gains_t2']) + l4 = laplace_approximation(globdict['trial_model'], globdict['d4'], globdict['data4'], globdict['uv4'], globdict['sigma4'], globdict['gains_t1'], globdict['gains_t2']) + return (l1, l2, l3, l4) + +def chisq_wgain(trial_model, dtype, data, uv, sigma, pol, jonesdict, gains, gains_t1, gains_t2, fit_or_marginalize_gains): + global globdict + gain = gain_factor(dtype,gains,gains_t1,gains_t2,fit_or_marginalize_gains) + log_likelihood = chisq(trial_model, uv, gain*data, gain*sigma, dtype, pol, jonesdict) + return log_likelihood + +def chisqgrad_wgain(trial_model, dtype, data, uv, sigma, jonesdict, gains, gains_t1, gains_t2, fit_or_marginalize_gains, param_mask, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False): + gain = gain_factor(dtype,gains,gains_t1,gains_t2,fit_or_marginalize_gains) + return chisqgrad(trial_model, uv, gain*data, gain*sigma, jonesdict, dtype, param_mask, pol, fit_or_marginalize_gains, gains, gains_t1, gains_t2, fit_pol, fit_cpol, fit_leakage) + +def chisq_list(gains): + global globdict + chi2_1 = chisq_wgain(globdict['trial_model'], globdict['d1'], globdict['data1'], globdict['uv1'], globdict['sigma1'], globdict['pol1'], globdict['jonesdict1'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains']) + chi2_2 = chisq_wgain(globdict['trial_model'], globdict['d2'], globdict['data2'], globdict['uv2'], globdict['sigma2'], globdict['pol2'], globdict['jonesdict2'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains']) + chi2_3 = chisq_wgain(globdict['trial_model'], globdict['d3'], globdict['data3'], globdict['uv3'], globdict['sigma3'], globdict['pol3'], globdict['jonesdict3'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains']) + chi2_4 = chisq_wgain(globdict['trial_model'], globdict['d4'], globdict['data4'], globdict['uv4'], globdict['sigma4'], globdict['pol4'], globdict['jonesdict4'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains']) + + return (chi2_1, chi2_2, chi2_3, chi2_4) + +def update_leakage(leakage): + # This function updates the 'jonesdict' entries based on current leakage estimates + # leakage is list of the fitted parameters (re and im are separate) + # station_leakages is the dictionary containing all station leakages, some of which may be fixed + global globdict + if len(leakage) == 0: return + + station_leakages = globdict['station_leakages'] + leakage_fit = globdict['leakage_fit'] + # First, update the entries in the leakage dictionary + for j in range(len(leakage)//2): + station_leakages[leakage_fit[j][0]][leakage_fit[j][1]] = leakage[2*j] + 1j * leakage[2*j + 1] + + # Now, recompute the jonesdict objects + for j in range(1,4): + jonesdict = globdict['jonesdict' + str(j)] + if jonesdict is not None: + if type(jonesdict) is dict: + jonesdict['DR1'] = np.array([station_leakages[jonesdict['t1'][_]]['R'] for _ in range(len(jonesdict['t1']))]) + jonesdict['DR2'] = np.array([station_leakages[jonesdict['t2'][_]]['R'] for _ in range(len(jonesdict['t1']))]) + jonesdict['DL1'] = np.array([station_leakages[jonesdict['t1'][_]]['L'] for _ in range(len(jonesdict['t1']))]) + jonesdict['DL2'] = np.array([station_leakages[jonesdict['t2'][_]]['L'] for _ in range(len(jonesdict['t1']))]) + jonesdict['leakage_fit'] = globdict['leakage_fit'] + else: + # In this case, the data product requires a list of jonesdicts + for jonesdict2 in jonesdict: + jonesdict2['DR1'] = np.array([station_leakages[jonesdict2['t1'][_]]['R'] for _ in range(len(jonesdict2['t1']))]) + jonesdict2['DR2'] = np.array([station_leakages[jonesdict2['t2'][_]]['R'] for _ in range(len(jonesdict2['t1']))]) + jonesdict2['DL1'] = np.array([station_leakages[jonesdict2['t1'][_]]['L'] for _ in range(len(jonesdict2['t1']))]) + jonesdict2['DL2'] = np.array([station_leakages[jonesdict2['t2'][_]]['L'] for _ in range(len(jonesdict2['t1']))]) + jonesdict2['leakage_fit'] = globdict['leakage_fit'] + +################################################################################################## +# Define the objective function and gradient +################################################################################################## +def objfunc(params, force_posterior=False): + global globdict + # Note: model parameters can have transformations applied; gains and leakage do not + set_params(params[:globdict['n_params']], globdict['trial_model'], globdict['param_map'], globdict['minimizer_func'], globdict['model_prior']) + gains = params[globdict['n_params']:(globdict['n_params'] + globdict['n_gains'])] + leakage = params[(globdict['n_params'] + globdict['n_gains']):] + update_leakage(leakage) + + if globdict['marginalize_gains']: + # Ugh, the use of global variables totally messes this up + _globdict = globdict + # This doesn't handle the passed gain_init properly because the dimensions are incorrect + _globdict['gain_init'] = caltable_to_gains(selfcal(globdict['Obsdata'], globdict['trial_model'], gain_init=None, gain_prior=globdict['gain_prior'], msgtype='none'),globdict['gain_list']) + globdict = _globdict + + (chi2_1, chi2_2, chi2_3, chi2_4) = chisq_list(gains) + datterm = ( globdict['alpha_d1'] * chi2_1 + + globdict['alpha_d2'] * chi2_2 + + globdict['alpha_d3'] * chi2_3 + + globdict['alpha_d4'] * chi2_4) + + if globdict['marginalize_gains']: + (l1, l2, l3, l4) = laplace_list() + datterm += l1 + l2 + l3 + l4 + + if (globdict['minimizer_func'] not in ['dynesty_static','dynesty_dynamic','pymc3']) or force_posterior: + priterm = prior(params[:globdict['n_params']], globdict['param_map'], globdict['model_prior'], globdict['minimizer_func']) + priterm += prior_gain(params[globdict['n_params']:(globdict['n_params'] + globdict['n_gains'])], globdict['gain_list'], globdict['gain_prior'], globdict['fit_gains']) + priterm += prior_leakage(params[(globdict['n_params'] + globdict['n_gains']):], globdict['leakage_fit'], globdict['leakage_prior'], globdict['fit_leakage']) + else: + priterm = 0.0 + fluxterm = globdict['alpha_flux'] * flux_constraint(globdict['trial_model'], globdict['alpha_flux'], globdict['flux']) + + return datterm - priterm + fluxterm - globdict['ln_norm'] + +def objgrad(params): + global globdict + set_params(params[:globdict['n_params']], globdict['trial_model'], globdict['param_map'], globdict['minimizer_func'], globdict['model_prior']) + gains = params[globdict['n_params']:(globdict['n_params'] + globdict['n_gains'])] + leakage = params[(globdict['n_params'] + globdict['n_gains']):] + update_leakage(leakage) + + datterm = ( globdict['alpha_d1'] * chisqgrad_wgain(globdict['trial_model'], globdict['d1'], globdict['data1'], globdict['uv1'], globdict['sigma1'], globdict['jonesdict1'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'], globdict['param_mask'], globdict['pol1'], globdict['fit_pol'], globdict['fit_cpol'], globdict['fit_leakage']) + + globdict['alpha_d2'] * chisqgrad_wgain(globdict['trial_model'], globdict['d2'], globdict['data2'], globdict['uv2'], globdict['sigma2'], globdict['jonesdict2'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'], globdict['param_mask'], globdict['pol2'], globdict['fit_pol'], globdict['fit_cpol'], globdict['fit_leakage']) + + globdict['alpha_d3'] * chisqgrad_wgain(globdict['trial_model'], globdict['d3'], globdict['data3'], globdict['uv3'], globdict['sigma3'], globdict['jonesdict3'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'], globdict['param_mask'], globdict['pol3'], globdict['fit_pol'], globdict['fit_cpol'], globdict['fit_leakage']) + + globdict['alpha_d4'] * chisqgrad_wgain(globdict['trial_model'], globdict['d4'], globdict['data4'], globdict['uv4'], globdict['sigma4'], globdict['jonesdict4'], gains, globdict['gains_t1'], globdict['gains_t2'], globdict['fit_gains'] + globdict['marginalize_gains'], globdict['param_mask'], globdict['pol4'], globdict['fit_pol'], globdict['fit_cpol'], globdict['fit_leakage'])) + + if globdict['minimizer_func'] not in ['dynesty_static','dynesty_dynamic','pymc3']: + priterm = np.concatenate([prior_grad(params[:globdict['n_params']], globdict['param_map'], + globdict['model_prior'], globdict['minimizer_func']), + prior_gain_grad(params[globdict['n_params']:(globdict['n_params'] + globdict['n_gains'])], + globdict['gain_list'], globdict['gain_prior'], globdict['fit_gains']), + prior_leakage_grad(params[(globdict['n_params'] + globdict['n_gains']):], globdict['leakage_fit'], + globdict['leakage_prior'], globdict['fit_leakage'])]) + else: + priterm = 0.0 + fluxterm = globdict['alpha_flux'] * flux_constraint_grad(params, globdict['alpha_flux'], globdict['flux'], params, globdict['param_map']) + + grad = datterm - priterm + fluxterm + + if globdict['minimizer_func'] not in ['dynesty_static','dynesty_dynamic','pymc3']: + for j in range(globdict['n_params']): + grad[j] *= globdict['param_map'][j][2] * transform_grad_param(params[j], globdict['model_prior'][globdict['param_map'][j][0]][globdict['param_map'][j][1]]) + else: + # For dynesty or pymc3, over-ride all specified parameter transformations to assume CDF + # However, the passed parameters are *not* transformed (i.e., they are not in the hypercube) + # The Jacobian still needs to account for the parameter transformation + for j in range(len(params)): + if j < globdict['n_params']: + j2 = j + x = params[j2] + prior_params = globdict['model_prior'][globdict['param_map'][j2][0]][globdict['param_map'][j2][1]] + grad[j] /= prior_func(x,prior_params) + elif j < globdict['n_params'] + globdict['n_gains']: + j2 = j-globdict['n_params'] + x = gains[j2] + prior_params = globdict['gain_prior'][globdict['gain_list'][j2][1]] + grad[j] /= prior_func(x, prior_params) + else: + cparts = ['re','im'] + j2 = j-globdict['n_params']-globdict['n_gains'] + x = leakage[j2] + prior_params = globdict['leakage_prior'][globdict['leakage_fit'][j2//2][0]][globdict['leakage_fit'][j2//2][1]][cparts[j2%2]] + grad[j] /= prior_func(x, prior_params) + + if globdict['test_gradient']: + print('Testing the gradient at ',params) + import copy + dx = 1e-5 + grad_numeric = np.zeros(len(grad)) + f1 = objfunc(params) + print('Objective Function:',f1) + print('\nNumeric Gradient Check: Analytic Numeric') + for j in range(len(grad)): + if globdict['minimizer_func'] in ['dynesty_static','dynesty_dynamic','pymc3']: + dx = np.abs(params[j]) * 1e-6 + + params2 = copy.deepcopy(params) + params2[j] += dx + f2 = objfunc(params2) + grad_numeric[j] = (f2 - f1)/dx + + if globdict['minimizer_func'] in ['dynesty_static','dynesty_dynamic','pymc3']: + if j < globdict['n_params']: + j2 = j + x = params[j2] + prior_params = globdict['model_prior'][globdict['param_map'][j2][0]][globdict['param_map'][j2][1]] + grad_numeric[j] /= prior_func(x,prior_params) + elif j < globdict['n_params'] + globdict['n_gains']: + j2 = j-globdict['n_params'] + x = gains[j2] + prior_params = globdict['gain_prior'][globdict['gain_list'][j2][1]] + grad_numeric[j] /= prior_func(x, prior_params) + else: + cparts = ['re','im'] + j2 = j-globdict['n_params']-globdict['n_gains'] + x = leakage[j2] + prior_params = globdict['leakage_prior'][globdict['leakage_fit'][j2//2][0]][globdict['leakage_fit'][j2//2][1]][cparts[j2%2]] + grad_numeric[j] /= prior_func(x, prior_params) + + if j < globdict['n_params']: + print('\nNumeric Gradient Check:',globdict['param_map'][j][0],globdict['param_map'][j][1],grad[j],grad_numeric[j]) + else: + print('\nNumeric Gradient Check:',grad[j],grad_numeric[j]) + + return grad + +################################################################################################## +# Modeler +################################################################################################## +def modeler_func(Obsdata, model_init, model_prior, + d1='vis', d2=False, d3=False, d4=False, + pol1='I', pol2='I', pol3='I', pol4='I', + normchisq = False, alpha_d1=0, alpha_d2=0, alpha_d3=0, alpha_d4=0, + flux=1.0, alpha_flux=0, + fit_model=True, fit_pol=False, fit_cpol=False, + fit_gains=False,marginalize_gains=False,gain_init=None,gain_prior=None, + fit_leakage=False, leakage_init=None, leakage_prior=None, + fit_noise_model=False, + minimizer_func='scipy.optimize.minimize', + minimizer_kwargs=None, + bounds=None, use_bounds=False, + processes=-1, + test_gradient=False, quiet=False, **kwargs): + + """Fit a specified model. + + Args: + Obsdata (Obsdata): The Obsdata object with VLBI data + model_init (Model): The Model object to fit + model_prior (dict): Priors for each model parameter + + d1 (str): The first data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag', 'm' + d2 (str): The second data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag', 'm' + d3 (str): The third data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag', 'm' + d4 (str): The fourth data term; options are 'vis', 'bs', 'amp', 'cphase', 'cphase_diag' 'camp', 'logcamp', 'logcamp_diag', 'm' + + normchisq (bool): If False (default), automatically assign weights alpha_d1-3 to match the true log-likelihood. + alpha_d1 (float): The first data term weighting. + alpha_d2 (float): The second data term weighting. Default value of zero will automatically assign weights to match the true log-likelihood. + alpha_d3 (float): The third data term weighting. Default value of zero will automatically assign weights to match the true log-likelihood. + alpha_d4 (float): The fourth data term weighting. Default value of zero will automatically assign weights to match the true log-likelihood. + + flux (float): Total flux of the fitted model + alpha_flux (float): Hyperparameter controlling how strongly to constrain that the total flux matches the specified flux. + + fit_model (bool): Whether or not to fit the model parameters + fit_pol (bool): Whether or not to fit linear polarization parameters + fit_cpol (bool): Whether or not to fit circular polarization parameters + fit_gains (bool): Whether or not to fit time-dependent amplitude gains for each station + marginalize_gains (bool): Whether or not to perform analytic gain marginalization (via the Laplace approximation to the posterior) + + gain_init (list or caltable): Initial gain amplitudes to apply; these can be specified even if gains aren't fitted + gain_prior (dict): Dictionary with the gain prior for each site. + + minimizer_func (str): Minimizer function to use. Current options are: + 'scipy.optimize.minimize' + 'scipy.optimize.dual_annealing' + 'scipy.optimize.basinhopping' + 'dynesty_static' + 'dynesty_dynamic' + 'pymc3' + minimizer_kwargs (dict): kwargs passed to the minimizer. + + bounds (list): List of parameter bounds for the fitted parameters (will automatically compute if needed) + use_bounds (bool): Whether or not to use bounds when fitting (required for some minimizers) + + processes (int): Number of processes to use for a multiprocessing pool. -1 disables multiprocessing; 0 uses all that are available. Only used for dynesty. + + Returns: + dict: Dictionary with fitted model ('model') and other diagnostics that are minimizer-dependent + """ + + global nit, globdict + nit = n_params = 0 + ln_norm = 0.0 + + if fit_model == False and fit_gains == False and fit_leakage == False: + raise Exception('Both fit_model, fit_gains, and fit_leakage are False. Must fit something!') + + if fit_gains == True and marginalize_gains == True: + raise Exception('Both fit_gains and marginalize_gains are True. Cannot do both!') + + if fit_gains == False and marginalize_gains == False and gain_init is not None: + if not quiet: print('Both fit_gains and marginalize_gains are False but gain_init was passed. Applying these gains as a fixed correction!') + + if minimizer_kwargs is None: + minimizer_kwargs = {} + + # Specifications for verbosity during fits + show_updates = kwargs.get('show_updates',True) + update_interval = kwargs.get('update_interval',1) + run_nested_kwargs = kwargs.get('run_nested_kwargs',{}) + + # Make sure data and regularizer options are ok + if not d1 and not d2 and not d3 and not d4: + raise Exception("Must have at least one data term!") + if (not ((d1 in DATATERMS) or d1==False)) or (not ((d2 in DATATERMS) or d2==False)) or (not ((d3 in DATATERMS) or d3==False)) or (not ((d4 in DATATERMS) or d4==False)): + raise Exception("Invalid data term: valid data terms are: " + ' '.join(DATATERMS)) + + # Create the trial model + trial_model = model_init.copy() + + # Define mapping for every fitted parameter: model component index, parameter name, rescale multiplier, unit + (param_map, param_mask) = make_param_map(model_init, model_prior, minimizer_func, fit_model, fit_pol, fit_cpol) + + # Get data and info for the data terms + if type(Obsdata) is obsdata.Obsdata: + (data1, sigma1, uv1, jonesdict1) = chisqdata(Obsdata, d1, pol=pol1) + (data2, sigma2, uv2, jonesdict2) = chisqdata(Obsdata, d2, pol=pol2) + (data3, sigma3, uv3, jonesdict3) = chisqdata(Obsdata, d3, pol=pol3) + (data4, sigma4, uv4, jonesdict4) = chisqdata(Obsdata, d4, pol=pol4) + elif type(Obsdata) is list: + # Combine a list of observations into one. + # Allow these to be from multiple sources for polarimetric zero-baseline purposes. + # Main thing for different sources is to compute field rotation before combining + def combine_data(d1,s1,u1,j1,d2,s2,u2,j2): + d = np.concatenate([d1,d2]) + s = np.concatenate([s1,s2]) + u = np.concatenate([u1,u2]) + j = j1.copy() + for key in ['fr1', 'fr2', 't1', 't2', 'DR1', 'DR2', 'DL1', 'DL2']: + j[key] = np.concatenate([j1[key],j2[key]]) + return (d, s, u, j) + + (data1, sigma1, uv1, jonesdict1) = chisqdata(Obsdata[0], d1, pol=pol1) + (data2, sigma2, uv2, jonesdict2) = chisqdata(Obsdata[0], d2, pol=pol2) + (data3, sigma3, uv3, jonesdict3) = chisqdata(Obsdata[0], d3, pol=pol3) + (data4, sigma4, uv4, jonesdict4) = chisqdata(Obsdata[0], d4, pol=pol4) + for j in range(1,len(Obsdata)): + (data1b, sigma1b, uv1b, jonesdict1b) = chisqdata(Obsdata[j], d1, pol=pol1) + (data2b, sigma2b, uv2b, jonesdict2b) = chisqdata(Obsdata[j], d2, pol=pol2) + (data3b, sigma3b, uv3b, jonesdict3b) = chisqdata(Obsdata[j], d3, pol=pol3) + (data4b, sigma4b, uv4b, jonesdict4b) = chisqdata(Obsdata[j], d4, pol=pol4) + + if data1b is not False: + (data1, sigma1, uv1, jonesdict1) = combine_data(data1,sigma1,uv1,jonesdict1,data1b,sigma1b,uv1b,jonesdict1b) + if data2b is not False: + (data2, sigma2, uv2, jonesdict2) = combine_data(data2,sigma2,uv2,jonesdict2,data2b,sigma2b,uv2b,jonesdict2b) + if data3b is not False: + (data3, sigma3, uv3, jonesdict3) = combine_data(data3,sigma3,uv3,jonesdict3,data3b,sigma3b,uv3b,jonesdict3b) + if data4b is not False: + (data4, sigma4, uv4, jonesdict4) = combine_data(data4,sigma4,uv4,jonesdict4,data4b,sigma4b,uv4b,jonesdict4b) + + alldata = np.concatenate([_.data for _ in Obsdata]) + Obsdata = Obsdata[0] + Obsdata.data = alldata + else: + raise Exception("Observation format not recognized!") + + if fit_leakage or leakage_init is not None: + # Determine what leakage terms must be fitted. At most, this would be L & R complex leakages terms for every site + # leakage_fit is a list of tuples [site, hand] that will be fitted + leakage_fit = [] + if fit_leakage: + import copy # Error on the next line if this isn't done again. Why python, why?!? + # Start with the list of all sites + sites = list(set(np.concatenate(Obsdata.unpack(['t1','t2']).tolist()))) + + # Add missing entries to leakage_prior + # leakage_prior is a nested dictionary with keys of station, hand, re/im + leakage_prior_init = copy.deepcopy(leakage_prior) + if leakage_prior_init is None: leakage_prior_init = {} + leakage_prior = {} + for s in sites: + leakage_prior[s] = {} + for pol in ['R','L']: + leakage_prior[s][pol] = {} + for cpart in ['re','im']: + # check to see if a prior is specified for the complex part, the pol, or the site (in that order) + if leakage_prior_init.get(s,{}).get(pol,{}).get(cpart,{}).get('prior_type','') != '': + leakage_prior[s][pol][cpart] = leakage_prior_init[s][pol][cpart] + elif leakage_prior_init.get(s,{}).get(pol,{}).get('prior_type','') != '': + leakage_prior[s][pol][cpart] = copy.deepcopy(leakage_prior_init[s][pol]) + elif leakage_prior_init.get(s,{}).get('prior_type','') != '': + leakage_prior[s][pol][cpart] = copy.deepcopy(leakage_prior_init[s]) + else: + leakage_prior[s][pol][cpart] = copy.deepcopy(LEAKAGE_PRIOR_DEFAULT) + + if Obsdata.polrep == 'stokes': + for s in sites: + for pol in ['R','L']: + if leakage_prior[s][pol]['re']['prior_type'] == 'fixed': continue + leakage_fit.append([s,pol]) + else: + vislist = Obsdata.unpack(['t1','t2','rlvis','lrvis']) + # Only fit leakage for sites that include cross hand visibilities + DR = list(set(np.concatenate([vislist[~np.isnan(vislist['rlvis'])]['t1'], vislist[~np.isnan(vislist['lrvis'])]['t2']]))) + DL = list(set(np.concatenate([vislist[~np.isnan(vislist['lrvis'])]['t1'], vislist[~np.isnan(vislist['rlvis'])]['t2']]))) + [leakage_fit.append([s,'R']) for s in DR if leakage_prior[s]['R']['re']['prior_type'] != 'fixed'] + [leakage_fit.append([s,'L']) for s in DL if leakage_prior[s]['L']['re']['prior_type'] != 'fixed'] + sites = list(set(np.concatenate([DR,DL]))) + + if type(leakage_init) is dict: + station_leakages = copy.deepcopy(leakage_init) + else: + station_leakages = {} + + # Add missing entries to station_leakages + for s in sites: + for pol in ['R','L']: + if s not in station_leakages.keys(): + station_leakages[s] = {} + if 'R' not in station_leakages[s].keys(): + station_leakages[s]['R'] = 0.0 + if 'L' not in station_leakages[s].keys(): + station_leakages[s]['L'] = 0.0 + else: + # Disable leakage computations + jonesdict1 = jonesdict2 = jonesdict3 = None + leakage_fit = [] + station_leakages = None + + if normchisq == False: + if not quiet: print('Assigning data weights to give the correct log-likelihood...') + (alpha_d1, alpha_d2, alpha_d3, alpha_d4, ln_norm) = compute_likelihood_constants(d1, d2, d3, d4, sigma1, sigma2, sigma3, sigma4) + else: + ln_norm = 0.0 + + # Determine the mapping between solution gains and the input visibilities + # Use passed gains even if fit_gains=False and marginalize_gains=False + # NOTE: THERE IS A PROBLEM IN THIS IMPLEMENTATION. A fixed gain prior is ignored. However, gain_init may still want to apply a constant correction, especially when passing a caltable. + # We should maybe have two gain lists: one for constant gains and one for fitted gains + mean_g1 = mean_g2 = 0.0 + if fit_gains or marginalize_gains: + if gain_prior is None: + gain_prior = default_gain_prior(Obsdata.tarr['site']) + (gain_list, gains_t1, gains_t2) = make_gain_map(Obsdata, gain_prior) + if type(gain_init) == caltable.Caltable: + if not quiet: print('Converting gain_init from caltable to a list') + gain_init = caltable_to_gains(gain_init, gain_list) + if gain_init is None: + if not quiet: print('Initializing all gain corrections to be zero') + gain_init = np.zeros(len(gain_list)) + else: + if len(gain_init) != len(gain_list): + raise Exception('Gain initialization has incorrect dimensions! %d %d' % (len(gain_init), len(gain_list))) + if fit_gains: + n_gains = len(gain_list) + elif marginalize_gains: + n_gains = 0 + else: + if gain_init is None: + n_gains = 0 + gain_list = [] + gains_t1 = gains_t2 = None + else: + if gain_prior is None: + gain_prior = default_gain_prior(Obsdata.tarr['site']) + (gain_list, gains_t1, gains_t2) = make_gain_map(Obsdata, gain_prior) + if type(gain_init) == caltable.Caltable: + if not quiet: print('Converting gain_init from caltable to a list') + gain_init = caltable_to_gains(gain_init, gain_list) + + if fit_leakage: + leakage_init = np.zeros(len(leakage_fit) * 2) + for j in range(len(leakage_init)//2): + leakage_init[2*j] = np.real(station_leakages[leakage_fit[j][0]][leakage_fit[j][1]]) + leakage_init[2*j + 1] = np.imag(station_leakages[leakage_fit[j][0]][leakage_fit[j][1]]) + else: + leakage_init = [] + + # Initial parameters + param_init = [] + for j in range(len(param_map)): + pm = param_map[j] + if param_map[j][1] in trial_model.params[param_map[j][0]].keys(): + param_init.append(transform_param(model_init.params[pm[0]][pm[1]]/pm[2], model_prior[pm[0]][pm[1]],inverse=False)) + else: # In this case, the parameter is a list of complex numbers, so the real/imaginary or abs/arg components need to be assigned + if param_map[j][1].find('cpol') != -1: + param_type = 'beta_list_cpol' + idx = int(param_map[j][1].split('_')[0][8:]) + elif param_map[j][1].find('pol') != -1: + param_type = 'beta_list_pol' + idx = int(param_map[j][1].split('_')[0][7:]) + (len(trial_model.params[param_map[j][0]][param_type])-1)//2 + elif param_map[j][1].find('beta') != -1: + param_type = 'beta_list' + idx = int(param_map[j][1].split('_')[0][4:]) - 1 + else: + raise Exception('Unsure how to interpret ' + param_map[j][1]) + + curval = model_init.params[param_map[j][0]][param_type][idx] + if '_' not in param_map[j][1]: + param_init.append(transform_param(np.real( model_init.params[pm[0]][param_type][idx]/pm[2]), model_prior[pm[0]][pm[1]],inverse=False)) + elif param_map[j][1][-2:] == 're': + param_init.append(transform_param(np.real( model_init.params[pm[0]][param_type][idx]/pm[2]), model_prior[pm[0]][pm[1]],inverse=False)) + elif param_map[j][1][-2:] == 'im': + param_init.append(transform_param(np.imag( model_init.params[pm[0]][param_type][idx]/pm[2]), model_prior[pm[0]][pm[1]],inverse=False)) + elif param_map[j][1][-3:] == 'abs': + param_init.append(transform_param(np.abs( model_init.params[pm[0]][param_type][idx]/pm[2]), model_prior[pm[0]][pm[1]],inverse=False)) + elif param_map[j][1][-3:] == 'arg': + param_init.append(transform_param(np.angle(model_init.params[pm[0]][param_type][idx])/pm[2], model_prior[pm[0]][pm[1]],inverse=False)) + else: + if not quiet: print('Parameter ' + param_map[j][1] + ' not understood!') + n_params = len(param_init) + + # Note: model parameters can have transformations applied; gains and leakage do not + if fit_gains: # Do not add these if marginalize_gains == True + param_init += list(gain_init) + if fit_leakage: + param_init += list(leakage_init) + + if minimizer_func not in ['dynesty_static','dynesty_dynamic','pymc3']: + # Define bounds (irrelevant for dynesty or pymc3) + if use_bounds == False and minimizer_func in ['scipy.optimize.dual_annealing']: + if not quiet: print('Bounds are required for ' + minimizer_func + '! Setting use_bounds=True.') + use_bounds = True + if use_bounds == False and bounds is not None: + if not quiet: print('Bounds passed but use_bounds=False; setting use_bounds=True.') + use_bounds = True + if bounds is None and use_bounds: + if not quiet: print('No bounds passed. Setting nominal bounds.') + bounds = make_bounds(model_prior, param_map, gain_prior, gain_list, n_gains, leakage_fit, leakage_prior) + if use_bounds == False: + bounds = None + + # Gather global variables into a dictionary + globdict = {'trial_model':trial_model, + 'd1':d1, 'd2':d2, 'd3':d3, 'd4':d4, + 'pol1':pol1, 'pol2':pol2, 'pol3':pol3, 'pol4':pol4, + 'data1':data1, 'sigma1':sigma1, 'uv1':uv1, 'jonesdict1':jonesdict1, + 'data2':data2, 'sigma2':sigma2, 'uv2':uv2, 'jonesdict2':jonesdict2, + 'data3':data3, 'sigma3':sigma3, 'uv3':uv3, 'jonesdict3':jonesdict3, + 'data4':data4, 'sigma4':sigma4, 'uv4':uv4, 'jonesdict4':jonesdict4, + 'alpha_d1':alpha_d1, 'alpha_d2':alpha_d2, 'alpha_d3':alpha_d3, 'alpha_d4':alpha_d4, + 'n_params': n_params, 'n_gains':n_gains, 'n_leakage':len(leakage_init), + 'model_prior':model_prior, 'param_map':param_map, 'param_mask':param_mask, + 'gain_prior':gain_prior, 'gain_list':gain_list, 'gain_init':gain_init, + 'fit_leakage':fit_leakage, 'leakage_init':leakage_init, 'leakage_fit':leakage_fit, 'station_leakages':station_leakages, 'leakage_prior':leakage_prior, + 'show_updates':show_updates, 'update_interval':update_interval, 'gains_t1':gains_t1, 'gains_t2':gains_t2, + 'minimizer_func':minimizer_func,'Obsdata':Obsdata, + 'fit_pol':fit_pol, 'fit_cpol':fit_cpol, + 'flux':flux, 'alpha_flux':alpha_flux, 'fit_gains':fit_gains, 'marginalize_gains':marginalize_gains, 'ln_norm':ln_norm, 'param_init':param_init, 'test_gradient':test_gradient} + if fit_leakage: + update_leakage(leakage_init) + + + # Define the function that reports progress + def plotcur(params_step, *args): + global nit, globdict + if globdict['show_updates'] and (nit % globdict['update_interval'] == 0) and (quiet == False): + if globdict['n_params'] > 0: + print('Params:',params_step[:globdict['n_params']]) + print('Transformed Params:',transform_params(params_step[:globdict['n_params']], globdict['param_map'], globdict['minimizer_func'], globdict['model_prior'])) + gains = params_step[globdict['n_params']:(globdict['n_params'] + globdict['n_gains'])] + leakage = params_step[(globdict['n_params'] + globdict['n_gains']):] + if len(leakage): + print('leakage:',leakage) + update_leakage(leakage) + (chi2_1, chi2_2, chi2_3, chi2_4) = chisq_list(gains) + print("i: %d chi2_1: %0.2f chi2_2: %0.2f chi2_3: %0.2f chi2_4: %0.2f prior: %0.2f" % (nit, chi2_1, chi2_2, chi2_3, chi2_4, prior(params_step[:globdict['n_params']], globdict['param_map'], globdict['model_prior'], globdict['minimizer_func']))) + nit += 1 + + # Print initial statistics + if not quiet: + print("Initial Objective Function: %f" % (objfunc(param_init))) + if d1 in DATATERMS: + print("Total Data 1: ", (len(data1))) + if d2 in DATATERMS: + print("Total Data 2: ", (len(data2))) + if d3 in DATATERMS: + print("Total Data 3: ", (len(data3))) + if d4 in DATATERMS: + print("Total Data 4: ", (len(data4))) + print("Total Fitted Real Parameters #: ",(len(param_init))) + print("Fitted Model Parameters: ",[_[1] for _ in param_map]) + print('Fitting Leakage Terms for:',leakage_fit) + plotcur(param_init) + + # Run the minimization + tstart = time.time() + ret = {} + if minimizer_func == 'scipy.optimize.minimize': + min_kwargs = {'method':minimizer_kwargs.get('method','L-BFGS-B'), + 'options':{'maxiter':MAXIT, 'ftol':STOP, 'maxcor':NHIST,'gtol':STOP,'maxls':MAXLS}} + + if 'options' in minimizer_kwargs.keys(): + for key in minimizer_kwargs['options'].keys(): + min_kwargs['options'][key] = minimizer_kwargs['options'][key] + + for key in minimizer_kwargs.keys(): + if key in ['options','method']: + continue + else: + min_kwargs[key] = minimizer_kwargs[key] + + res = opt.minimize(objfunc, param_init, jac=objgrad, callback=plotcur, bounds=bounds, **min_kwargs) + elif minimizer_func == 'scipy.optimize.dual_annealing': + min_kwargs = {} + min_kwargs['local_search_options'] = {'jac':objgrad, + 'method':'L-BFGS-B','options':{'maxiter':MAXIT, 'ftol':STOP, 'maxcor':NHIST,'gtol':STOP,'maxls':MAXLS}} + if 'local_search_options' in minimizer_kwargs.keys(): + for key in minimizer_kwargs['local_search_options'].keys(): + min_kwargs['local_search_options'][key] = minimizer_kwargs['local_search_options'][key] + + for key in minimizer_kwargs.keys(): + if key in ['local_search_options']: + continue + min_kwargs[key] = minimizer_kwargs[key] + + res = opt.dual_annealing(objfunc, x0=param_init, bounds=bounds, callback=plotcur, **min_kwargs) + elif minimizer_func == 'scipy.optimize.basinhopping': + min_kwargs = {} + for key in minimizer_kwargs.keys(): + min_kwargs[key] = minimizer_kwargs[key] + + res = opt.basinhopping(objfunc, param_init, **min_kwargs) + elif minimizer_func == 'pymc3': + ######################## + ## Sample using pymc3 ## + ######################## + import pymc3 as pm + import theano + import theano.tensor as tt + + # To simplfy things, we'll use cdf transforms to map everything to a hypercube, as in dynesty + + # First, define a theano Op for our likelihood function + # This is based on the example here: https://docs.pymc.io/notebooks/blackbox_external_likelihood.html + class LogLike(tt.Op): + itypes = [tt.dvector] # expects a vector of parameter values when called + otypes = [tt.dscalar] # outputs a single scalar value (the log likelihood) + + def __init__(self, objfunc, objgrad): + # add inputs as class attributes + self.objfunc = objfunc + self.objgrad = objgrad + self.logpgrad = LogLikeGrad(objfunc, objgrad) + + def prior_transform(self, u): + # This function transforms samples from the unit hypercube (u) to the target prior (x) + global globdict + cparts = ['re','im'] + model_params_u = u[:n_params] + gain_params_u = u[n_params:(n_params+n_gains)] + leakage_params_u = u[(n_params+n_gains):] + model_params_x = [cdf_inverse( model_params_u[j], globdict['model_prior'][globdict['param_map'][j][0]][globdict['param_map'][j][1]]) for j in range(len(model_params_u))] + gain_params_x = [cdf_inverse( gain_params_u[j], globdict['gain_prior'][globdict['gain_list'][j][1]]) for j in range(len(gain_params_u))] + leakage_params_x = [cdf_inverse(leakage_params_u[j], globdict['leakage_prior'][globdict['leakage_fit'][j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage_params_u))] + return np.concatenate([model_params_x, gain_params_x, leakage_params_x]) + + def perform(self, node, inputs, outputs): + # the method that is used when calling the Op + theta, = inputs # this will contain my variables + # Transform from the hypercube to the prior + x = self.prior_transform(theta) + + # call the log-likelihood function + logl = -self.objfunc(x) + + outputs[0][0] = np.array(logl) # output the log-likelihood + + def grad(self, inputs, g): + # the method that calculates the vector-Jacobian product + # http://deeplearning.net/software/theano_versions/dev/extending/op.html#grad + theta, = inputs + return [g[0]*self.logpgrad(theta)] + + class LogLikeGrad(tt.Op): + """ + This Op will be called with a vector of values and also return a vector of + values - the gradients in each dimension. + """ + itypes = [tt.dvector] + otypes = [tt.dvector] + + def __init__(self, objfunc, objgrad): + self.objfunc = objfunc + self.objgrad = objgrad + + def prior_transform(self, u): + # This function transforms samples from the unit hypercube (u) to the target prior (x) + global globdict + cparts = ['re','im'] + model_params_u = u[:n_params] + gain_params_u = u[n_params:(n_params+n_gains)] + leakage_params_u = u[(n_params+n_gains):] + model_params_x = [cdf_inverse( model_params_u[j], globdict['model_prior'][globdict['param_map'][j][0]][globdict['param_map'][j][1]]) for j in range(len(model_params_u))] + gain_params_x = [cdf_inverse( gain_params_u[j], globdict['gain_prior'][globdict['gain_list'][j][1]]) for j in range(len(gain_params_u))] + leakage_params_x = [cdf_inverse(leakage_params_u[j], globdict['leakage_prior'][globdict['leakage_fit'][j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage_params_u))] + return np.concatenate([model_params_x, gain_params_x, leakage_params_x]) + + def perform(self, node, inputs, outputs): + theta, = inputs + x = self.prior_transform(theta) + outputs[0][0] = -self.objgrad(x) + + # create the log-likelihood Op + logl = LogLike(objfunc, objgrad) + + # Define the sampler keywords + min_kwargs = {} + for key in minimizer_kwargs.keys(): + min_kwargs[key] = minimizer_kwargs[key] + + # Define the initial value if not passed + if 'start' not in min_kwargs.keys(): + cparts = ['re','im'] + model_params_x = param_init[:n_params] + gain_params_x = param_init[n_params:(n_params+n_gains)] + leakage_params_x = param_init[(n_params+n_gains):] + model_params_u = [cdf( model_params_x[j], globdict['model_prior'][globdict['param_map'][j][0]][globdict['param_map'][j][1]]) for j in range(len(model_params_x))] + gain_params_u = [cdf( gain_params_x[j], globdict['gain_prior'][globdict['gain_list'][j][1]]) for j in range(len(gain_params_x))] + leakage_params_u = [cdf(leakage_params_x[j], globdict['leakage_prior'][globdict['leakage_fit'][j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage_params_x))] + param_init_u = np.concatenate([model_params_u, gain_params_u, leakage_params_u]) + min_kwargs['start'] = {} + for j in range(len(param_init)): + min_kwargs['start']['var' + str(j)] = param_init_u[j] + + # Setup the sampler + with pm.Model() as model: + theta = tt.as_tensor_variable([ pm.Uniform('var' + str(j), lower=0., upper=1.) for j in range(len(param_init)) ]) + pm.DensityDist('likelihood', lambda v: logl(v), observed={'v': theta}) + trace = pm.sample(**min_kwargs) + + # Extract useful sampling diagnostics. + samples_u = np.vstack([trace['var' + str(j)] for j in range(len(param_init))]).T # samples in the hypercube + samples = np.array([logl.prior_transform(u) for u in samples_u]) # samples + mean = np.mean(samples,axis=0) + var = np.var(samples,axis=0) + + # Compute the log-posterior + if not quiet: print('Calculating the posterior values for the samples...') + logposterior = np.array([-objfunc(x, force_posterior=True) for x in samples]) + + # Select the MAP + j_MAP = np.argmax(logposterior) + MAP = samples[j_MAP] + + # Return a model determined by the MAP + set_params(MAP[:n_params], trial_model, param_map, minimizer_func, model_prior) + gains = MAP[n_params:(n_params+n_gains)] + leakage = MAP[(n_params+n_gains):] + update_leakage(leakage) + + # Return the sampler + ret['trace'] = trace + ret['mean'] = mean + ret['map'] = MAP + ret['std'] = var**0.5 + ret['samples'] = samples + ret['logposterior'] = logposterior + + # Return a set of models from the posterior + posterior_models = [] + for j in range(N_POSTERIOR_SAMPLES): + posterior_model = trial_model.copy() + set_params(samples[-j][:n_params], posterior_model, param_map, minimizer_func, model_prior) + posterior_models.append(posterior_model) + ret['posterior_models'] = posterior_models + + # Return data that has been rescaled based on 'natural' units for each parameter + import copy + samples_natural = copy.deepcopy(samples) + samples_natural[:,:n_params] /= np.array([_[4] for _ in param_map]) + ret['samples_natural'] = samples_natural + + # Return the names of the fitted parameters + labels = [] + labels_natural = [] + for _ in param_map: + labels.append(_[1].replace('_','-')) + labels_natural.append(_[1].replace('_','-')) + if _[3] != '': + labels_natural[-1] += ' (' + _[3] + ')' + for _ in gain_list: + labels.append(str(_[0]) + ' ' + _[1]) + labels_natural.append(str(_[0]) + ' ' + _[1]) + for _ in leakage_fit: + for coord in ['re','im']: + labels.append(_[0] + ',' + _[1] + ',' + coord) + labels_natural.append(_[0] + ',' + _[1] + ',' + coord) + + ret['labels'] = labels + ret['labels_natural'] = labels_natural + elif minimizer_func in ['dynesty_static','dynesty_dynamic']: + ########################## + ## Sample using dynesty ## + ########################## + import dynesty + from dynesty import utils as dyfunc + # Define the functions that dynesty requires + def prior_transform(u): + # This function transforms samples from the unit hypercube (u) to the target prior (x) + global globdict + cparts = ['re','im'] + model_params_u = u[:n_params] + gain_params_u = u[n_params:(n_params+n_gains)] + leakage_params_u = u[(n_params+n_gains):] + model_params_x = [cdf_inverse( model_params_u[j], globdict['model_prior'][globdict['param_map'][j][0]][globdict['param_map'][j][1]]) for j in range(len(model_params_u))] + gain_params_x = [cdf_inverse( gain_params_u[j], globdict['gain_prior'][globdict['gain_list'][j][1]]) for j in range(len(gain_params_u))] + leakage_params_x = [cdf_inverse(leakage_params_u[j], globdict['leakage_prior'][globdict['leakage_fit'][j//2][0]][leakage_fit[j//2][1]][cparts[j%2]]) for j in range(len(leakage_params_u))] + return np.concatenate([model_params_x, gain_params_x, leakage_params_x]) + + def loglike(x): + return -objfunc(x) + + def grad(x): + return -objgrad(x) + + # Setup a multiprocessing pool if needed + if processes >= 0: + import pathos.multiprocessing as mp + from multiprocessing import cpu_count + if processes == 0: processes = int(cpu_count()) + + # Ensure efficient memory allocation among the processes and separate trial models for each + def init(_globdict): + global globdict + globdict = _globdict + if processes >= 0: + globdict['trial_model'] = globdict['trial_model'].copy() + + return + + pool = mp.Pool(processes=processes, initializer=init, initargs=(globdict,)) + if not quiet: print('Using a pool with %d processes' % processes) + else: + pool = processes = None + + # Setup the sampler + if minimizer_func == 'dynesty_static': + sampler = dynesty.NestedSampler(loglike, prior_transform, ndim=len(param_init), gradient=grad, pool=pool, queue_size=processes, **minimizer_kwargs) + else: + sampler = dynesty.DynamicNestedSampler(loglike, prior_transform, ndim=len(param_init), gradient=grad, pool=pool, queue_size=processes, **minimizer_kwargs) + + # Run the sampler + sampler.run_nested(**run_nested_kwargs) + + # Print the sampler summary + res = sampler.results + if not quiet: + try: res.summary() + except: pass + + # Extract useful sampling diagnostics. + samples = res.samples # samples + weights = np.exp(res.logwt - res.logz[-1]) # normalized weights + mean, cov = dyfunc.mean_and_cov(samples, weights) + + # Compute the log-posterior + if not quiet: print('Calculating the posterior values for the samples...') + if pool is not None: + from functools import partial + def logpost(j): + return -objfunc(samples[j], force_posterior=True) + + logposterior = pool.map(logpost, range(len(samples))) + else: + logposterior = np.array([-objfunc(x, force_posterior=True) for x in samples]) + + # Close the pool (this may not be the desired behavior if the sampling is to be iterative!) + if pool is not None: + pool.close() + + # Select the MAP + j_MAP = np.argmax(logposterior) + MAP = samples[j_MAP] + + # Resample from the posterior + samples = dyfunc.resample_equal(samples, weights) + + # Return a model determined by the MAP + set_params(MAP[:n_params], trial_model, param_map, minimizer_func, model_prior) + gains = MAP[n_params:(n_params+n_gains)] + leakage = MAP[(n_params+n_gains):] + update_leakage(leakage) + + # Return the sampler + ret['sampler'] = sampler + ret['mean'] = mean + ret['map'] = MAP + ret['std'] = cov.diagonal()**0.5 + ret['samples'] = samples + ret['logposterior'] = logposterior + + # Return a set of models from the posterior + posterior_models = [] + for j in range(N_POSTERIOR_SAMPLES): + posterior_model = trial_model.copy() + set_params(samples[j][:n_params], posterior_model, param_map, minimizer_func, model_prior) + posterior_models.append(posterior_model) + ret['posterior_models'] = posterior_models + + # Return data that has been rescaled based on 'natural' units for each parameter + import copy + res_natural = copy.deepcopy(res) + res_natural.samples[:,:n_params] /= np.array([_[4] for _ in param_map]) + samples_natural = samples[:,:n_params]/np.array([_[4] for _ in param_map]) + ret['res_natural'] = res_natural + ret['samples_natural'] = samples_natural + + # Return the names of the fitted parameters + labels = [] + labels_natural = [] + for _ in param_map: + labels.append(_[1].replace('_','-')) + labels_natural.append(_[1].replace('_','-')) + if _[3] != '': + labels_natural[-1] += ' (' + _[3] + ')' + for _ in gain_list: + labels.append(str(_[0]) + ' ' + _[1]) + labels_natural.append(str(_[0]) + ' ' + _[1]) + for _ in leakage_fit: + for coord in ['re','im']: + labels.append(_[0] + ',' + _[1] + ',' + coord) + labels_natural.append(_[0] + ',' + _[1] + ',' + coord) + + ret['labels'] = labels + ret['labels_natural'] = labels_natural + else: + raise Exception('Minimizer function ' + minimizer_func + ' is not recognized!') + + # Format and print summary and fitted parameters + tstop = time.time() + trial_model = globdict['trial_model'] + + if not quiet: + print("\ntime: %f s" % (tstop - tstart)) + print("\nFitted Parameters:") + if minimizer_func not in ['dynesty_static','dynesty_dynamic','pymc3']: + out = res.x + set_params(out[:n_params], trial_model, param_map, minimizer_func, model_prior) + gains = out[n_params:(n_params + n_gains)] + leakage = out[(n_params + n_gains):] + update_leakage(leakage) + tparams = transform_params(out[:n_params], param_map, minimizer_func, model_prior) + if not quiet: + cur_idx = -1 + if len(param_map): + print('Model Parameters:') + for j in range(len(param_map)): + if param_map[j][0] != cur_idx: + cur_idx = param_map[j][0] + print(model_init.models[cur_idx] + ' (component %d/%d):' % (cur_idx+1,model_init.N_models())) + print(('\t' + param_map[j][1] + ': %f ' + param_map[j][3]) % (tparams[j] * param_map[j][2]/param_map[j][4])) + print('\n') + + if len(leakage_fit): + print('Leakage (%; re, im):') + for j in range(len(leakage_fit)): + print('\t' + leakage_fit[j][0] + ', ' + leakage_fit[j][1] + ': %2.2f %2.2f' % (leakage[2*j]*100,leakage[2*j + 1]*100)) + print('\n') + + print("Final Chi^2_1: %f Chi^2_2: %f Chi^2_3: %f Chi^2_4: %f" % chisq_list(gains)) + print("J: %f" % res.fun) + print(res.message) + else: + if not quiet: + cur_idx = -1 + if len(param_map): + print('Model Parameters (mean and std):') + for j in range(len(param_map)): + if param_map[j][0] != cur_idx: + cur_idx = param_map[j][0] + print(model_init.models[cur_idx] + ' (component %d/%d):' % (cur_idx+1,model_init.N_models())) + print(('\t' + param_map[j][1] + ': %f +/- %f ' + param_map[j][3]) % (ret['mean'][j] * param_map[j][2]/param_map[j][4], ret['std'][j] * param_map[j][2]/param_map[j][4])) + print('\n') + + if len(leakage_fit): + print('Leakage (%; re, im):') + for j in range(len(leakage_fit)): + j2 = 2*j + n_params + n_gains + print(('\t' + leakage_fit[j][0] + ', ' + leakage_fit[j][1] + + ': %2.2f +/- %2.2f, %2.2f +/- %2.2f') + % (mean[j2]*100,cov[j2,j2]**0.5 * 100,mean[j2+1]*100,cov[j2+1,j2+1]**0.5 * 100)) + print('\n') + + # Return fitted model + ret['model'] = trial_model + ret['param_map'] = param_map + ret['chisq_list'] = chisq_list(gains) + try: ret['res'] = res + except: pass + + if fit_gains: + ret['gains'] = gains + + # Create and return a caltable + caldict = {} + for site in set(np.array(gain_list)[:,1]): + caldict[site] = [] + + for j in range(len(gains)): + caldict[gain_list[j][1]].append((gain_list[j][0], (1.0 + gains[j]), (1.0 + gains[j]))) + + for site in caldict.keys(): + caldict[site] = np.array(caldict[site], dtype=DTCAL) + + ct = caltable.Caltable(Obsdata.ra, Obsdata.dec, Obsdata.rf, Obsdata.bw, caldict, Obsdata.tarr, + source=Obsdata.source, mjd=Obsdata.mjd, timetype=Obsdata.timetype) + ret['caltable'] = ct + + # If relevant, return useful quantities associated with the leakage + if station_leakages is not None: + ret['station_leakages'] = station_leakages + tarr = Obsdata.tarr.copy() + for s in station_leakages.keys(): + if 'R' in station_leakages[s].keys(): tarr[Obsdata.tkey[s]]['dr'] = station_leakages[s]['R'] + if 'L' in station_leakages[s].keys(): tarr[Obsdata.tkey[s]]['dl'] = station_leakages[s]['L'] + ret['tarr'] = tarr + + return ret + +################################################################################################## +# Wrapper Functions +################################################################################################## + +def chisq(model, uv, data, sigma, dtype, pol='I', jonesdict=None): + """return the chi^2 for the appropriate dtype + """ + + chisq = 1 + if not dtype in DATATERMS: + return chisq + + if dtype == 'vis': + chisq = chisq_vis(model, uv, data, sigma, pol=pol, jonesdict=jonesdict) + elif dtype == 'amp': + chisq = chisq_amp(model, uv, data, sigma, pol=pol, jonesdict=jonesdict) + elif dtype == 'logamp': + chisq = chisq_logamp(model, uv, data, sigma, pol=pol, jonesdict=jonesdict) + elif dtype == 'bs': + chisq = chisq_bs(model, uv, data, sigma, pol=pol, jonesdict=jonesdict) + elif dtype == 'cphase': + chisq = chisq_cphase(model, uv, data, sigma, pol=pol, jonesdict=jonesdict) + elif dtype == 'cphase_diag': + chisq = chisq_cphase_diag(model, uv, data, sigma, pol=pol, jonesdict=jonesdict) + elif dtype == 'camp': + chisq = chisq_camp(model, uv, data, sigma, pol=pol, jonesdict=jonesdict) + elif dtype == 'logcamp': + chisq = chisq_logcamp(model, uv, data, sigma, pol=pol, jonesdict=jonesdict) + elif dtype == 'logcamp_diag': + chisq = chisq_logcamp_diag(model, uv, data, sigma, pol=pol, jonesdict=jonesdict) + elif dtype == 'pvis': + chisq = chisq_pvis(model, uv, data, sigma, jonesdict=jonesdict) + elif dtype == 'm': + chisq = chisq_m(model, uv, data, sigma, jonesdict=jonesdict) + elif dtype in ['rrll','llrr','rlrr','rlll','lrrr','lrll']: + chisq = chisq_fracpol(dtype[:2],dtype[2:],model, uv, data, sigma, jonesdict=jonesdict) + elif dtype == 'polclosure': + chisq = chisq_polclosure(model, uv, data, sigma, jonesdict=jonesdict) + + return chisq + +def chisqgrad(model, uv, data, sigma, jonesdict, dtype, param_mask, pol='I', fit_gains=False, gains=None, gains_t1=None, gains_t2=None, fit_pol=False, fit_cpol=False, fit_leakage=False): + """return the chi^2 gradient for the appropriate dtype + """ + global globdict + + n_chisqgrad = len(param_mask) + if fit_leakage: + n_chisqgrad += 2*len(globdict['leakage_fit']) + + chisqgrad = np.zeros(n_chisqgrad) + if fit_gains: + gaingrad = np.zeros_like(gains) + else: + gaingrad = np.array([]) + + # Now we need to be sure to put the gradient in the correct order: model parameters, then gains, then leakage + param_mask_full = np.zeros(len(chisqgrad), dtype=bool) + leakage_mask_full = np.zeros(len(chisqgrad), dtype=bool) + param_mask_full[:len(param_mask)] = param_mask + leakage_mask_full[len(param_mask):] = ~leakage_mask_full[len(param_mask):] + + if not dtype in DATATERMS: + return np.concatenate([chisqgrad[param_mask_full],gaingrad,chisqgrad[leakage_mask_full]]) + + if dtype == 'vis': + chisqgrad = chisqgrad_vis(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + elif dtype == 'amp': + chisqgrad = chisqgrad_amp(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + + if fit_gains: + i1 = model.sample_uv(uv[:,0],uv[:,1], pol=pol, jonesdict=jonesdict) + amp_samples = np.abs(i1) + amp = data + pp = ((amp - amp_samples) * amp_samples) / (sigma**2) + gaingrad = 2.0/(1.0 + np.array(gains)) * np.array([np.sum(pp[(np.array(gains_t1) == j) + (np.array(gains_t2) == j)]) for j in range(len(gains))])/len(data) + elif dtype == 'logamp': + chisqgrad = chisqgrad_logamp(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + elif dtype == 'bs': + chisqgrad = chisqgrad_bs(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + elif dtype == 'cphase': + chisqgrad = chisqgrad_cphase(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + elif dtype == 'cphase_diag': + chisqgrad = chisqgrad_cphase_diag(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + elif dtype == 'camp': + chisqgrad = chisqgrad_camp(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + elif dtype == 'logcamp': + chisqgrad = chisqgrad_logcamp(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + elif dtype == 'logcamp_diag': + chisqgrad = chisqgrad_logcamp_diag(model, uv, data, sigma, pol=pol, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + elif dtype == 'pvis': + chisqgrad = chisqgrad_pvis(model, uv, data, sigma, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + elif dtype == 'm': + chisqgrad = chisqgrad_m(model, uv, data, sigma, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + elif dtype in ['rrll','llrr','rlrr','rlll','lrrr','lrll']: + chisqgrad = chisqgrad_fracpol(dtype[:2],dtype[2:],model, uv, data, sigma, jonesdict=jonesdict, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage) + elif dtype == 'polclosure': + chisqgrad = chisqgrad_polclosure(model, uv, data, sigma, fit_pol=fit_pol, fit_cpol=fit_cpol, fit_leakage=fit_leakage, jonesdict=jonesdict) + + return np.concatenate([chisqgrad[param_mask_full],gaingrad,chisqgrad[leakage_mask_full]]) + +def chisqdata(Obsdata, dtype, pol='I', **kwargs): + + """Return the data, sigma, and matrices for the appropriate dtype + """ + + (data, sigma, uv, jonesdict) = (False, False, False, None) + + if dtype == 'vis': + (data, sigma, uv, jonesdict) = chisqdata_vis(Obsdata, pol=pol, **kwargs) + elif dtype == 'amp' or dtype == 'logamp': + (data, sigma, uv, jonesdict) = chisqdata_amp(Obsdata, pol=pol,**kwargs) + elif dtype == 'bs': + (data, sigma, uv, jonesdict) = chisqdata_bs(Obsdata, pol=pol,**kwargs) + elif dtype == 'cphase': + (data, sigma, uv, jonesdict) = chisqdata_cphase(Obsdata, pol=pol,**kwargs) + elif dtype == 'cphase_diag': + (data, sigma, uv, jonesdict) = chisqdata_cphase_diag(Obsdata, pol=pol,**kwargs) + elif dtype == 'camp': + (data, sigma, uv, jonesdict) = chisqdata_camp(Obsdata, pol=pol,**kwargs) + elif dtype == 'logcamp': + (data, sigma, uv, jonesdict) = chisqdata_logcamp(Obsdata, pol=pol,**kwargs) + elif dtype == 'logcamp_diag': + (data, sigma, uv, jonesdict) = chisqdata_logcamp_diag(Obsdata, pol=pol,**kwargs) + elif dtype == 'pvis': + (data, sigma, uv, jonesdict) = chisqdata_pvis(Obsdata, pol=pol,**kwargs) + elif dtype == 'm': + (data, sigma, uv, jonesdict) = chisqdata_m(Obsdata, pol=pol,**kwargs) + elif dtype in ['rrll','llrr','rlrr','rlll','lrrr','lrll']: + (data, sigma, uv, jonesdict) = chisqdata_fracpol(Obsdata,dtype[:2],dtype[2:],jonesdict=jonesdict) + elif dtype == 'polclosure': + (data, sigma, uv, jonesdict) = chisqdata_polclosure(Obsdata,jonesdict=jonesdict) + + return (data, sigma, uv, jonesdict) + + +################################################################################################## +# Chi-squared and Gradient Functions +################################################################################################## + +def chisq_vis(model, uv, vis, sigma, pol='I', jonesdict=None): + """Visibility chi-squared""" + + samples = model.sample_uv(uv[:,0],uv[:,1], pol=pol, jonesdict=jonesdict) + return np.sum(np.abs((samples-vis)/sigma)**2)/(2*len(vis)) + +def chisqgrad_vis(model, uv, vis, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the visibility chi-squared""" + samples = model.sample_uv(uv[:,0],uv[:,1], pol=pol, jonesdict=jonesdict) + wdiff = (vis - samples)/(sigma**2) + grad = model.sample_grad_uv(uv[:,0],uv[:,1],fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + + out = -np.real(np.dot(grad.conj(), wdiff))/len(vis) + return out + +def chisq_amp(model, uv, amp, sigma, pol='I', jonesdict=None): + """Visibility Amplitudes (normalized) chi-squared""" + + amp_samples = np.abs(model.sample_uv(uv[:,0],uv[:,1], pol=pol, jonesdict=jonesdict)) + return np.sum(np.abs((amp - amp_samples)/sigma)**2)/len(amp) + +def chisqgrad_amp(model, uv, amp, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the amplitude chi-squared""" + + i1 = model.sample_uv(uv[:,0],uv[:,1], jonesdict=jonesdict) + amp_samples = np.abs(i1) + + pp = ((amp - amp_samples) * amp_samples) / (sigma**2) / i1 + grad = model.sample_grad_uv(uv[:,0],uv[:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + out = (-2.0/len(amp)) * np.real(np.dot(grad, pp)) + return out + +def chisq_bs(model, uv, bis, sigma, pol='I', jonesdict=None): + """Bispectrum chi-squared""" + + bisamples = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict) * model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict) * model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict) + chisq= np.sum(np.abs(((bis - bisamples)/sigma))**2)/(2.*len(bis)) + return chisq + +def chisqgrad_bs(model, uv, bis, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the bispectrum chi-squared""" + + V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict) + V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict) + V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict) + V1_grad = model.sample_grad_uv(uv[0][:,0],uv[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + V2_grad = model.sample_grad_uv(uv[1][:,0],uv[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + V3_grad = model.sample_grad_uv(uv[2][:,0],uv[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + bisamples = V1 * V2 * V3 + wdiff = ((bis - bisamples).conj())/(sigma**2) + pt1 = wdiff * V2 * V3 + pt2 = wdiff * V1 * V3 + pt3 = wdiff * V1 * V2 + out = -np.real(np.dot(pt1, V1_grad.T) + np.dot(pt2, V2_grad.T) + np.dot(pt3, V3_grad.T))/len(bis) + return out + +def chisq_cphase(model, uv, clphase, sigma, pol='I', jonesdict=None): + """Closure Phases (normalized) chi-squared""" + clphase = clphase * DEGREE + sigma = sigma * DEGREE + + V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict) + V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict) + V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict) + + clphase_samples = np.angle(V1 * V2 * V3) + chisq= (2.0/len(clphase)) * np.sum((1.0 - np.cos(clphase-clphase_samples))/(sigma**2)) + return chisq + +def chisqgrad_cphase(model, uv, clphase, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the closure phase chi-squared""" + clphase = clphase * DEGREE + sigma = sigma * DEGREE + + V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict) + V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict) + V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict) + V1_grad = model.sample_grad_uv(uv[0][:,0],uv[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + V2_grad = model.sample_grad_uv(uv[1][:,0],uv[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + V3_grad = model.sample_grad_uv(uv[2][:,0],uv[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + + clphase_samples = np.angle(V1 * V2 * V3) + + pref = np.sin(clphase - clphase_samples)/(sigma**2) + pt1 = pref/V1 + pt2 = pref/V2 + pt3 = pref/V3 + out = -(2.0/len(clphase)) * np.imag(np.dot(pt1, V1_grad.T) + np.dot(pt2, V2_grad.T) + np.dot(pt3, V3_grad.T)) + return out + +def chisq_cphase_diag(model, uv, clphase_diag, sigma, pol='I', jonesdict=None): + """Diagonalized closure phases (normalized) chi-squared""" + clphase_diag = np.concatenate(clphase_diag) * DEGREE + sigma = np.concatenate(sigma) * DEGREE + + uv_diag = uv[0] + tform_mats = uv[1] + + clphase_diag_samples = [] + for iA, uv3 in enumerate(uv_diag): + i1 = model.sample_uv(uv3[0][:,0],uv3[0][:,1],pol=pol,jonesdict=jonesdict) + i2 = model.sample_uv(uv3[1][:,0],uv3[1][:,1],pol=pol,jonesdict=jonesdict) + i3 = model.sample_uv(uv3[2][:,0],uv3[2][:,1],pol=pol,jonesdict=jonesdict) + + clphase_samples = np.angle(i1 * i2 * i3) + clphase_diag_samples.append(np.dot(tform_mats[iA],clphase_samples)) + clphase_diag_samples = np.concatenate(clphase_diag_samples) + + chisq = (2.0/len(clphase_diag)) * np.sum((1.0 - np.cos(clphase_diag-clphase_diag_samples))/(sigma**2)) + return chisq + +def chisqgrad_cphase_diag(model, uv, clphase_diag, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the diagonalized closure phase chi-squared""" + clphase_diag = clphase_diag * DEGREE + sigma = sigma * DEGREE + + uv_diag = uv[0] + tform_mats = uv[1] + + deriv = np.zeros(len(model.sample_grad_uv(0,0,pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict))) + for iA, uv3 in enumerate(uv_diag): + + i1 = model.sample_uv(uv3[0][:,0],uv3[0][:,1],pol=pol,jonesdict=jonesdict) + i2 = model.sample_uv(uv3[1][:,0],uv3[1][:,1],pol=pol,jonesdict=jonesdict) + i3 = model.sample_uv(uv3[2][:,0],uv3[2][:,1],pol=pol,jonesdict=jonesdict) + + i1_grad = model.sample_grad_uv(uv3[0][:,0],uv3[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + i2_grad = model.sample_grad_uv(uv3[1][:,0],uv3[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + i3_grad = model.sample_grad_uv(uv3[2][:,0],uv3[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + + clphase_samples = np.angle(i1 * i2 * i3) + clphase_diag_samples = np.dot(tform_mats[iA],clphase_samples) + + clphase_diag_measured = clphase_diag[iA] + clphase_diag_sigma = sigma[iA] + + term1 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples)/(clphase_diag_sigma**2.0)),(tform_mats[iA]/i1)),i1_grad.T) + term2 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples)/(clphase_diag_sigma**2.0)),(tform_mats[iA]/i2)),i2_grad.T) + term3 = np.dot(np.dot((np.sin(clphase_diag_measured-clphase_diag_samples)/(clphase_diag_sigma**2.0)),(tform_mats[iA]/i3)),i3_grad.T) + deriv += -2.0*np.imag(term1 + term2 + term3) + + deriv *= 1.0/np.float(len(np.concatenate(clphase_diag))) + + return deriv + +def chisq_camp(model, uv, clamp, sigma, pol='I', jonesdict=None): + """Closure Amplitudes (normalized) chi-squared""" + + V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict) + V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict) + V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict) + V4 = model.sample_uv(uv[3][:,0],uv[3][:,1],pol=pol,jonesdict=jonesdict) + + clamp_samples = np.abs(V1 * V2 / (V3 * V4)) + chisq = np.sum(np.abs((clamp - clamp_samples)/sigma)**2)/len(clamp) + return chisq + +def chisqgrad_camp(model, uv, clamp, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the closure amplitude chi-squared""" + + V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict) + V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict) + V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict) + V4 = model.sample_uv(uv[3][:,0],uv[3][:,1],pol=pol,jonesdict=jonesdict) + V1_grad = model.sample_grad_uv(uv[0][:,0],uv[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + V2_grad = model.sample_grad_uv(uv[1][:,0],uv[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + V3_grad = model.sample_grad_uv(uv[2][:,0],uv[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + V4_grad = model.sample_grad_uv(uv[3][:,0],uv[3][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + + clamp_samples = np.abs((V1 * V2)/(V3 * V4)) + + pp = ((clamp - clamp_samples) * clamp_samples)/(sigma**2) + pt1 = pp/V1 + pt2 = pp/V2 + pt3 = -pp/V3 + pt4 = -pp/V4 + out = (-2.0/len(clamp)) * np.real(np.dot(pt1, V1_grad.T) + np.dot(pt2, V2_grad.T) + np.dot(pt3, V3_grad.T) + np.dot(pt4, V4_grad.T)) + return out + +def chisq_logcamp(model, uv, log_clamp, sigma, pol='I', jonesdict=None): + """Log Closure Amplitudes (normalized) chi-squared""" + + a1 = np.abs(model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict)) + a2 = np.abs(model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict)) + a3 = np.abs(model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict)) + a4 = np.abs(model.sample_uv(uv[3][:,0],uv[3][:,1],pol=pol,jonesdict=jonesdict)) + + samples = np.log(a1) + np.log(a2) - np.log(a3) - np.log(a4) + chisq = np.sum(np.abs((log_clamp - samples)/sigma)**2) / (len(log_clamp)) + return chisq + +def chisqgrad_logcamp(model, uv, log_clamp, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the Log closure amplitude chi-squared""" + + V1 = model.sample_uv(uv[0][:,0],uv[0][:,1],pol=pol,jonesdict=jonesdict) + V2 = model.sample_uv(uv[1][:,0],uv[1][:,1],pol=pol,jonesdict=jonesdict) + V3 = model.sample_uv(uv[2][:,0],uv[2][:,1],pol=pol,jonesdict=jonesdict) + V4 = model.sample_uv(uv[3][:,0],uv[3][:,1],pol=pol,jonesdict=jonesdict) + V1_grad = model.sample_grad_uv(uv[0][:,0],uv[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + V2_grad = model.sample_grad_uv(uv[1][:,0],uv[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + V3_grad = model.sample_grad_uv(uv[2][:,0],uv[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + V4_grad = model.sample_grad_uv(uv[3][:,0],uv[3][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + + log_clamp_samples = np.log(np.abs(V1)) + np.log(np.abs(V2)) - np.log(np.abs(V3)) - np.log(np.abs(V4)) + + pp = (log_clamp - log_clamp_samples) / (sigma**2) + pt1 = pp / V1 + pt2 = pp / V2 + pt3 = -pp / V3 + pt4 = -pp / V4 + out = (-2.0/len(log_clamp)) * np.real(np.dot(pt1, V1_grad.T) + np.dot(pt2, V2_grad.T) + np.dot(pt3, V3_grad.T) + np.dot(pt4, V4_grad.T)) + return out + +def chisq_logcamp_diag(model, uv, log_clamp_diag, sigma, pol='I', jonesdict=None): + """Diagonalized log closure amplitudes (normalized) chi-squared""" + + log_clamp_diag = np.concatenate(log_clamp_diag) + sigma = np.concatenate(sigma) + + uv_diag = uv[0] + tform_mats = uv[1] + + log_clamp_diag_samples = [] + for iA, uv4 in enumerate(uv_diag): + + a1 = np.abs(model.sample_uv(uv4[0][:,0],uv4[0][:,1],pol=pol,jonesdict=jonesdict)) + a2 = np.abs(model.sample_uv(uv4[1][:,0],uv4[1][:,1],pol=pol,jonesdict=jonesdict)) + a3 = np.abs(model.sample_uv(uv4[2][:,0],uv4[2][:,1],pol=pol,jonesdict=jonesdict)) + a4 = np.abs(model.sample_uv(uv4[3][:,0],uv4[3][:,1],pol=pol,jonesdict=jonesdict)) + + log_clamp_samples = np.log(a1) + np.log(a2) - np.log(a3) - np.log(a4) + log_clamp_diag_samples.append(np.dot(tform_mats[iA],log_clamp_samples)) + + log_clamp_diag_samples = np.concatenate(log_clamp_diag_samples) + + chisq = np.sum(np.abs((log_clamp_diag - log_clamp_diag_samples)/sigma)**2) / (len(log_clamp_diag)) + return chisq + +def chisqgrad_logcamp_diag(model, uv, log_clamp_diag, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the diagonalized log closure amplitude chi-squared""" + + uv_diag = uv[0] + tform_mats = uv[1] + + deriv = np.zeros(len(model.sample_grad_uv(0,0,pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict))) + for iA, uv4 in enumerate(uv_diag): + + i1 = model.sample_uv(uv4[0][:,0],uv4[0][:,1],pol=pol,jonesdict=jonesdict) + i2 = model.sample_uv(uv4[1][:,0],uv4[1][:,1],pol=pol,jonesdict=jonesdict) + i3 = model.sample_uv(uv4[2][:,0],uv4[2][:,1],pol=pol,jonesdict=jonesdict) + i4 = model.sample_uv(uv4[3][:,0],uv4[3][:,1],pol=pol,jonesdict=jonesdict) + + i1_grad = model.sample_grad_uv(uv4[0][:,0],uv4[0][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + i2_grad = model.sample_grad_uv(uv4[1][:,0],uv4[1][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + i3_grad = model.sample_grad_uv(uv4[2][:,0],uv4[2][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + i4_grad = model.sample_grad_uv(uv4[3][:,0],uv4[3][:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + + log_clamp_samples = np.log(np.abs(i1)) + np.log(np.abs(i2)) - np.log(np.abs(i3)) - np.log(np.abs(i4)) + log_clamp_diag_samples = np.dot(tform_mats[iA],log_clamp_samples) + + log_clamp_diag_measured = log_clamp_diag[iA] + log_clamp_diag_sigma = sigma[iA] + + term1 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples)/(log_clamp_diag_sigma**2.0)),(tform_mats[iA]/i1)),i1_grad.T) + term2 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples)/(log_clamp_diag_sigma**2.0)),(tform_mats[iA]/i2)),i2_grad.T) + term3 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples)/(log_clamp_diag_sigma**2.0)),(tform_mats[iA]/i3)),i3_grad.T) + term4 = np.dot(np.dot(((log_clamp_diag_measured-log_clamp_diag_samples)/(log_clamp_diag_sigma**2.0)),(tform_mats[iA]/i4)),i4_grad.T) + deriv += -2.0*np.real(term1 + term2 - term3 - term4) + + deriv *= 1.0/np.float(len(np.concatenate(log_clamp_diag))) + + return deriv + +def chisq_logamp(model, uv, amp, sigma, pol='I', jonesdict=None): + """Log Visibility Amplitudes (normalized) chi-squared""" + + # to lowest order the variance on the logarithm of a quantity x is + # sigma^2_log(x) = sigma^2/x^2 + logsigma = sigma / amp + + amp_samples = np.abs(model.sample_uv(uv[:,0],uv[:,1],pol=pol,jonesdict=jonesdict)) + return np.sum(np.abs((np.log(amp) - np.log(amp_samples))/logsigma)**2)/len(amp) + +def chisqgrad_logamp(model, uv, amp, sigma, pol='I', fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the Log amplitude chi-squared""" + + # to lowest order the variance on the logarithm of a quantity x is + # sigma^2_log(x) = sigma^2/x^2 + logsigma = sigma / amp + + i1 = model.sample_uv(uv[:,0],uv[:,1],pol=pol,jonesdict=jonesdict) + amp_samples = np.abs(i1) + + V_grad = model.sample_grad_uv(uv[:,0],uv[:,1],pol=pol,fit_pol=fit_pol,fit_cpol=fit_cpol,jonesdict=jonesdict) + + pp = ((np.log(amp) - np.log(amp_samples))) / (logsigma**2) / i1 + out = (-2.0/len(amp)) * np.real(np.dot(pp, V_grad.T)) + return out + + +def chisq_pvis(model, uv, pvis, psigma, jonesdict=None): + """Polarimetric visibility chi-squared + """ + + psamples = model.sample_uv(uv[:,0],uv[:,1],pol='P',jonesdict=jonesdict) + return np.sum(np.abs((psamples-pvis)/psigma)**2)/(2*len(pvis)) + +def chisqgrad_pvis(model, uv, pvis, psigma, fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """Polarimetric visibility chi-squared gradient + """ + samples = model.sample_uv(uv[:,0],uv[:,1],pol='P',jonesdict=jonesdict) + wdiff = (pvis - samples)/(psigma**2) + grad = model.sample_grad_uv(uv[:,0],uv[:,1],pol='P',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + + out = -np.real(np.dot(grad.conj(), wdiff))/len(pvis) + return out + +def chisq_m(model, uv, m, msigma, jonesdict=None): + """Polarimetric ratio chi-squared + """ + + msamples = model.sample_uv(uv[:,0],uv[:,1],pol='P',jonesdict=jonesdict)/model.sample_uv(uv[:,0],uv[:,1],pol='I',jonesdict=jonesdict) + + return np.sum(np.abs((m - msamples))**2/(msigma**2)) / (2*len(m)) + +def chisqgrad_m(model, uv, mvis, msigma, fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the polarimetric ratio chisq + """ + + samp_P = model.sample_uv(uv[:,0],uv[:,1],pol='P',jonesdict=jonesdict) + samp_I = model.sample_uv(uv[:,0],uv[:,1],pol='I',jonesdict=jonesdict) + grad_P = model.sample_grad_uv(uv[:,0],uv[:,1],pol='P',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + grad_I = model.sample_grad_uv(uv[:,0],uv[:,1],pol='I',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + + msamples = samp_P/samp_I + wdiff = (mvis - msamples)/(msigma**2) + # Get the gradient from the quotient rule + grad = ( grad_P * samp_I - grad_I * samp_P)/samp_I**2 + + return -np.real(np.dot(grad.conj(), wdiff))/len(mvis) + +def chisq_fracpol(upper, lower, model, uv, m, msigma, jonesdict=None): + """Polarimetric ratio chi-squared + """ + + msamples = model.sample_uv(uv[:,0],uv[:,1],pol=upper.upper(),jonesdict=jonesdict)/model.sample_uv(uv[:,0],uv[:,1],pol=lower.upper(),jonesdict=jonesdict) + + return np.sum(np.abs((m - msamples))**2/(msigma**2)) / (2*len(m)) + +def chisqgrad_fracpol(upper, lower, model, uv, mvis, msigma, fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the polarimetric ratio chisq + """ + + samp_upper = model.sample_uv(uv[:,0],uv[:,1],pol=upper.upper(),jonesdict=jonesdict) + samp_lower = model.sample_uv(uv[:,0],uv[:,1],pol=lower.upper(),jonesdict=jonesdict) + grad_upper = model.sample_grad_uv(uv[:,0],uv[:,1],pol=upper.upper(),fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + grad_lower = model.sample_grad_uv(uv[:,0],uv[:,1],pol=lower.upper(),fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + + msamples = samp_upper/samp_lower + wdiff = (mvis - msamples)/(msigma**2) + # Get the gradient from the quotient rule + grad = ( grad_upper * samp_lower - grad_lower * samp_upper)/samp_lower**2 + + return -np.real(np.dot(grad.conj(), wdiff))/len(mvis) + +def chisq_polclosure(model, uv, vis, sigma, jonesdict=None): + """Polarimetric ratio chi-squared + """ + + RL = model.sample_uv(uv[:,0],uv[:,1],pol='RL',jonesdict=jonesdict) + LR = model.sample_uv(uv[:,0],uv[:,1],pol='LR',jonesdict=jonesdict) + RR = model.sample_uv(uv[:,0],uv[:,1],pol='RR',jonesdict=jonesdict) + LL = model.sample_uv(uv[:,0],uv[:,1],pol='LL',jonesdict=jonesdict) + samples = (RL * LR)/(RR * LL) + + return np.sum(np.abs((vis - samples))**2/(sigma**2)) / (2*len(vis)) + +def chisqgrad_polclosure(model, uv, vis, sigma, fit_pol=False, fit_cpol=False, fit_leakage=False, jonesdict=None): + """The gradient of the polarimetric ratio chisq + """ + + RL = model.sample_uv(uv[:,0],uv[:,1],pol='RL',jonesdict=jonesdict) + LR = model.sample_uv(uv[:,0],uv[:,1],pol='LR',jonesdict=jonesdict) + RR = model.sample_uv(uv[:,0],uv[:,1],pol='RR',jonesdict=jonesdict) + LL = model.sample_uv(uv[:,0],uv[:,1],pol='LL',jonesdict=jonesdict) + + dRL = model.sample_grad_uv(uv[:,0],uv[:,1],pol='RL',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + dLR = model.sample_grad_uv(uv[:,0],uv[:,1],pol='LR',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + dRR = model.sample_grad_uv(uv[:,0],uv[:,1],pol='RR',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + dLL = model.sample_grad_uv(uv[:,0],uv[:,1],pol='LL',fit_pol=fit_pol,fit_cpol=fit_cpol,fit_leakage=fit_leakage,jonesdict=jonesdict) + + samples = (RL * LR)/(RR * LL) + wdiff = (vis - samples)/(sigma**2) + + # Get the gradient from the quotient rule + samp_upper = RL * LR + samp_lower = RR * LL + grad_upper = RL * dLR + dRL * LR + grad_lower = RR * dLL + dRR * LL + grad = ( grad_upper * samp_lower - grad_lower * samp_upper)/samp_lower**2 + + return -np.real(np.dot(grad.conj(), wdiff))/len(vis) + + +################################################################################################## +# Chi^2 Data functions +################################################################################################## +def apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol): + """apply systematic noise to VISIBILITIES or AMPLITUDES + data_arr should have fields 't1','t2','u','v','vis','amp','sigma' + + returns: (uv, vis, amp, sigma) + """ + + vtype=vis_poldict[pol] + atype=amp_poldict[pol] + etype=sig_poldict[pol] + + t1 = data_arr['t1'] + t2 = data_arr['t2'] + + sigma = data_arr[etype] + amp = data_arr[atype] + try: + vis = data_arr[vtype] + except ValueError: + vis = amp.astype('c16') + + snrmask = np.abs(amp/sigma) >= snrcut + + if type(systematic_noise) is dict: + sys_level = np.zeros(len(t1)) + for i in range(len(t1)): + if t1[i] in systematic_noise.keys(): + t1sys = systematic_noise[t1[i]] + else: + t1sys = 0. + if t2[i] in systematic_noise.keys(): + t2sys = systematic_noise[t2[i]] + else: + t2sys = 0. + + if t1sys<0 or t2sys<0: + sys_level[i] = -1 + else: + sys_level[i] = np.sqrt(t1sys**2 + t2sys**2) + else: + sys_level = np.sqrt(2)*systematic_noise*np.ones(len(t1)) + + mask = sys_level>=0. + mask = snrmask * mask + + sigma = np.linalg.norm([sigma, sys_level*np.abs(amp)], axis=0)[mask] + vis = vis[mask] + amp = amp[mask] + uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1)))[mask] + return (uv, vis, amp, sigma) + +def make_jonesdict(Obsdata, data_arr): + # Make a dictionary with entries needed to form the Jones matrices + # Currently, this only works for data types on a single baseline (e.g., closure quantities aren't supported yet) + + # Get the names of each station for every measurement + t1 = data_arr['t1'] + t2 = data_arr['t2'] + + # Get the elevation of each station + el1 = data_arr['el1']*np.pi/180. + el2 = data_arr['el2']*np.pi/180. + + # Get the parallactic angle of each station + par1 = data_arr['par_ang1']*np.pi/180. + par2 = data_arr['par_ang2']*np.pi/180. + + # Compute the full field rotation angle for each site, based information in the Obsdata Array + fr_elev1 = np.array([Obsdata.tarr[Obsdata.tkey[o['t1']]]['fr_elev'] for o in data_arr]) + fr_elev2 = np.array([Obsdata.tarr[Obsdata.tkey[o['t2']]]['fr_elev'] for o in data_arr]) + fr_par1 = np.array([Obsdata.tarr[Obsdata.tkey[o['t1']]]['fr_par'] for o in data_arr]) + fr_par2 = np.array([Obsdata.tarr[Obsdata.tkey[o['t2']]]['fr_par'] for o in data_arr]) + fr_off1 = np.array([Obsdata.tarr[Obsdata.tkey[o['t1']]]['fr_off'] for o in data_arr]) + fr_off2 = np.array([Obsdata.tarr[Obsdata.tkey[o['t2']]]['fr_off'] for o in data_arr]) + fr1 = fr_elev1*el1 + fr_par1*par1 + fr_off1*np.pi/180. + fr2 = fr_elev2*el2 + fr_par2*par2 + fr_off2*np.pi/180. + + # Now populate the left and right D-term entries based on the Obsdata Array + DR1 = np.array([Obsdata.tarr[Obsdata.tkey[o['t1']]]['dr'] for o in data_arr]) + DL1 = np.array([Obsdata.tarr[Obsdata.tkey[o['t1']]]['dl'] for o in data_arr]) + DR2 = np.array([Obsdata.tarr[Obsdata.tkey[o['t2']]]['dr'] for o in data_arr]) + DL2 = np.array([Obsdata.tarr[Obsdata.tkey[o['t2']]]['dl'] for o in data_arr]) + + return {'fr1':fr1,'fr2':fr2,'t1':t1,'t2':t2, + 'DR1':DR1, 'DR2':DR2, 'DL1':DL1, 'DL2':DL2} + +def chisqdata_vis(Obsdata, pol='I', **kwargs): + """Return the data, sigmas, and fourier matrix for visibilities + """ + + # unpack keyword args + systematic_noise = kwargs.get('systematic_noise',0.) + snrcut = kwargs.get('snrcut',0.) + debias = kwargs.get('debias',True) + weighting = kwargs.get('weighting','natural') + + # unpack data + vtype=vis_poldict[pol] + atype=amp_poldict[pol] + etype=sig_poldict[pol] + data_arr = Obsdata.unpack(['t1','t2','u','v',vtype,atype,etype,'el1','el2','par_ang1','par_ang2'], debias=debias) + (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol) + + jonesdict = make_jonesdict(Obsdata, data_arr) + + return (vis, sigma, uv, jonesdict) + +def chisqdata_amp(Obsdata, pol='I',**kwargs): + """Return the data, sigmas, and fourier matrix for visibility amplitudes + """ + + # unpack keyword args + systematic_noise = kwargs.get('systematic_noise',0.) + snrcut = kwargs.get('snrcut',0.) + debias = kwargs.get('debias',True) + weighting = kwargs.get('weighting','natural') + + # unpack data + vtype=vis_poldict[pol] + atype=amp_poldict[pol] + etype=sig_poldict[pol] + if (Obsdata.amp is None) or (len(Obsdata.amp)==0) or pol!='I': + data_arr = Obsdata.unpack(['time','t1','t2','u','v',vtype,atype,etype,'el1','el2','par_ang1','par_ang2'], debias=debias) + + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed amplitude table in amplitude chi^2!") + if not type(Obsdata.amp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed amplitude table is not a numpy rec array!") + data_arr = Obsdata.amp + + + # apply systematic noise and SNR cut + # TODO -- after pre-computed?? + (uv, vis, amp, sigma) = apply_systematic_noise_snrcut(data_arr, systematic_noise, snrcut, pol) + + # data weighting + if weighting=='uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + jonesdict = make_jonesdict(Obsdata, data_arr) + + return (amp, sigma, uv, jonesdict) + +def chisqdata_bs(Obsdata, pol='I',**kwargs): + """return the data, sigmas, and fourier matrices for bispectra + """ + + # unpack keyword args + #systematic_noise = kwargs.get('systematic_noise',0.) #this will break with a systematic noise dict + maxset = kwargs.get('maxset',False) + if maxset: count='max' + else: count='min' + + snrcut = kwargs.get('snrcut',0.) + debias = kwargs.get('debias',True) + weighting = kwargs.get('weighting','natural') + + # unpack data + vtype=vis_poldict[pol] + if (Obsdata.bispec is None) or (len(Obsdata.bispec)==0) or pol!='I': + biarr = Obsdata.bispectra(mode="all", vtype=vtype, count=count,snrcut=snrcut) + + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed bispectrum table in cphase chi^2!") + if not type(Obsdata.bispec) in [np.ndarray, np.recarray]: + raise Exception("pre-computed bispectrum table is not a numpy rec array!") + biarr = Obsdata.bispec + # reduce to a minimal set + if count!='max': + biarr = reduce_tri_minimal(Obsdata, biarr) + + uv1 = np.hstack((biarr['u1'].reshape(-1,1), biarr['v1'].reshape(-1,1))) + uv2 = np.hstack((biarr['u2'].reshape(-1,1), biarr['v2'].reshape(-1,1))) + uv3 = np.hstack((biarr['u3'].reshape(-1,1), biarr['v3'].reshape(-1,1))) + bi = biarr['bispec'] + sigma = biarr['sigmab'] + + #add systematic noise + #sigma = np.linalg.norm([biarr['sigmab'], systematic_noise*np.abs(biarr['bispec'])], axis=0) + + # data weighting + if weighting=='uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + return (bi, sigma, (uv1, uv2, uv3), None) + +def chisqdata_cphase(Obsdata, pol='I',**kwargs): + """Return the data, sigmas, and fourier matrices for closure phases + """ + + # unpack keyword args + maxset = kwargs.get('maxset',False) + uv_min = kwargs.get('cp_uv_min', False) + if maxset: count='max' + else: count='min' + + snrcut = kwargs.get('snrcut',0.) + systematic_cphase_noise = kwargs.get('systematic_cphase_noise',0.) + weighting = kwargs.get('weighting','natural') + + # unpack data + vtype=vis_poldict[pol] + if (Obsdata.cphase is None) or (len(Obsdata.cphase)==0) or pol!='I': + clphasearr = Obsdata.c_phases(mode="all", vtype=vtype, count=count, uv_min=uv_min, snrcut=snrcut) + else: #TODO precomputed with not Stokes I + print("Using pre-computed cphase table in cphase chi^2!") + if not type(Obsdata.cphase) in [np.ndarray, np.recarray]: + raise Exception("pre-computed closure phase table is not a numpy rec array!") + clphasearr = Obsdata.cphase + # reduce to a minimal set + if count!='max': + clphasearr = reduce_tri_minimal(Obsdata, clphasearr) + + uv1 = np.hstack((clphasearr['u1'].reshape(-1,1), clphasearr['v1'].reshape(-1,1))) + uv2 = np.hstack((clphasearr['u2'].reshape(-1,1), clphasearr['v2'].reshape(-1,1))) + uv3 = np.hstack((clphasearr['u3'].reshape(-1,1), clphasearr['v3'].reshape(-1,1))) + clphase = clphasearr['cphase'] + sigma = clphasearr['sigmacp'] + + #add systematic cphase noise (in DEGREES) + sigma = np.linalg.norm([sigma, systematic_cphase_noise*np.ones(len(sigma))], axis=0) + + # data weighting + if weighting=='uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + return (clphase, sigma, (uv1, uv2, uv3), None) + +def chisqdata_cphase_diag(Obsdata, pol='I',**kwargs): + """Return the data, sigmas, and fourier matrices for diagonalized closure phases + """ + + # unpack keyword args + maxset = kwargs.get('maxset',False) + uv_min = kwargs.get('cp_uv_min', False) + if maxset: count='max' + else: count='min' + + snrcut = kwargs.get('snrcut',0.) + + # unpack data + vtype=vis_poldict[pol] + clphasearr = Obsdata.c_phases_diag(vtype=vtype,count=count,snrcut=snrcut,uv_min=uv_min) + + # loop over timestamps + clphase_diag = [] + sigma_diag = [] + uv_diag = [] + tform_mats = [] + for ic, cl in enumerate(clphasearr): + + # get diagonalized closure phases and errors + clphase_diag.append(cl[0]['cphase']) + sigma_diag.append(cl[0]['sigmacp']) + + # get uv arrays + u1 = cl[2][:,0].astype('float') + v1 = cl[3][:,0].astype('float') + uv1 = np.hstack((u1.reshape(-1,1), v1.reshape(-1,1))) + + u2 = cl[2][:,1].astype('float') + v2 = cl[3][:,1].astype('float') + uv2 = np.hstack((u2.reshape(-1,1), v2.reshape(-1,1))) + + u3 = cl[2][:,2].astype('float') + v3 = cl[3][:,2].astype('float') + uv3 = np.hstack((u3.reshape(-1,1), v3.reshape(-1,1))) + + # compute Fourier matrices + uv = (uv1, + uv2, + uv3 + ) + uv_diag.append(uv) + + # get transformation matrix for this timestamp + tform_mats.append(cl[4].astype('float')) + + # combine Fourier and transformation matrices into tuple for outputting + uvmatrices = (np.array(uv_diag),np.array(tform_mats)) + + return (np.array(clphase_diag), np.array(sigma_diag), uvmatrices, None) + +def chisqdata_camp(Obsdata, pol='I',**kwargs): + """Return the data, sigmas, and fourier matrices for closure amplitudes + """ + # unpack keyword args + maxset = kwargs.get('maxset',False) + if maxset: count='max' + else: count='min' + + snrcut = kwargs.get('snrcut',0.) + debias = kwargs.get('debias',True) + weighting = kwargs.get('weighting','natural') + + # unpack data & mask low snr points + vtype=vis_poldict[pol] + if (Obsdata.camp is None) or (len(Obsdata.camp)==0) or pol!='I': + clamparr = Obsdata.c_amplitudes(mode='all', count=count, ctype='camp', debias=debias, snrcut=snrcut) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed closure amplitude table in closure amplitude chi^2!") + if not type(Obsdata.camp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed closure amplitude table is not a numpy rec array!") + clamparr = Obsdata.camp + # reduce to a minimal set + if count!='max': + clamparr = reduce_quad_minimal(Obsdata, clamparr, ctype='camp') + + uv1 = np.hstack((clamparr['u1'].reshape(-1,1), clamparr['v1'].reshape(-1,1))) + uv2 = np.hstack((clamparr['u2'].reshape(-1,1), clamparr['v2'].reshape(-1,1))) + uv3 = np.hstack((clamparr['u3'].reshape(-1,1), clamparr['v3'].reshape(-1,1))) + uv4 = np.hstack((clamparr['u4'].reshape(-1,1), clamparr['v4'].reshape(-1,1))) + clamp = clamparr['camp'] + sigma = clamparr['sigmaca'] + + # data weighting + if weighting=='uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + return (clamp, sigma, (uv1, uv2, uv3, uv4), None) + +def chisqdata_logcamp(Obsdata, pol='I', **kwargs): + """Return the data, sigmas, and fourier matrices for log closure amplitudes + """ + # unpack keyword args + maxset = kwargs.get('maxset',False) + if maxset: count='max' + else: count='min' + + snrcut = kwargs.get('snrcut',0.) + debias = kwargs.get('debias',True) + weighting = kwargs.get('weighting','natural') + + # unpack data & mask low snr points + vtype=vis_poldict[pol] + if (Obsdata.logcamp is None) or (len(Obsdata.logcamp)==0) or pol!='I': + clamparr = Obsdata.c_amplitudes(mode='all', count=count, vtype=vtype, ctype='logcamp', debias=debias, snrcut=snrcut) + else: # TODO -- pre-computed with not stokes I? + print("Using pre-computed log closure amplitude table in log closure amplitude chi^2!") + if not type(Obsdata.logcamp) in [np.ndarray, np.recarray]: + raise Exception("pre-computed log closure amplitude table is not a numpy rec array!") + clamparr = Obsdata.logcamp + # reduce to a minimal set + if count!='max': + clamparr = reduce_quad_minimal(Obsdata, clamparr, ctype='logcamp') + + uv1 = np.hstack((clamparr['u1'].reshape(-1,1), clamparr['v1'].reshape(-1,1))) + uv2 = np.hstack((clamparr['u2'].reshape(-1,1), clamparr['v2'].reshape(-1,1))) + uv3 = np.hstack((clamparr['u3'].reshape(-1,1), clamparr['v3'].reshape(-1,1))) + uv4 = np.hstack((clamparr['u4'].reshape(-1,1), clamparr['v4'].reshape(-1,1))) + clamp = clamparr['camp'] + sigma = clamparr['sigmaca'] + + # data weighting + if weighting=='uniform': + sigma = np.median(sigma) * np.ones(len(sigma)) + + return (clamp, sigma, (uv1, uv2, uv3, uv4), None) + +def chisqdata_logcamp_diag(Obsdata, pol='I', **kwargs): + """Return the data, sigmas, and fourier matrices for diagonalized log closure amplitudes + """ + # unpack keyword args + maxset = kwargs.get('maxset',False) + if maxset: count='max' + else: count='min' + + snrcut = kwargs.get('snrcut',0.) + debias = kwargs.get('debias',True) + + # unpack data & mask low snr points + vtype=vis_poldict[pol] + clamparr = Obsdata.c_log_amplitudes_diag(vtype=vtype,count=count,debias=debias,snrcut=snrcut) + + # loop over timestamps + clamp_diag = [] + sigma_diag = [] + uv_diag = [] + tform_mats = [] + for ic, cl in enumerate(clamparr): + + # get diagonalized log closure amplitudes and errors + clamp_diag.append(cl[0]['camp']) + sigma_diag.append(cl[0]['sigmaca']) + + # get uv arrays + u1 = cl[2][:,0].astype('float') + v1 = cl[3][:,0].astype('float') + uv1 = np.hstack((u1.reshape(-1,1), v1.reshape(-1,1))) + + u2 = cl[2][:,1].astype('float') + v2 = cl[3][:,1].astype('float') + uv2 = np.hstack((u2.reshape(-1,1), v2.reshape(-1,1))) + + u3 = cl[2][:,2].astype('float') + v3 = cl[3][:,2].astype('float') + uv3 = np.hstack((u3.reshape(-1,1), v3.reshape(-1,1))) + + u4 = cl[2][:,3].astype('float') + v4 = cl[3][:,3].astype('float') + uv4 = np.hstack((u4.reshape(-1,1), v4.reshape(-1,1))) + + # compute Fourier matrices + uv = (uv1, + uv2, + uv3, + uv4 + ) + uv_diag.append(uv) + + # get transformation matrix for this timestamp + tform_mats.append(cl[4].astype('float')) + + # combine Fourier and transformation matrices into tuple for outputting + uvmatrices = (np.array(uv_diag),np.array(tform_mats)) + + return (np.array(clamp_diag), np.array(sigma_diag), uvmatrices, None) + +def chisqdata_pvis(Obsdata, pol='I', **kwargs): + data_arr = Obsdata.unpack(['t1','t2','u','v','pvis','psigma','el1','el2','par_ang1','par_ang2'], conj=True) + uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1))) + mask = np.isfinite(data_arr['pvis'] + data_arr['psigma']) # don't include nan (missing data) or inf (division by zero) + jonesdict = make_jonesdict(Obsdata, data_arr[mask]) + return (data_arr['pvis'][mask], data_arr['psigma'][mask], uv[mask], jonesdict) + +def chisqdata_m(Obsdata, pol='I',**kwargs): + debias = kwargs.get('debias',True) + data_arr = Obsdata.unpack(['t1','t2','u','v','m','msigma','el1','el2','par_ang1','par_ang2'], conj=True, debias=False) + uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1))) + mask = np.isfinite(data_arr['m'] + data_arr['msigma']) # don't include nan (missing data) or inf (division by zero) + jonesdict = make_jonesdict(Obsdata, data_arr[mask]) + return (data_arr['m'][mask], data_arr['msigma'][mask], uv[mask], jonesdict) + +def chisqdata_fracpol(Obsdata, pol_upper,pol_lower,**kwargs): + debias = kwargs.get('debias',True) + data_arr = Obsdata.unpack(['t1','t2','u','v','m','msigma','el1','el2','par_ang1','par_ang2','rrvis','rlvis','lrvis','llvis','rramp','rlamp','lramp','llamp','rrsigma','rlsigma','lrsigma','llsigma'], conj=False, debias=True) + uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1))) + + upper = data_arr[pol_upper + 'vis'] + lower = data_arr[pol_lower + 'vis'] + upper_amp = data_arr[pol_upper + 'amp'] + lower_amp = data_arr[pol_lower + 'amp'] + upper_sig = data_arr[pol_upper + 'sigma'] + lower_sig = data_arr[pol_lower + 'sigma'] + + sig = ((upper_sig/lower_amp)**2 + (lower_sig*upper_amp/lower_amp**2)**2)**0.5 + + # Mask bad data + mask = np.isfinite(upper + lower + sig) # don't include nan (missing data) or inf (division by zero) + jonesdict = make_jonesdict(Obsdata, data_arr[mask]) + + return ((upper/lower)[mask], sig[mask], uv[mask], jonesdict) + +def chisqdata_polclosure(Obsdata, **kwargs): + debias = kwargs.get('debias',True) + data_arr = Obsdata.unpack(['t1','t2','u','v','m','msigma','el1','el2','par_ang1','par_ang2','rrvis','rlvis','lrvis','llvis','rramp','rlamp','lramp','llamp','rrsigma','rlsigma','lrsigma','llsigma'], conj=False, debias=True) + uv = np.hstack((data_arr['u'].reshape(-1,1), data_arr['v'].reshape(-1,1))) + + RL = data_arr['rlvis'] + LR = data_arr['lrvis'] + RR = data_arr['rrvis'] + LL = data_arr['llvis'] + vis = (RL * LR)/(RR * LL) + sig = (np.abs(LR/(LL*RR) * data_arr['rlsigma'])**2 + +np.abs(RL/(LL*RR) * data_arr['lrsigma'])**2 + +np.abs(LR*RL/(RR**2*LL) * data_arr['rrsigma'])**2 + +np.abs(RL*LR/(LL**2*RR) * data_arr['llsigma'])**2)**0.5 + + # Mask bad data + mask = np.isfinite(vis + sig) # don't include nan (missing data) or inf (division by zero) + jonesdict = make_jonesdict(Obsdata, data_arr[mask]) + + return (vis[mask], sig[mask], uv[mask], jonesdict) diff --git a/movie.py b/movie.py new file mode 100644 index 00000000..d0855e15 --- /dev/null +++ b/movie.py @@ -0,0 +1,1935 @@ +# movie.py +# a interferometric movie class +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import string +import numpy as np +import scipy.interpolate +import scipy.ndimage.filters as filt + +import ehtim.image +import ehtim.obsdata +import ehtim.observing.obs_simulate as simobs +import ehtim.io.save +import ehtim.io.load +import ehtim.const_def as ehc +import ehtim.observing.obs_helpers as obsh + +INTERPOLATION_KINDS = ['linear', 'nearest', 'zero', 'slinear', + 'quadratic', 'cubic', 'previous', 'next'] + +################################################################################################### +# Movie object +################################################################################################### + + +class Movie(object): + + """A polarimetric movie (in units of Jy/pixel). + + Attributes: + pulse (function): The function convolved with the pixel values for continuous image + psize (float): The pixel dimension in radians + xdim (int): The number of pixels along the x dimension + ydim (int): The number of pixels along the y dimension + mjd (int): The integer MJD of the image + source (str): The astrophysical source name + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + rf (float): The image frequency in Hz + + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, or RR,LL,LR,RL for Circular + interp (str): Interpolation method, for scipy.interpolate.interp1d 'kind' keyword + (e.g. 'linear', 'nearest', 'quadratic', 'cubic', 'previous', 'next'...) + bounds_error (bool): if False, return nearest frame when outside [start_hr, stop_hr] + + + times (list): The list of frame time stamps in hours + + _movdict (dict): The dictionary with the lists of frames + """ + + def __init__(self, frames, times, psize, ra, dec, + rf=ehc.RF_DEFAULT, polrep='stokes', pol_prim=None, + pulse=ehc.PULSE_DEFAULT, source=ehc.SOURCE_DEFAULT, + mjd=ehc.MJD_DEFAULT, + bounds_error=ehc.BOUNDS_ERROR, interp=ehc.INTERP_DEFAULT): + """A polarimetric image (in units of Jy/pixel). + + Args: + frames (list): The list of 2D frames; each is a Jy/pixel array + times (list): The list of frame time stamps in hours + psize (float): The pixel dimension in radians + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + rf (float): The image frequency in Hz + + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + pulse (function): The function convolved with the pixel values for continuous image. + source (str): The source name + mjd (int): The integer MJD of the image + interp (str): Interpolation method, for scipy.interpolate.interp1d 'kind' keyword + (e.g. 'linear', 'nearest', 'quadratic', 'cubic', 'previous', 'next'...) + bounds_error (bool): if False, return nearest frame when outside [start_hr, stop_hr] + + Returns: + (Image): the Image object + """ + + if len(frames[0].shape) != 2: + raise Exception("frames must each be a 2D numpy array") + + if len(frames) != len(times): + raise Exception("len(frames) != len(times) !") + + if not (interp in INTERPOLATION_KINDS): + raise Exception( + "'interp' must be a valid argument for scipy.interpolate.interp1d: " + + string.join(INTERPOLATION_KINDS)) + + self.times = times + start_hr = np.min(self.times) + self.mjd = int(mjd) + if start_hr > 24: + self.mjd += int((start_hr - start_hr % 24)/24) + self.start_hr = float(start_hr % 24) + else: + self.start_hr = start_hr + self.stop_hr = np.max(self.times) + self.duration = self.stop_hr - self.start_hr + + # frame shape parameters + self.nframes = len(frames) + self.polrep = polrep + self.pulse = pulse + self.psize = float(psize) + self.xdim = frames[0].shape[1] + self.ydim = frames[0].shape[0] + + # the list of frames + frames = np.array([image.flatten() for image in frames]) + self.interp = interp + self.bounds_error = bounds_error + + fill_value = (frames[0], frames[-1]) + fun = scipy.interpolate.interp1d(self.times, frames.T, kind=interp, + bounds_error=bounds_error, fill_value=fill_value) + + if polrep == 'stokes': + if pol_prim is None: + pol_prim = 'I' + if pol_prim == 'I': + self._movdict = {'I': frames, 'Q': [], 'U': [], 'V': []} + self._fundict = {'I': fun, 'Q': None, 'U': None, 'V': None} + elif pol_prim == 'V': + self._movdict = {'I': [], 'Q': [], 'U': [], 'V': frames} + self._fundict = {'I': None, 'Q': None, 'U': None, 'V': fun} + elif pol_prim == 'Q': + self._movdict = {'I': [], 'Q': frames, 'U': [], 'V': []} + self._fundict = {'I': None, 'Q': fun, 'U': None, 'V': None} + elif pol_prim == 'U': + self._movdict = {'I': [], 'Q': [], 'U': frames, 'V': []} + self._fundict = {'I': None, 'Q': None, 'U': frames, 'V': None} + else: + raise Exception("for polrep=='stokes', pol_prim must be 'I','Q','U', or 'V'!") + + elif polrep == 'circ': + if pol_prim is None: + print("polrep is 'circ' and no pol_prim specified! Setting pol_prim='RR'") + pol_prim = 'RR' + if pol_prim == 'RR': + self._movdict = {'RR': frames, 'LL': [], 'RL': [], 'LR': []} + self._fundict = {'RR': fun, 'LL': None, 'RL': None, 'LR': None} + elif pol_prim == 'LL': + self._movdict = {'RR': [], 'LL': frames, 'RL': [], 'LR': []} + self._fundict = {'RR': None, 'LL': fun, 'RL': None, 'LR': None} + else: + raise Exception("for polrep=='circ', pol_prim must be 'RR' or 'LL'!") + + self.pol_prim = pol_prim + + self.ra = float(ra) + self.dec = float(dec) + self.rf = float(rf) + self.source = str(source) + self.pa = 0.0 # TODO: The pa needs to be properly implemented in the movie object + # TODO: What is this doing?? + + @property + def frames(self): + frames = self._movdict[self.pol_prim] + return frames + + @frames.setter + def frames(self, frames): + if len(frames[0]) != self.xdim*self.ydim: + raise Exception("imvec size is not consistent with xdim*ydim!") + # TODO -- more checks on consistency with the existing pol data??? + + frames = np.array(frames) + self._movdict[self.pol_prim] = frames + + fill_value = (frames[0], frames[-1]) + fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp, + bounds_error=self.bounds_error, fill_value=fill_value) + self._fundict[self.pol_prim] = fun + + @property + def iframes(self): + + if self.polrep != 'stokes': + raise Exception( + "iframes is not defined unless self.polrep=='stokes' -- try self.switch_polrep()") + + frames = self._movdict['I'] + return frames + + @iframes.setter + def iframes(self, frames): + + if len(frames[0]) != self.xdim*self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + + # TODO -- more checks on the consistency of the imvec with the existing pol data??? + frames = np.array(frames) + self._movdict['I'] = frames + fill_value = (frames[0], frames[-1]) + fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp, + bounds_error=self.bounds_error, fill_value=fill_value) + self._fundict['I'] = fun + + @property + def qframes(self): + + if self.polrep != 'stokes': + raise Exception( + "qframes is not defined unless self.polrep=='stokes' -- try self.switch_polrep()") + + frames = self._movdict['Q'] + return frames + + @qframes.setter + def qframes(self, frames): + + if len(frames[0]) != self.xdim*self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + + # TODO -- more checks on the consistency of the imvec with the existing pol data??? + frames = np.array(frames) + self._movdict['Q'] = frames + fill_value = (frames[0], frames[-1]) + fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp, + bounds_error=self.bounds_error, fill_value=fill_value) + self._fundict['Q'] = fun + + @property + def uframes(self): + + if self.polrep != 'stokes': + raise Exception( + "uframes is not defined unless self.polrep=='stokes' -- try self.switch_polrep()") + + frames = self._movdict['U'] + + return frames + + @uframes.setter + def uframes(self, frames): + + if len(frames[0]) != self.xdim*self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + + # TODO -- more checks on the consistency of the imvec with the existing pol data??? + frames = np.array(frames) + self._movdict['U'] = frames + fill_value = (frames[0], frames[-1]) + fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp, + bounds_error=self.bounds_error, fill_value=fill_value) + self._fundict['U'] = fun + + @property + def vframes(self): + + if self.polrep != 'stokes': + raise Exception( + "vframes is not defined unless self.polrep=='stokes' -- try self.switch_polrep()") + + frames = self._movdict['V'] + + return frames + + @vframes.setter + def vframes(self, frames): + + if len(frames[0]) != self.xdim*self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + + # TODO -- more checks on the consistency of the imvec with the existing pol data??? + frames = np.array(frames) + self._movdict['V'] = frames + fill_value = (frames[0], frames[-1]) + fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp, + bounds_error=self.bounds_error, fill_value=fill_value) + self._fundict['V'] = fun + + @property + def rrframes(self): + + if self.polrep != 'circ': + raise Exception( + "rrframes is not defined unless self.polrep=='circ' -- try self.switch_polrep()") + + frames = self._movdict['RR'] + return frames + + @rrframes.setter + def rrframes(self, frames): + + if len(frames[0]) != self.xdim*self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + + # TODO -- more checks on the consistency of the imvec with the existing pol data??? + frames = np.array(frames) + self._movdict['RR'] = frames + fill_value = (frames[0], frames[-1]) + fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp, + bounds_error=self.bounds_error, fill_value=fill_value) + self._fundict['RR'] = fun + + @property + def llframes(self): + + if self.polrep != 'circ': + raise Exception( + "llframes is not defined unless self.polrep=='circ' -- try self.switch_polrep()") + + frames = self._movdict['LL'] + return frames + + @llframes.setter + def llframes(self, frames): + + if len(frames[0]) != self.xdim*self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + + # TODO -- more checks on the consistency of the imvec with the existing pol data??? + frames = np.array(frames) + self._movdict['LL'] = frames + fill_value = (frames[0], frames[-1]) + fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp, + bounds_error=self.bounds_error, fill_value=fill_value) + self._fundict['LL'] = fun + + @property + def rlvec(self): + + if self.polrep != 'circ': + raise Exception( + "rlframes is not defined unless self.polrep=='circ' -- try self.switch_polrep()") + + frames = self._movdict['RL'] + return frames + + @rlvec.setter + def rlvec(self, frames): + + if len(frames[0]) != self.xdim*self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + + # TODO -- more checks on the consistency of the imvec with the existing pol data??? + frames = np.array(frames) + self._movdict['RL'] = frames + fill_value = (frames[0], frames[-1]) + fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp, + bounds_error=self.bounds_error, fill_value=fill_value) + self._fundict['RL'] = fun + + @property + def lrvec(self): + + if self.polrep != 'circ': + raise Exception( + "lrframes is not defined unless self.polrep=='circ' -- try self.switch_polrep()") + + frames = self._movdict['LR'] + return frames + + @lrvec.setter + def lrvec(self, frames): + + if len(frames[0]) != self.xdim*self.ydim: + raise Exception("vec size is not consistent with xdim*ydim!") + + # TODO -- more checks on the consistency of the imvec with the existing pol data??? + frames = np.array(frames) + self._movdict['LR'] = frames + fill_value = (frames[0], frames[-1]) + fun = scipy.interpolate.interp1d(self.times, frames.T, kind=self.interp, + bounds_error=self.bounds_error, fill_value=fill_value) + self._fundict['LR'] = fun + + def movie_args(self): + """"Copy arguments for making a new Movie into a list and dictonary + """ + + frames2D = self.frames.reshape((self.nframes, self.ydim, self.xdim)) + arglist = [frames2D.copy(), self.times.copy(), self.psize, self.ra, self.dec] + #arglist = [frames2D, self.times, self.psize, self.ra, self.dec] + argdict = {'rf': self.rf, 'polrep': self.polrep, 'pol_prim': self.pol_prim, + 'pulse': self.pulse, 'source': self.source, + 'mjd': self.mjd, 'interp': self.interp, 'bounds_error': self.bounds_error} + + return (arglist, argdict) + + def copy(self): + """Return a copy of the Movie object. + + Args: + + Returns: + (Movie): copy of the Image. + """ + + arglist, argdict = self.movie_args() + + # Make new movie with primary polarization + newmov = Movie(*arglist, **argdict) + + # Copy over all polarization movies + for pol in list(self._movdict.keys()): + if pol == self.pol_prim: + continue + polframes = self._movdict[pol] + if len(polframes): + polframes = polframes.reshape((self.nframes, self.ydim, self.xdim)) + newmov.add_pol_movie(polframes, pol) + + return newmov + + def reset_interp(self, interp=None, bounds_error=None): + """Reset the movie interpolation function to change the interp. type or change the frames + + Args: + interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword + bounds_error (bool): if False, return nearest frame outside [start_hr, stop_hr] + """ + + if interp is None: + interp = self.interp + if bounds_error is None: + bounds_error = self.bounds_error + + # Copy over all polarization movies + for pol in list(self._movdict.keys()): + polframes = self._movdict[pol] + if len(polframes): + fill_value = (polframes[0], polframes[-1]) + fun = scipy.interpolate.interp1d(self.times, polframes.T, kind=interp, + fill_value=fill_value, bounds_error=bounds_error) + self._fundict[pol] = fun + else: + self._fundict[pol] = None + + self.interp = interp + self.bounds_error = bounds_error + return + + def offset_time(self, t_offset): + """Offset the movie in time by t_offset + + Args: + t_offset (float): offset time in hours + Returns: + + """ + mov = self.copy() + mov.start_hr += t_offset + mov.stop_hr += t_offset + mov.times += t_offset + mov.reset_interp(interp=mov.interp, bounds_error=mov.bounds_error) + return mov + + def add_pol_movie(self, movie, pol): + """Add another movie polarization. + + Args: + movie (list): list of 2D frames (possibly complex) in a Jy/pixel array + pol (str): The image type: 'I','Q','U','V' for stokes, 'RR','LL','RL','LR' for circ + """ + if not(len(movie) == self.nframes): + raise Exception("new pol movies must have same length as primary movie!") + + if pol == self.pol_prim: + raise Exception("new pol in add_pol_movie is the same as pol_prim!") + if np.any(np.array([image.shape != (self.ydim, self.xdim) for image in movie])): + raise Exception("add_pol_movie image shapes incompatible with primary image!") + if not (pol in list(self._movdict.keys())): + raise Exception("for polrep==%s, pol in add_pol_movie must be in " % + self.polrep + ",".join(list(self._movdict.keys()))) + + if self.polrep == 'stokes': + if pol == 'I': + self.iframes = [image.flatten() for image in movie] + elif pol == 'Q': + self.qframes = [image.flatten() for image in movie] + elif pol == 'U': + self.uframes = [image.flatten() for image in movie] + elif pol == 'V': + self.vframes = [image.flatten() for image in movie] + + if len(self.iframes) > 0: + fill_value = (self.iframes[0], self.iframes[-1]) + ifun = scipy.interpolate.interp1d(self.times, self.iframes.T, kind=self.interp, + fill_value=fill_value, + bounds_error=self.bounds_error) + else: + ifun = None + if len(self.vframes) > 0: + fill_value = (self.vframes[0], self.vframes[-1]) + vfun = scipy.interpolate.interp1d(self.times, self.vframes.T, kind=self.interp, + fill_value=fill_value, + bounds_error=self.bounds_error) + else: + vfun = None + if len(self.qframes) > 0: + fill_value = (self.qframes[0], self.qframes[-1]) + qfun = scipy.interpolate.interp1d(self.times, self.qframes.T, kind=self.interp, + fill_value=fill_value, + bounds_error=self.bounds_error) + else: + qfun = None + if len(self.uframes) > 0: + fill_value = (self.uframes[0], self.uframes[-1]) + ufun = scipy.interpolate.interp1d(self.times, self.uframes.T, kind=self.interp, + fill_value=fill_value, + bounds_error=self.bounds_error) + else: + ufun = None + + self._movdict = {'I': self.iframes, 'Q': self.qframes, + 'U': self.uframes, 'V': self.vframes} + self._fundict = {'I': ifun, 'Q': qfun, 'U': ufun, 'V': vfun} + + elif self.polrep == 'circ': + if pol == 'RR': + self.rrframes = [image.flatten() for image in movie] + elif pol == 'LL': + self.llframes = [image.flatten() for image in movie] + elif pol == 'RL': + self.rlframes = [image.flatten() for image in movie] + elif pol == 'LR': + self.lrframes = [image.flatten() for image in movie] + + if len(self.rrframes) > 0: + fill_value = (self.rrframes[0], self.rrframes[-1]) + rrfun = scipy.interpolate.interp1d(self.times, self.rrframes.T, kind=self.interp, + fill_value=fill_value, + bounds_error=self.bounds_error) + else: + rrfun = None + if len(self.llframes) > 0: + fill_value = (self.llframes[0], self.llframes[-1]) + llfun = scipy.interpolate.interp1d(self.times, self.llframes.T, kind=self.interp, + fill_value=fill_value, + bounds_error=self.bounds_error) + else: + llfun = None + if len(self.rlframes) > 0: + fill_value = (self.rlframes[0], self.rlframes[-1]) + rlfun = scipy.interpolate.interp1d(self.times, self.rlframes.T, kind=self.interp, + fill_value=fill_value, + bounds_error=self.bounds_error) + else: + rlfun = None + if len(self.lrframes) > 0: + fill_value = (self.lrframes[0], self.lrframes[-1]) + lrfun = scipy.interpolate.interp1d(self.times, self.lrframes.T, kind=self.interp, + fill_value=fill_value, + bounds_error=self.bounds_error) + else: + lrfun = None + + self._movdict = {'RR': self.rrframes, 'LL': self.llframes, + 'RL': self.rlframes, 'LR': self.lrframes} + self._fundict = {'RR': rrfun, 'LL': llfun, 'RL': rlfun, 'LR': lrfun} + return + + # TODO deprecated -- replace with generic add_pol_movie + def add_qu(self, qmovie, umovie): + """Add Stokes Q and U movies. self.polrep must be 'stokes' + + Args: + qmovie (list): list of 2D Stokes Q frames in Jy/pixel array + umovie (list): list of 2D Stokes U frames in Jy/pixel array + + Returns: + """ + + if self.polrep != 'stokes': + raise Exception("polrep must be 'stokes' for add_qu() !") + self.add_pol_movie(qmovie, 'Q') + self.add_pol_movie(umovie, 'U') + + return + + # TODO deprecated -- replace with generic add_pol_movie + def add_v(self, vmovie): + """Add Stokes V movie. self.polrep must be 'stokes' + + Args: + vmovie (list): list of 2D Stokes Q frames in Jy/pixel array + + Returns: + """ + + if self.polrep != 'stokes': + raise Exception("polrep must be 'stokes' for add_v() !") + self.add_pol_movie(vmovie, 'V') + + return + + def switch_polrep(self, polrep_out='stokes', pol_prim_out=None): + """Return a new movie with the polarization representation changed + + Args: + polrep_out (str): the polrep of the output data + pol_prim_out (str): The default movie: I,Q,U or V for Stokes, + RR,LL,LR,RL for Circular + + Returns: + (Movie): new movie object with potentially different polrep + """ + + if polrep_out not in ['stokes', 'circ']: + raise Exception("polrep_out must be either 'stokes' or 'circ'") + if pol_prim_out is None: + if polrep_out == 'stokes': + pol_prim_out = 'I' + elif polrep_out == 'circ': + pol_prim_out = 'RR' + + # Simply copy if the polrep is unchanged + if polrep_out == self.polrep and pol_prim_out == self.pol_prim: + return self.copy() + + # Assemble a dictionary of new polarization vectors + framedim = (self.nframes, self.ydim, self.xdim) + if polrep_out == 'stokes': + if self.polrep == 'stokes': + movdict = {'I': self.iframes, 'Q': self.qframes, + 'U': self.uframes, 'V': self.vframes} + else: + if len(self.rrframes) == 0 or len(self.llframes) == 0: + iframes = [] + vframes = [] + else: + iframes = 0.5*(self.rrframes.reshape(framedim) + + self.llframes.reshape(framedim)) + vframes = 0.5*(self.rrframes.reshape(framedim) - + self.llframes.reshape(framedim)) + + if len(self.rlframes) == 0 or len(self.lrframes) == 0: + qframes = [] + uframes = [] + else: + qframes = np.real(0.5*(self.lrframes.reshape(framedim) + + self.rlframes.reshape(framedim))) + uframes = np.real(0.5j*(self.lrframes.reshape(framedim) - + self.rlframes.reshape(framedim))) + + movdict = {'I': iframes, 'Q': qframes, 'U': uframes, 'V': vframes} + + elif polrep_out == 'circ': + if self.polrep == 'circ': + movdict = {'RR': self.rrframes, 'LL': self.llframes, + 'RL': self.rlframes, 'LR': self.lrframes} + else: + if len(self.iframes) == 0 or len(self.vframes) == 0: + rrframes = [] + llframes = [] + else: + rrframes = (self.iframes.reshape(framedim) + self.vframes.reshape(framedim)) + llframes = (self.iframes.reshape(framedim) - self.vframes.reshape(framedim)) + + if len(self.qframes) == 0 or len(self.uframes) == 0: + rlframes = [] + lrframes = [] + else: + rlframes = (self.qframes.reshape(framedim) + 1j*self.uframes.reshape(framedim)) + lrframes = (self.qframes.reshape(framedim) - 1j*self.uframes.reshape(framedim)) + + movdict = {'RR': rrframes, 'LL': llframes, 'RL': rlframes, 'LR': lrframes} + + # Assemble the new movie + frames = movdict[pol_prim_out] + if len(frames) == 0: + raise Exception("switch_polrep to " + + "%s with pol_prim_out=%s, \n" % (polrep_out, pol_prim_out) + + "output movie is not defined") + + # Make new movie with primary polarization + arglist, argdict = self.movie_args() + arglist[0] = frames + argdict['polrep'] = polrep_out + argdict['pol_prim'] = pol_prim_out + newmov = Movie(*arglist, **argdict) + + # Add in any other polarizations + for pol in list(movdict.keys()): + if pol == pol_prim_out: + continue + polframes = movdict[pol] + if len(polframes): + polframes = polframes.reshape((self.nframes, self.ydim, self.xdim)) + newmov.add_pol_movie(polframes, pol) + + return newmov + + def orth_chi(self): + """Rotate the EVPA 90 degrees + + Args: + + Returns: + (Image): movie with rotated EVPA + """ + mov = self.copy() + if mov.polrep == 'stokes': + mov.qframes *= -1 + mov.uframes *= -1 + elif mov.polrep == 'circ': + mov.lrframes *= -1 + mov.rlframes *= -1 + + return mov + + def fovx(self): + """Return the movie fov in x direction in radians. + + Args: + + Returns: + (float) : movie fov in x direction (radian) + """ + + return self.psize * self.xdim + + def fovy(self): + """Returns the movie fov in y direction in radians. + + Args: + + Returns: + (float) : movie fov in y direction (radian) + """ + + return self.psize * self.ydim + + @property + def lightcurve(self): + """Return the total flux over time of the image in Jy. + + Args: + + Returns: + (numpy.Array) : image total flux (Jy) over time + """ + if self.polrep == 'stokes': + flux = [np.sum(ivec) for ivec in self.iframes] + elif self.polrep == 'circ': + flux = [0.5*(np.sum(self.rrframes[i])+np.sum(self.llframes[i])) + for i in range(self.nframes)] + + return np.array(flux) + + def lin_polfrac_curve(self): + """Return the total fractional linear polarized flux over time + + Args: + + Returns: + (numpy.ndarray) : image fractional linear polarized flux per frame + """ + if self.polrep == 'stokes': + frac = [np.abs(np.sum(self.qframes[i] + 1j*self.uframes[i])) / + np.abs(np.sum(self.iframes[i])) + for i in range(self.nframes)] + elif self.polrep == 'circ': + frac = [2*np.abs(np.sum(self.rlframes[i])) / + np.abs(np.sum(self.rrframes[i]+self.llframes[i])) + for i in range(self.nframes)] + return np.array(frac) + + def circ_polfrac_curve(self): + """Return the (signed) total fractional circular polarized flux over time + + Args: + + Returns: + (numpy.ndarray) : image fractional circular polarized flux per frame + """ + if self.polrep == 'stokes': + frac = [np.sum(self.vframes[i]) / np.abs(np.sum(self.iframes[i])) + for i in range(self.nframes)] + elif self.polrep == 'circ': + frac = [np.sum(self.rrframes[i]-self.llframes[i]) / + np.abs(np.sum(self.rrframes[i] + self.llframes[i])) + for i in range(self.nframes)] + + return np.array(frac) + + def get_image(self, time): + """Return an Image at time + + Args: + time (float): the time in hours + + Returns: + (Image): the Image object at the given time + """ + + if (time < self.start_hr): + if not(self.bounds_error): + pass + # print ("time %f before movie start time %f" % (time, self.start_hr)) + # print ("returning constant frame 0! \n") + else: + raise Exception("time %f must be in the range %f - %f" % + (time, self.start_hr, self.stop_hr)) + + if (time > self.stop_hr): + if not(self.bounds_error): + pass + # print ("time %f after movie stop time %f" % (time, self.stop_hr)) + # print ("returning constant frame -1! \n") + else: + raise Exception("time %f must be in the range %f - %f" % + (time, self.start_hr, self.stop_hr)) + + # interpolate the imvec to the given time + imvec = self._fundict[self.pol_prim](time) + + # Make the primary image + imarr = imvec.reshape(self.ydim, self.xdim) + outim = ehtim.image.Image(imarr, self.psize, self.ra, self.dec, self.pa, + polrep=self.polrep, pol_prim=self.pol_prim, time=time, + rf=self.rf, source=self.source, mjd=self.mjd, pulse=self.pulse) + + # Copy over the rest of the polarizations + for pol in list(self._movdict.keys()): + if pol == self.pol_prim: + continue + polframes = self._movdict[pol] + if len(polframes): + polvec = self._fundict[pol](time) + polarr = polvec.reshape(self.ydim, self.xdim).copy() + outim.add_pol_image(polarr, pol) + + return outim + + def get_frame(self, n): + """Return an Image of the nth frame + + Args: + n (int): the frame number + + Returns: + (Image): the Image object of the nth frame + """ + + if n < 0 or n >= len(self.frames): + raise Exception("n must be in the range 0 - %i" % self.nframes) + + time = self.times[n] + + # Make the primary image + imarr = self.frames[n].reshape(self.ydim, self.xdim) + outim = ehtim.image.Image(imarr, self.psize, self.ra, self.dec, self.pa, + polrep=self.polrep, pol_prim=self.pol_prim, time=time, + rf=self.rf, source=self.source, mjd=self.mjd, pulse=self.pulse) + + # Copy over the rest of the polarizations + for pol in list(self._movdict.keys()): + if pol == self.pol_prim: + continue + polframes = self._movdict[pol] + if len(polframes): + polvec = polframes[n] + polarr = polvec.reshape(self.ydim, self.xdim).copy() + outim.add_pol_image(polarr, pol) + + return outim + + def im_list(self): + """Return a list of the movie frames + + Args: + + Returns: + (list): list of Image objects + """ + + return [self.get_frame(j) for j in range(self.nframes)] + + def avg_frame(self): + """Coherently Average the movie frames into a single image. + + Returns: + (Image) : averaged image of all frames + """ + + # Make the primary image + avg_imvec = np.mean(np.array(self.frames), axis=0) + avg_imarr = avg_imvec.reshape(self.ydim, self.xdim) + outim = ehtim.image.Image(avg_imarr, self.psize, self.ra, self.dec, self.pa, + polrep=self.polrep, pol_prim=self.pol_prim, time=self.start_hr, + rf=self.rf, source=self.source, mjd=self.mjd, pulse=self.pulse) + + # Copy over the rest of the average polarizations + for pol in list(self._movdict.keys()): + if pol == self.pol_prim: + continue + polframes = self._movdict[pol] + if len(polframes): + avg_polvec = np.mean(np.array(polframes), axis=0) + avg_polarr = avg_polvec.reshape(self.ydim, self.xdim) + outim.add_pol_image(avg_polarr, pol) + + return outim + + def blur_circ(self, fwhm_x, fwhm_t, fwhm_x_pol=0): + """Apply a Gaussian filter to a list of images. + + Args: + fwhm_x (float): circular beam size for spatial blurring in radians + fwhm_t (float): temporal blurring in frames + fwhm_x_pol (float): circular beam size for Stokes Q,U,V spatial blurring in radians + Returns: + (Image): output image list + """ + + # Unpack the frames + frames = self.im_list() + + # Blur Stokes I + sigma_x = fwhm_x / self.psize / (2. * np.sqrt(2. * np.log(2.))) + sigma_t = fwhm_t / (2. * np.sqrt(2. * np.log(2.))) + sigma_x_pol = fwhm_x_pol / self.psize / (2. * np.sqrt(2. * np.log(2.))) + + arr = np.array([im.imvec.reshape(self.ydim, self.xdim) for im in frames]) + arr = filt.gaussian_filter(arr, (sigma_t, sigma_x, sigma_x)) + + # Make a new blurred movie + arglist, argdict = self.movie_args() + arglist[0] = arr + movie_blur = Movie(*arglist, **argdict) + + # Process the remaining polarizations + for pol in list(self._movdict.keys()): + if pol == self.pol_prim: + continue + polframes = self._movdict[pol] + + if len(polframes): + arr = np.array([imvec.reshape(self.ydim, self.xdim) for imvec in polframes]) + arr = filt.gaussian_filter(arr, (sigma_t, sigma_x_pol, sigma_x_pol)) + movie_blur.add_pol_movie(arr, pol) + + return movie_blur + + def observe_same_nonoise(self, obs, repeat=False, sgrscat=False, + ttype="nfft", fft_pad_factor=2, + zero_empty_pol=True, verbose=True): + """Observe the movie on the same baselines as an existing observation + without adding noise. + + Args: + obs (Obsdata): existing observation with baselines where the FT will be sampled + repeat (bool): if True, repeat the movie to fill up the observation interval + sgrscat (bool): if True, the visibilites are blurred by the Sgr A* scattering kernel + ttype (str): if "fast", use FFT to produce visibilities. Else "direct" for DTFT + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + + zero_empty_pol (bool): if True, returns zero vec if the polarization doesn't exist. + Otherwise return None + verbose (bool): Boolean value controls output prints. + + Returns: + (Obsdata): an observation object + """ + + # Check for agreement in coordinates and frequency + tolerance = 1e-8 + if (np.abs(self.ra - obs.ra) > tolerance) or (np.abs(self.dec - obs.dec) > tolerance): + raise Exception("Movie coordinates are not the same as observtion coordinates!") + if (np.abs(self.rf - obs.rf)/obs.rf > tolerance): + raise Exception("Movie frequency is not the same as observation frequency!") + + if ttype == 'direct' or ttype == 'fast' or ttype == 'nfft': + if verbose: print("Producing clean visibilities from movie with " + ttype + " FT . . . ") + else: + raise Exception("ttype=%s, options for ttype are 'direct', 'fast', 'nfft'" % ttype) + + # Get data + obslist = obs.tlist() + + obstimes = np.array([obsdata[0]['time'] for obsdata in obslist]) + + if (obstimes < self.start_hr).any(): + if repeat: + print("Some observation times before movie start time %f" % self.start_hr) + print("Looping movie before start\n") + elif not(self.bounds_error): + print("Some observation times before movie start time %f" % self.start_hr) + print("bounds_error is False: using constant frame 0 before start_hr! \n") + else: + raise Exception("Some observation times before movie start time %f" % self.start_hr) + if (obstimes > self.stop_hr).any(): + if repeat: + print("Some observation times after movie stop time %f" % self.stop_hr) + print("Looping movie after stop\n") + elif not(self.bounds_error): + print("Some observation times after movie stop time %f" % self.stop_hr) + print("bounds_error is False: using constant frame -1 after stop_hr! \n") + else: + raise Exception("Some observation times after movie stop time %f" % self.stop_hr) + + # Observe nearest frame + obsdata_out = [] + + for i in range(len(obslist)): + obsdata = obslist[i] + + # Frame number + time = obsdata[0]['time'] + + if self.bounds_error: + if (time < self.start_hr or time > self.stop_hr): + if repeat: + time = self.start_hr + np.mod(time - self.start_hr, self.duration) + else: + raise Exception("Obs time %f outside movie range %f--%f" % + (time, self.start_hr, self.stop_hr)) + + # Get the frame visibilities + uv = obsh.recarr_to_ndarr(obsdata[['u', 'v']], 'f8') + + try: + im = self.get_image(time) + except ValueError: + raise Exception("Interpolation error for time %f: movie range %f--%f" % + (time, self.start_hr, self.stop_hr)) + + data = simobs.sample_vis(im, uv, sgrscat=sgrscat, polrep_obs=obs.polrep, + ttype=ttype, fft_pad_factor=fft_pad_factor, + zero_empty_pol=zero_empty_pol, verbose=verbose) + verbose = False # only print for one frame + + # Put visibilities into the obsdata + if obs.polrep == 'stokes': + obsdata['vis'] = data[0] + if not(data[1] is None): + obsdata['qvis'] = data[1] + obsdata['uvis'] = data[2] + obsdata['vvis'] = data[3] + + elif obs.polrep == 'circ': + obsdata['rrvis'] = data[0] + if not(data[1] is None): + obsdata['llvis'] = data[1] + if not(data[2] is None): + obsdata['rlvis'] = data[2] + obsdata['lrvis'] = data[3] + + if len(obsdata_out): + obsdata_out = np.hstack((obsdata_out, obsdata)) + else: + obsdata_out = obsdata + + obsdata_out = np.array(obsdata_out, dtype=obs.poltype) + obs_no_noise = ehtim.obsdata.Obsdata(self.ra, self.dec, self.rf, obs.bw, + obsdata_out, obs.tarr, + source=self.source, mjd=np.floor(obs.mjd), + polrep=obs.polrep, + ampcal=True, phasecal=True, opacitycal=True, + dcal=True, frcal=True, + timetype=obs.timetype, scantable=obs.scans) + + return obs_no_noise + + def observe_same(self, obs_in, repeat=False, + ttype='nfft', fft_pad_factor=2, + sgrscat=False, add_th_noise=True, + jones=False, inv_jones=False, + opacitycal=True, ampcal=True, phasecal=True, + frcal=True, dcal=True, rlgaincal=True, + stabilize_scan_phase=False, stabilize_scan_amp=False, + neggains=False, + taup=ehc.GAINPDEF, + gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF, + dterm_offset=ehc.DTERMPDEF, + rlratio_std=0.,rlphase_std=0., + caltable_path=None, seed=False, sigmat=None, verbose=True): + """Observe the image on the same baselines as an existing observation object and add noise. + + Args: + obs_in (Obsdata): existing observation with baselines where the FT will be sampled + repeat (bool): if True, repeat the movie to fill up the observation interval + ttype (str): "fast" or "nfft" or "direct" + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel + add_th_noise (bool): if True, baseline-dependent thermal noise is added + + jones (bool): if True, uses Jones matrix to apply mis-calibration effects + inv_jones (bool): if True, applies estimated inverse Jones matrix + (not including random terms) to a priori calibrate data + opacitycal (bool): if False, time-dependent gaussian errors are added to opacities + ampcal (bool): if False, time-dependent gaussian errors are added to station gains + phasecal (bool): if False, time-dependent station-based random phases are added + frcal (bool): if False, feed rotation angle terms are added to Jones matrices. + dcal (bool): if False, time-dependent gaussian errors added to D-terms. + rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + neggains (bool): if True, force the applied gains to be <1 + + taup (float): the fractional std. dev. of the random error on the opacities + gainp (float): the fractional std. dev. of the random error on the gains + or a dict giving one std. dev. per site + + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + dterm_offset (float): the base std. dev. of random additive error at all sites, + or a dict giving one std. dev. per site + + rlratio_std (float): the fractional std. dev. of the R/L gain offset + or a dict giving one std. dev. per site + rlphase_std (float): std. dev. of R/L phase offset, + or a dict giving one std. dev. per site + a negative value samples from uniform + + caltable_path (string): If not None, path and prefix for saving the applied caltable + seed (int): seeds the random component of the noise terms. DO NOT set to 0! + sigmat (float): temporal std for a Gaussian Process used to generate gain noise. + if sigmat=None then an iid gain noise is applied. + verbose (bool): print updates and warnings + + Returns: + (Obsdata): an observation object + + """ + + if seed: + np.random.seed(seed=seed) + + # print("Producing clean visibilities from movie . . . ") + obs = self.observe_same_nonoise(obs_in, repeat=repeat, sgrscat=sgrscat, + ttype=ttype, fft_pad_factor=fft_pad_factor, + zero_empty_pol=True, verbose=verbose) + + # Jones Matrix Corruption & Calibration + if jones: + obsdata = simobs.add_jones_and_noise(obs, add_th_noise=add_th_noise, + opacitycal=opacitycal, ampcal=ampcal, + phasecal=phasecal, frcal=frcal, dcal=dcal, + rlgaincal=rlgaincal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + neggains=neggains, + taup=taup, + gain_offset=gain_offset, gainp=gainp, + dterm_offset=dterm_offset, + rlratio_std=rlratio_std, rlphase_std=rlphase_std, + caltable_path=caltable_path, + seed=seed, sigmat=sigmat, verbose=verbose) + + obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, + ampcal=ampcal, phasecal=phasecal, opacitycal=opacitycal, + dcal=dcal, frcal=frcal, + timetype=obs.timetype, scantable=obs.scans) + + if inv_jones: + obsdata = simobs.apply_jones_inverse(obs, + opacitycal=opacitycal, dcal=dcal, frcal=frcal, + verbose=verbose) + + obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, + ampcal=ampcal, phasecal=phasecal, + opacitycal=True, dcal=True, frcal=True, + timetype=obs.timetype, scantable=obs.scans) + + # No Jones Matrices, Add noise the old way + # TODO There is an asymmetry here - in the old way, we don't offer the ability to + # *not* unscale estimated noise. + else: + if caltable_path: + print('WARNING: the caltable is only saved if you apply noise with a Jones Matrix') + + obsdata = simobs.add_noise(obs, add_th_noise=add_th_noise, + opacitycal=opacitycal, ampcal=ampcal, phasecal=phasecal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + neggains=neggains, + taup=taup, gain_offset=gain_offset, gainp=gainp, + caltable_path=caltable_path, seed=seed, sigmat=sigmat, + verbose=verbose) + + obs = ehtim.obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, obsdata, obs.tarr, + source=obs.source, mjd=obs.mjd, polrep=obs_in.polrep, + ampcal=ampcal, phasecal=phasecal, + opacitycal=True, dcal=True, frcal=True, + timetype=obs.timetype, scantable=obs.scans) + + return obs + + def observe(self, array, tint, tadv, tstart, tstop, bw, repeat=False, + mjd=None, timetype='UTC', polrep_obs=None, + elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH, + no_elevcut_space=False, + ttype='nfft', fft_pad_factor=2, fix_theta_GMST=False, + sgrscat=False, add_th_noise=True, + jones=False, inv_jones=False, + opacitycal=True, ampcal=True, phasecal=True, + frcal=True, dcal=True, rlgaincal=True, + stabilize_scan_phase=False, stabilize_scan_amp=False, + neggains=False, + tau=ehc.TAUDEF, taup=ehc.GAINPDEF, + gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF, + dterm_offset=ehc.DTERMPDEF, + rlratio_std=0.,rlphase_std=0., + caltable_path=None, seed=False, sigmat=None, verbose=True): + """Generate baselines from an array object and observe the movie. + + Args: + array (Array): an array object containing sites with which to generate baselines + tint (float): the scan integration time in seconds + tadv (float): the uniform cadence between scans in seconds + tstart (float): the start time of the observation in hours + tstop (float): the end time of the observation in hours + bw (float): the observing bandwidth in Hz + repeat (bool): if True, repeat the movie to fill up the observation interval + + mjd (int): the mjd of the observation, if set as different from the image mjd + timetype (str): how to interpret tstart and tstop; either 'GMST' or 'UTC' + polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation + elevmin (float): station minimum elevation in degrees + elevmax (float): station maximum elevation in degrees + no_elevcut_space (bool): if True, do not apply elevation cut to orbiters + + ttype (str): "fast", "nfft" or "dtft" + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in the FFT + fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v + + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel + add_th_noise (bool): if True, baseline-dependent thermal noise is added + + jones (bool): if True, uses Jones matrix to apply mis-calibration effects + otherwise uses old formalism without D-terms + inv_jones (bool): if True, applies estimated inverse Jones matrix + (not including random terms) to calibrate data + + opacitycal (bool): if False, time-dependent gaussian errors are added to opacities + ampcal (bool): if False, time-dependent gaussian errors are added to station gains + phasecal (bool): if False, time-dependent station-based random phases are added + frcal (bool): if False, feed rotation angle terms are added to Jones matrix. + dcal (bool): if False, time-dependent gaussian errors added to Jones matrix D-terms. + rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + neggains (bool): if True, force the applied gains to be <1 + + tau (float): the base opacity at all sites, or a dict giving one opacity per site + taup (float): the fractional std. dev. of the random error on the opacities + gainp (float): the fractional std. dev. of the random error on the gains + or a dict giving one std. dev. per site + + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + dterm_offset (float): the base std. dev. of random additive error at all sites, + or a dict giving one std. dev. per site + + rlratio_std (float): the fractional std. dev. of the R/L gain offset + or a dict giving one std. dev. per site + rlphase_std (float): std. dev. of R/L phase offset, + or a dict giving one std. dev. per site + a negative value samples from uniform + + + caltable_path (string): If not None, path and prefix for saving the applied caltable + seed (int): seeds the random component of the noise terms. DO NOT set to 0! + sigmat (float): temporal std for a Gaussian Process used to generate gain noise. + if sigmat=None then an iid gain noise is applied. + verbose (bool): print updates and warnings + + Returns: + (Obsdata): an observation object + + """ + + # Generate empty observation + if verbose: print("Generating empty observation file . . . ") + if mjd is None: + mjd = self.mjd + if polrep_obs is None: + polrep_obs = self.polrep + + obs = array.obsdata(self.ra, self.dec, self.rf, bw, tint, tadv, tstart, tstop, + mjd=mjd, polrep=polrep_obs, + tau=tau, timetype=timetype, + elevmin=elevmin, elevmax=elevmax, + no_elevcut_space=no_elevcut_space, + fix_theta_GMST=fix_theta_GMST) + + # Observe on the same baselines as the empty observation and add noise + obs = self.observe_same(obs, repeat=repeat, + ttype=ttype, fft_pad_factor=fft_pad_factor, + sgrscat=sgrscat, + add_th_noise=add_th_noise, + jones=jones, inv_jones=inv_jones, + opacitycal=opacitycal, ampcal=ampcal, + phasecal=phasecal, dcal=dcal, + frcal=frcal, rlgaincal=rlgaincal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + neggains=neggains, + taup=taup, + gain_offset=gain_offset, gainp=gainp, + dterm_offset=dterm_offset, + rlratio_std=rlratio_std,rlphase_std=rlphase_std, + caltable_path=caltable_path, seed=seed, sigmat=sigmat, + verbose=verbose) + + return obs + + def observe_vex(self, vex, source, synchronize_start=True, t_int=0.0, + polrep_obs=None, ttype='nfft', fft_pad_factor=2, + fix_theta_GMST=False, + sgrscat=False, add_th_noise=True, + jones=False, inv_jones=False, + opacitycal=True, ampcal=True, phasecal=True, + frcal=True, dcal=True, rlgaincal=True, + stabilize_scan_phase=False, stabilize_scan_amp=False, + neggains=False, + tau=ehc.TAUDEF, taup=ehc.GAINPDEF, + gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF, + dterm_offset=ehc.DTERMPDEF, + caltable_path=None, seed=False, sigmat=None, verbose=True): + """Generate baselines from a vex file and observe the movie. + + Args: + vex (Vex): an vex object containing sites and scan information + source (str): the source to observe + synchronize_start (bool): if True, the start of the movie is defined + as the start of the observations + t_int (float): if not zero, overrides the vex scan lengths + + polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representation + ttype (str): "fast" or "nfft" or "dtft" + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v + + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* kernel + add_th_noise (bool): if True, baseline-dependent thermal noise is added + + jones (bool): if True, uses Jones matrix to apply mis-calibration effects + otherwise uses old formalism without D-terms + inv_jones (bool): if True, applies estimated inverse Jones matrix + (not including random terms) to calibrate data + opacitycal (bool): if False, time-dependent gaussian errors are added to opacities + ampcal (bool): if False, time-dependent gaussian errors are added to station gains + phasecal (bool): if False, time-dependent station-based random phases are added + frcal (bool): if False, feed rotation angle terms are added to Jones matrix. + dcal (bool): if False, time-dependent gaussian errors added to Jones matrix D-terms. + rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + neggains (bool): if True, force the applied gains to be <1 + + tau (float): the base opacity at all sites, + or a dict giving one opacity per site + taup (float): the fractional std. dev. of the random error on the opacities + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + gainp (float): the fractional std. dev. of the random error on the gains + dterm_offset (float): the base dterm offset at all sites, + or a dict giving one dterm offset per site + + caltable_path (string): If not None, path and prefix for saving the applied caltable + seed (int): seeds the random component of the noise terms. DO NOT set to 0! + sigmat (float): temporal std for a Gaussian Process used to generate gain noise. + if sigmat=None then an iid gain noise is applied. + verbose (bool): print updates and warnings + + Returns: + (Obsdata): an observation object + + """ + + if polrep_obs is None: + polrep_obs = self.polrep + + obs_List = [] + movie = self.copy() + + if synchronize_start: + movie.mjd = vex.sched[0]['mjd_floor'] + movie.start_hr = vex.sched[0]['start_hr'] + + movie_start = float(movie.mjd) + movie.start_hr/24.0 + movie_end = float(movie.mjd) + movie.stop_hr/24.0 + + print("Movie MJD Range: ", movie_start, movie_end) + + snapshot = 1.0 + if t_int > 0.0: + snapshot = 0.0 + + for i_scan in range(len(vex.sched)): + if vex.sched[i_scan]['source'] != source: + continue + scankeys = list(vex.sched[i_scan]['scan'].keys()) + subarray = vex.array.make_subarray([vex.sched[i_scan]['scan'][key]['site'] + for key in scankeys]) + + if snapshot == 1.0: + t_int = np.max(np.array([vex.sched[i_scan]['scan'][site] + ['scan_sec'] for site in scankeys])) + print(t_int) + + vex_scan_start_mjd = float(vex.sched[i_scan]['mjd_floor']) + vex_scan_start_mjd += vex.sched[i_scan]['start_hr']/24.0 + + vex_scan_length_mjd = vex.sched[i_scan]['scan'][0]['scan_sec']/3600.0/24.0 + vex_scan_stop_mjd = vex_scan_start_mjd + vex_scan_length_mjd + + print("Scan MJD Range: ", vex_scan_start_mjd, vex_scan_stop_mjd) + + if vex_scan_start_mjd < movie_start or vex_scan_stop_mjd > movie_end: + continue + + t_start = vex.sched[i_scan]['start_hr'] + t_stop = t_start + vex.sched[i_scan]['scan'][0]['scan_sec']/3600.0 - ehc.EP + + mjd = vex.sched[i_scan]['mjd_floor'] + obs = subarray.obsdata(movie.ra, movie.dec, movie.rf, vex.bw_hz, + t_int, t_int, t_start, t_stop, + mjd=mjd, polrep=polrep_obs, tau=tau, + elevmin=.01, elevmax=89.99, timetype='UTC', + fix_theta_GMST=fix_theta_GMST) + obs_List.append(obs) + + if len(obs_List) == 0: + raise Exception("Movie has no overlap with the vex file") + + obs = ehtim.obsdata.merge_obs(obs_List) + + obsout = movie.observe_same(obs, repeat=False, + ttype=ttype, fft_pad_factor=fft_pad_factor, + sgrscat=sgrscat, add_th_noise=add_th_noise, + jones=jones, inv_jones=inv_jones, + opacitycal=opacitycal, ampcal=ampcal, phasecal=phasecal, + frcal=frcal, dcal=dcal, rlgaincal=rlgaincal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, + neggains=neggains, + taup=taup, + gain_offset=gain_offset, gainp=gainp, + dterm_offset=dterm_offset, + caltable_path=caltable_path, seed=seed,sigmat=sigmat, + verbose=verbose) + + return obsout + + def save_txt(self, fname): + """Save the Movie data to individual text files with filenames basename + 00001, etc. + + Args: + fname (str): basename of output files + + Returns: + """ + + ehtim.io.save.save_mov_txt(self, fname) + + return + + def save_fits(self, fname): + """Save the Movie data to individual fits files with filenames basename + 00001, etc. + + Args: + fname (str): basename of output files + + Returns: + """ + + ehtim.io.save.save_mov_fits(self, fname) + return + + def save_hdf5(self, fname): + """Save the Movie data to a single hdf5 file. + + Args: + fname (str): output file name + + Returns: + """ + + ehtim.io.save.save_mov_hdf5(self, fname) + return + + def export_mp4(self, out='movie.mp4', fps=10, dpi=120, + interp='gaussian', scale='lin', dynamic_range=1000.0, cfun='afmhot', + nvec=20, pcut=0.01, plotp=False, gamma=0.5, frame_pad_factor=1, + label_time=False, verbose=False): + """Save the Movie to an mp4 file + """ + + import matplotlib + matplotlib.use('agg') + import matplotlib.pyplot as plt + import matplotlib.animation as animation + + if self.polrep != 'stokes': + raise Exception("export_mp4 requires self.polrep=='stokes' -- try self.switch_polrep()") + + if (interp in ['gauss', 'gaussian', 'Gaussian', 'Gauss']): + interp = 'gaussian' + else: + interp = 'nearest' + + if scale == 'lin': + unit = 'Jy/pixel' + elif scale == 'log': + unit = 'log(Jy/pixel)' + elif scale == 'gamma': + unit = '(Jy/pixel)^gamma' + else: + raise Exception("Scale not recognized!") + + fig = plt.figure() + maxi = np.max(np.concatenate([im for im in self.frames])) + + if len(self.qframes) and plotp: + thin = self.xdim//nvec + mask = (self.frames[0]).reshape(self.ydim, self.xdim) > pcut * np.max(self.frames[0]) + mask2 = mask[::thin, ::thin] + x = (np.array([[i for i in range(self.xdim)] + for j in range(self.ydim)])[::thin, ::thin])[mask2] + y = (np.array([[j for i in range(self.xdim)] + for j in range(self.ydim)])[::thin, ::thin])[mask2] + a = (-np.sin(np.angle(self.qframes[0]+1j*self.uframes[0]) / + 2).reshape(self.ydim, self.xdim)[::thin, ::thin])[mask2] + b = (np.cos(np.angle(self.qframes[0]+1j*self.uframes[0]) / + 2).reshape(self.ydim, self.xdim)[::thin, ::thin])[mask2] + + m = (np.abs(self.qframes[0] + 1j*self.uframes[0]) / + self.frames[0]).reshape(self.ydim, self.xdim) + m[np.logical_not(mask)] = 0 + + Q1 = plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.01*self.xdim, units='x', pivot='mid', color='k', + angles='uv', scale=1.0/thin) + Q2 = plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.005*self.xdim, units='x', pivot='mid', color='w', + angles='uv', scale=1.1/thin) + + def im_data(n): + + n_data = int((n-n % frame_pad_factor)/frame_pad_factor) + + if len(self.qframes) and plotp: + a = (-np.sin(np.angle(self.qframes[n_data]+1j*self.uframes[n_data] + )/2).reshape(self.ydim, self.xdim)[::thin, ::thin])[mask2] + b = (np.cos(np.angle(self.qframes[n_data]+1j*self.uframes[n_data] + )/2).reshape(self.ydim, self.xdim)[::thin, ::thin])[mask2] + + Q1.set_UVC(a, b) + Q2.set_UVC(a, b) + + if scale == 'lin': + return self.frames[n_data].reshape((self.ydim, self.xdim)) + elif scale == 'log': + return np.log(self.frames[n_data].reshape( + (self.ydim, self.xdim)) + maxi/dynamic_range) + elif scale == 'gamma': + return (self.frames[n_data]**(gamma)).reshape((self.ydim, self.xdim)) + + plt_im = plt.imshow(im_data(0), cmap=plt.get_cmap(cfun), interpolation=interp) + plt.colorbar(plt_im, fraction=0.046, pad=0.04, label=unit) + + if scale == 'lin': + + plt_im.set_clim([0, maxi]) + else: + plt_im.set_clim([np.log(maxi/dynamic_range), np.log(maxi)]) + + xticks = obsh.ticks(self.xdim, self.psize/ehc.RADPERAS/1e-6) + yticks = obsh.ticks(self.ydim, self.psize/ehc.RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel(r'Relative RA ($\mu$as)') + plt.ylabel(r'Relative Dec ($\mu$as)') + + fig.set_size_inches([5, 5]) + plt.tight_layout() + + def update_img(n): + if verbose: + print("processing frame {0} of {1}".format(n, len(self.frames)*frame_pad_factor)) + plt_im.set_data(im_data(n)) + + if label_time: + time = self.times[n] + time_str = ("%02d:%02d:%02d" % (int(time), (time*60) % 60, (time*3600) % 60)) + fig.suptitle(time_str) + + return plt_im + + ani = animation.FuncAnimation(fig, update_img, len( + self.frames)*frame_pad_factor, interval=1e3/fps) + writer = animation.writers['ffmpeg'](fps=fps, bitrate=1e6) + ani.save(out, writer=writer, dpi=dpi) + +################################################################################################## +# Movie creation functions +################################################################################################## + + +def export_multipanel_mp4(input_list, out='movie.mp4', start_hr=None, stop_hr=None, nframes=100, + fov=None, npix=None, + nrows=1, fps=10, dpi=120, verbose=False, titles=None, + panel_size=4.0, common_scale=False, scale='linear', label_type='scale', + has_cbar=False, **kwargs): + """Export a movie comparing multiple movies in a grid. + + Args: + input_list (list): The list of input Movies or Images + out (string): The output filename + start_hr (float): The start time in hours. If None, defaults to first start time + end_hr (float): The end time in hours. If None, defaults to last start time + nframes (int): The number of frames in the output movie + fov (float): If specified, use this field of view for all panels + npix (int): If specified, use this linear pixel dimension for all panels + nrows (int): Number of rows in movie + fps (int): Frames per second + titles (list): List of panel titles for input_list + panel_size (float): Size of individual panels (inches) + + """ + import matplotlib + matplotlib.use('agg') + import matplotlib.pyplot as plt + import matplotlib.animation as animation + + if start_hr is None: + try: + start_hr = np.min([x.start_hr for x in input_list if hasattr(x, 'start_hr')]) + except ValueError: + raise Exception("no movies in input_list!") + + if stop_hr is None: + try: + stop_hr = np.max([x.stop_hr for x in input_list if hasattr(x, 'stop_hr')]) + except ValueError: + raise Exception("no movies in input_list!") + + print("%s will have %i frames in the range %f-%f hr" % (out, nframes, start_hr, stop_hr)) + + ncols = int(np.ceil(len(input_list)/nrows)) + suptitle_space = 0.6 # inches + w = panel_size*ncols + h = panel_size*nrows + suptitle_space + tgap = suptitle_space / h + bgap = .1 + rgap = .1 + lgap = .1 + subw = (1-lgap-rgap)/ncols + subh = (1-tgap-bgap)/nrows + print("Rows: " + str(nrows)) + print("Cols: " + str(ncols)) + + fig = plt.figure(figsize=(w, h)) + ax_all = [[] for j in range(nrows)] + for y in range(nrows): + for x in range(ncols): + ax = fig.add_axes([lgap+subw*x, bgap+subh*(nrows-y-1), subw, subh]) + ax_all[y].append(ax) + + times = np.linspace(start_hr, stop_hr, nframes) + hr_step = times[1]-times[0] + mjd_step = hr_step/24. + + im_List_Set = [[x.get_image(time) if hasattr(x, 'get_image') else x.copy() for time in times] + for x in input_list] + + if fov and npix: + im_List_Set = [[x.regrid_image(fov, npix) for x in y] for y in im_List_Set] + else: + print('not rescaling images to common fov and npix!') + + maxi = [np.max([im.imvec for im in im_List_Set[j]]) for j in range(len(im_List_Set))] + if common_scale: + maxi = np.max(maxi) + 0.0*maxi + + i = 0 + for y in range(nrows): + for x in range(ncols): + if i >= len(im_List_Set): + ax_all[y][x].set_visible(False) + else: + kwargs.get('ttype', 'nfft') + if (y == nrows-1 and x == 0) or fov is None: + label_type_cur = label_type + else: + label_type_cur = 'none' + + im_List_Set[i][0].display(axis=ax_all[y][x], scale=scale, + label_type=label_type_cur, has_cbar=has_cbar, **kwargs) + if y == nrows-1 and x == 0: + plt.xlabel(r'Relative RA ($\mu$as)') + plt.ylabel(r'Relative Dec ($\mu$as)') + else: + plt.xlabel('') + plt.ylabel('') + if not titles: + ax_all[y][x].set_title('') + else: + ax_all[y][x].set_title(titles[i]) + i = i+1 + + def im_data(i, n): + if scale == 'linear': + return im_List_Set[i][n].imvec.reshape((im_List_Set[i][n].ydim, im_List_Set[i][n].xdim)) + else: + return np.log(im_List_Set[i][n].imvec.reshape( + (im_List_Set[i][n].ydim, im_List_Set[i][n].xdim)) + 1e-20) + + def update_img(n): + if verbose: + print("processing frame {0} of {1}".format(n, len(im_List_Set[0]))) + i = 0 + for y in range(nrows): + for x in range(ncols): + ax_all[y][x].images[0].set_data(im_data(i, n)) + i = i+1 + if i >= len(im_List_Set): + break + + if mjd_step > 0.1: + # , verticalalignment=verticalalignment) + fig.suptitle('MJD: ' + str(im_List_Set[0][n].mjd)) + else: + time = im_List_Set[0][n].time + time_str = ("%d:%02d.%02d" % (int(time), (time*60) % 60, (time*3600) % 60)) + fig.suptitle(time_str) + + return + + ani = animation.FuncAnimation(fig, update_img, len(im_List_Set[0]), interval=1e3/fps) + writer = animation.writers['ffmpeg'](fps=fps, bitrate=1e6) + ani.save(out, writer=writer, dpi=dpi) + + +def merge_im_list(imlist, framedur=-1, interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR): + """Merge a list of image objects into a movie object. + + Args: + imlist (list): list of Image objects + framedur (float): duration of a movie frame in seconds + use to override times in the individual movies + interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword + bounds_error (bool): if False, return nearest frame outside interval [start_hr, stop_hr] + + Returns: + (Movie): a Movie object assembled from the images + """ + framelist = [] + nframes = len(imlist) + + print("\nMerging %i frames from MJD %i %.2f hr to MJD %i %.2f hr" % ( + nframes, imlist[0].mjd, imlist[0].time, imlist[-1].mjd, imlist[-1].time)) + + for i in range(nframes): + im = imlist[i] + if i == 0: + polrep0 = im.polrep + pol_prim0 = im.pol_prim + movdict = {key: [] for key in list(im._imdict.keys())} + psize0 = im.psize + xdim0 = im.xdim + ydim0 = im.ydim + ra0 = im.ra + dec0 = im.dec + rf0 = im.rf + src0 = im.source + mjd0 = im.mjd + hour0 = im.time + pulse = im.pulse + times = [hour0] + else: + if (im.polrep != polrep0): + raise Exception("polrep of image %i != polrep of image 0!" % i) + if (im.psize != psize0): + raise Exception("psize of image %i != psize of image 0!" % i) + if (im.xdim != xdim0): + raise Exception("xdim of image %i != xdim of image 0!" % i) + if (im.ydim != ydim0): + raise Exception("ydim of image %i != ydim of image 0!" % i) + if (im.ra != ra0): + raise Exception("RA of image %i != RA of image 0!" % i) + if (im.dec != dec0): + raise Exception("DEC of image %i != DEC of image 0!" % i) + if (im.rf != rf0): + raise Exception("rf of image %i != rf of image 0!" % i) + if (im.source != src0): + raise Exception("source of image %i != src of image 0!" % i) + if (im.mjd < mjd0): + raise Exception("mjd of image %i < mjd of image 0!" % i) + + hour = im.time + if im.mjd > mjd0: + hour += 24*(im.mjd - mjd0) + times.append(hour) + + imarr = im.imvec.reshape(ydim0, xdim0) + framelist.append(imarr) + + # Look for other polarizations + for pol in list(movdict.keys()): + polvec = im._imdict[pol] + if len(polvec): + polarr = polvec.reshape(ydim0, xdim0) + movdict[pol].append(polarr) + else: + if movdict[pol]: + raise Exception("all frames in merge_im_list must have the same pol layout: " + + "error in frame %i" % i) + + # assume equispaced with a given framedur instead of reading the individual image times + if framedur != -1: + framedur_hr = framedur/3600. + tstart = hour0 + tstop = hour0 + framedur_hr*nframes + times = np.linspace(tstart, tstop, nframes) + + elif len(set(times)) < len(framelist): + raise Exception("image times have duplicates!") + + # Make new movie with primary polarization + newmov = Movie(framelist, times, + psize0, ra0, dec0, interp=interp, bounds_error=bounds_error, + polrep=polrep0, pol_prim=pol_prim0, + rf=rf0, source=src0, mjd=mjd0, pulse=pulse) + + # Copy over all polarization movies + for pol in list(movdict.keys()): + if pol == newmov.pol_prim: + continue + polframes = np.array(movdict[pol]) + if len(polframes): + polframes = polframes.reshape((newmov.nframes, newmov.ydim, newmov.xdim)) + newmov.add_pol_movie(polframes, pol) + + return newmov + + +def load_hdf5(file_name, + pulse=ehc.PULSE_DEFAULT, interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR): + """Read in a movie from an hdf5 file and create a Movie object. + + Args: + file_name (str): The name of the hdf5 file. + pulse (function): The function convolved with the pixel values for continuous image + interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword + bounds_error (bool): if False, return nearest frame outside interval [start_hr, stop_hr] + + Returns: + Movie: a Movie object + """ + + return ehtim.io.load.load_movie_hdf5(file_name, pulse=pulse, interp=interp, + bounds_error=bounds_error) + + +def load_txt(basename, nframes, + framedur=-1, pulse=ehc.PULSE_DEFAULT, + polrep='stokes', pol_prim=None, zero_pol=True, + interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR): + """Read in a movie from text files and create a Movie object. + + Args: + basename (str): The base name of individual movie frames. + Files should have names basename + 00001, etc. + nframes (int): The total number of frames + framedur (float): The frame duration in seconds + if famedur==-1, frame duration taken from file headers + pulse (function): The function convolved with the pixel values for continuous image + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + zero_pol (bool): If True, loads any missing polarizations as zeros + interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword + bounds_error (bool): if False, return nearest frame outside interval [start_hr, stop_hr] + + Returns: + Movie: a Movie object + """ + + return ehtim.io.load.load_movie_txt(basename, nframes, framedur=framedur, pulse=pulse, + polrep=polrep, pol_prim=pol_prim, zero_pol=zero_pol, + interp=interp, bounds_error=bounds_error) + + +def load_fits(basename, nframes, + framedur=-1, pulse=ehc.PULSE_DEFAULT, + polrep='stokes', pol_prim=None, zero_pol=True, + interp=ehc.INTERP_DEFAULT, bounds_error=ehc.BOUNDS_ERROR): + """Read in a movie from fits files and create a Movie object. + + Args: + basename (str): The base name of individual movie frames. + Files should have names basename + 00001, etc. + nframes (int): The total number of frames + framedur (float): The frame duration in seconds. + if famedur==-1, frame duration taken from file headers + pulse (function): The function convolved with the pixel values for continuous image + polrep (str): polarization representation, either 'stokes' or 'circ' + pol_prim (str): The default image: I,Q,U or V for Stokes, RR,LL,LR,RL for Circular + zero_pol (bool): If True, loads any missing polarizations as zeros + interp (str): Interpolation method, input to scipy.interpolate.interp1d kind keyword + bounds_error (bool): if False, return nearest frame outside interval [start_hr, stop_hr] + + Returns: + Movie: a Movie object + """ + + return ehtim.io.load.load_movie_fits(basename, nframes, framedur=framedur, pulse=pulse, + polrep=polrep, pol_prim=pol_prim, zero_pol=zero_pol, + interp=interp, bounds_error=bounds_error) diff --git a/obsdata.py b/obsdata.py new file mode 100644 index 00000000..532b9217 --- /dev/null +++ b/obsdata.py @@ -0,0 +1,5006 @@ +# obsdata.py +# a interferometric observation class +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import string +import copy +import numpy as np +import numpy.lib.recfunctions as rec +import matplotlib.pyplot as plt +import scipy.optimize as opt +import scipy.spatial as spatial +import itertools as it +import sys + +try: + import pandas as pd +except ImportError: + print("Warning: pandas not installed!") + print("Please install pandas to use statistics package!") + + +import ehtim.image +import ehtim.io.save +import ehtim.io.load +import ehtim.const_def as ehc +import ehtim.observing.obs_helpers as obsh +import ehtim.statistics.dataframes as ehdf + +import warnings +warnings.filterwarnings("ignore", + message="Casting complex values to real discards the imaginary part") + +RAPOS = 0 +DECPOS = 1 +RFPOS = 2 +BWPOS = 3 +DATPOS = 4 +TARRPOS = 5 + +################################################################################################## +# Obsdata object +################################################################################################## + + +class Obsdata(object): + + """A polarimetric VLBI observation of visibility amplitudes and phases (in Jy). + + Attributes: + source (str): The source name + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + mjd (int): The integer MJD of the observation + tstart (float): The start time of the observation in hours + tstop (float): The end time of the observation in hours + rf (float): The observation frequency in Hz + bw (float): The observation bandwidth in Hz + timetype (str): How to interpret tstart and tstop; either 'GMST' or 'UTC' + polrep (str): polarization representation, either 'stokes' or 'circ' + + tarr (numpy.recarray): The array of telescope data with datatype DTARR + tkey (dict): A dictionary of rows in the tarr for each site name + data (numpy.recarray): the basic data with datatype DTPOL_STOKES or DTPOL_CIRC + scantable (numpy.recarray): The array of scan information + + ampcal (bool): True if amplitudes calibrated + phasecal (bool): True if phases calibrated + opacitycal (bool): True if time-dependent opacities correctly accounted for in sigmas + frcal (bool): True if feed rotation calibrated out of visibilities + dcal (bool): True if D terms calibrated out of visibilities + + amp (numpy.recarray): An array of (averaged) visibility amplitudes + bispec (numpy.recarray): An array of (averaged) bispectra + cphase (numpy.recarray): An array of (averaged) closure phases + cphase_diag (numpy.recarray): An array of (averaged) diagonalized closure phases + camp (numpy.recarray): An array of (averaged) closure amplitudes + logcamp (numpy.recarray): An array of (averaged) log closure amplitudes + logcamp_diag (numpy.recarray): An array of (averaged) diagonalized log closure amps + """ + + def __init__(self, ra, dec, rf, bw, datatable, tarr, scantable=None, + polrep='stokes', source=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, timetype='UTC', + ampcal=True, phasecal=True, opacitycal=True, dcal=True, frcal=True, + trial_speedups=False, reorder=True): + """A polarimetric VLBI observation of visibility amplitudes and phases (in Jy). + + Args: + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + rf (float): The observation frequency in Hz + bw (float): The observation bandwidth in Hz + + datatable (numpy.recarray): the basic data with datatype DTPOL_STOKES or DTPOL_CIRC + tarr (numpy.recarray): The array of telescope data with datatype DTARR + scantable (numpy.recarray): The array of scan information + + polrep (str): polarization representation, either 'stokes' or 'circ' + source (str): The source name + mjd (int): The integer MJD of the observation + timetype (str): How to interpret tstart and tstop; either 'GMST' or 'UTC' + + ampcal (bool): True if amplitudes calibrated + phasecal (bool): True if phases calibrated + opacitycal (bool): True if time-dependent opacities correctly accounted for in sigmas + frcal (bool): True if feed rotation calibrated out of visibilities + dcal (bool): True if D terms calibrated out of visibilities + + Returns: + obsdata (Obsdata): an Obsdata object + """ + + if len(datatable) == 0: + raise Exception("No data in input table!") + if not (datatable.dtype in [ehc.DTPOL_STOKES, ehc.DTPOL_CIRC]): + raise Exception("Data table dtype should be DTPOL_STOKES or DTPOL_CIRC") + + # Polarization Representation + if polrep == 'stokes': + self.polrep = 'stokes' + self.poldict = ehc.POLDICT_STOKES + self.poltype = ehc.DTPOL_STOKES + elif polrep == 'circ': + self.polrep = 'circ' + self.poldict = ehc.POLDICT_CIRC + self.poltype = ehc.DTPOL_CIRC + else: + raise Exception("only 'stokes' and 'circ' are supported polreps!") + + # Set the various observation parameters + self.source = str(source) + self.ra = float(ra) + self.dec = float(dec) + self.rf = float(rf) + self.bw = float(bw) + self.ampcal = bool(ampcal) + self.phasecal = bool(phasecal) + self.opacitycal = bool(opacitycal) + self.dcal = bool(dcal) + self.frcal = bool(frcal) + + if timetype not in ['GMST', 'UTC']: + raise Exception("timetype must be 'GMST' or 'UTC'") + self.timetype = timetype + + # Save the data + self.data = datatable + self.scans = scantable + + # Telescope array: default ordering is by sefd + self.tarr = tarr + self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))} + if np.any(self.tarr['sefdr'] != 0) or np.any(self.tarr['sefdl'] != 0): + self.reorder_tarr_sefd(reorder_baselines=False) + + # reorder baselines to uvfits convention + if reorder: + self.reorder_baselines(trial_speedups=trial_speedups) # comment out for Closure Invariants + else: + self.data = np.array(sorted(self.data, key=lambda x: x['time'])) + + + # Get tstart, mjd and tstop + times = self.unpack(['time'])['time'] + self.tstart = times[0] + self.mjd = int(mjd) + self.tstop = times[-1] + if self.tstop < self.tstart: + self.tstop += 24.0 + + # Saved closure quantity arrays + self.amp = None + self.bispec = None + self.cphase = None + self.cphase_diag = None + self.camp = None + self.logcamp = None + self.logcamp_diag = None + + @property + def tarr(self): + return self._tarr + + @tarr.setter + def tarr(self, tarr): + self._tarr = tarr + self.tkey = {tarr[i]['site']: i for i in range(len(tarr))} + + def obsdata_args(self): + """"Copy arguments for making a new Obsdata into a list and dictonary + """ + + arglist = [self.ra, self.dec, self.rf, self.bw, self.data, self.tarr] + argdict = {'scantable': self.scans, 'polrep': self.polrep, 'source': self.source, + 'mjd': self.mjd, 'timetype': self.timetype, + 'ampcal': self.ampcal, 'phasecal': self.phasecal, 'opacitycal': self.opacitycal, + 'dcal': self.dcal, 'frcal': self.frcal} + return (arglist, argdict) + + def copy(self): + """Copy the observation object. + + Args: + + Returns: + (Obsdata): a copy of the Obsdata object. + """ + + # TODO: Do we want to copy over e.g. closure tables? + newobs = copy.deepcopy(self) + + return newobs + + def switch_timetype(self, timetype_out='UTC'): + """Return a new observation with the time type switched + + Args: + timetype (str): "UTC" or "GMST" + + Returns: + (Obsdata): new Obsdata object with potentially different timetype + """ + + if timetype_out not in ['GMST', 'UTC']: + raise Exception("timetype_out must be 'GMST' or 'UTC'") + + out = self.copy() + if timetype_out == self.timetype: + return out + + if timetype_out == 'UTC': + out.data['time'] = obsh.gmst_to_utc(out.data['time'], out.mjd) + if timetype_out == 'GMST': + out.data['time'] = obsh.utc_to_gmst(out.data['time'], out.mjd) + + out.timetype = timetype_out + return out + + def switch_polrep(self, polrep_out='stokes', allow_singlepol=True, singlepol_hand='R'): + """Return a new observation with the polarization representation changed + + Args: + polrep_out (str): the polrep of the output data + allow_singlepol (bool): If True, treat single-polarization data as Stokes I + when converting from 'circ' polrep to 'stokes' + singlepol_hand (str): 'R' or 'L'; determines which parallel-hand is assumed + when converting 'stokes' to 'circ' if only I is present + + Returns: + (Obsdata): new Obsdata object with potentially different polrep + """ + + if polrep_out not in ['stokes', 'circ']: + raise Exception("polrep_out must be either 'stokes' or 'circ'") + if polrep_out == self.polrep: + return self.copy() + elif polrep_out == 'stokes': # circ -> stokes + data = np.empty(len(self.data), dtype=ehc.DTPOL_STOKES) + rrmask = np.isnan(self.data['rrvis']) + llmask = np.isnan(self.data['llvis']) + + for f in ehc.DTPOL_STOKES: + f = f[0] + if f in ['time', 'tint', 't1', 't2', 'tau1', 'tau2', 'u', 'v']: + data[f] = self.data[f] + elif f == 'vis': + data[f] = 0.5 * (self.data['rrvis'] + self.data['llvis']) + elif f == 'qvis': + data[f] = 0.5 * (self.data['lrvis'] + self.data['rlvis']) + elif f == 'uvis': + data[f] = 0.5j * (self.data['lrvis'] - self.data['rlvis']) + elif f == 'vvis': + data[f] = 0.5 * (self.data['rrvis'] - self.data['llvis']) + elif f in ['sigma', 'vsigma']: + data[f] = 0.5 * np.sqrt(self.data['rrsigma']**2 + self.data['llsigma']**2) + elif f in ['qsigma', 'usigma']: + data[f] = 0.5 * np.sqrt(self.data['rlsigma']**2 + self.data['lrsigma']**2) + + if allow_singlepol: + # In cases where only one polarization is present + # use it as an estimator for Stokes I + data['vis'][rrmask] = self.data['llvis'][rrmask] + data['sigma'][rrmask] = self.data['llsigma'][rrmask] + + data['vis'][llmask] = self.data['rrvis'][llmask] + data['sigma'][llmask] = self.data['rrsigma'][llmask] + + elif polrep_out == 'circ': # stokes -> circ + data = np.empty(len(self.data), dtype=ehc.DTPOL_CIRC) + Vmask = np.isnan(self.data['vvis']) + + for f in ehc.DTPOL_CIRC: + f = f[0] + if f in ['time', 'tint', 't1', 't2', 'tau1', 'tau2', 'u', 'v']: + data[f] = self.data[f] + elif f == 'rrvis': + data[f] = (self.data['vis'] + self.data['vvis']) + elif f == 'llvis': + data[f] = (self.data['vis'] - self.data['vvis']) + elif f == 'rlvis': + data[f] = (self.data['qvis'] + 1j * self.data['uvis']) + elif f == 'lrvis': + data[f] = (self.data['qvis'] - 1j * self.data['uvis']) + elif f in ['rrsigma', 'llsigma']: + data[f] = np.sqrt(self.data['sigma']**2 + self.data['vsigma']**2) + elif f in ['rlsigma', 'lrsigma']: + data[f] = np.sqrt(self.data['qsigma']**2 + self.data['usigma']**2) + + if allow_singlepol: + # In cases where only Stokes I is present, copy it to a specified parallel-hand + prefix = singlepol_hand.lower() + singlepol_hand.lower() # rr or ll + if prefix not in ['rr', 'll']: + raise Exception('singlepol_hand must be R or L') + + data[prefix + 'vis'][Vmask] = self.data['vis'][Vmask] + data[prefix + 'sigma'][Vmask] = self.data['sigma'][Vmask] + + arglist, argdict = self.obsdata_args() + arglist[DATPOS] = data + argdict['polrep'] = polrep_out + newobs = Obsdata(*arglist, **argdict) + + return newobs + + def reorder_baselines(self, trial_speedups=False): + """Reorder baselines to match uvfits convention, based on the telescope array ordering + """ + if trial_speedups: + self.reorder_baselines_trial_speedups() + + else: # original code + + # Time partition the datatable + datatable = self.data.copy() + datalist = [] + for key, group in it.groupby(datatable, lambda x: x['time']): + #print(key*60*60,len(list(group))) + datalist.append(np.array([obs for obs in group])) + + # loop through all data + obsdata = [] + for tlist in datalist: + blpairs = [] + for dat in tlist: + # Remove conjugate baselines + if not (set((dat['t1'], dat['t2']))) in blpairs: + + # Reverse the baseline in the right order for uvfits: + if(self.tkey[dat['t2']] < self.tkey[dat['t1']]): + + (dat['t1'], dat['t2']) = (dat['t2'], dat['t1']) + (dat['tau1'], dat['tau2']) = (dat['tau2'], dat['tau1']) + dat['u'] = -dat['u'] + dat['v'] = -dat['v'] + + if self.polrep == 'stokes': + dat['vis'] = np.conj(dat['vis']) + dat['qvis'] = np.conj(dat['qvis']) + dat['uvis'] = np.conj(dat['uvis']) + dat['vvis'] = np.conj(dat['vvis']) + elif self.polrep == 'circ': + dat['rrvis'] = np.conj(dat['rrvis']) + dat['llvis'] = np.conj(dat['llvis']) + # must switch l & r !! + rl = dat['rlvis'].copy() + lr = dat['lrvis'].copy() + dat['rlvis'] = np.conj(lr) + dat['lrvis'] = np.conj(rl) + + # You also have to switch the errors for the coherency! + rlerr = dat['rlsigma'].copy() + lrerr = dat['lrsigma'].copy() + dat["rlsigma"] = lrerr + dat["lrsigma"] = rlerr + + else: + raise Exception("polrep must be either 'stokes' or 'circ'") + + # Append the data point + blpairs.append(set((dat['t1'], dat['t2']))) + obsdata.append(dat) + + obsdata = np.array(obsdata, dtype=self.poltype) + + # Timesort data + obsdata = obsdata[np.argsort(obsdata, order=['time', 't1'])] + + # Save the data + self.data = obsdata + + return + + def reorder_baselines_trial_speedups(self): + """Reorder baselines to match uvfits convention, based on the telescope array ordering + """ + + dat = self.data.copy() + + ############ Ensure correct baseline order + # TODO can these be faster? + t1nums = np.fromiter([self.tkey[t] for t in dat['t1']],int) + t2nums = np.fromiter([self.tkey[t] for t in dat['t2']],int) + + # which entries are in the wrong telescope order? + ordermask = t2nums < t1nums + + # flip the order of these entries + t1 = dat['t1'].copy() + t2 = dat['t2'].copy() + tau1 = dat['tau1'].copy() + tau2 = dat['tau2'].copy() + + dat['t1'][ordermask] = t2[ordermask] + dat['t2'][ordermask] = t1[ordermask] + dat['tau1'][ordermask] = tau2[ordermask] + dat['tau2'][ordermask] = tau1[ordermask] + dat['u'][ordermask] *= -1 + dat['v'][ordermask] *= -1 + + if self.polrep=='stokes': + dat['vis'][ordermask] = np.conj(dat['vis'][ordermask]) + dat['qvis'][ordermask] = np.conj(dat['qvis'][ordermask]) + dat['uvis'][ordermask] = np.conj(dat['uvis'][ordermask]) + dat['vvis'][ordermask] = np.conj(dat['vvis'][ordermask]) + + elif self.polrep == 'circ': + dat['rrvis'][ordermask] = np.conj(dat['rrvis'][ordermask]) + dat['llvis'][ordermask] = np.conj(dat['llvis'][ordermask]) + rl = dat['rlvis'].copy() + lr = dat['lrvis'].copy() + dat['rlvis'][ordermask] = np.conj(lr[ordermask]) + dat['lrvis'][ordermask] = np.conj(rl[ordermask]) + + # Also need to switch error matrix + rle = dat['rlsigma'].copy() + lre = dat['lrsigma'].copy() + dat['rlsigma'][ordermask] = lre[ordermask] + dat['lrsigma'][ordermask] = rle[ordermask] + + else: + raise Exception("polrep must be either 'stokes' or 'circ'") + + # Remove duplicate or conjugate entries at any timestep + # Since telescope order has been sorted conjugates should appear as duplicates + timeblcombos = np.vstack((dat['time'],t1nums,t2nums)).T + uniqdat, uniqdatinv = np.unique(timeblcombos,axis=0, return_inverse=True) + + if len(uniqdat) != len(dat): + print("WARNING: removing duplicate/conjuagte points in reorder_baselines!") + deletemask = np.ones(len(dat)).astype(bool) + + for j in len(uniqdat): + idxs = np.argwhere(uniqdatinv==j)[:,0] + for idx in idxs[1:]: # delete all but first occurance + deletemask[idx] = False + + # remove duplicates + dat_unique = dat[deletemask] + + # sort data + dat = dat[np.argsort(dat, order=['time', 't1'])] + + # save the data + self.data = dat + + return + + def reorder_tarr_sefd(self, reorder_baselines=True): + """Reorder the telescope array by SEFD minimal to maximum. + """ + + sorted_list = sorted(self.tarr, key=lambda x: np.sqrt(x['sefdr']**2 + x['sefdl']**2)) + self.tarr = np.array(sorted_list, dtype=ehc.DTARR) + self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))} + if reorder_baselines: + self.reorder_baselines() + + return + + def reorder_tarr_snr(self, reorder_baselines=True): + """Reorder the telescope array by median SNR maximal to minimal. + """ + + snr = self.unpack(['t1', 't2', 'snr']) + snr_median = [np.median(snr[(snr['t1'] == site) + (snr['t2'] == site)]['snr']) + for site in self.tarr['site']] + idx = np.argsort(snr_median)[::-1] + self.tarr = self.tarr[idx] + self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))} + if reorder_baselines: + self.reorder_baselines() + + return + + def reorder_tarr_random(self, reorder_baselines=True): + """Randomly reorder the telescope array. + """ + + idx = np.arange(len(self.tarr)) + np.random.shuffle(idx) + self.tarr = self.tarr[idx] + self.tkey = {self.tarr[i]['site']: i for i in range(len(self.tarr))} + if reorder_baselines: + self.reorder_baselines() + + return + + def data_conj(self): + """Make a data array including all conjugate baselines. + + Args: + + Returns: + (numpy.recarray): a copy of the Obsdata.data table including all conjugate baselines. + """ + + data = np.empty(2 * len(self.data), dtype=self.poltype) + + # Add the conjugate baseline data + for f in self.poltype: + f = f[0] + if f in ['t1', 't2', 'tau1', 'tau2']: + if f[-1] == '1': + f2 = f[:-1] + '2' + else: + f2 = f[:-1] + '1' + data[f] = np.hstack((self.data[f], self.data[f2])) + + elif f in ['u', 'v']: + data[f] = np.hstack((self.data[f], -self.data[f])) + + elif f in [self.poldict['vis1'], self.poldict['vis2'], + self.poldict['vis3'], self.poldict['vis4']]: + if self.polrep == 'stokes': + data[f] = np.hstack((self.data[f], np.conj(self.data[f]))) + elif self.polrep == 'circ': + if f in ['rrvis', 'llvis']: + data[f] = np.hstack((self.data[f], np.conj(self.data[f]))) + elif f == 'rlvis': + data[f] = np.hstack((self.data['rlvis'], np.conj(self.data['lrvis']))) + elif f == 'lrvis': + data[f] = np.hstack((self.data['lrvis'], np.conj(self.data['rlvis']))) + + # ALSO SWITCH THE ERRORS! + else: + raise Exception("polrep must be either 'stokes' or 'circ'") + # The conjugate baselines need the transpose error terms. + elif f == "rlsigma": + data[f] = np.hstack((self.data["rlsigma"], self.data["lrsigma"])) + elif f == "lrsigma": + data[f] = np.hstack((self.data["lrsigma"], self.data["rlsigma"])) + + else: + data[f] = np.hstack((self.data[f], self.data[f])) + + # Sort the data by time + data = data[np.argsort(data['time'])] + + return data + + def tlist(self, conj=False, t_gather=0., scan_gather=False): + """Group the data in a list of equal time observation datatables. + + Args: + conj (bool): True if tlist_out includes conjugate baselines. + t_gather (float): Grouping timescale (in seconds). 0.0 indicates no grouping. + scan_gather (bool): If true, gather data into scans + + Returns: + (list): a list of data tables containing time-partitioned data + """ + + if conj: + data = self.data_conj() + else: + data = self.data + + # partition the data by time + datalist = [] + + if t_gather <= 0.0 and not scan_gather: + # Only group measurements at the same time + for key, group in it.groupby(data, lambda x: x['time']): + datalist.append(np.array([obs for obs in group])) + elif t_gather > 0.0 and not scan_gather: + # Group measurements in time + for key, group in it.groupby(data, lambda x: int(x['time'] / (t_gather / 3600.0))): + datalist.append(np.array([obs for obs in group])) + else: + # Group measurements by scan + if ((self.scans is None) or + np.any([scan is None for scan in self.scans]) or + len(self.scans) == 0): + print("No scan table in observation. Adding scan table before gathering...") + self.add_scans() + + for key, group in it.groupby( + data, lambda x: np.searchsorted(self.scans[:, 0], x['time'])): + datalist.append(np.array([obs for obs in group])) + + # return np.array(datalist, dtype=object) + return datalist + + + def split_obs(self, t_gather=0., scan_gather=False): + """Split single observation into multiple observation files, one per scan.. + + Args: + t_gather (float): Grouping timescale (in seconds). 0.0 indicates no grouping. + scan_gather (bool): If true, gather data into scans + + Returns: + (list): list of single-scan Obsdata objects + """ + + tlist = self.tlist(t_gather=t_gather, scan_gather=scan_gather) + + print("Splitting Observation File into " + str(len(tlist)) + " times") + arglist, argdict = self.obsdata_args() + + # note that the tarr of the output includes all sites, + # even those that don't participate in the scan + splitlist = [] + for tdata in tlist: + arglist[DATPOS] = tdata + splitlist.append(Obsdata(*arglist, **argdict)) + + return splitlist + + + def getClosestScan(self, time, splitObs=None): + """Split observation by scan and grab scan closest to timestamp + + Args: + time (float): Time (GMST) you want to find the scan closest to + splitObs (bool): a list of Obsdata objects, output from split_obs, to save time + + Returns: + (Obsdata): Obsdata object composed of scan closest to time + """ + + ## check if splitObs has been passed in alread ## + if splitObs is None: + splitObs = self.split_obs() + + ## check for the scan with the closest start time to time arg ## + ## TODO: allow user to choose start time, end time, or mid-time + closest_index = 0 + delta_t = 1e22 + for s, s_obs in enumerate(splitObs): + dt = abs(s_obs.tstart - time) + if dt < delta_t: + delta_t = dt + closest_index = s + + print(f"Using scan with time {splitObs[closest_index].tstart}.") + return splitObs[closest_index] + + + def bllist(self, conj=False): + """Group the data in a list of same baseline datatables. + + Args: + conj (bool): True if tlist_out includes conjugate baselines. + + Returns: + (list): a list of data tables containing baseline-partitioned data + """ + + if conj: + data = self.data_conj() + else: + data = self.data + + # partition the data by baseline + datalist = [] + idx = np.lexsort((data['t2'], data['t1'])) + for key, group in it.groupby(data[idx], lambda x: set((x['t1'], x['t2']))): + datalist.append(np.array([obs for obs in group])) + + return np.array(datalist, dtype=object) + + def unpack_bl(self, site1, site2, fields, ang_unit='deg', debias=False, timetype=False): + """Unpack the data over time on the selected baseline site1-site2. + + Args: + site1 (str): First site name + site2 (str): Second site name + fields (list): list of unpacked quantities from available quantities in FIELDS + ang_unit (str): 'deg' for degrees and 'rad' for radian phases + debias (bool): True to debias visibility amplitudes + timetype (str): 'GMST' or 'UTC' changes what is returned for 'time' + + Returns: + (numpy.recarray): unpacked numpy array with data in fields requested + """ + + if timetype is False: + timetype = self.timetype + + # If we only specify one field + if timetype not in ['GMST', 'UTC', 'utc', 'gmst']: + raise Exception("timetype should be 'GMST' or 'UTC'!") + allfields = ['time'] + + if not isinstance(fields, list): + allfields.append(fields) + else: + for i in range(len(fields)): + allfields.append(fields[i]) + + # Get the data from data table on the selected baseline + allout = [] + tlist = self.tlist(conj=True) + for scan in tlist: + for obs in scan: + if (obs['t1'], obs['t2']) == (site1, site2): + obs = np.array([obs]) + out = self.unpack_dat(obs, allfields, ang_unit=ang_unit, + debias=debias, timetype=timetype) + + allout.append(out) + + return np.array(allout) + + def unpack(self, fields, mode='all', ang_unit='deg', debias=False, conj=False, timetype=False): + """Unpack the data for the whole observation . + + Args: + fields (list): list of unpacked quantities from availalbe quantities in FIELDS + mode (str): 'all' returns all data in single table, + 'time' groups output by equal time, 'bl' groups by baseline + ang_unit (str): 'deg' for degrees and 'rad' for radian phases + debias (bool): True to debias visibility amplitudes + conj (bool): True to include conjugate baselines + timetype (str): 'GMST' or 'UTC' changes what is returned for 'time' + + Returns: + (numpy.recarray): unpacked numpy array with data in fields requested + + """ + + if mode not in ('time', 'all', 'bl'): + raise Exception("possible options for mode are 'time', 'all' and 'bl'") + + # If we only specify one field + if not isinstance(fields, list): + fields = [fields] + + if mode == 'all': + if conj: + data = self.data_conj() + else: + data = self.data + allout = self.unpack_dat(data, fields, ang_unit=ang_unit, + debias=debias, timetype=timetype) + + elif mode == 'time': + allout = [] + tlist = self.tlist(conj=True) + for scan in tlist: + out = self.unpack_dat(scan, fields, ang_unit=ang_unit, + debias=debias, timetype=timetype) + allout.append(out) + + elif mode == 'bl': + allout = [] + bllist = self.bllist() + for bl in bllist: + out = self.unpack_dat(bl, fields, ang_unit=ang_unit, + debias=debias, timetype=timetype) + allout.append(out) + + return allout + + def unpack_dat(self, data, fields, conj=False, ang_unit='deg', debias=False, timetype=False): + """Unpack the data in a passed data recarray. + + Args: + data (numpy.recarray): data recarray of format DTPOL_STOKES or DTPOL_CIRC + fields (list): list of unpacked quantities from availalbe quantities in FIELDS + conj (bool): True to include conjugate baselines + ang_unit (str): 'deg' for degrees and 'rad' for radian phases + debias (bool): True to debias visibility amplitudes + timetype (str): 'GMST' or 'UTC' changes what is returned for 'time' + + Returns: + (numpy.recarray): unpacked numpy array with data in fields requested + + """ + + if ang_unit == 'deg': + angle = ehc.DEGREE + else: + angle = 1.0 + + # If we only specify one field + if isinstance(fields, str): + fields = [fields] + + if not timetype: + timetype = self.timetype + if timetype not in ['GMST', 'UTC', 'gmst', 'utc']: + raise Exception("timetype should be 'GMST' or 'UTC'!") + + # Get field data + allout = [] + for field in fields: + if field in ["time", "time_utc", "time_gmst"]: + out = data['time'] + ty = 'f8' + elif field in ["u", "v", "tint", "tau1", "tau2"]: + out = data[field] + ty = 'f8' + elif field in ["uvdist"]: + out = np.abs(data['u'] + 1j * data['v']) + ty = 'f8' + elif field in ["t1", "el1", "par_ang1", "hr_ang1"]: + sites = data["t1"] + keys = [self.tkey[site] for site in sites] + tdata = self.tarr[keys] + out = sites + ty = 'U32' + elif field in ["t2", "el2", "par_ang2", "hr_ang2"]: + sites = data["t2"] + keys = [self.tkey[site] for site in sites] + tdata = self.tarr[keys] + out = sites + ty = 'U32' + elif field in ['vis', 'amp', 'phase', 'snr', 'sigma', 'sigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = data['vis'] + sig = data['sigma'] + elif self.polrep == 'circ': + out = 0.5 * (data['rrvis'] + data['llvis']) + sig = 0.5 * np.sqrt(data['rrsigma']**2 + data['llsigma']**2) + elif field in ['qvis', 'qamp', 'qphase', 'qsnr', 'qsigma', 'qsigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = data['qvis'] + sig = data['qsigma'] + elif self.polrep == 'circ': + out = 0.5 * (data['lrvis'] + data['rlvis']) + sig = 0.5 * np.sqrt(data['lrsigma']**2 + data['rlsigma']**2) + elif field in ['uvis', 'uamp', 'uphase', 'usnr', 'usigma', 'usigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = data['uvis'] + sig = data['usigma'] + elif self.polrep == 'circ': + out = 0.5j * (data['lrvis'] - data['rlvis']) + sig = 0.5 * np.sqrt(data['lrsigma']**2 + data['rlsigma']**2) + elif field in ['vvis', 'vamp', 'vphase', 'vsnr', 'vsigma', 'vsigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = data['vvis'] + sig = data['vsigma'] + elif self.polrep == 'circ': + out = 0.5 * (data['rrvis'] - data['llvis']) + sig = 0.5 * np.sqrt(data['rrsigma']**2 + data['llsigma']**2) + elif field in ['pvis', 'pamp', 'pphase', 'psnr', 'psigma', 'psigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = data['qvis'] + 1j * data['uvis'] + sig = np.sqrt(data['qsigma']**2 + data['usigma']**2) + elif self.polrep == 'circ': + out = data['rlvis'] + sig = data['rlsigma'] + elif field in ['m', 'mamp', 'mphase', 'msnr', 'msigma', 'msigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = (data['qvis'] + 1j * data['uvis']) / data['vis'] + sig = obsh.merr(data['sigma'], data['qsigma'], data['usigma'], data['vis'], out) + elif self.polrep == 'circ': + out = 2 * data['rlvis'] / (data['rrvis'] + data['llvis']) + sig = obsh.merr2(data['rlsigma'], data['rrsigma'], data['llsigma'], + 0.5 * (data['rrvis'] + data['llvis']), out) + elif field in ['evis', 'eamp', 'ephase', 'esnr', 'esigma', 'esigma_phase']: + ty = 'c16' + ang = np.arctan2(data['u'], data['v']) # TODO: correct convention EofN? + if self.polrep == 'stokes': + q = data['qvis'] + u = data['uvis'] + qsig = data['qsigma'] + usig = data['usigma'] + elif self.polrep == 'circ': + q = 0.5 * (data['lrvis'] + data['rlvis']) + u = 0.5j * (data['lrvis'] - data['rlvis']) + qsig = 0.5 * np.sqrt(data['lrsigma']**2 + data['rlsigma']**2) + usig = qsig + out = (np.cos(2 * ang) * q + np.sin(2 * ang) * u) + sig = np.sqrt(0.5 * ((np.cos(2 * ang) * qsig)**2 + (np.sin(2 * ang) * usig)**2)) + elif field in ['bvis', 'bamp', 'bphase', 'bsnr', 'bsigma', 'bsigma_phase']: + ty = 'c16' + ang = np.arctan2(data['u'], data['v']) # TODO: correct convention EofN? + if self.polrep == 'stokes': + q = data['qvis'] + u = data['uvis'] + qsig = data['qsigma'] + usig = data['usigma'] + elif self.polrep == 'circ': + q = 0.5 * (data['lrvis'] + data['rlvis']) + u = 0.5j * (data['lrvis'] - data['rlvis']) + qsig = 0.5 * np.sqrt(data['lrsigma']**2 + data['rlsigma']**2) + usig = qsig + out = (-np.sin(2 * ang) * q + np.cos(2 * ang) * u) + sig = np.sqrt(0.5 * ((np.sin(2 * ang) * qsig)**2 + (np.cos(2 * ang) * usig)**2)) + elif field in ['rrvis', 'rramp', 'rrphase', 'rrsnr', 'rrsigma', 'rrsigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = data['vis'] + data['vvis'] + sig = np.sqrt(data['sigma']**2 + data['vsigma']**2) + elif self.polrep == 'circ': + out = data['rrvis'] + sig = data['rrsigma'] + elif field in ['llvis', 'llamp', 'llphase', 'llsnr', 'llsigma', 'llsigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = data['vis'] - data['vvis'] + sig = np.sqrt(data['sigma']**2 + data['vsigma']**2) + elif self.polrep == 'circ': + out = data['llvis'] + sig = data['llsigma'] + elif field in ['rlvis', 'rlamp', 'rlphase', 'rlsnr', 'rlsigma', 'rlsigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = data['qvis'] + 1j * data['uvis'] + sig = np.sqrt(data['qsigma']**2 + data['usigma']**2) + elif self.polrep == 'circ': + out = data['rlvis'] + sig = data['rlsigma'] + elif field in ['lrvis', 'lramp', 'lrphase', 'lrsnr', 'lrsigma', 'lrsigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = data['qvis'] - 1j * data['uvis'] + sig = np.sqrt(data['qsigma']**2 + data['usigma']**2) + elif self.polrep == 'circ': + out = data['lrvis'] + sig = data['lrsigma'] + elif field in ['rrllvis', 'rrllamp', 'rrllphase', 'rrllsnr', + 'rrllsigma', 'rrllsigma_phase']: + ty = 'c16' + if self.polrep == 'stokes': + out = (data['vis'] + data['vvis']) / (data['vis'] - data['vvis']) + sig = (2.0**0.5 * (np.abs(data['vis'])**2 + np.abs(data['vvis'])**2)**0.5 + / np.abs(data['vis'] - data['vvis'])**2 + * (data['sigma']**2 + data['vsigma']**2)**0.5) + elif self.polrep == 'circ': + out = data['rrvis'] / data['llvis'] + sig = np.sqrt(np.abs(data['rrsigma'] / data['llvis'])**2 + + np.abs(data['llsigma'] * data['rrvis'] / data['llvis'])**2) + + else: + raise Exception("%s is not a valid field \n" % field + + "valid field values are: " + ' '.join(ehc.FIELDS)) + + if field in ["time_utc"] and self.timetype == 'GMST': + out = obsh.gmst_to_utc(out, self.mjd) + if field in ["time_gmst"] and self.timetype == 'UTC': + out = obsh.utc_to_gmst(out, self.mjd) + if field in ["time"] and self.timetype == 'GMST' and timetype == 'UTC': + out = obsh.gmst_to_utc(out, self.mjd) + if field in ["time"] and self.timetype == 'UTC' and timetype == 'GMST': + out = obsh.utc_to_gmst(out, self.mjd) + + # Compute elevation and parallactic angles + if field in ["el1", "el2", "hr_ang1", "hr_ang2", "par_ang1", "par_ang2"]: + if self.timetype == 'GMST': + times_sid = data['time'] + else: + times_sid = obsh.utc_to_gmst(data['time'], self.mjd) + + thetas = np.mod((times_sid - self.ra) * ehc.HOUR, 2 * np.pi) + coords = obsh.recarr_to_ndarr(tdata[['x', 'y', 'z']], 'f8') + el_angle = obsh.elev(obsh.earthrot(coords, thetas), self.sourcevec()) + latlon = obsh.xyz_2_latlong(coords) + hr_angles = obsh.hr_angle(times_sid * ehc.HOUR, latlon[:, 1], self.ra * ehc.HOUR) + + if field in ["el1", "el2"]: + out = el_angle / angle + ty = 'f8' + if field in ["hr_ang1", "hr_ang2"]: + out = hr_angles / angle + ty = 'f8' + if field in ["par_ang1", "par_ang2"]: + par_ang = obsh.par_angle(hr_angles, latlon[:, 0], self.dec * ehc.DEGREE) + out = par_ang / angle + ty = 'f8' + + # Get arg/amps/snr + if field in ["amp", "qamp", "uamp", "vamp", "pamp", "mamp", "bamp", "eamp", + "rramp", "llamp", "rlamp", "lramp", "rrllamp"]: + out = np.abs(out) + if debias: + out = obsh.amp_debias(out, sig) + ty = 'f8' + elif field in ["sigma", "qsigma", "usigma", "vsigma", + "psigma", "msigma", "bsigma", "esigma", + "rrsigma", "llsigma", "rlsigma", "lrsigma", "rrllsigma"]: + out = np.abs(sig) + ty = 'f8' + elif field in ["phase", "qphase", "uphase", "vphase", "pphase", "bphase", "ephase", + "mphase", "rrphase", "llphase", "lrphase", "rlphase", "rrllphase"]: + out = np.angle(out) / angle + ty = 'f8' + elif field in ["sigma_phase", "qsigma_phase", "usigma_phase", "vsigma_phase", + "psigma_phase", "msigma_phase", "bsigma_phase", "esigma_phase", + "rrsigma_phase", "llsigma_phase", "rlsigma_phase", "lrsigma_phase", + "rrllsigma_phase"]: + out = np.abs(sig) / np.abs(out) / angle + ty = 'f8' + elif field in ["snr", "qsnr", "usnr", "vsnr", "psnr", "bsnr", "esnr", + "msnr", "rrsnr", "llsnr", "rlsnr", "lrsnr", "rrllsnr"]: + out = np.abs(out) / np.abs(sig) + ty = 'f8' + + # Reshape and stack with other fields + out = np.array(out, dtype=[(field, ty)]) + + if len(allout) > 0: + allout = rec.merge_arrays((allout, out), asrecarray=True, flatten=True) + else: + allout = out + + return allout + + def sourcevec(self): + """Return the source position vector in geocentric coordinates at 0h GMST. + + Args: + + Returns: + (numpy.array): normal vector pointing to source in geocentric coordinates (m) + """ + + sourcevec = np.array([np.cos(self.dec * ehc.DEGREE), 0, np.sin(self.dec * ehc.DEGREE)]) + + return sourcevec + + def res(self): + """Return the nominal resolution (1/longest baseline) of the observation in radians. + + Args: + + Returns: + (float): normal array resolution in radians + """ + + res = 1.0 / np.max(self.unpack('uvdist')['uvdist']) + + return res + + + def chisq(self, im_or_mov, dtype='vis', pol='I', ttype='nfft', mask=[], **kwargs): + """Give the reduced chi^2 of the observation for the specified image and datatype. + + Args: + im_or_mov (Image or Movie): image or movie object on which to test chi^2 + dtype (str): data type of chi^2 (e.g., 'vis', 'amp', 'bs', 'cphase') + pol (str): polarization type ('I', 'Q', 'U', 'V', 'LL', 'RR', 'LR', or 'RL' + mask (arr): mask of same dimension as im.imvec + ttype (str): "fast" or "nfft" or "direct" + fft_pad_factor (float): zero pad the image to (fft_pad_factor * image size) in FFT + conv_func ('str'): The convolving function for gridding; 'gaussian', 'pill','cubic' + p_rad (int): The pixel radius for the convolving function + order ('str'): Interpolation order for sampling the FFT + + systematic_noise (float): adds a fractional systematic noise tolerance to sigmas + snrcut (float): a snr cutoff for including data in the chi^2 sum + debias (bool): if True then apply debiasing to amplitudes/closure amplitudes + weighting (str): 'natural' or 'uniform' + + systematic_cphase_noise (float): a value in degrees to add to closure phase sigmas + cp_uv_min (float): flag short baselines before forming closure quantities + maxset (bool): if True, use maximal set instead of minimal for closure quantities + + Returns: + (float): image chi^2 + """ + if dtype not in ['vis', 'bs', 'amp', 'cphase', + 'cphase_diag', 'camp', 'logcamp', 'logcamp_diag', 'm']: + raise Exception("%s is not a supported dterms!" % dtype) + + # TODO -- should import this at top, but the circular dependencies create a mess... + import ehtim.imaging.imager_utils as iu + import ehtim.modeling.modeling_utils as mu + + # Movie -- weighted sum of all frame chi^2 values + if hasattr(im_or_mov, 'get_image'): + mov = im_or_mov + obs_list = self.split_obs() + + chisq_list = [] + num_list = [] + for ii, obs in enumerate(obs_list): + + im = mov.get_image(obs.data[0]['time']) # Get image at the observation time + + if pol not in im._imdict.keys(): + raise Exception(pol + ' is not in the current image.' + + ' Consider changing the polarization basis of the image.') + + try: + (data, sigma, A) = iu.chisqdata(obs, im, mask, dtype, + pol=pol, ttype=ttype, **kwargs) + + except IndexError: # Not enough data to form closures! + continue + + imvec = im._imdict[pol] + if len(mask) > 0 and np.any(np.invert(mask)): + imvec = imvec[mask] + + chisq_list.append(iu.chisq(imvec, A, data, sigma, dtype, ttype=ttype, mask=mask)) + num_list.append(len(data)) + + chisq = np.sum(np.array(num_list) * np.array(chisq_list)) / np.sum(num_list) + + # Model -- single chi^2 + elif hasattr(im_or_mov,'N_models'): + (data, sigma, uv, jonesdict) = mu.chisqdata(self, dtype, pol, **kwargs) + chisq = mu.chisq(im_or_mov, uv, data, sigma, dtype, jonesdict) + + # Image -- single chi^2 + else: + im = im_or_mov + if pol not in im._imdict.keys(): + raise Exception(pol + ' is not in the current image.' + + ' Consider changing the polarization basis of the image.') + + (data, sigma, A) = iu.chisqdata(self, im, mask, dtype, pol=pol, ttype=ttype, **kwargs) + + imvec = im._imdict[pol] + if len(mask) > 0 and np.any(np.invert(mask)): + imvec = imvec[mask] + + chisq = iu.chisq(imvec, A, data, sigma, dtype, ttype=ttype, mask=mask) + + return chisq + + def polchisq(self, im, dtype='pvis', ttype='nfft', pol_trans=True, mask=[], **kwargs): + """Give the reduced chi^2 for the specified image and polarimetric datatype. + + Args: + im (Image): image to test polarimetric chi^2 + dtype (str): data type of polarimetric chi^2 ('pvis','m','pbs') + pol (str): polarization type ('I', 'Q', 'U', 'V', 'LL', 'RR', 'LR', or 'RL' + mask (arr): mask of same dimension as im.imvec + ttype (str): if "fast" or "nfft" or "direct" + pol_trans (bool): True for I,m,chi, False for IQU + fft_pad_factor (float): zero pad the image to (fft_pad_factor * image size) in FFT + conv_func ('str'): The convolving function for gridding; 'gaussian', 'pill','cubic' + p_rad (int): The pixel radius for the convolving function + order ('str'): Interpolation order for sampling the FFT + + systematic_noise (float): adds a fractional systematic noise tolerance to sigmas + snrcut (float): a snr cutoff for including data in the chi^2 sum + debias (bool): if True then apply debiasing to amplitudes/closure amplitudes + weighting (str): 'natural' or 'uniform' + + systematic_cphase_noise (float): value in degrees to add to closure phase sigmas + cp_uv_min (float): flag short baselines before forming closure quantities + maxset (bool): if True, use maximal set instead of minimal for closure quantities + + Returns: + (float): image chi^2 + """ + + if dtype not in ['pvis', 'm', 'pbs','vvis']: + raise Exception("Only supported polarimetric dterms are 'pvis','m, 'pbs','vvis'!") + + # TODO -- should import this at top, but the circular dependencies create a mess... + import ehtim.imaging.pol_imager_utils as piu + + # Unpack the necessary polarimetric data + (data, sigma, A) = piu.polchisqdata(self, im, mask, dtype, ttype=ttype, **kwargs) + + # Pack the comparison image in the proper format + imstokes = im.switch_polrep(polrep_out='stokes', pol_prim_out='I') + if pol_trans: + ivec = imstokes.imvec + mvec = (np.abs(imstokes.qvec + 1j * imstokes.uvec) / ivec) + chivec = np.angle(imstokes.qvec + 1j * imstokes.uvec) / 2 + vvec = imstokes.vvec/ivec + if len(mask) > 0 and np.any(np.invert(mask)): + ivec = ivec[mask] + mvec = mvec[mask] + chivec = chivec[mask] + vvec = vvec[mask] + imtuple = np.array((ivec, mvec, chivec,vvec)) + else: + ivec = imstokes.imvec + qvec = imstokes.qvec + uvec = imstokes.uvec + vvec = imstokes.vvec + if len(mask) > 0 and np.any(np.invert(mask)): + ivec = ivec[mask] + qvec = qvec[mask] + uvec = uvec[mask] + vvec = vvec[mask] + imtuple = np.array((ivec, qvec, uvec,vvec)) + + + # Calculate the chi^2 + chisq = piu.polchisq(imtuple, A, data, sigma, dtype, + ttype=ttype, mask=mask, pol_trans=pol_trans) + + return chisq + + def recompute_uv(self): + """Recompute u,v points using observation times and metadata + + Args: + + Returns: + (Obsdata): New Obsdata object containing the same data with recomputed u,v points + """ + + times = self.data['time'] + site1 = self.data['t1'] + site2 = self.data['t2'] + arr = ehtim.array.Array(self.tarr) + print("Recomputing U,V Points using MJD %d \n RA %e \n DEC %e \n RF %e GHz" + % (self.mjd, self.ra, self.dec, self.rf / 1.e9)) + + (timesout, uout, vout) = obsh.compute_uv_coordinates(arr, site1, site2, times, + self.mjd, self.ra, self.dec, self.rf, + timetype=self.timetype, + elevmin=0, elevmax=90, no_elevcut_space=False) + + if len(timesout) != len(times): + raise Exception( + "len(timesout) != len(times) in recompute_uv: check elevation limits!!") + + datatable = self.data.copy() + datatable['u'] = uout + datatable['v'] = vout + + arglist, argdict = self.obsdata_args() + arglist[DATPOS] = np.array(datatable) + out = Obsdata(*arglist, **argdict) + + return out + + def avg_coherent(self, inttime, scan_avg=False, moving=False): + """Coherently average data along u,v tracks in chunks of length inttime (sec) + + Args: + inttime (float): coherent integration time in seconds + scan_avg (bool): if True, average over scans in self.scans instead of intime + moving (bool): averaging with moving window (boxcar width in seconds) + Returns: + (Obsdata): Obsdata object containing averaged data + """ + + if (scan_avg) and (getattr(self.scans, "shape", None) is None or len(self.scans) == 0): + print('No scan data, ignoring scan_avg!') + scan_avg = False + + if inttime <= 0.0 and scan_avg is False: + print('No averaging done!') + return self.copy() + + if moving: + vis_avg = ehdf.coh_moving_avg_vis(self, dt=inttime, return_type='rec') + else: + vis_avg = ehdf.coh_avg_vis(self, dt=inttime, return_type='rec', + err_type='predicted', scan_avg=scan_avg) + + arglist, argdict = self.obsdata_args() + arglist[DATPOS] = vis_avg + out = Obsdata(*arglist, **argdict) + + return out + + def avg_incoherent(self, inttime, scan_avg=False, debias=True, err_type='predicted'): + """Incoherently average data along u,v tracks in chunks of length inttime (sec) + + Args: + inttime (float): incoherent integration time in seconds + scan_avg (bool): if True, average over scans in self.scans instead of intime + debias (bool): if True, debias the averaged amplitudes + err_type (str): 'predicted' or 'measured' + + Returns: + (Obsdata): Obsdata object containing averaged data + """ + + print('Incoherently averaging data, putting phases to zero!') + amp_rec = ehdf.incoh_avg_vis(self, dt=inttime, debias=debias, scan_avg=scan_avg, + return_type='rec', rec_type='vis', err_type=err_type) + arglist, argdict = self.obsdata_args() + arglist[DATPOS] = amp_rec + out = Obsdata(*arglist, **argdict) + + return out + + def add_amp(self, avg_time=0, scan_avg=False, debias=True, err_type='predicted', + return_type='rec', round_s=0.1, snrcut=0.): + """Adds attribute self.amp: aan amplitude table with incoherently averaged amplitudes + + Args: + avg_time (float): incoherent integration time in seconds + scan_avg (bool): if True, average over scans in self.scans instead of intime + debias (bool): if True then apply debiasing + err_type (str): 'predicted' or 'measured' + return_type: data frame ('df') or recarray ('rec') + round_s (float): accuracy of datetime object in seconds + snrcut (float): flag amplitudes with snr lower than this + + """ + + # Get the spacing between datapoints in seconds + if len(set([x[0] for x in list(self.unpack('time'))])) > 1: + tint0 = np.min(np.diff(np.asarray(sorted(list(set( + [x[0] for x in list(self.unpack('time'))])))))) * 3600. + else: + tint0 = 0.0 + + if avg_time <= tint0: + adf = ehdf.make_amp(self, debias=debias, round_s=round_s) + if return_type == 'rec': + adf = ehdf.df_to_rec(adf, 'amp') + print("Updated self.amp: no averaging") + else: + adf = ehdf.incoh_avg_vis(self, dt=avg_time, debias=debias, scan_avg=scan_avg, + return_type=return_type, rec_type='amp', err_type=err_type) + + # snr cut + adf = adf[adf['amp'] / adf['sigma'] > snrcut] + self.amp = adf + print("Updated self.amp: avg_time %f s\n" % avg_time) + + return + + def add_bispec(self, avg_time=0, return_type='rec', count='max', snrcut=0., + err_type='predicted', num_samples=1000, round_s=0.1, uv_min=False): + """Adds attribute self.bispec: bispectra table with bispectra averaged for dt + + Args: + avg_time (float): bispectrum averaging timescale + return_type: data frame ('df') or recarray ('rec') + count (str): If 'min', return minimal set of bispectra, + if 'max' return all bispectra up to reordering + err_type (str): 'predicted' or 'measured' + num_samples: number of bootstrap (re)samples if measuring error + round_s (float): accuracy of datetime object in seconds + snrcut (float): flag bispectra with snr lower than this + + """ + + # Get spacing between datapoints in seconds + if len(set([x[0] for x in list(self.unpack('time'))])) > 1: + tint0 = np.min(np.diff(np.asarray(sorted(list(set( + [x[0] for x in list(self.unpack('time'))])))))) * 3600. + else: + tint0 = 0 + + if avg_time > tint0: + cdf = ehdf.make_bsp_df(self, mode='all', round_s=round_s, count=count, + snrcut=0., uv_min=uv_min) + cdf = ehdf.average_bispectra(cdf, avg_time, return_type=return_type, + num_samples=num_samples, snrcut=snrcut) + else: + cdf = ehdf.make_bsp_df(self, mode='all', round_s=round_s, count=count, + snrcut=snrcut, uv_min=uv_min) + print("Updated self.bispec: no averaging") + if return_type == 'rec': + cdf = ehdf.df_to_rec(cdf, 'bispec') + + self.bispec = cdf + print("Updated self.bispec: avg_time %f s\n" % avg_time) + + return + + def add_cphase(self, avg_time=0, return_type='rec', count='max', snrcut=0., + err_type='predicted', num_samples=1000, round_s=0.1, uv_min=False): + """Adds attribute self.cphase: cphase table averaged for dt + + Args: + avg_time (float): closure phase averaging timescale + return_type: data frame ('df') or recarray ('rec') + count (str): If 'min', return minimal set of phases, + if 'max' return all closure phases up to reordering + err_type (str): 'predicted' or 'measured' + num_samples: number of bootstrap (re)samples if measuring error + round_s (float): accuracy of datetime object in seconds + snrcut (float): flag closure phases with snr lower than this + + """ + + # Get spacing between datapoints in seconds + if len(set([x[0] for x in list(self.unpack('time'))])) > 1: + tint0 = np.min(np.diff(np.asarray(sorted(list(set( + [x[0] for x in list(self.unpack('time'))])))))) * 3600. + else: + tint0 = 0 + + if avg_time > tint0: + cdf = ehdf.make_cphase_df(self, mode='all', round_s=round_s, count=count, + snrcut=0., uv_min=uv_min) + cdf = ehdf.average_cphases(cdf, avg_time, return_type=return_type, err_type=err_type, + num_samples=num_samples, snrcut=snrcut) + else: + cdf = ehdf.make_cphase_df(self, mode='all', round_s=round_s, count=count, + snrcut=snrcut, uv_min=uv_min) + if return_type == 'rec': + cdf = ehdf.df_to_rec(cdf, 'cphase') + print("Updated self.cphase: no averaging") + + self.cphase = cdf + print("updated self.cphase: avg_time %f s\n" % avg_time) + + return + + def add_cphase_diag(self, avg_time=0, return_type='rec', vtype='vis', count='min', snrcut=0., + err_type='predicted', num_samples=1000, round_s=0.1, uv_min=False): + """Adds attribute self.cphase_diag: cphase_diag table averaged for dt + + Args: + avg_time (float): closure phase averaging timescale + return_type: data frame ('df') or recarray ('rec') + vtype (str): Visibility type (e.g., 'vis', 'llvis', 'rrvis', etc.) + count (str): If 'min', return minimal set of phases, + If 'max' return all closure phases up to reordering + err_type (str): 'predicted' or 'measured' + num_samples: number of bootstrap (re)samples if measuring error + round_s (float): accuracy of datetime object in seconds + snrcut (float): flag closure phases with snr lower than this + + """ + + # Get spacing between datapoints in seconds + if len(set([x[0] for x in list(self.unpack('time'))])) > 1: + tint0 = np.min(np.diff(np.asarray(sorted(list(set( + [x[0] for x in list(self.unpack('time'))])))))) + tint0 *= 3600 + else: + tint0 = 0 + + # Dom TODO: implement averaging during diagonal closure phase creation + if avg_time > tint0: + print("Averaging while creating diagonal closure phases is not yet implemented!") + print("Proceeding for now without averaging.") + cdf = ehdf.make_cphase_diag_df(self, vtype=vtype, round_s=round_s, + count=count, snrcut=snrcut, uv_min=uv_min) + else: + cdf = ehdf.make_cphase_diag_df(self, vtype=vtype, round_s=round_s, + count=count, snrcut=snrcut, uv_min=uv_min) + if return_type == 'rec': + cdf = ehdf.df_to_rec(cdf, 'cphase_diag') + print("Updated self.cphase_diag: no averaging") + + self.cphase_diag = cdf + print("updated self.cphase_diag: avg_time %f s\n" % avg_time) + + return + + def add_camp(self, avg_time=0, return_type='rec', ctype='camp', + count='max', debias=True, snrcut=0., + err_type='predicted', num_samples=1000, round_s=0.1): + """Adds attribute self.camp or self.logcamp: closure amplitudes table + + Args: + avg_time (float): closure amplitude averaging timescale + return_type: data frame ('df') or recarray ('rec') + ctype (str): The closure amplitude type ('camp' or 'logcamp') + debias (bool): If True, debias the closure amplitude + count (str): If 'min', return minimal set of amplitudes, + if 'max' return all closure amplitudes up to inverses + err_type (str): 'predicted' or 'measured' + num_samples: number of bootstrap (re)samples if measuring error + round_s (float): accuracy of datetime object in seconds + snrcut (float): flag closure amplitudes with snr lower than this + """ + + # Get spacing between datapoints in seconds + if len(set([x[0] for x in list(self.unpack('time'))])) > 1: + tint0 = np.min(np.diff(np.asarray(sorted(list(set( + [x[0] for x in list(self.unpack('time'))])))))) + tint0 *= 3600 + else: + tint0 = 0 + + if avg_time > tint0: + foo = self.avg_incoherent(avg_time, debias=debias, err_type=err_type) + else: + foo = self + cdf = ehdf.make_camp_df(foo, ctype=ctype, debias=False, + count=count, round_s=round_s, snrcut=snrcut) + + if ctype == 'logcamp': + print("updated self.lcamp: no averaging") + elif ctype == 'camp': + print("updated self.camp: no averaging") + if return_type == 'rec': + cdf = ehdf.df_to_rec(cdf, 'camp') + + if ctype == 'logcamp': + self.logcamp = cdf + print("updated self.logcamp: avg_time %f s\n" % avg_time) + elif ctype == 'camp': + self.camp = cdf + print("updated self.camp: avg_time %f s\n" % avg_time) + + return + + def add_logcamp(self, avg_time=0, return_type='rec', ctype='camp', + count='max', debias=True, snrcut=0., + err_type='predicted', num_samples=1000, round_s=0.1): + """Adds attribute self.logcamp: closure amplitudes table + + Args: + avg_time (float): closure amplitude averaging timescale + return_type: data frame ('df') or recarray ('rec') + ctype (str): The closure amplitude type ('camp' or 'logcamp') + debias (bool): If True, debias the closure amplitude + count (str): If 'min', return minimal set of amplitudes, + if 'max' return all closure amplitudes up to inverses + err_type (str): 'predicted' or 'measured' + num_samples: number of bootstrap (re)samples if measuring error + round_s (float): accuracy of datetime object in seconds + snrcut (float): flag closure amplitudes with snr lower than this + + """ + + self.add_camp(return_type=return_type, ctype='logcamp', + count=count, debias=debias, snrcut=snrcut, + avg_time=avg_time, err_type=err_type, + num_samples=num_samples, round_s=round_s) + + return + + def add_logcamp_diag(self, avg_time=0, return_type='rec', count='min', snrcut=0., + debias=True, err_type='predicted', num_samples=1000, round_s=0.1): + """Adds attribute self.logcamp_diag: logcamp_diag table averaged for dt + + Args: + avg_time (float): diagonal log closure amplitude averaging timescale + return_type: data frame ('df') or recarray ('rec') + debias (bool): If True, debias the diagonal log closure amplitude + count (str): If 'min', return minimal set of amplitudes, + If 'max' return all diagonal log closure amplitudes up to inverses + err_type (str): 'predicted' or 'measured' + num_samples: number of bootstrap (re)samples if measuring error + round_s (float): accuracy of datetime object in seconds + snrcut (float): flag diagonal log closure amplitudes with snr lower than this + + """ + + # Get spacing between datapoints in seconds + if len(set([x[0] for x in list(self.unpack('time'))])) > 1: + tint0 = np.min(np.diff(np.asarray(sorted(list(set( + [x[0] for x in list(self.unpack('time'))])))))) + tint0 *= 3600 + else: + tint0 = 0 + + if avg_time > tint0: + foo = self.avg_incoherent(avg_time, debias=debias, err_type=err_type) + cdf = ehdf.make_logcamp_diag_df(foo, debias='False', count=count, + round_s=round_s, snrcut=snrcut) + else: + foo = self + cdf = ehdf.make_logcamp_diag_df(foo, debias=debias, count=count, + round_s=round_s, snrcut=snrcut) + + if return_type == 'rec': + cdf = ehdf.df_to_rec(cdf, 'logcamp_diag') + + self.logcamp_diag = cdf + print("updated self.logcamp_diag: avg_time %f s\n" % avg_time) + + return + + def add_all(self, avg_time=0, return_type='rec', + count='max', debias=True, snrcut=0., + err_type='predicted', num_samples=1000, round_s=0.1): + """Adds tables of all all averaged derived quantities + self.amp,self.bispec,self.cphase,self.camp,self.logcamp + + Args: + avg_time (float): closure amplitude averaging timescale + return_type: data frame ('df') or recarray ('rec') + debias (bool): If True, debias the closure amplitude + count (str): If 'min', return minimal set of closure quantities, + if 'max' return all closure quantities + err_type (str): 'predicted' or 'measured' + num_samples: number of bootstrap (re)samples if measuring error + round_s (float): accuracy of datetime object in seconds + snrcut (float): flag closure amplitudes with snr lower than this + + """ + + self.add_amp(return_type=return_type, avg_time=avg_time, debias=debias, err_type=err_type) + self.add_bispec(return_type=return_type, count=count, + avg_time=avg_time, snrcut=snrcut, err_type=err_type, + num_samples=num_samples, round_s=round_s) + self.add_cphase(return_type=return_type, count=count, + avg_time=avg_time, snrcut=snrcut, err_type=err_type, + num_samples=num_samples, round_s=round_s) + self.add_cphase_diag(return_type=return_type, count='min', + avg_time=avg_time, snrcut=snrcut, err_type=err_type, + num_samples=num_samples, round_s=round_s) + self.add_camp(return_type=return_type, ctype='camp', + count=count, debias=debias, snrcut=snrcut, + avg_time=avg_time, err_type=err_type, + num_samples=num_samples, round_s=round_s) + self.add_camp(return_type=return_type, ctype='logcamp', + count=count, debias=debias, snrcut=snrcut, + avg_time=avg_time, err_type=err_type, + num_samples=num_samples, round_s=round_s) + self.add_logcamp_diag(return_type=return_type, count='min', + debias=debias, avg_time=avg_time, + snrcut=snrcut, err_type=err_type, + num_samples=num_samples, round_s=round_s) + + return + + def add_scans(self, info='self', filepath='', dt=0.0165, margin=0.0001): + """Compute scans and add self.scans to Obsdata object. + + Args: + info (str): 'self' to infer from data, 'txt' for text file, + 'vex' for vex schedule file + filepath (str): path to txt/vex file with scans info + dt (float): minimal time interval between scans in hours + margin (float): padding scans by that time margin in hours + + """ + + # infer scans directly from data + if info == 'self': + times_uni = np.asarray(sorted(list(set(self.data['time'])))) + scans = np.zeros_like(times_uni) + scan_id = 0 + for cou in range(len(times_uni) - 1): + scans[cou] = scan_id + if (times_uni[cou + 1] - times_uni[cou] > dt): + scan_id += 1 + scans[-1] = scan_id + scanlist = np.asarray([np.asarray([ + np.min(times_uni[scans == cou]) - margin, + np.max(times_uni[scans == cou]) + margin]) + for cou in range(int(scans[-1]) + 1)]) + + # read in scans from a text file + elif info == 'txt': + scanlist = np.loadtxt(filepath) + + # read in scans from a vex file + elif info == 'vex': + vex0 = ehtim.vex.Vex(filepath) + t_min = [vex0.sched[x]['start_hr'] for x in range(len(vex0.sched))] + duration = [] + for x in range(len(vex0.sched)): + duration_foo = max([vex0.sched[x]['scan'][y]['scan_sec'] + for y in range(len(vex0.sched[x]['scan']))]) + duration.append(duration_foo) + t_max = [tmin + dur / 3600. for (tmin, dur) in zip(t_min, duration)] + scanlist = np.array([[tmin, tmax] for (tmin, tmax) in zip(t_min, t_max)]) + + else: + print("Parameter 'info' can only assume values 'self', 'txt' or 'vex'! ") + scanlist = None + + self.scans = scanlist + + return + + def cleanbeam(self, npix, fov, pulse=ehc.PULSE_DEFAULT): + """Make an image of the observation clean beam. + + Args: + npix (int): The pixel size of the square output image. + fov (float): The field of view of the square output image in radians. + pulse (function): The function convolved with the pixel values for continuous image. + + Returns: + (Image): an Image object with the clean beam. + """ + + im = ehtim.image.make_square(self, npix, fov, pulse=pulse) + beamparams = self.fit_beam() + im = im.add_gauss(1.0, beamparams) + + return im + + def fit_beam(self, weighting='uniform', units='rad'): + """Fit a Gaussian to the dirty beam and return the parameters (fwhm_maj, fwhm_min, theta). + + Args: + weighting (str): 'uniform' or 'natural'. + units (string): 'rad' returns values in radians, + 'natural' returns FWHMs in uas and theta in degrees + + Returns: + (tuple): a tuple (fwhm_maj, fwhm_min, theta) of the dirty beam parameters in radians. + """ + + # Define the fit function that compares the quadratic expansion of the dirty image + # with the quadratic expansion of an elliptical gaussian + def fit_chisq(beamparams, db_coeff): + + (fwhm_maj2, fwhm_min2, theta) = beamparams + a = 4 * np.log(2) * (np.cos(theta)**2 / fwhm_min2 + np.sin(theta)**2 / fwhm_maj2) + b = 4 * np.log(2) * (np.cos(theta)**2 / fwhm_maj2 + np.sin(theta)**2 / fwhm_min2) + c = 8 * np.log(2) * np.cos(theta) * np.sin(theta) * (1.0 / fwhm_maj2 - 1.0 / fwhm_min2) + gauss_coeff = np.array((a, b, c)) + + chisq = np.sum((np.array(db_coeff) - gauss_coeff)**2) + + return chisq + + # These are the coefficients (a,b,c) of a quadratic expansion of the dirty beam + # For a point (x,y) in the image plane, the dirty beam expansion is 1-ax^2-by^2-cxy + u = self.unpack('u')['u'] + v = self.unpack('v')['v'] + sigma = self.unpack('sigma')['sigma'] + + weights = np.ones(u.shape) + if weighting == 'natural': + weights = 1. / sigma**2 + + abc = np.array([np.sum(weights * u**2), + np.sum(weights * v**2), + 2 * np.sum(weights * u * v)]) + abc *= (2. * np.pi**2 / np.sum(weights)) + abc *= 1e-20 # Decrease size of coefficients + + # Fit the beam + guess = [(50)**2, (50)**2, 0.0] + params = opt.minimize(fit_chisq, guess, args=(abc,), method='Powell') + + # Return parameters, adjusting fwhm_maj and fwhm_min if necessary + if params.x[0] > params.x[1]: + fwhm_maj = 1e-10 * np.sqrt(params.x[0]) + fwhm_min = 1e-10 * np.sqrt(params.x[1]) + theta = np.mod(params.x[2], np.pi) + else: + fwhm_maj = 1e-10 * np.sqrt(params.x[1]) + fwhm_min = 1e-10 * np.sqrt(params.x[0]) + theta = np.mod(params.x[2] + np.pi / 2.0, np.pi) + + gparams = np.array((fwhm_maj, fwhm_min, theta)) + + if units == 'natural': + gparams[0] /= ehc.RADPERUAS + gparams[1] /= ehc.RADPERUAS + gparams[2] *= 180. / np.pi + + return gparams + + def dirtybeam(self, npix, fov, pulse=ehc.PULSE_DEFAULT, weighting='uniform'): + """Make an image of the observation dirty beam. + + Args: + npix (int): The pixel size of the square output image. + fov (float): The field of view of the square output image in radians. + pulse (function): The function convolved with the pixel values for continuous image. + weighting (str): 'uniform' or 'natural' + Returns: + (Image): an Image object with the dirty beam. + """ + + pdim = fov / npix + sigma = self.unpack('sigma')['sigma'] + u = self.unpack('u')['u'] + v = self.unpack('v')['v'] + if weighting == 'natural': + weights = 1. / sigma**2 + else: + weights = np.ones(u.shape) + + xlist = np.arange(0, -npix, -1) * pdim + (pdim * npix) / 2.0 - pdim / 2.0 + + # TODO -- use NFFT + # TODO -- different beam weightings + im = np.array([[np.mean(weights * np.cos(-2 * np.pi * (i * u + j * v))) + for i in xlist] + for j in xlist]) + + im = im[0:npix, 0:npix] + im = im / np.sum(im) # Normalize to a total beam power of 1 + + src = self.source + "_DB" + outim = ehtim.image.Image(im, pdim, self.ra, self.dec, + rf=self.rf, source=src, mjd=self.mjd, pulse=pulse) + + return outim + + def dirtyimage(self, npix, fov, pulse=ehc.PULSE_DEFAULT, weighting='uniform'): + """Make the observation dirty image (direct Fourier transform). + + Args: + npix (int): The pixel size of the square output image. + fov (float): The field of view of the square output image in radians. + pulse (function): The function convolved with the pixel values for continuous image. + weighting (str): 'uniform' or 'natural' + Returns: + (Image): an Image object with dirty image. + """ + + pdim = fov / npix + u = self.unpack('u')['u'] + v = self.unpack('v')['v'] + sigma = self.unpack('sigma')['sigma'] + xlist = np.arange(0, -npix, -1) * pdim + (pdim * npix) / 2.0 - pdim / 2.0 + if weighting == 'natural': + weights = 1. / sigma**2 + else: + weights = np.ones(u.shape) + + dim = np.array([[np.mean(weights * np.cos(-2 * np.pi * (i * u + j * v))) + for i in xlist] + for j in xlist]) + normfac = 1. / np.sum(dim) + + for label in ['vis1', 'vis2', 'vis3', 'vis4']: + visname = self.poldict[label] + + vis = self.unpack(visname)[visname] + + # TODO -- use NFFT + # TODO -- different beam weightings + im = np.array([[np.mean(weights * (np.real(vis) * np.cos(-2 * np.pi * (i * u + j * v)) - + np.imag(vis) * np.sin(-2 * np.pi * (i * u + j * v)))) + for i in xlist] + for j in xlist]) + + # Final normalization + im = im * normfac + im = im[0:npix, 0:npix] + + if label == 'vis1': + out = ehtim.image.Image(im, pdim, self.ra, self.dec, polrep=self.polrep, + rf=self.rf, source=self.source, mjd=self.mjd, pulse=pulse) + else: + pol = {ehc.vis_poldict[key]: key for key in ehc.vis_poldict.keys()}[visname] + out.add_pol_image(im, pol) + + return out + + def rescale_zbl(self, totflux, uv_max, debias=True): + """Rescale the short baselines to a new level of total flux. + + Args: + totflux (float): new total flux to rescale to + uv_max (float): maximum baseline length to rescale + debias (bool): Debias amplitudes before computing original total flux from short bls + + Returns: + (Obsdata): An Obsdata object with the inflated noise values. + """ + + # estimate the original total flux + obs_zerobl = self.flag_uvdist(uv_max=uv_max) + obs_zerobl.add_amp(debias=True) + orig_totflux = np.sum(obs_zerobl.amp['amp'] * (1 / obs_zerobl.amp['sigma']**2)) + orig_totflux /= np.sum(1 / obs_zerobl.amp['sigma']**2) + + print('Rescaling zero baseline by ' + str(orig_totflux - totflux) + ' Jy' + + ' to ' + str(totflux) + ' Jy') + + # Rescale short baselines to excise contributions from extended flux + # Note: this does not do the proper thing for fractional polarization) + obs = self.copy() + for j in range(len(obs.data)): + if (obs.data['u'][j]**2 + obs.data['v'][j]**2)**0.5 < uv_max: + obs.data['vis'][j] *= totflux / orig_totflux + obs.data['qvis'][j] *= totflux / orig_totflux + obs.data['uvis'][j] *= totflux / orig_totflux + obs.data['vvis'][j] *= totflux / orig_totflux + obs.data['sigma'][j] *= totflux / orig_totflux + obs.data['qsigma'][j] *= totflux / orig_totflux + obs.data['usigma'][j] *= totflux / orig_totflux + obs.data['vsigma'][j] *= totflux / orig_totflux + + return obs + + def add_leakage_noise(self, Dterm_amp=0.1, min_noise=0.01, debias=False): + """Add estimated systematic noise from leakage at quadrature to thermal noise. + Requires cross-hand visibilities. + !! this operation is not currently tracked and should be applied with extreme caution!! + + Args: + Dterm_amp (float): Estimated magnitude of leakage terms + min_noise (float): Minimum fractional systematic noise to add + debias (bool): Debias amplitudes before computing fractional noise + + Returns: + (Obsdata): An Obsdata object with the inflated noise values. + """ + + # Extract visibility amplitudes + # Switch to Stokes for graceful handling of circular basis products missing RR or LL + amp = self.switch_polrep('stokes').unpack('amp', debias=debias)['amp'] + rlamp = np.nan_to_num(self.switch_polrep('circ').unpack('rlamp', debias=debias)['rlamp']) + lramp = np.nan_to_num(self.switch_polrep('circ').unpack('lramp', debias=debias)['lramp']) + + frac_noise = (Dterm_amp * rlamp / amp)**2 + (Dterm_amp * lramp / amp)**2 + frac_noise = frac_noise * (frac_noise > min_noise) + min_noise * (frac_noise < min_noise) + + out = self.copy() + for sigma in ['sigma1', 'sigma2', 'sigma3', 'sigma4']: + try: + field = self.poldict[sigma] + out.data[field] = (self.data[field]**2 + np.abs(frac_noise * amp)**2)**0.5 + except KeyError: + continue + + return out + + def add_fractional_noise(self, frac_noise, debias=False): + """Add a constant fraction of amplitude at quadrature to thermal noise. + Effectively imposes a maximal signal-to-noise ratio. + !! this operation is not currently tracked and should be applied with extreme caution!! + + Args: + frac_noise (float): The fraction of noise to add. + debias (bool): Whether or not to add frac_noise of debiased amplitudes. + + Returns: + (Obsdata): An Obsdata object with the inflated noise values. + """ + + # Extract visibility amplitudes + # Switch to Stokes for graceful handling of circular basis products missing RR or LL + amp = self.switch_polrep('stokes').unpack('amp', debias=debias)['amp'] + + out = self.copy() + for sigma in ['sigma1', 'sigma2', 'sigma3', 'sigma4']: + try: + field = self.poldict[sigma] + out.data[field] = (self.data[field]**2 + np.abs(frac_noise * amp)**2)**0.5 + except KeyError: + continue + + return out + + def find_amt_fractional_noise(self, im, dtype='vis', target=1.0, debias=False, + maxiter=200, ftol=1e-20, gtol=1e-20): + """Returns the amount of fractional sys error necessary + to make the image have a chisq close to the targeted value (1.0) + """ + + obs = self.copy() + + def objfunc(frac_noise): + obs_tmp = obs.add_fractional_noise(frac_noise, debias=debias) + chisq = obs_tmp.chisq(im, dtype=dtype) + return np.abs(target - chisq) + + optdict = {'maxiter': maxiter, 'ftol': ftol, 'gtol': gtol} + res = opt.minimize(objfunc, 0.0, method='L-BFGS-B', options=optdict) + + return res.x + + def rescale_noise(self, noise_rescale_factor=1.0): + """Rescale the thermal noise on all Stokes parameters by a constant factor. + This is useful for AIPS data, which has a missing factor relating 'weights' to noise. + + Args: + noise_rescale_factor (float): The number to multiple the existing sigmas by. + + Returns: + (Obsdata): An Obsdata object with the rescaled noise values. + """ + + datatable = self.data.copy() + + for d in datatable: + d[-4] = d[-4] * noise_rescale_factor + d[-3] = d[-3] * noise_rescale_factor + d[-2] = d[-2] * noise_rescale_factor + d[-1] = d[-1] * noise_rescale_factor + + arglist, argdict = self.obsdata_args() + arglist[DATPOS] = np.array(datatable) + out = Obsdata(*arglist, **argdict) + + return out + + def estimate_noise_rescale_factor(self, max_diff_sec=0.0, min_num=10, median_snr_cut=0, + count='max', vtype='vis', print_std=False): + """Estimate a constant noise rescaling factor on all baselines, times, and polarizations. + Uses pairwise differences of closure phases relative to the expected scatter. + This is useful for AIPS data, which has a missing factor relating 'weights' to noise. + + Args: + max_diff_sec (float): The maximum difference of adjacent closure phases (in seconds) + If 0, auto-estimates to twice the median scan length. + min_num (int): The minimum number of closure phase differences for a triangle + to be included in the set of estimators. + median_snr_cut (float): Do not include a triangle if its median SNR is below this + count (str): If 'min', use minimal set of phases, + if 'max' use all closure phases up to reordering + vtype (str): Visibility type (e.g., 'vis', 'llvis', 'rrvis', etc.) + print_std (bool): Whether or not to print the std dev. for each closure triangle. + + Returns: + (float): The rescaling factor. + """ + + if max_diff_sec == 0.0: + max_diff_sec = 5 * np.median(self.unpack('tint')['tint']) + print("estimated max_diff_sec: ", max_diff_sec) + + # Now check the noise statistics on all closure phase triangles + c_phases = self.c_phases(vtype=vtype, mode='time', count=count, ang_unit='') + + # First, just determine the set of closure phase triangles + all_triangles = [] + for scan in c_phases: + for cphase in scan: + all_triangles.append((cphase[1], cphase[2], cphase[3])) + std_list = [] + print("Estimating noise rescaling factor from %d triangles...\n" % len(set(all_triangles))) + + # Now determine the differences of adjacent samples on each triangle, + # relative to the expected thermal noise + i_count = 0 + for tri in set(all_triangles): + i_count = i_count + 1 + if print_std: + sys.stdout.write('\rGetting noise for triangles %i/%i ' % + (i_count, len(set(all_triangles)))) + sys.stdout.flush() + all_tri = np.array([[]]) + for scan in c_phases: + for cphase in scan: + if (cphase[1] == tri[0] and cphase[2] == tri[1] and cphase[3] == tri[2] and + not np.isnan(cphase[-2]) and not np.isnan(cphase[-2])): + + all_tri = np.append(all_tri, ((cphase[0], cphase[-2], cphase[-1]))) + + all_tri = all_tri.reshape(int(len(all_tri) / 3), 3) + + # See whether the triangle has sufficient SNR + if np.median(np.abs(all_tri[:, 1] / all_tri[:, 2])) < median_snr_cut: + if print_std: + print(tri, 'median snr too low (%6.4f)' % + np.median(np.abs(all_tri[:, 1] / all_tri[:, 2]))) + continue + + # Now go through and find studentized differences of adjacent points + s_list = np.array([]) + for j in range(len(all_tri) - 1): + if (all_tri[j + 1, 0] - all_tri[j, 0]) * 3600.0 < max_diff_sec: + diff = (all_tri[j + 1, 1] - all_tri[j, 1]) % (2.0 * np.pi) + if diff > np.pi: + diff -= 2.0 * np.pi + s_list = np.append( + s_list, diff / (all_tri[j, 2]**2 + all_tri[j + 1, 2]**2)**0.5) + + if len(s_list) > min_num: + std_list.append(np.std(s_list)) + if print_std: + print(tri, '%6.4f [%d differences]' % (np.std(s_list), len(s_list))) + else: + if print_std and len(all_tri) > 0: + print(tri, '%d cphases found [%d differences < min_num = %d]' % + (len(all_tri), len(s_list), min_num)) + + if len(std_list) == 0: + print("No suitable closure phase differences! Try using a larger max_diff_sec.") + median = 1.0 + else: + median = np.median(std_list) + + return median + + def flag_elev(self, elev_min=0.0, elev_max=90, output='kept'): + """Flag visibilities for which either station is outside a stated elevation range + + Args: + elev_min (float): Minimum elevation (deg) + elev_max (float): Maximum elevation (deg) + output (str): returns 'kept', 'flagged', or 'both' (a dictionary) + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + el_pairs = self.unpack(['el1', 'el2']) + mask = (np.min((el_pairs['el1'], el_pairs['el2']), axis=0) > elev_min) + mask *= (np.max((el_pairs['el1'], el_pairs['el2']), axis=0) < elev_max) + + datatable_kept = self.data.copy() + datatable_flagged = self.data.copy() + + datatable_kept = datatable_kept[mask] + datatable_flagged = datatable_flagged[np.invert(mask)] + print('Flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data))) + + obs_kept = self.copy() + obs_flagged = self.copy() + obs_kept.data = datatable_kept + obs_flagged.data = datatable_flagged + + if output == 'flagged': + return obs_flagged + elif output == 'both': + return {'kept': obs_kept, 'flagged': obs_flagged} + else: + return obs_kept + + def flag_large_fractional_pol(self, max_fractional_pol=1.0, output='kept'): + """Flag visibilities for which the fractional polarization is above a specified threshold + + Args: + max_fractional_pol (float): Maximum fractional polarization + output (str): returns 'kept', 'flagged', or 'both' (a dictionary) + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + m = np.nan_to_num(self.unpack(['mamp'])['mamp']) + mask = m < max_fractional_pol + + datatable_kept = self.data.copy() + datatable_flagged = self.data.copy() + + datatable_kept = datatable_kept[mask] + datatable_flagged = datatable_flagged[np.invert(mask)] + print('Flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data))) + + obs_kept = self.copy() + obs_flagged = self.copy() + obs_kept.data = datatable_kept + obs_flagged.data = datatable_flagged + + if output == 'flagged': + return obs_flagged + elif output == 'both': + return {'kept': obs_kept, 'flagged': obs_flagged} + else: + return obs_kept + + def flag_uvdist(self, uv_min=0.0, uv_max=1e12, output='kept'): + """Flag data points outside a given uv range + + Args: + uv_min (float): remove points with uvdist less than this + uv_max (float): remove points with uvdist greater than this + output (str): returns 'kept', 'flagged', or 'both' (a dictionary) + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + uvdist_list = self.unpack('uvdist')['uvdist'] + mask = np.array([uv_min <= uvdist_list[j] <= uv_max for j in range(len(uvdist_list))]) + datatable_kept = self.data.copy() + datatable_flagged = self.data.copy() + + datatable_kept = datatable_kept[mask] + datatable_flagged = datatable_flagged[np.invert(mask)] + print('U-V flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data))) + + obs_kept = self.copy() + obs_flagged = self.copy() + obs_kept.data = datatable_kept + obs_flagged.data = datatable_flagged + + if output == 'flagged': + return obs_flagged + elif output == 'both': + return {'kept': obs_kept, 'flagged': obs_flagged} + else: + return obs_kept + + def flag_sites(self, sites, output='kept'): + """Flag data points that include the specified sites + + Args: + sites (list): list of sites to remove from the data + output (str): returns 'kept', 'flagged', or 'both' (a dictionary) + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + # This will remove all visibilities that include any of the specified sites + + t1_list = self.unpack('t1')['t1'] + t2_list = self.unpack('t2')['t2'] + mask = np.array([t1_list[j] not in sites and t2_list[j] not in sites + for j in range(len(t1_list))]) + + datatable_kept = self.data.copy() + datatable_flagged = self.data.copy() + + datatable_kept = datatable_kept[mask] + datatable_flagged = datatable_flagged[np.invert(mask)] + print('Flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data))) + + obs_kept = self.copy() + obs_flagged = self.copy() + obs_kept.data = datatable_kept + obs_flagged.data = datatable_flagged + + if output == 'flagged': # return only the points flagged as anomalous + return obs_flagged + elif output == 'both': + return {'kept': obs_kept, 'flagged': obs_flagged} + else: + return obs_kept + + def flag_bl(self, sites, output='kept'): + """Flag data points that include the specified baseline + + Args: + sites (list): baseline to remove from the data + output (str): returns 'kept', 'flagged', or 'both' (a dictionary) + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + # This will remove all visibilities that include any of the specified baseline + obs_out = self.copy() + t1_list = obs_out.unpack('t1')['t1'] + t2_list = obs_out.unpack('t2')['t2'] + mask = np.array([not(t1_list[j] in sites and t2_list[j] in sites) + for j in range(len(t1_list))]) + + datatable_kept = self.data.copy() + datatable_flagged = self.data.copy() + + datatable_kept = datatable_kept[mask] + datatable_flagged = datatable_flagged[np.invert(mask)] + print('Flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data))) + + obs_kept = self.copy() + obs_flagged = self.copy() + obs_kept.data = datatable_kept + obs_flagged.data = datatable_flagged + + if output == 'flagged': # return only the points flagged as anomalous + return obs_flagged + elif output == 'both': + return {'kept': obs_kept, 'flagged': obs_flagged} + else: + return obs_kept + + def flag_low_snr(self, snr_cut=3, output='kept'): + """Flag low snr data points + + Args: + snr_cut (float): remove points with snr lower than this + output (str): returns 'kept', 'flagged', or 'both' (a dictionary) + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + mask = self.unpack('snr')['snr'] > snr_cut + + datatable_kept = self.data.copy() + datatable_flagged = self.data.copy() + + datatable_kept = datatable_kept[mask] + datatable_flagged = datatable_flagged[np.invert(mask)] + print('snr flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data))) + + obs_kept = self.copy() + obs_flagged = self.copy() + obs_kept.data = datatable_kept + obs_flagged.data = datatable_flagged + + if output == 'flagged': # return only the points flagged as anomalous + return obs_flagged + elif output == 'both': + return {'kept': obs_kept, 'flagged': obs_flagged} + else: + return obs_kept + + def flag_high_sigma(self, sigma_cut=.005, sigma_type='sigma', output='kept'): + """Flag high sigma (thermal noise on Stoke I) data points + + Args: + sigma_cut (float): remove points with sigma higher than this + sigma_type (str): sigma type (sigma, rrsigma, llsigma, etc.) + output (str): returns 'kept', 'flagged', or 'both' (a dictionary) + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + mask = self.unpack(sigma_type)[sigma_type] < sigma_cut + + datatable_kept = self.data.copy() + datatable_flagged = self.data.copy() + + datatable_kept = datatable_kept[mask] + datatable_flagged = datatable_flagged[np.invert(mask)] + print('sigma flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data))) + + obs_kept = self.copy() + obs_flagged = self.copy() + obs_kept.data = datatable_kept + obs_flagged.data = datatable_flagged + + if output == 'flagged': # return only the points flagged as anomalous + return obs_flagged + elif output == 'both': + return {'kept': obs_kept, 'flagged': obs_flagged} + else: + return obs_kept + + def flag_UT_range(self, UT_start_hour=0., UT_stop_hour=0., + flag_type='all', flag_what='', output='kept'): + """Flag data points within a certain UT range + + Args: + UT_start_hour (float): start of time window + UT_stop_hour (float): end of time window + flag_type (str): 'all', 'baseline', or 'station' + flag_what (str): baseline or station to flag + output (str): returns 'kept', 'flagged', or 'both' (a dictionary) + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + # This drops (or only keeps) points within a specified UT range + + UT_mask = self.unpack('time')['time'] <= UT_start_hour + UT_mask = UT_mask + (self.unpack('time')['time'] >= UT_stop_hour) + if flag_type != 'all': + t1_list = self.unpack('t1')['t1'] + t2_list = self.unpack('t2')['t2'] + if flag_type == 'station': + station = flag_what + what_mask = np.array([not (t1_list[j] == station or t2_list[j] == station) + for j in range(len(t1_list))]) + elif flag_type == 'baseline': + station1 = flag_what.split('-')[0] + station2 = flag_what.split('-')[1] + stations = [station1, station2] + what_mask = np.array([not ((t1_list[j] in stations) and (t2_list[j] in stations)) + for j in range(len(t1_list))]) + else: + what_mask = np.array([False for j in range(len(UT_mask))]) + mask = UT_mask | what_mask + + datatable_kept = self.data.copy() + datatable_flagged = self.data.copy() + + datatable_kept = datatable_kept[mask] + datatable_flagged = datatable_flagged[np.invert(mask)] + print('time flagged %d/%d visibilities' % (len(datatable_flagged), len(self.data))) + + obs_kept = self.copy() + obs_flagged = self.copy() + obs_kept.data = datatable_kept + obs_flagged.data = datatable_flagged + + if output == 'flagged': # return only the points flagged as anomalous + return obs_flagged + elif output == 'both': + return {'kept': obs_kept, 'flagged': obs_flagged} + else: + return obs_kept + + def flags_from_file(self, flagfile, flag_type='station'): + """Flagging data based on csv file + + Args: + flagfile (str): path to csv file with mjds of flagging start / stop time, + and optionally baseline / stations + flag_type (str): 'all', 'baseline', or 'station' + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + df = pd.read_csv(flagfile) + mjd_start = list(df['mjd_start']) + mjd_stop = list(df['mjd_stop']) + if flag_type == 'station': + whatL = list(df['station']) + elif flag_type == 'baseline': + whatL = list(df['baseline']) + elif flag_type == 'all': + whatL = ['' for cou in range(len(mjd_start))] + obs = self.copy() + for cou in range(len(mjd_start)): + what = whatL[cou] + starth = (mjd_start[cou] % 1) * 24. + stoph = (mjd_stop[cou] % 1) * 24. + obs = obs.flag_UT_range(UT_start_hour=starth, UT_stop_hour=stoph, + flag_type=flag_type, flag_what=what, output='kept') + + return obs + + def flag_anomalous(self, field='snr', max_diff_seconds=100, robust_nsigma_cut=5, output='kept'): + """Flag anomalous data points + + Args: + field (str): The quantity to test for + max_diff_seconds (float): The moving window size for testing outliers + robust_nsigma_cut (float): Outliers further than this from the mean are removed + output (str): returns 'kept', 'flagged', or 'both' (a dictionary) + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + stats = dict() + for t1 in set(self.data['t1']): + for t2 in set(self.data['t2']): + vals = self.unpack_bl(t1, t2, field) + if len(vals) > 0: + # nans will all be dropped, which can be problematic for polarimetric values + vals[field] = np.nan_to_num(vals[field]) + for j in range(len(vals)): + near_vals_mask = np.abs(vals['time'] - vals['time'] + [j]) < max_diff_seconds / 3600.0 + fields = vals[field][near_vals_mask] + + # Here, we use median absolute deviation as a robust proxy for standard + # deviation + dfields = np.median(np.abs(fields - np.median(fields))) + # Avoid problems when the MAD is zero (e.g., a single sample) + if dfields == 0.0: + dfields = 1.0 + stat = np.abs(vals[field][j] - np.median(fields)) / dfields + stats[(vals['time'][j][0], tuple(sorted((t1, t2))))] = stat + + mask = np.array([stats[(rec[0], tuple(sorted((rec[2], rec[3]))))][0] < robust_nsigma_cut + for rec in self.data]) + + datatable_kept = self.data.copy() + datatable_flagged = self.data.copy() + + datatable_kept = datatable_kept[mask] + datatable_flagged = datatable_flagged[np.invert(mask)] + print('anomalous %s flagged %d/%d visibilities' % + (field, len(datatable_flagged), len(self.data))) + + # Make new observations with all data first to avoid problems with empty arrays + obs_kept = self.copy() + obs_flagged = self.copy() + obs_kept.data = datatable_kept + obs_flagged.data = datatable_flagged + + if output == 'flagged': # return only the points flagged as anomalous + return obs_flagged + elif output == 'both': + return {'kept': obs_kept, 'flagged': obs_flagged} + else: + return obs_kept + + def filter_subscan_dropouts(self, perc=0, return_type='rec'): + """Filtration to drop data and ensure that we only average parts with same timestamp. + Potentially this could reduce risk of non-closing errors. + + Args: + perc (float): drop baseline from scan if it has less than this fraction + of median baseline observation time during the scan + return_type (str): data frame ('df') or recarray ('rec') + + Returns: + (Obsdata): a observation object with flagged data points removed + """ + + if not isinstance(self.scans, np.ndarray): + print('List of scans in ndarray format required! Add it with add_scans') + + else: + # make df and add scan_id to data + df = ehdf.make_df(self) + tot_points = np.shape(df)[0] + bins, labs = ehdf.get_bins_labels(self.scans) + df['scan_id'] = list(pd.cut(df.time, bins, labels=labs)) + + # first flag baselines that are working for short part of scan + df['count_samples'] = 1 + hm1 = df.groupby(['scan_id', 'baseline', 'polarization']) + hm1 = hm1.agg({'count_samples': np.sum}).reset_index() + hm1['count_baselines_before'] = 1 + hm2 = hm1.groupby(['scan_id', 'polarization']) + hm2 = hm2.agg({'count_samples': lambda x: perc * np.median(x), + 'count_baselines_before': np.sum}).reset_index() + + # dictionary with minimum acceptable number of samples per scan + dict_elem_in_scan = dict(zip(hm2.scan_id, hm2.count_samples)) + + # list of acceptable scans and baselines + hm1 = hm1[list(map(lambda x: x[1] >= dict_elem_in_scan[x[0]], + list(zip(hm1.scan_id, hm1.count_samples))))] + list_good_scans_baselines = list(zip(hm1.scan_id, hm1.baseline)) + + # filter out data + df_filtered = df[list(map(lambda x: x in list_good_scans_baselines, + list(zip(df.scan_id, df.baseline))))] + + # how many baselines present during scan? + df_filtered['count_samples'] = 1 + hm3 = df_filtered.groupby(['scan_id', 'baseline', 'polarization']) + hm3 = hm3.agg({'count_samples': np.sum}).reset_index() + hm3['count_baselines_after'] = 1 + hm4 = hm3.groupby(['scan_id', 'polarization']) + hm4 = hm4.agg({'count_baselines_after': np.sum}).reset_index() + dict_how_many_baselines = dict(zip(hm4.scan_id, hm4.count_baselines_after)) + + # how many baselines present during each time? + df_filtered['count_baselines_per_time'] = 1 + hm5 = df_filtered.groupby(['datetime', 'scan_id', 'polarization']) + hm5 = hm5.agg({'count_baselines_per_time': np.sum}).reset_index() + dict_datetime_num_baselines = dict(zip(hm5.datetime, hm5.count_baselines_per_time)) + + # only keep times when all baselines available + df_filtered2 = df_filtered[list(map(lambda x: dict_datetime_num_baselines[x[1]] == dict_how_many_baselines[x[0]], list( + zip(df_filtered.scan_id, df_filtered.datetime))))] + + remaining_points = np.shape(df_filtered2)[0] + print('Flagged out {} of {} datapoints'.format( + tot_points - remaining_points, tot_points)) + if return_type == 'rec': + out_vis = ehdf.df_to_rec(df_filtered2, 'vis') + + arglist, argdict = self.obsdata_args() + arglist[DATPOS] = out_vis + out = Obsdata(*arglist, **argdict) + + return out + + def reverse_taper(self, fwhm): + """Reverse taper the observation with a circular Gaussian kernel + + Args: + fwhm (float): real space fwhm size of convolution kernel in radian + + Returns: + (Obsdata): a new reverse-tapered observation object + """ + datatable = self.data.copy() + vis1 = datatable[self.poldict['vis1']] + vis2 = datatable[self.poldict['vis2']] + vis3 = datatable[self.poldict['vis3']] + vis4 = datatable[self.poldict['vis4']] + sigma1 = datatable[self.poldict['sigma1']] + sigma2 = datatable[self.poldict['sigma2']] + sigma3 = datatable[self.poldict['sigma3']] + sigma4 = datatable[self.poldict['sigma4']] + u = datatable['u'] + v = datatable['v'] + + fwhm_sigma = fwhm / (2 * np.sqrt(2 * np.log(2))) + ker = np.exp(-2 * np.pi**2 * fwhm_sigma**2 * (u**2 + v**2)) + + datatable[self.poldict['vis1']] = vis1 / ker + datatable[self.poldict['vis2']] = vis2 / ker + datatable[self.poldict['vis3']] = vis3 / ker + datatable[self.poldict['vis4']] = vis4 / ker + datatable[self.poldict['sigma1']] = sigma1 / ker + datatable[self.poldict['sigma2']] = sigma2 / ker + datatable[self.poldict['sigma3']] = sigma3 / ker + datatable[self.poldict['sigma4']] = sigma4 / ker + + arglist, argdict = self.obsdata_args() + arglist[DATPOS] = datatable + obstaper = Obsdata(*arglist, **argdict) + + return obstaper + + def taper(self, fwhm): + """Taper the observation with a circular Gaussian kernel + + Args: + fwhm (float): real space fwhm size of convolution kernel in radian + + Returns: + (Obsdata): a new tapered observation object + """ + datatable = self.data.copy() + + vis1 = datatable[self.poldict['vis1']] + vis2 = datatable[self.poldict['vis2']] + vis3 = datatable[self.poldict['vis3']] + vis4 = datatable[self.poldict['vis4']] + sigma1 = datatable[self.poldict['sigma1']] + sigma2 = datatable[self.poldict['sigma2']] + sigma3 = datatable[self.poldict['sigma3']] + sigma4 = datatable[self.poldict['sigma4']] + u = datatable['u'] + v = datatable['v'] + + fwhm_sigma = fwhm / (2 * np.sqrt(2 * np.log(2))) + ker = np.exp(-2 * np.pi**2 * fwhm_sigma**2 * (u**2 + v**2)) + + datatable[self.poldict['vis1']] = vis1 * ker + datatable[self.poldict['vis2']] = vis2 * ker + datatable[self.poldict['vis3']] = vis3 * ker + datatable[self.poldict['vis4']] = vis4 * ker + datatable[self.poldict['sigma1']] = sigma1 * ker + datatable[self.poldict['sigma2']] = sigma2 * ker + datatable[self.poldict['sigma3']] = sigma3 * ker + datatable[self.poldict['sigma4']] = sigma4 * ker + + arglist, argdict = self.obsdata_args() + arglist[DATPOS] = datatable + obstaper = Obsdata(*arglist, **argdict) + + return obstaper + + def deblur(self): + """Deblur the observation obs by dividing by the Sgr A* scattering kernel. + + Args: + + Returns: + (Obsdata): a new deblurred observation object. + """ + + # make a copy of observation data + datatable = self.data.copy() + + vis1 = datatable[self.poldict['vis1']] + vis2 = datatable[self.poldict['vis2']] + vis3 = datatable[self.poldict['vis3']] + vis4 = datatable[self.poldict['vis4']] + sigma1 = datatable[self.poldict['sigma1']] + sigma2 = datatable[self.poldict['sigma2']] + sigma3 = datatable[self.poldict['sigma3']] + sigma4 = datatable[self.poldict['sigma4']] + u = datatable['u'] + v = datatable['v'] + + # divide visibilities by the scattering kernel + for i in range(len(vis1)): + ker = obsh.sgra_kernel_uv(self.rf, u[i], v[i]) + vis1[i] = vis1[i] / ker + vis2[i] = vis2[i] / ker + vis2[i] = vis3[i] / ker + vis4[i] = vis4[i] / ker + sigma1[i] = sigma1[i] / ker + sigma2[i] = sigma2[i] / ker + sigma3[i] = sigma3[i] / ker + sigma4[i] = sigma4[i] / ker + + datatable[self.poldict['vis1']] = vis1 + datatable[self.poldict['vis2']] = vis2 + datatable[self.poldict['vis3']] = vis3 + datatable[self.poldict['vis4']] = vis4 + datatable[self.poldict['sigma1']] = sigma1 + datatable[self.poldict['sigma2']] = sigma2 + datatable[self.poldict['sigma3']] = sigma3 + datatable[self.poldict['sigma4']] = sigma4 + + arglist, argdict = self.obsdata_args() + arglist[DATPOS] = datatable + obsdeblur = Obsdata(*arglist, **argdict) + + return obsdeblur + + def reweight(self, uv_radius, weightdist=1.0): + """Reweight the sigmas based on the local density of uv points + + Args: + uv_radius (float): radius in uv-plane to look for nearby points + weightdist (float): ?? + + Returns: + (Obsdata): a new reweighted observation object. + """ + + obs_new = self.copy() + npts = len(obs_new.data) + + uvpoints = np.vstack((obs_new.data['u'], obs_new.data['v'])).transpose() + uvpoints_tree1 = spatial.cKDTree(uvpoints) + uvpoints_tree2 = spatial.cKDTree(-uvpoints) + + for i in range(npts): + matches1 = uvpoints_tree1.query_ball_point(uvpoints[i, :], uv_radius) + matches2 = uvpoints_tree2.query_ball_point(uvpoints[i, :], uv_radius) + nmatches = len(matches1) + len(matches2) + + for sigma in ['sigma', 'qsigma', 'usigma', 'vsigma']: + obs_new.data[sigma][i] = np.sqrt(nmatches) + + scale = np.mean(self.data['sigma']) / np.mean(obs_new.data['sigma']) + for sigma in ['sigma', 'qsigma', 'usigma', 'vsigma']: + obs_new.data[sigma] *= scale * weightdist + + if weightdist < 1.0: + for i in range(npts): + for sigma in ['sigma', 'qsigma', 'usigma', 'vsigma']: + obs_new.data[sigma][i] += (1 - weightdist) * self.data[sigma][i] + + return obs_new + + def fit_gauss(self, flux=1.0, fittype='amp', paramguess=( + 100 * ehc.RADPERUAS, 100 * ehc.RADPERUAS, 0.)): + """Fit a gaussian to either Stokes I complex visibilities or Stokes I visibility amplitudes. + + Args: + flux (float): total flux in the fitted gaussian + fitttype (str): "amp" to fit to visibilty amplitudes + paramguess (tuble): initial guess of fit Gaussian (fwhm_maj, fwhm_min, theta) + + Returns: + (tuple) : (fwhm_maj, fwhm_min, theta) of the fit Gaussian parameters in radians. + """ + + # TODO this fit doesn't work very well!! + vis = self.data['vis'] + u = self.data['u'] + v = self.data['v'] + sig = self.data['sigma'] + + # error function + if fittype == 'amp': + def errfunc(p): + vismodel = obsh.gauss_uv(u, v, flux, p, x=0., y=0.) + err = np.sum((np.abs(vis) - np.abs(vismodel))**2 / sig**2) + return err + else: + def errfunc(p): + vismodel = obsh.gauss_uv(u, v, flux, p, x=0., y=0.) + err = np.sum(np.abs(vis - vismodel)**2 / sig**2) + return err + + optdict = {'maxiter': 5000} # minimizer params + res = opt.minimize(errfunc, paramguess, method='Powell', options=optdict) + gparams = res.x + + return gparams + + def ClosureInvariants(self): + """ + Calculates copolar closure invariants for visibilities assuming an n element + interferometer array using method 1. + + Nithyanandan, T., Rajaram, N., Joseph, S. 2022 “Invariants in copolar + interferometry: An Abelian gauge theory”, PHYS. REV. D 105, 043019. + https://doi.org/10.1103/PhysRevD.105.043019 + + Args: + vis (np.ndarray): visibility data sampled by the interferometer array + n (int): number of antenna as part of the interferometer array + + Returns: + ci (np.ndarray): closure invariants + """ + tlist = self.tlist() + out_ci = np.array([]) + for tdata in tlist: + num_antenna = len(np.unique(tdata['t1'])) + 1 + if num_antenna < 3: + continue + vis = tdata['vis'].reshape(1,1,-1) + _, btriads = self.Triads(num_antenna) + C_oa = vis[:, :, btriads[:, 0]] + C_ab = vis[:, :, btriads[:, 1]] + C_bo = np.conjugate(vis[:, :, btriads[:, 2]]) + A_oab = C_oa / np.conjugate(C_ab) * C_bo + A_oab = np.dstack((A_oab.real, A_oab.imag)) + A_max = np.nanmax(np.abs(A_oab), axis=-1, keepdims=True) + ci = A_oab / A_max + ci = ci.reshape(-1) + out_ci = np.concatenate([out_ci, ci], axis=0) + + return out_ci + + def Triads(self, n:int): + """ + Generates arrays of antenna and baseline indicies that form triangular + loops pivoted around the 0th antenna. Used to calculate closure invariants + whereby specific baseline correlations need to be indexed according + to those triangular loops. + Baseline array format [ant1, ant2]: + [[0, 1], [0, 2], [0, 3], [0, 4], [0, 5], [0, 6] ... + [1, 2], [1, 3], [1, 4], [1, 5], [1, 6] ... + [2, 3], [2, 4], [2, 5], [2, 6] ... + [3, 4], [3, 5], [3, 6] ... + [4, 5], [4, 6] ... + [5, 6] ... + + Args: + n (int): number of antenna in the array + + Returns: + atriads (np.ndarray): antenna triangular loop indicies + btriads (np.ndarray): baseline triangular loop indicies + """ + ntriads = (n-1)*(n-2)//2 + ant1 = np.zeros(ntriads, dtype=np.uint8) + ant2 = np.arange(1, n, dtype=np.uint8).reshape(n-1, 1) + np.zeros(n-2, dtype=np.uint8).reshape(1, n-2) + ant3 = np.arange(2, n, dtype=np.uint8).reshape(1, n-2) + np.zeros(n-1, dtype=np.uint8).reshape(n-1, 1) + anti = np.where(ant3 > ant2) + ant2, ant3 = ant2[anti], ant3[anti] + atriads = np.concatenate([ant1.reshape(-1, 1), ant2.reshape(-1, 1), ant3.reshape(-1, 1)], axis=-1) + + ant_pairs_01 = list(zip(ant1, ant2)) + ant_pairs_12 = list(zip(ant2, ant3)) + ant_pairs_20 = list(zip(ant3, ant1)) + + t1 = np.arange(n, dtype=int).reshape(n, 1) + np.zeros(n, dtype=int).reshape(1, n) + t2 = np.arange(n, dtype=int).reshape(1, n) + np.zeros(n, dtype=int).reshape(n, 1) + bli = np.where(t2 > t1) + t1, t2 = t1[bli], t2[bli] + bl_pairs = list(zip(t1, t2)) + + bl_01 = np.asarray([bl_pairs.index(apair) for apair in ant_pairs_01]) + bl_12 = np.asarray([bl_pairs.index(apair) for apair in ant_pairs_12]) + bl_20 = np.asarray([bl_pairs.index(tuple(reversed(apair))) for apair in ant_pairs_20]) + btriads = np.concatenate([bl_01.reshape(-1, 1), bl_12.reshape(-1, 1), bl_20.reshape(-1, 1)], axis=-1) + return atriads, btriads + + def bispectra(self, vtype='vis', mode='all', count='min', + timetype=False, uv_min=False, snrcut=0.): + """Return a recarray of the equal time bispectra. + + Args: + vtype (str): The visibilty type from which to assemble bispectra + ('vis', 'qvis', 'uvis','vvis','rrvis','lrvis','rlvis','llvis') + mode (str): If 'time', return phases in a list of equal time arrays, + if 'all', return all phases in a single array + count (str): If 'min', return minimal set of bispectra, + if 'max' return all bispectra up to reordering + timetype (str): 'GMST' or 'UTC' + uv_min (float): flag baselines shorter than this before forming closure quantities + snrcut (float): flag bispectra with snr lower than this + + Returns: + (numpy.recarry): A recarray of the bispectra values with datatype DTBIS + """ + + if timetype is False: + timetype = self.timetype + if mode not in ('time', 'all'): + raise Exception("possible options for mode are 'time' and 'all'") + if count not in ('max', 'min', 'min-cut0bl'): + raise Exception("possible options for count are 'max', 'min', or 'min-cut0bl'") + if vtype not in ('vis', 'qvis', 'uvis', 'vvis', 'rrvis', 'lrvis', 'rlvis', 'llvis'): + raise Exception("possible options for vtype are" + + " 'vis', 'qvis', 'uvis','vvis','rrvis','lrvis','rlvis','llvis'") + if timetype not in ['GMST', 'UTC', 'gmst', 'utc']: + raise Exception("timetype should be 'GMST' or 'UTC'!") + + # Flag zero baselines + obsdata = self.copy() + if uv_min: + obsdata = obsdata.flag_uvdist(uv_min=uv_min) + # get which sites were flagged + obsdata_flagged = self.copy() + obsdata_flagged = obsdata_flagged.flag_uvdist(uv_max=uv_min) + + # Generate the time-sorted data with conjugate baselines + tlist = obsdata.tlist(conj=True) + out = [] + bis = [] + tt = 1 + for tdata in tlist: + + # sys.stdout.write('\rGetting bispectra:: type %s, count %s, scan %i/%i ' % + # (vtype, count, tt, len(tlist))) + # sys.stdout.flush() + + tt += 1 + + time = tdata[0]['time'] + if timetype in ['GMST', 'gmst'] and self.timetype == 'UTC': + time = obsh.utc_to_gmst(time, self.mjd) + if timetype in ['UTC', 'utc'] and self.timetype == 'GMST': + time = obsh.gmst_to_utc(time, self.mjd) + sites = list(set(np.hstack((tdata['t1'], tdata['t2'])))) + + # Create a dictionary of baselines at the current time incl. conjugates; + l_dict = {} + for dat in tdata: + l_dict[(dat['t1'], dat['t2'])] = dat + + # Determine the triangles in the time step + # Minimal Set + if count == 'min': + tris = obsh.tri_minimal_set(sites, self.tarr, self.tkey) + + # Maximal Set + elif count == 'max': + tris = np.sort(list(it.combinations(sites, 3))) + + elif count == 'min-cut0bl': + tris = obsh.tri_minimal_set(sites, self.tarr, self.tkey) + + # if you cut the 0 baselines, add in triangles that now are not in the minimal set + if uv_min: + # get the reference site + sites_ordered = [x for x in self.tarr['site'] if x in sites] + ref = sites_ordered[0] + + # check if the reference site was in a zero baseline + zerobls = np.vstack([obsdata_flagged.data['t1'], obsdata_flagged.data['t2']]) + if np.sum(zerobls == ref): + + # determine which sites were cut out of the minimal set + cutsites = np.unique(np.hstack([zerobls[1][zerobls[0] == ref], + zerobls[0][zerobls[1] == ref]])) + + # we can only handle if there was 1 connecting site that was cut + if len(cutsites) > 1: + raise Exception("Cannot have the root node be in a clique" + + "with more than 2 sites sharing 0 baselines'") + + # get the remaining sites + cutsite = cutsites[0] + sites_remaining = np.array(sites_ordered)[np.array(sites_ordered) != ref] + sites_remaining = sites_remaining[np.array(sites_remaining) != cutsite] + # get the next site in the list, ideally sorted by snr + second_ref = sites_remaining[0] + + # add in additional triangles + for s2 in range(1, len(sites_remaining)): + tris.append((cutsite, second_ref, sites_remaining[s2])) + + # Generate bispectra for each triangle + for tri in tris: + + # Select triangle entries in the data dictionary + try: + l1 = l_dict[(tri[0], tri[1])] + l2 = l_dict[(tri[1], tri[2])] + l3 = l_dict[(tri[2], tri[0])] + except KeyError: + continue + + (bi, bisig) = obsh.make_bispectrum(l1, l2, l3, vtype, polrep=self.polrep) + + # Cut out low snr points + if np.abs(bi) / bisig < snrcut: + continue + + # Append to the equal-time list + bis.append(np.array((time, + tri[0], tri[1], tri[2], + l1['u'], l1['v'], + l2['u'], l2['v'], + l3['u'], l3['v'], + bi, bisig), dtype=ehc.DTBIS)) + + # Append to outlist + if mode == 'time' and len(bis) > 0: + out.append(np.array(bis)) + bis = [] + + if mode == 'all': + out = np.array(bis) + + return out + + def c_phases(self, vtype='vis', mode='all', count='min', ang_unit='deg', + timetype=False, uv_min=False, snrcut=0.): + """Return a recarray of the equal time closure phases. + + Args: + vtype (str): The visibilty type from which to assemble closure phases + ('vis','qvis','uvis','vvis','pvis') + mode (str): If 'time', return phases in a list of equal time arrays, + if 'all', return all phases in a single array + count (str): If 'min', return minimal set of phases, + if 'max' return all closure phases up to reordering + ang_unit (str): If 'deg', return closure phases in degrees, else return in radians + timetype (str): 'UTC' or 'GMST' + uv_min (float): flag baselines shorter than this before forming closure quantities + snrcut (float): flag bispectra with snr lower than this + + Returns: + (numpy.recarry): A recarray of the closure phases with datatype DTCPHASE + """ + + if timetype is False: + timetype = self.timetype + if mode not in ('time', 'all'): + raise Exception("possible options for mode are 'time' and 'all'") + if count not in ('max', 'min', 'min-cut0bl'): + raise Exception("possible options for count are 'max', 'min', or 'min-cut0bl'") + if vtype not in ('vis', 'qvis', 'uvis', 'vvis', 'rrvis', 'lrvis', 'rlvis', 'llvis'): + raise Exception("possible options for vtype are" + + " 'vis', 'qvis', 'uvis','vvis','rrvis','lrvis','rlvis','llvis'") + if timetype not in ['GMST', 'UTC', 'gmst', 'utc']: + raise Exception("timetype should be 'GMST' or 'UTC'!") + + if ang_unit == 'deg': + angle = ehc.DEGREE + else: + angle = 1.0 + + # Get the bispectra data + bispecs = self.bispectra(vtype=vtype, mode='time', count=count, + timetype=timetype, uv_min=uv_min, snrcut=snrcut) + + # Reformat into a closure phase list/array + out = [] + cps = [] + + cpnames = ('time', 't1', 't2', 't3', 'u1', 'v1', 'u2', + 'v2', 'u3', 'v3', 'cphase', 'sigmacp') + for bis in bispecs: + for bi in bis: + if len(bi) == 0: + continue + bi.dtype.names = cpnames + bi['sigmacp'] = np.real(bi['sigmacp'] / np.abs(bi['cphase']) / angle) + bi['cphase'] = np.real((np.angle(bi['cphase']) / angle)) + cps.append(bi.astype(np.dtype(ehc.DTCPHASE))) + + if mode == 'time' and len(cps) > 0: + out.append(np.array(cps)) + cps = [] + + if mode == 'all': + out = np.array(cps) + + return out + + def c_phases_diag(self, vtype='vis', count='min', ang_unit='deg', + timetype=False, uv_min=False, snrcut=0.): + """Return a recarray of the equal time diagonalized closure phases. + + Args: + vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis') + from which to assemble closure phases + count (str): If 'min', return minimal set of phases, + If 'min-cut0bl' return minimal set after flagging zero-baselines + ang_unit (str): If 'deg', return closure phases in degrees, else return in radians + timetype (str): 'UTC' or 'GMST' + uv_min (float): flag baselines shorter than this before forming closure quantities + snrcut (float): flag bispectra with snr lower than this + + Returns: + (numpy.recarry): A recarray of diagonalized closure phases (datatype DTCPHASEDIAG), + along with associated triangles and transformation matrices + """ + + if timetype is False: + timetype = self.timetype + if count not in ('min', 'min-cut0bl'): + raise Exception( + "possible options for count are 'min' or 'min-cut0bl' for diagonal closure phases") + if vtype not in ('vis', 'qvis', 'uvis', 'vvis', 'rrvis', 'lrvis', 'rlvis', 'llvis'): + raise Exception( + "possible options for vtype are 'vis', 'qvis', " + + "'uvis','vvis','rrvis','lrvis','rlvis','llvis'") + if timetype not in ['GMST', 'UTC', 'gmst', 'utc']: + raise Exception("timetype should be 'GMST' or 'UTC'!") + + if ang_unit == 'deg': + angle = ehc.DEGREE + else: + angle = 1.0 + + # determine the appropriate sigmatype + if vtype in ["vis", "qvis", "uvis", "vvis"]: + if vtype == 'vis': + sigmatype = 'sigma' + if vtype == 'qvis': + sigmatype = 'qsigma' + if vtype == 'uvis': + sigmatype = 'usigma' + if vtype == 'vvis': + sigmatype = 'vsigma' + if vtype in ["rrvis", "llvis", "rlvis", "lrvis"]: + if vtype == 'rrvis': + sigmatype = 'rrsigma' + if vtype == 'llvis': + sigmatype = 'llsigma' + if vtype == 'rlvis': + sigmatype = 'rlsigma' + if vtype == 'lrvis': + sigmatype = 'lrsigma' + + # get the time-sorted visibility data including conjugate baselines + viss = np.concatenate(self.tlist(conj=True)) + + # get the closure phase data + cps = self.c_phases(vtype=vtype, mode='all', count=count, ang_unit=ang_unit, + timetype=timetype, uv_min=uv_min, snrcut=snrcut) + + # get the unique timestamps for the closure phases + T_cps = np.unique(cps['time']) + + # list of diagonalized closure phases and corresponding transformation matrices + dcps = [] + dcp_errs = [] + tfmats = [] + + tris = [] + us = [] + vs = [] + + # loop over the timestamps + for kk, t in enumerate(T_cps): + + sys.stdout.write('\rDiagonalizing closure phases:: type %s, count %s, scan %i/%i ' % + (vtype, count, kk + 1, len(T_cps))) + sys.stdout.flush() + + # index masks for this timestamp + mask_cp = (cps['time'] == t) + mask_vis = (viss['time'] == t) + + # closure phases for this timestamp + cps_here = cps[mask_cp] + + # visibilities for this timestamp + viss_here = viss[mask_vis] + + # initialize the design matrix + design_mat = np.zeros((mask_cp.sum(), mask_vis.sum())) + + # loop over the closure phases within this timestamp + trilist = [] + ulist = [] + vlist = [] + for ic, cp in enumerate(cps_here): + + trilist.append((cp['t1'], cp['t2'], cp['t3'])) + ulist.append((cp['u1'], cp['u2'], cp['u3'])) + vlist.append((cp['v1'], cp['v2'], cp['v3'])) + + # matrix entry for first leg of triangle + ind1 = ((viss_here['t1'] == cp['t1']) & (viss_here['t2'] == cp['t2'])) + design_mat[ic, ind1] = 1.0 + + # matrix entry for second leg of triangle + ind2 = ((viss_here['t1'] == cp['t2']) & (viss_here['t2'] == cp['t3'])) + design_mat[ic, ind2] = 1.0 + + # matrix entry for third leg of triangle + ind3 = ((viss_here['t1'] == cp['t3']) & (viss_here['t2'] == cp['t1'])) + design_mat[ic, ind3] = 1.0 + + # construct the covariance matrix + visphase_err = viss_here[sigmatype] / np.abs(viss_here[vtype]) + sigma_mat = np.diag(visphase_err**2.0) + covar_mat = np.matmul(np.matmul(design_mat, sigma_mat), np.transpose(design_mat)) + + # diagonalize via eigendecomposition + eigeninfo = np.linalg.eigh(covar_mat) + S_matrix = np.copy(eigeninfo[1]).transpose() + dcphase = np.matmul(S_matrix, cps_here['cphase']) + if ang_unit != 'deg': + dcphase *= angle + dcphase_err = np.sqrt(np.copy(eigeninfo[0])) / angle + + dcps.append(dcphase) + dcp_errs.append(dcphase_err) + tfmats.append(S_matrix) + tris.append(trilist) + us.append(ulist) + vs.append(vlist) + + # Reformat into a list + out = [] + for kk, t in enumerate(T_cps): + dcparr = [] + for idcp, dcp in enumerate(dcps[kk]): + dcparr.append((t, dcp, dcp_errs[kk][idcp])) + dcparr = np.array(dcparr, dtype=[('time', 'f8'), ('cphase', 'f8'), ('sigmacp', 'f8')]) + out.append((dcparr, + np.array(tris[kk]).astype(np.dtype([('trianges', 'U2')])), + np.array(us[kk]).astype(np.dtype([('u', 'f8')])), + np.array(vs[kk]).astype(np.dtype([('v', 'f8')])), + tfmats[kk].astype(np.dtype([('tform_matrix', 'f8')])))) + print("\n") + return out + + def bispectra_tri(self, site1, site2, site3, + vtype='vis', timetype=False, snrcut=0., method='from_maxset', + bs=[], force_recompute=False): + + """Return complex bispectrum over time on a triangle (1-2-3). + + Args: + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + + vtype (str): The visibilty type from which to assemble bispectra + ('vis','qvis','uvis','vvis','pvis') + timetype (str): 'UTC' or 'GMST' + snrcut (float): flag bispectra with snr lower than this + + method (str): 'from_maxset' (old, default), 'from_vis' (new, more robust) + bs (list): optionally pass in the precomputed, time-sorted bispectra + force_recompute (bool): if True, recompute bispectra instead of using saved data + + Returns: + (numpy.recarry): A recarray of the bispectra on this triangle with datatype DTBIS + """ + if timetype is False: + timetype = self.timetype + + if method=='from_maxset' and (vtype in ['lrvis','pvis','rlvis']): + print ("Warning! method='from_maxset' default in bispectra_tri() inconsistent with vtype=%s" % vtype) + print ("Switching to method='from_vis'") + method = 'from_vis' + + tri = (site1, site2, site3) + outdata = [] + + # get selected bispectra from the maximal set + # TODO: verify consistency/performance of from_vis, and delete this method + if method=='from_maxset': + + if ((len(bs) == 0) and not (self.bispec is None) and not (len(self.bispec) == 0) and + not force_recompute): + bs = self.bispec + elif (len(bs) == 0) or force_recompute: + bs = self.bispectra(mode='all', count='max', vtype=vtype, + timetype=timetype, snrcut=snrcut) + + # Get requested bispectra over time + for obs in bs: + obstri = (obs['t1'], obs['t2'], obs['t3']) + if set(obstri) == set(tri): + t1 = copy.deepcopy(obs['t1']) + t2 = copy.deepcopy(obs['t2']) + t3 = copy.deepcopy(obs['t3']) + u1 = copy.deepcopy(obs['u1']) + u2 = copy.deepcopy(obs['u2']) + u3 = copy.deepcopy(obs['u3']) + v1 = copy.deepcopy(obs['v1']) + v2 = copy.deepcopy(obs['v2']) + v3 = copy.deepcopy(obs['v3']) + + # Reorder baselines and flip the sign of the closure phase if necessary + if t1 == site1: + if t2 == site2: + pass + else: + obs['t2'] = t3 + obs['t3'] = t2 + + obs['u1'] = -u3 + obs['v1'] = -v3 + obs['u2'] = -u2 + obs['v2'] = -v2 + obs['u3'] = -u1 + obs['v3'] = -v1 + obs['bispec'] = np.conjugate(obs['bispec']) + + elif t1 == site2: + if t2 == site3: + obs['t1'] = t3 + obs['t2'] = t1 + obs['t3'] = t2 + + obs['u1'] = u3 + obs['v1'] = v3 + obs['u2'] = u1 + obs['v2'] = v1 + obs['u3'] = u2 + obs['v3'] = v2 + + else: + obs['t1'] = t2 + obs['t2'] = t1 + + obs['u1'] = -u1 + obs['v1'] = -v1 + obs['u2'] = -u3 + obs['v2'] = -v3 + obs['u3'] = -u2 + obs['v3'] = -v2 + obs['bispec'] = np.conjugate(obs['bispec']) + + elif t1 == site3: + if t2 == site1: + obs['t1'] = t2 + obs['t2'] = t3 + obs['t3'] = t1 + + obs['u1'] = u2 + obs['v1'] = v2 + obs['u2'] = u3 + obs['v2'] = v3 + obs['u3'] = u1 + obs['v3'] = v1 + + else: + obs['t1'] = t3 + obs['t3'] = t1 + + obs['u1'] = -u2 + obs['v1'] = -v2 + obs['u2'] = -u1 + obs['v2'] = -v1 + obs['u3'] = -u3 + obs['v3'] = -v3 + obs['bispec'] = np.conjugate(obs['bispec']) + + outdata.append(np.array(obs, dtype=ehc.DTBIS)) + continue + + # get selected bispectra from the visibilities directly + # taken from bispectra() method + elif method=='from_vis': + + # get all equal-time data, and loop over to construct bispectra + tlist = self.tlist(conj=True) + for tdata in tlist: + + time = tdata[0]['time'] + if timetype in ['GMST', 'gmst'] and self.timetype == 'UTC': + time = obsh.utc_to_gmst(time, self.mjd) + if timetype in ['UTC', 'utc'] and self.timetype == 'GMST': + time = obsh.gmst_to_utc(time, self.mjd) + + # Create a dictionary of baselines at the current time incl. conjugates; + l_dict = {} + for dat in tdata: + l_dict[(dat['t1'], dat['t2'])] = dat + + # Select triangle entries in the data dictionary + try: + l1 = l_dict[(tri[0], tri[1])] + l2 = l_dict[(tri[1], tri[2])] + l3 = l_dict[(tri[2], tri[0])] + except KeyError: + continue + + (bi, bisig) = obsh.make_bispectrum(l1, l2, l3, vtype, polrep=self.polrep) + + # Cut out low snr points + if np.abs(bi) / bisig < snrcut: + continue + + # Append to the equal-time list + outdata.append(np.array((time, + tri[0], tri[1], tri[2], + l1['u'], l1['v'], + l2['u'], l2['v'], + l3['u'], l3['v'], + bi, + bisig), + dtype=ehc.DTBIS)) + else: + raise Exception("keyword 'method' in bispectra_tri() must be either 'from_cphase' or 'from_vis'") + + outdata = np.array(outdata) + return outdata + + def cphase_tri(self, site1, site2, site3, vtype='vis', ang_unit='deg', + timetype=False, snrcut=0., method='from_maxset', + cphases=[], force_recompute=False): + """Return closure phase over time on a triangle (1-2-3). + + Args: + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + + vtype (str): The visibilty type from which to assemble closure phases + (e.g., 'vis','qvis','uvis','vvis','pvis') + ang_unit (str): If 'deg', return closure phases in degrees, else return in radians + timetype (str): 'GMST' or 'UTC' + snrcut (float): flag bispectra with snr lower than this + + method (str): 'from_maxset' (old, default), 'from_vis' (new, more robust) + cphases (list): optionally pass in the precomputed time-sorted cphases + force_recompute (bool): if True, do not use save closure phase tables + + Returns: + (numpy.recarry): A recarray of the closure phases with datatype DTCPHASE + """ + + if timetype is False: + timetype = self.timetype + + if method=='from_maxset' and (vtype in ['lrvis','pvis','rlvis']): + print ("Warning! method='from_maxset' default in cphase_tri() is inconsistent with vtype=%s" % vtype) + print ("Switching to method='from_vis'") + method = 'from_vis' + + tri = (site1, site2, site3) + outdata = [] + + # get selected closure phases from the maximal set + # TODO: verify consistency/performance of from_vis, and delete this method + if method=='from_maxset': + + # Get closure phases (maximal set) + if ((len(cphases) == 0) and not (self.cphase is None) and not (len(self.cphase) == 0) and + not force_recompute): + cphases = self.cphase + + elif (len(cphases) == 0) or force_recompute: + cphases = self.c_phases(mode='all', count='max', vtype=vtype, ang_unit=ang_unit, + timetype=timetype, snrcut=snrcut) + + # Get requested closure phases over time + for obs in cphases: + obstri = (obs['t1'], obs['t2'], obs['t3']) + if set(obstri) == set(tri): + t1 = copy.deepcopy(obs['t1']) + t2 = copy.deepcopy(obs['t2']) + t3 = copy.deepcopy(obs['t3']) + u1 = copy.deepcopy(obs['u1']) + u2 = copy.deepcopy(obs['u2']) + u3 = copy.deepcopy(obs['u3']) + v1 = copy.deepcopy(obs['v1']) + v2 = copy.deepcopy(obs['v2']) + v3 = copy.deepcopy(obs['v3']) + + # Reorder baselines and flip the sign of the closure phase if necessary + if t1 == site1: + if t2 == site2: + pass + else: + obs['t2'] = t3 + obs['t3'] = t2 + + obs['u1'] = -u3 + obs['v1'] = -v3 + obs['u2'] = -u2 + obs['v2'] = -v2 + obs['u3'] = -u1 + obs['v3'] = -v1 + obs['cphase'] *= -1 + + elif t1 == site2: + if t2 == site3: + obs['t1'] = t3 + obs['t2'] = t1 + obs['t3'] = t2 + + obs['u1'] = u3 + obs['v1'] = v3 + obs['u2'] = u1 + obs['v2'] = v1 + obs['u3'] = u2 + obs['v3'] = v2 + + else: + obs['t1'] = t2 + obs['t2'] = t1 + + obs['u1'] = -u1 + obs['v1'] = -v1 + obs['u2'] = -u3 + obs['v2'] = -v3 + obs['u3'] = -u2 + obs['v3'] = -v2 + obs['cphase'] *= -1 + + elif t1 == site3: + if t2 == site1: + obs['t1'] = t2 + obs['t2'] = t3 + obs['t3'] = t1 + + obs['u1'] = u2 + obs['v1'] = v2 + obs['u2'] = u3 + obs['v2'] = v3 + obs['u3'] = u1 + obs['v3'] = v1 + + else: + obs['t1'] = t3 + obs['t3'] = t1 + + obs['u1'] = -u2 + obs['v1'] = -v2 + obs['u2'] = -u1 + obs['v2'] = -v1 + obs['u3'] = -u3 + obs['v3'] = -v3 + obs['cphase'] *= -1 + + outdata.append(np.array(obs, dtype=ehc.DTCPHASE)) + continue + + # get selected closure phases from the visibilities directly + # taken from bispectra() method + elif method=='from_vis': + if ang_unit == 'deg': angle = ehc.DEGREE + else: angle = 1.0 + + # get all equal-time data, and loop over to construct closure phase + tlist = self.tlist(conj=True) + for tdata in tlist: + + time = tdata[0]['time'] + if timetype in ['GMST', 'gmst'] and self.timetype == 'UTC': + time = obsh.utc_to_gmst(time, self.mjd) + if timetype in ['UTC', 'utc'] and self.timetype == 'GMST': + time = obsh.gmst_to_utc(time, self.mjd) + + # Create a dictionary of baselines at the current time incl. conjugates; + l_dict = {} + for dat in tdata: + l_dict[(dat['t1'], dat['t2'])] = dat + + # Select triangle entries in the data dictionary + try: + l1 = l_dict[(tri[0], tri[1])] + l2 = l_dict[(tri[1], tri[2])] + l3 = l_dict[(tri[2], tri[0])] + except KeyError: + continue + + (bi, bisig) = obsh.make_bispectrum(l1, l2, l3, vtype, polrep=self.polrep) + + # Cut out low snr points + if np.abs(bi) / bisig < snrcut: + continue + + # Append to the equal-time list + outdata.append(np.array((time, + tri[0], tri[1], tri[2], + l1['u'], l1['v'], + l2['u'], l2['v'], + l3['u'], l3['v'], + np.real(np.angle(bi) / angle), + np.real(bisig / np.abs(bi) / angle)), + dtype=ehc.DTCPHASE)) + else: + raise Exception("keyword 'method' in cphase_tri() must be either 'from_cphase' or 'from_vis'") + + outdata = np.array(outdata) + return outdata + + def c_amplitudes(self, vtype='vis', mode='all', count='min', ctype='camp', debias=True, + timetype=False, snrcut=0.): + """Return a recarray of the equal time closure amplitudes. + + Args: + vtype (str): The visibilty type from which to assemble closure amplitudes + ('vis','qvis','uvis','vvis','pvis') + ctype (str): The closure amplitude type ('camp' or 'logcamp') + mode (str): If 'time', return amplitudes in a list of equal time arrays, + if 'all', return all amplitudes in a single array + count (str): If 'min', return minimal set of amplitudes, + if 'max' return all closure amplitudes up to inverses + debias (bool): If True, debias the closure amplitude + timetype (str): 'GMST' or 'UTC' + snrcut (float): flag closure amplitudes with snr lower than this + + Returns: + (numpy.recarry): A recarray of the closure amplitudes with datatype DTCAMP + + """ + + if timetype is False: + timetype = self.timetype + if mode not in ('time', 'all'): + raise Exception("possible options for mode are 'time' and 'all'") + if count not in ('max', 'min'): + raise Exception("possible options for count are 'max' and 'min'") + if vtype not in ('vis', 'qvis', 'uvis', 'vvis', 'rrvis', 'lrvis', 'rlvis', 'llvis'): + raise Exception("possible options for vtype are " + + "'vis', 'qvis', 'uvis','vvis','rrvis','lrvis','rlvis','llvis'") + if not (ctype in ['camp', 'logcamp']): + raise Exception("closure amplitude type must be 'camp' or 'logcamp'!") + if timetype not in ['GMST', 'UTC', 'gmst', 'utc']: + raise Exception("timetype should be 'GMST' or 'UTC'!") + + # Get data sorted by time + tlist = self.tlist(conj=True) + out = [] + cas = [] + tt = 1 + for tdata in tlist: + + # sys.stdout.write('\rGetting closure amps:: type %s %s , count %s, scan %i/%i' % + # (vtype, ctype, count, tt, len(tlist))) + # sys.stdout.flush() + tt += 1 + + time = tdata[0]['time'] + if timetype in ['GMST', 'gmst'] and self.timetype == 'UTC': + time = obsh.utc_to_gmst(time, self.mjd) + if timetype in ['UTC', 'utc'] and self.timetype == 'GMST': + time = obsh.gmst_to_utc(time, self.mjd) + + sites = np.array(list(set(np.hstack((tdata['t1'], tdata['t2']))))) + if len(sites) < 4: + continue + + # Create a dictionary of baseline data at the current time including conjugates; + l_dict = {} + for dat in tdata: + l_dict[(dat['t1'], dat['t2'])] = dat + + # Minimal set + if count == 'min': + quadsets = obsh.quad_minimal_set(sites, self.tarr, self.tkey) + + # Maximal Set + elif count == 'max': + # Find all quadrangles + quadsets = np.sort(list(it.combinations(sites, 4))) + # Include 3 closure amplitudes on each quadrangle + quadsets = np.array([(q, [q[0], q[2], q[1], q[3]], [q[0], q[1], q[3], q[2]]) + for q in quadsets]).reshape((-1, 4)) + + # Loop over all closure amplitudes + for quad in quadsets: + # Blue is numerator, red is denominator + if (quad[0], quad[1]) not in l_dict.keys(): + continue + if (quad[2], quad[3]) not in l_dict.keys(): + continue + if (quad[1], quad[2]) not in l_dict.keys(): + continue + if (quad[0], quad[3]) not in l_dict.keys(): + continue + + try: + blue1 = l_dict[quad[0], quad[1]] + blue2 = l_dict[quad[2], quad[3]] + red1 = l_dict[quad[0], quad[3]] + red2 = l_dict[quad[1], quad[2]] + except KeyError: + continue + + # Compute the closure amplitude and the error + (camp, camperr) = obsh.make_closure_amplitude(blue1, blue2, red1, red2, vtype, + polrep=self.polrep, + ctype=ctype, debias=debias) + + if ctype == 'camp' and camp / camperr < snrcut: + continue + elif ctype == 'logcamp' and 1. / camperr < snrcut: + continue + + # Add the closure amplitudes to the equal-time list + # Our site convention is (12)(34)/(14)(23) + cas.append(np.array((time, + quad[0], quad[1], quad[2], quad[3], + blue1['u'], blue1['v'], blue2['u'], blue2['v'], + red1['u'], red1['v'], red2['u'], red2['v'], + camp, camperr), + dtype=ehc.DTCAMP)) + + # Append all equal time closure amps to outlist + if mode == 'time': + out.append(np.array(cas)) + cas = [] + + if mode == 'all': + out = np.array(cas) + + return out + + def c_log_amplitudes_diag(self, vtype='vis', mode='all', count='min', + debias=True, timetype=False, snrcut=0.): + """Return a recarray of the equal time diagonalized log closure amplitudes. + + Args: + vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis') + From which to assemble closure amplitudes + ctype (str): The closure amplitude type ('camp' or 'logcamp') + mode (str): If 'time', return amplitudes in a list of equal time arrays, + If 'all', return all amplitudes in a single array + count (str): If 'min', return minimal set of amplitudes, + If 'max' return all closure amplitudes up to inverses + debias (bool): If True, debias the closure amplitude - + The individual visibility amplitudes are always debiased. + timetype (str): 'GMST' or 'UTC' + snrcut (float): flag closure amplitudes with snr lower than this + + Returns: + (numpy.recarry): A recarray of diagonalized closure amps with datatype DTLOGCAMPDIAG + + """ + + if timetype is False: + timetype = self.timetype + if mode not in ('time', 'all'): + raise Exception("possible options for mode are 'time' and 'all'") + if count not in ('min'): + raise Exception("count can only be 'min' for diagonal log closure amplitudes") + if vtype not in ('vis', 'qvis', 'uvis', 'vvis', 'rrvis', 'lrvis', 'rlvis', 'llvis'): + raise Exception( + "possible options for vtype are 'vis', 'qvis', 'uvis', " + + "'vvis','rrvis','lrvis','rlvis','llvis'") + if timetype not in ['GMST', 'UTC', 'gmst', 'utc']: + raise Exception("timetype should be 'GMST' or 'UTC'!") + + # determine the appropriate sigmatype + if vtype in ["vis", "qvis", "uvis", "vvis"]: + if vtype == 'vis': + sigmatype = 'sigma' + if vtype == 'qvis': + sigmatype = 'qsigma' + if vtype == 'uvis': + sigmatype = 'usigma' + if vtype == 'vvis': + sigmatype = 'vsigma' + if vtype in ["rrvis", "llvis", "rlvis", "lrvis"]: + if vtype == 'rrvis': + sigmatype = 'rrsigma' + if vtype == 'llvis': + sigmatype = 'llsigma' + if vtype == 'rlvis': + sigmatype = 'rlsigma' + if vtype == 'lrvis': + sigmatype = 'lrsigma' + + # get the time-sorted visibility data including conjugate baselines + viss = np.concatenate(self.tlist(conj=True)) + + # get the log closure amplitude data + lcas = self.c_amplitudes(vtype=vtype, mode=mode, count=count, + ctype='logcamp', debias=debias, timetype=timetype, snrcut=snrcut) + + # get the unique timestamps for the log closure amplitudes + T_lcas = np.unique(lcas['time']) + + # list of diagonalized log closure camplitudes and corresponding transformation matrices + dlcas = [] + dlca_errs = [] + tfmats = [] + + quads = [] + us = [] + vs = [] + + # loop over the timestamps + for kk, t in enumerate(T_lcas): + + printstr = ('\rDiagonalizing log closure amplitudes:: type %s, count %s, scan %i/%i ' % + (vtype, count, kk + 1, len(T_lcas))) + sys.stdout.write(printstr) + sys.stdout.flush() + + # index masks for this timestamp + mask_lca = (lcas['time'] == t) + mask_vis = (viss['time'] == t) + + # log closure amplitudes for this timestamp + lcas_here = lcas[mask_lca] + + # visibilities for this timestamp + viss_here = viss[mask_vis] + + # initialize the design matrix + design_mat = np.zeros((mask_lca.sum(), mask_vis.sum())) + + # loop over the log closure amplitudes within this timestamp + quadlist = [] + ulist = [] + vlist = [] + for il, lca in enumerate(lcas_here): + + quadlist.append((lca['t1'], lca['t2'], lca['t3'], lca['t4'])) + ulist.append((lca['u1'], lca['u2'], lca['u3'], lca['u4'])) + vlist.append((lca['v1'], lca['v2'], lca['v3'], lca['v4'])) + + # matrix entry for first leg of quadrangle + ind1 = ((viss_here['t1'] == lca['t1']) & (viss_here['t2'] == lca['t2'])) + design_mat[il, ind1] = 1.0 + + # matrix entry for second leg of quadrangle + ind2 = ((viss_here['t1'] == lca['t3']) & (viss_here['t2'] == lca['t4'])) + design_mat[il, ind2] = 1.0 + + # matrix entry for third leg of quadrangle + ind3 = ((viss_here['t1'] == lca['t1']) & (viss_here['t2'] == lca['t4'])) + design_mat[il, ind3] = -1.0 + + # matrix entry for fourth leg of quadrangle + ind4 = ((viss_here['t1'] == lca['t2']) & (viss_here['t2'] == lca['t3'])) + design_mat[il, ind4] = -1.0 + + # construct the covariance matrix + logvisamp_err = viss_here[sigmatype] / np.abs(viss_here[vtype]) + sigma_mat = np.diag(logvisamp_err**2.0) + covar_mat = np.matmul(np.matmul(design_mat, sigma_mat), np.transpose(design_mat)) + + # diagonalize via eigendecomposition + eigeninfo = np.linalg.eigh(covar_mat) + T_matrix = np.copy(eigeninfo[1]).transpose() + dlogcamp = np.matmul(T_matrix, lcas_here['camp']) + dlogcamp_err = np.sqrt(np.copy(eigeninfo[0])) + + dlcas.append(dlogcamp) + dlca_errs.append(dlogcamp_err) + tfmats.append(T_matrix) + quads.append(quadlist) + us.append(ulist) + vs.append(vlist) + + # Reformat into a list + out = [] + for kk, t in enumerate(T_lcas): + dlcaarr = [] + for idlca, dlca in enumerate(dlcas[kk]): + dlcaarr.append((t, dlca, dlca_errs[kk][idlca])) + dlcaarr = np.array(dlcaarr, dtype=[('time', 'f8'), ('camp', 'f8'), ('sigmaca', 'f8')]) + out.append((dlcaarr, + np.array(quads[kk]).astype(np.dtype([('quadrangles', 'U2')])), + np.array(us[kk]).astype(np.dtype([('u', 'f8')])), + np.array(vs[kk]).astype(np.dtype([('v', 'f8')])), + tfmats[kk].astype(np.dtype([('tform_matrix', 'f8')])))) + print("\n") + return out + + def camp_quad(self, site1, site2, site3, site4, + vtype='vis', ctype='camp', debias=True, timetype=False, snrcut=0., + method='from_maxset', + camps=[], force_recompute=False): + """Return closure phase over time on a quadrange (1-2)(3-4)/(1-4)(2-3). + + Args: + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + site4 (str): station 4 name + + vtype (str): The visibilty type from which to assemble closure amplitudes + ('vis','qvis','uvis','vvis','pvis') + ctype (str): The closure amplitude type ('camp' or 'logcamp') + debias (bool): If True, debias the closure amplitude + timetype (str): 'UTC' or 'GMST' + snrcut (float): flag closure amplitudes with snr lower than this + + method (str): 'from_maxset' (old, default), 'from_vis' (new, more robust) + camps (list): optionally pass in the time-sorted, precomputed camps + force_recompute (bool): if True, do not use save closure amplitude data + + Returns: + (numpy.recarry): A recarray of the closure amplitudes with datatype DTCAMP + """ + + if timetype is False: + timetype = self.timetype + + + if method=='from_maxset' and (vtype in ['lrvis','pvis','rlvis']): + print ("Warning! method='from_maxset' default in camp_quad() is inconsistent with vtype=%s" % vtype) + print ("Switching to method='from_vis'") + method = 'from_vis' + + quad = (site1, site2, site3, site4) + outdata = [] + + # get selected closure amplitudes from the maximal set + # TODO: verify consistency/performance of from_vis, and delete this method + if method=='from_maxset': + if (((ctype == 'camp') and (len(camps) == 0)) and not (self.camp is None) and + not (len(self.camp) == 0) and not force_recompute): + camps = self.camp + elif (((ctype == 'logcamp') and (len(camps) == 0)) and not (self.logcamp is None) and + not (len(self.logcamp) == 0) and not force_recompute): + camps = self.logcamp + elif (len(camps) == 0) or force_recompute: + camps = self.c_amplitudes(mode='all', count='max', vtype=vtype, ctype=ctype, + debias=debias, timetype=timetype, snrcut=snrcut) + + # blue bls in numerator, red in denominator + b1 = set((site1, site2)) + b2 = set((site3, site4)) + r1 = set((site1, site4)) + r2 = set((site2, site3)) + + for obs in camps: # camps does not contain inverses! + + num = [set((obs['t1'], obs['t2'])), set((obs['t3'], obs['t4']))] + denom = [set((obs['t1'], obs['t4'])), set((obs['t2'], obs['t3']))] + + obsquad = (obs['t1'], obs['t2'], obs['t3'], obs['t4']) + if set(quad) == set(obsquad): + + # is this either the closure amplitude or inverse? + rightup = (b1 in num) and (b2 in num) and (r1 in denom) and (r2 in denom) + wrongup = (b1 in denom) and (b2 in denom) and (r1 in num) and (r2 in num) + if not (rightup or wrongup): + continue + + # flip the inverse closure amplitudes + if wrongup: + t1old = copy.deepcopy(obs['t1']) + u1old = copy.deepcopy(obs['u1']) + v1old = copy.deepcopy(obs['v1']) + t2old = copy.deepcopy(obs['t2']) + u2old = copy.deepcopy(obs['u2']) + v2old = copy.deepcopy(obs['v2']) + t3old = copy.deepcopy(obs['t3']) + u3old = copy.deepcopy(obs['u3']) + v3old = copy.deepcopy(obs['v3']) + t4old = copy.deepcopy(obs['t4']) + u4old = copy.deepcopy(obs['u4']) + v4old = copy.deepcopy(obs['v4']) + campold = copy.deepcopy(obs['camp']) + csigmaold = copy.deepcopy(obs['sigmaca']) + + obs['t1'] = t1old + obs['t2'] = t4old + obs['t3'] = t3old + obs['t4'] = t2old + + obs['u1'] = u3old + obs['v1'] = v3old + + obs['u2'] = -u4old + obs['v2'] = -v4old + + obs['u3'] = u1old + obs['v3'] = v1old + + obs['u4'] = -u2old + obs['v4'] = -v2old + + if ctype == 'logcamp': + obs['camp'] = -campold + obs['sigmaca'] = csigmaold + else: + obs['camp'] = 1. / campold + obs['sigmaca'] = csigmaold / (campold**2) + + t1old = copy.deepcopy(obs['t1']) + u1old = copy.deepcopy(obs['u1']) + v1old = copy.deepcopy(obs['v1']) + t2old = copy.deepcopy(obs['t2']) + u2old = copy.deepcopy(obs['u2']) + v2old = copy.deepcopy(obs['v2']) + t3old = copy.deepcopy(obs['t3']) + u3old = copy.deepcopy(obs['u3']) + v3old = copy.deepcopy(obs['v3']) + t4old = copy.deepcopy(obs['t4']) + u4old = copy.deepcopy(obs['u4']) + v4old = copy.deepcopy(obs['v4']) + + # this is all same closure amplitude, but the ordering of labels is different + # return the label ordering that the user requested! + if (obs['t2'], obs['t1'], obs['t4'], obs['t3']) == quad: + obs['t1'] = t2old + obs['t2'] = t1old + obs['t3'] = t4old + obs['t4'] = t3old + + obs['u1'] = -u1old + obs['v1'] = -v1old + + obs['u2'] = -u2old + obs['v2'] = -v2old + + obs['u3'] = u4old + obs['v3'] = v4old + + obs['u4'] = u3old + obs['v4'] = v3old + + elif (obs['t3'], obs['t4'], obs['t1'], obs['t2']) == quad: + obs['t1'] = t3old + obs['t2'] = t4old + obs['t3'] = t1old + obs['t4'] = t2old + + obs['u1'] = u2old + obs['v1'] = v2old + + obs['u2'] = u1old + obs['v2'] = v1old + + obs['u3'] = -u4old + obs['v3'] = -v4old + + obs['u4'] = -u3old + obs['v4'] = -v3old + + elif (obs['t4'], obs['t3'], obs['t2'], obs['t1']) == quad: + obs['t1'] = t4old + obs['t2'] = t3old + obs['t3'] = t2old + obs['t4'] = t1old + + obs['u1'] = -u2old + obs['v1'] = -v2old + + obs['u2'] = -u1old + obs['v2'] = -v1old + + obs['u3'] = -u3old + obs['v3'] = -v3old + + obs['u4'] = -u4old + obs['v4'] = -v4old + + # append to output array + outdata.append(np.array(obs, dtype=ehc.DTCAMP)) + + # get selected bispectra from the visibilities directly + # taken from c_ampitudes() method + elif method=='from_vis': + + # get all equal-time data, and loop over to construct closure amplitudes + tlist = self.tlist(conj=True) + for tdata in tlist: + + time = tdata[0]['time'] + if timetype in ['GMST', 'gmst'] and self.timetype == 'UTC': + time = obsh.utc_to_gmst(time, self.mjd) + if timetype in ['UTC', 'utc'] and self.timetype == 'GMST': + time = obsh.gmst_to_utc(time, self.mjd) + sites = np.array(list(set(np.hstack((tdata['t1'], tdata['t2']))))) + if len(sites) < 4: + continue + + # Create a dictionary of baselines at the current time incl. conjugates; + l_dict = {} + for dat in tdata: + l_dict[(dat['t1'], dat['t2'])] = dat + + # Select quadrangle entries in the data dictionary + # Blue is numerator, red is denominator + if (quad[0], quad[1]) not in l_dict.keys(): + continue + if (quad[2], quad[3]) not in l_dict.keys(): + continue + if (quad[1], quad[2]) not in l_dict.keys(): + continue + if (quad[0], quad[3]) not in l_dict.keys(): + continue + + try: + blue1 = l_dict[quad[0], quad[1]] + blue2 = l_dict[quad[2], quad[3]] + red1 = l_dict[quad[0], quad[3]] + red2 = l_dict[quad[1], quad[2]] + except KeyError: + continue + + # Compute the closure amplitude and the error + (camp, camperr) = obsh.make_closure_amplitude(blue1, blue2, red1, red2, vtype, + polrep=self.polrep, + ctype=ctype, debias=debias) + + if ctype == 'camp' and camp / camperr < snrcut: + continue + elif ctype == 'logcamp' and 1. / camperr < snrcut: + continue + + # Add the closure amplitudes to the equal-time list + # Our site convention is (12)(34)/(14)(23) + outdata.append(np.array((time, + quad[0], quad[1], quad[2], quad[3], + blue1['u'], blue1['v'], blue2['u'], blue2['v'], + red1['u'], red1['v'], red2['u'], red2['v'], + camp, camperr), + dtype=ehc.DTCAMP)) + + else: + raise Exception("keyword 'method' in camp_quad() must be either 'from_cphase' or 'from_vis'") + + outdata = np.array(outdata) + return outdata + + def plotall(self, field1, field2, + conj=False, debias=False, tag_bl=False, ang_unit='deg', timetype=False, + axis=False, rangex=False, rangey=False, snrcut=0., + color=ehc.SCOLORS[0], marker='o', markersize=ehc.MARKERSIZE, label=None, + grid=True, ebar=True, axislabels=True, legend=False, + show=True, export_pdf=""): + """Plot two fields against each other. + + Args: + field1 (str): x-axis field (from FIELDS) + field2 (str): y-axis field (from FIELDS) + + conj (bool): Plot conjuage baseline data points if True + debias (bool): If True, debias amplitudes. + tag_bl (bool): if True, label each baseline + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + + color (str): color for scatterplot points + marker (str): matplotlib plot marker + markersize (int): size of plot markers + label (str): plot legend label + snrcut (float): flag closure amplitudes with snr lower than this + + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + + show (bool): Display the plot if true + export_pdf (str): path to pdf file to save figure + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + """ + + if timetype is False: + timetype = self.timetype + + # Determine if fields are valid + field1 = field1.lower() + field2 = field2.lower() + if (field1 not in ehc.FIELDS) and (field2 not in ehc.FIELDS): + raise Exception("valid fields are " + ' '.join(ehc.FIELDS)) + + if 'amp' in [field1, field2] and not (self.amp is None): + print("Warning: plotall is not using amplitudes in Obsdata.amp array!") + + # Label individual baselines + # ANDREW TODO this is way too slow, make it faster?? + if tag_bl: + clist = ehc.SCOLORS + + # make a color coding dictionary + cdict = {} + ii = 0 + baselines = np.sort(list(it.combinations(self.tarr['site'], 2))) + for baseline in baselines: + cdict[(baseline[0], baseline[1])] = clist[ii % len(clist)] + cdict[(baseline[1], baseline[0])] = clist[ii % len(clist)] + ii += 1 + + # get unique baselines -- TODO easier way? separate function? + alldata = [] + allsigx = [] + allsigy = [] + bllist = [] + colors = [] + bldata = self.bllist(conj=conj) + for bl in bldata: + t1 = bl['t1'][0] + t2 = bl['t2'][0] + + bllist.append((t1, t2)) + colors.append(cdict[(t1, t2)]) + + # Unpack data + dat = self.unpack_dat(bl, [field1, field2], + ang_unit=ang_unit, debias=debias, timetype=timetype) + alldata.append(dat) + + # X error bars + if obsh.sigtype(field1): + allsigx.append(self.unpack_dat(bl, [obsh.sigtype(field1)], + ang_unit=ang_unit)[obsh.sigtype(field1)]) + else: + allsigx.append(None) + + # Y error bars + if obsh.sigtype(field2): + allsigy.append(self.unpack_dat(bl, [obsh.sigtype(field2)], + ang_unit=ang_unit)[obsh.sigtype(field2)]) + else: + allsigy.append(None) + + # Don't Label individual baselines + else: + bllist = [['All', 'All']] + colors = [color] + + # unpack data + alldata = [self.unpack([field1, field2], + conj=conj, ang_unit=ang_unit, debias=debias, timetype=timetype)] + + # X error bars + if obsh.sigtype(field1): + allsigx = self.unpack(obsh.sigtype(field2), conj=conj, ang_unit=ang_unit) + allsigx = [allsigx[obsh.sigtype(field1)]] + else: + allsigx = [None] + + # Y error bars + if obsh.sigtype(field2): + allsigy = self.unpack(obsh.sigtype(field2), conj=conj, ang_unit=ang_unit) + allsigy = [allsigy[obsh.sigtype(field2)]] + else: + allsigy = [None] + + # make plot(s) + if axis: + x = axis + else: + fig = plt.figure() + x = fig.add_subplot(1, 1, 1) + + xmins = [] + xmaxes = [] + ymins = [] + ymaxes = [] + for i in range(len(alldata)): + data = alldata[i] + sigy = allsigy[i] + sigx = allsigx[i] + color = colors[i] + bl = bllist[i] + + # Flag out nans (to avoid problems determining plotting limits) + mask = ~(np.isnan(data[field1]) + np.isnan(data[field2])) + + # Flag out due to snrcut + if snrcut > 0.: + sigs = [sigx, sigy] + for jj, field in enumerate([field1, field2]): + if field in ehc.FIELDS_AMPS: + fmask = data[field] / sigs[jj] > snrcut + elif field in ehc.FIELDS_PHASE: + fmask = sigs[jj] < (180. / np.pi / snrcut) + elif field in ehc.FIELDS_SNRS: + fmask = data[field] > snrcut + else: + fmask = np.ones(mask.shape).astype(bool) + mask *= fmask + + data = data[mask] + if sigy is not None: + sigy = sigy[mask] + if sigx is not None: + sigx = sigx[mask] + if len(data) == 0: + continue + + xmins.append(np.min(data[field1])) + xmaxes.append(np.max(data[field1])) + ymins.append(np.min(data[field2])) + ymaxes.append(np.max(data[field2])) + + # Plot the data + tolerance = len(data[field2]) + + if label is None: + labelstr = "%s-%s" % ((str(bl[0]), str(bl[1]))) + + else: + labelstr = str(label) + + if ebar and (np.any(sigy) or np.any(sigx)): + x.errorbar(data[field1], data[field2], xerr=sigx, yerr=sigy, label=labelstr, + fmt=marker, markersize=markersize, color=color, picker=tolerance) + else: + x.plot(data[field1], data[field2], marker, markersize=markersize, color=color, + label=labelstr, picker=tolerance) + + # Data ranges + if not rangex: + rangex = [np.min(xmins) - 0.2 * np.abs(np.min(xmins)), + np.max(xmaxes) + 0.2 * np.abs(np.max(xmaxes))] + if np.any(np.isnan(np.array(rangex))): + print("Warning: NaN in data x range: specifying rangex to default") + rangex = [-100, 100] + + if not rangey: + rangey = [np.min(ymins) - 0.2 * np.abs(np.min(ymins)), + np.max(ymaxes) + 0.2 * np.abs(np.max(ymaxes))] + if np.any(np.isnan(np.array(rangey))): + print("Warning: NaN in data y range: specifying rangey to default") + rangey = [-100, 100] + + x.set_xlim(rangex) + x.set_ylim(rangey) + + # label and save + if axislabels: + try: + x.set_xlabel(ehc.FIELD_LABELS[field1]) + x.set_ylabel(ehc.FIELD_LABELS[field2]) + except KeyError: + x.set_xlabel(field1.capitalize()) + x.set_ylabel(field2.capitalize()) + if legend and tag_bl: + plt.legend(ncol=2) + elif legend: + plt.legend() + if grid: + x.grid() + if export_pdf != "" and not axis: + fig.savefig(export_pdf, bbox_inches='tight') + if export_pdf != "" and axis: + fig = plt.gcf() + fig.savefig(export_pdf, bbox_inches='tight') + if show: + #plt.show(block=False) + ehc.show_noblock() + + return x + + def plot_bl(self, site1, site2, field, + debias=False, ang_unit='deg', timetype=False, + axis=False, rangex=False, rangey=False, snrcut=0., + color=ehc.SCOLORS[0], marker='o', markersize=ehc.MARKERSIZE, label=None, + grid=True, ebar=True, axislabels=True, legend=False, + show=True, export_pdf=""): + """Plot a field over time on a baseline site1-site2. + + Args: + site1 (str): station 1 name + site2 (str): station 2 name + field (str): y-axis field (from FIELDS) + + debias (bool): If True, debias amplitudes. + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + + color (str): color for scatterplot points + marker (str): matplotlib plot marker + markersize (int): size of plot markers + label (str): plot legend label + + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + export_pdf (str): path to pdf file to save figure + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + """ + + if timetype is False: + timetype = self.timetype + + field = field.lower() + if field == 'amp' and not (self.amp is None): + print("Warning: plot_bl is not using amplitudes in Obsdata.amp array!") + + if label is None: + label = str(self.source) + else: + label = str(label) + + # Determine if fields are valid + if field not in ehc.FIELDS: + raise Exception("valid fields are " + string.join(ehc.FIELDS)) + + plotdata = self.unpack_bl(site1, site2, field, ang_unit=ang_unit, + debias=debias, timetype=timetype) + sigmatype = obsh.sigtype(field) + if obsh.sigtype(field): + errdata = self.unpack_bl(site1, site2, obsh.sigtype(field), + ang_unit=ang_unit, debias=debias) + else: + errdata = None + + # Flag out nans (to avoid problems determining plotting limits) + mask = ~np.isnan(plotdata[field][:, 0]) + + # Flag out due to snrcut + if snrcut > 0.: + if field in ehc.FIELDS_AMPS: + fmask = plotdata[field] / errdata[sigmatype] > snrcut + elif field in ehc.FIELDS_PHASE: + fmask = errdata[sigmatype] < (180. / np.pi / snrcut) + elif field in ehc.FIELDS_SNRS: + fmask = plotdata[field] > snrcut + else: + fmask = np.ones(mask.shape).astype(bool) + fmask = fmask[:, 0] + mask *= fmask + + plotdata = plotdata[mask] + if errdata is not None: + errdata = errdata[mask] + + if not rangex: + rangex = [self.tstart, self.tstop] + if np.any(np.isnan(np.array(rangex))): + print("Warning: NaN in data x range: specifying rangex to default") + rangex = [0, 24] + if not rangey: + rangey = [np.min(plotdata[field]) - 0.2 * np.abs(np.min(plotdata[field])), + np.max(plotdata[field]) + 0.2 * np.abs(np.max(plotdata[field]))] + if np.any(np.isnan(np.array(rangey))): + print("Warning: NaN in data y range: specifying rangex to default") + rangey = [-100, 100] + + # Plot the data + if axis: + x = axis + else: + fig = plt.figure() + x = fig.add_subplot(1, 1, 1) + + if ebar and obsh.sigtype(field) is not False: + x.errorbar(plotdata['time'][:, 0], plotdata[field][:, 0], + yerr=errdata[obsh.sigtype(field)][:, 0], + fmt=marker, markersize=markersize, color=color, + linestyle='none', label=label) + else: + x.plot(plotdata['time'][:, 0], plotdata[field][:, 0], marker, markersize=markersize, + color=color, label=label, linestyle='none') + + x.set_xlim(rangex) + x.set_ylim(rangey) + + if axislabels: + x.set_xlabel(timetype + ' (hr)') + try: + x.set_ylabel(ehc.FIELD_LABELS[field]) + except KeyError: + x.set_ylabel(field.capitalize()) + x.set_title('%s - %s' % (site1, site2)) + + if grid: + x.grid() + if legend: + plt.legend() + if export_pdf != "" and not axis: + fig.savefig(export_pdf, bbox_inches='tight') + if export_pdf != "" and axis: + fig = plt.gcf() + fig.savefig(export_pdf, bbox_inches='tight') + if show: + #plt.show(block=False) + ehc.show_noblock() + return x + + def plot_cphase(self, site1, site2, site3, + vtype='vis', cphases=[], force_recompute=False, + ang_unit='deg', timetype=False, snrcut=0., + axis=False, rangex=False, rangey=False, + color=ehc.SCOLORS[0], marker='o', markersize=ehc.MARKERSIZE, label=None, + grid=True, ebar=True, axislabels=True, legend=False, + show=True, export_pdf=""): + """Plot a closure phase over time on a triangle site1-site2-site3. + + Args: + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + + vtype (str): The visibilty type from which to assemble closure phases + ('vis','qvis','uvis','vvis','pvis') + cphases (list): optionally pass in the prcomputed, time-sorted closure phases + force_recompute (bool): if True, do not use stored closure phase able + snrcut (float): flag closure amplitudes with snr lower than this + + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + + color (str): color for scatterplot points + marker (str): matplotlib plot marker + markersize (int): size of plot markers + label (str): plot legend label + + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + + show (bool): Display the plot if True + export_pdf (str): path to pdf file to save figure + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + """ + + if timetype is False: + timetype = self.timetype + if ang_unit == 'deg': + angle = 1.0 + else: + angle = ehc.DEGREE + + if label is None: + label = str(self.source) + else: + label = str(label) + + # Get closure phases (maximal set) + if (len(cphases) == 0) and (self.cphase is not None) and not force_recompute: + cphases = self.cphase + + cpdata = self.cphase_tri(site1, site2, site3, vtype=vtype, timetype=timetype, + cphases=cphases, force_recompute=force_recompute, snrcut=snrcut) + plotdata = np.array([[obs['time'], obs['cphase'] * angle, obs['sigmacp']] + for obs in cpdata]) + + nan_mask = np.isnan(plotdata[:, 1]) + plotdata = plotdata[~nan_mask] + + if len(plotdata) == 0: + print("%s %s %s : No closure phases on this triangle!" % (site1, site2, site3)) + return + + # Plot the data + if axis: + x = axis + else: + fig = plt.figure() + x = fig.add_subplot(1, 1, 1) + + # Data ranges + if not rangex: + rangex = [self.tstart, self.tstop] + if np.any(np.isnan(np.array(rangex))): + print("Warning: NaN in data x range: specifying rangex to default") + rangex = [0, 24] + + if not rangey: + if ang_unit == 'deg': + rangey = [-190, 190] + else: + rangey = [-1.1 * np.pi, 1.1 * np.pi] + + x.set_xlim(rangex) + x.set_ylim(rangey) + + if ebar and np.any(plotdata[:, 2]): + x.errorbar(plotdata[:, 0], plotdata[:, 1], yerr=plotdata[:, 2], + fmt=marker, markersize=markersize, + color=color, linestyle='none', label=label) + else: + x.plot(plotdata[:, 0], plotdata[:, 1], marker, markersize=markersize, + color=color, linestyle='none', label=label) + + if axislabels: + x.set_xlabel(self.timetype + ' (h)') + if ang_unit == 'deg': + x.set_ylabel(r'Closure Phase $(^\circ)$') + else: + x.set_ylabel(r'Closure Phase (radian)') + + x.set_title('%s - %s - %s' % (site1, site2, site3)) + + if grid: + x.grid() + if legend: + plt.legend() + if export_pdf != "" and not axis: + fig.savefig(export_pdf, bbox_inches='tight') + if export_pdf != "" and axis: + fig = plt.gcf() + fig.savefig(export_pdf, bbox_inches='tight') + if show: + #plt.show(block=False) + ehc.show_noblock() + + return x + + def plot_camp(self, site1, site2, site3, site4, + vtype='vis', ctype='camp', camps=[], force_recompute=False, + debias=False, timetype=False, snrcut=0., + axis=False, rangex=False, rangey=False, + color=ehc.SCOLORS[0], marker='o', markersize=ehc.MARKERSIZE, label=None, + grid=True, ebar=True, axislabels=True, legend=False, + show=True, export_pdf=""): + """Plot closure amplitude over time on a quadrangle (1-2)(3-4)/(1-4)(2-3). + + Args: + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + site4 (str): station 4 name + + vtype (str): The visibilty type from which to assemble closure amplitudes + ('vis','qvis','uvis','vvis','pvis') + ctype (str): The closure amplitude type ('camp' or 'logcamp') + camps (list): optionally pass in camps so they don't have to be recomputed + force_recompute (bool): if True, recompute camps instead of using stored data + snrcut (float): flag closure amplitudes with snr lower than this + + debias (bool): If True, debias the closure amplitude + timetype (str): 'GMST' or 'UTC' + + axis (matplotlib.axes.Axes): amake_cdd plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + color (str): color for scatterplot points + marker (str): matplotlib plot marker + markersize (int): size of plot markers + label (str): plot legend label + + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + + show (bool): Display the plot if True + export_pdf (str): path to pdf file to save figure + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + """ + + if timetype is False: + timetype = self.timetype + if label is None: + label = str(self.source) + + else: + label = str(label) + + # Get closure amplitudes (maximal set) + if ((ctype == 'camp') and (len(camps) == 0) and (self.camp is not None) and + not (len(self.camp) == 0) and not force_recompute): + camps = self.camp + elif ((ctype == 'logcamp') and (len(camps) == 0) and (self.logcamp is not None) and + not (len(self.logcamp) == 0) and not force_recompute): + camps = self.logcamp + + # Get closure amplitudes (maximal set) + cpdata = self.camp_quad(site1, site2, site3, site4, + vtype=vtype, ctype=ctype, snrcut=snrcut, + debias=debias, timetype=timetype, + camps=camps, force_recompute=force_recompute) + + if len(cpdata) == 0: + print('No closure amplitudes on this triangle!') + return + + plotdata = np.array([[obs['time'], obs['camp'], obs['sigmaca']] for obs in cpdata]) + plotdata = np.array(plotdata) + + nan_mask = np.isnan(plotdata[:, 1]) + plotdata = plotdata[~nan_mask] + + if len(plotdata) == 0: + print("No closure amplitudes on this quadrangle!") + return + + # Data ranges + if not rangex: + rangex = [self.tstart, self.tstop] + if np.any(np.isnan(np.array(rangex))): + print("Warning: NaN in data x range: specifying rangex to default") + rangex = [0, 24] + + if not rangey: + rangey = [np.min(plotdata[:, 1]) - 0.2 * np.abs(np.min(plotdata[:, 1])), + np.max(plotdata[:, 1]) + 0.2 * np.abs(np.max(plotdata[:, 1]))] + if np.any(np.isnan(np.array(rangey))): + print("Warning: NaN in data y range: specifying rangey to default") + if ctype == 'camp': + rangey = [0, 100] + if ctype == 'logcamp': + rangey = [-10, 10] + + # Plot the data + if axis: + x = axis + else: + fig = plt.figure() + x = fig.add_subplot(1, 1, 1) + + if ebar and np.any(plotdata[:, 2]): + x.errorbar(plotdata[:, 0], plotdata[:, 1], yerr=plotdata[:, 2], + fmt=marker, markersize=markersize, + color=color, linestyle='none', label=label) + else: + x.plot(plotdata[:, 0], plotdata[:, 1], marker, markersize=markersize, + color=color, linestyle='none', label=label) + + x.set_xlim(rangex) + x.set_ylim(rangey) + + if axislabels: + x.set_xlabel(self.timetype + ' (h)') + if ctype == 'camp': + x.set_ylabel('Closure Amplitude') + elif ctype == 'logcamp': + x.set_ylabel('Log Closure Amplitude') + x.set_title('(%s - %s)(%s - %s)/(%s - %s)(%s - %s)' % (site1, site2, site3, site4, + site1, site4, site2, site3)) + if grid: + x.grid() + if legend: + plt.legend() + if export_pdf != "" and not axis: + fig.savefig(export_pdf, bbox_inches='tight') + if export_pdf != "" and axis: + fig = plt.gcf() + fig.savefig(export_pdf, bbox_inches='tight') + if show: + #plt.show(block=False) + ehc.show_noblock() + return + else: + return x + + def save_txt(self, fname): + """Save visibility data to a text file. + + Args: + fname (str): path to output text file + """ + + ehtim.io.save.save_obs_txt(self, fname) + + return + + def save_uvfits(self, fname, force_singlepol=False, polrep_out='circ'): + """Save visibility data to uvfits file. + Args: + fname (str): path to output uvfits file. + force_singlepol (str): if 'R' or 'L', will interpret stokes I field as 'RR' or 'LL' + polrep_out (str): 'circ' or 'stokes': how data should be stored in the uvfits file + """ + + if (force_singlepol is not False) and (self.polrep != 'stokes'): + raise Exception( + "force_singlepol is incompatible with polrep!='stokes'") + + output = ehtim.io.save.save_obs_uvfits(self, fname, + force_singlepol=force_singlepol, polrep_out=polrep_out) + + return + + def make_hdulist(self, force_singlepol=False, polrep_out='circ'): + """Returns an hdulist in the same format as in a saved .uvfits file. + Args: + force_singlepol (str): if 'R' or 'L', will interpret stokes I field as 'RR' or 'LL' + polrep_out (str): 'circ' or 'stokes': how data should be stored in the uvfits file + Returns: + hdulist (astropy.io.fits.HDUList) + """ + + if (force_singlepol is not False) and (self.polrep != 'stokes'): + raise Exception( + "force_singlepol is incompatible with polrep!='stokes'") + + hdulist = ehtim.io.save.save_obs_uvfits(self, None, + force_singlepol=force_singlepol, polrep_out=polrep_out) + return hdulist + + + def save_oifits(self, fname, flux=1.0): + """ Save visibility data to oifits. Polarization data is NOT saved. + + Args: + fname (str): path to output text file + flux (float): normalization total flux + """ + + if self.polrep != 'stokes': + raise Exception("save_oifits not yet implemented for polreps other than 'stokes'") + + # Antenna diameters are currently incorrect + # the exact times are also not correct in the datetime object + ehtim.io.save.save_obs_oifits(self, fname, flux=flux) + + return + +################################################################################################## +# Observation creation functions +################################################################################################## + + +def merge_obs(obs_List, force_merge=False): + """Merge a list of observations into a single observation file. + + Args: + obs_List (list): list of split observation Obsdata objects. + force_merge (bool): forces the observations to merge even if parameters are different + + Returns: + mergeobs (Obsdata): merged Obsdata object containing all scans in input list + """ + + if (len(set([obs.polrep for obs in obs_List])) > 1): + raise Exception("All observations must have the same polarization representation !") + return + + if np.any([obs.timetype == 'GMST' for obs in obs_List]): + raise Exception("merge_obs only works for observations with obs.timetype='UTC'!") + return + + if not force_merge: + if (len(set([obs.ra for obs in obs_List])) > 1 or + len(set([obs.dec for obs in obs_List])) > 1 or + len(set([obs.rf for obs in obs_List])) > 1 or + len(set([obs.bw for obs in obs_List])) > 1 or + len(set([obs.source for obs in obs_List])) > 1): + # or len(set([np.floor(obs.mjd) for obs in obs_List])) > 1): + + raise Exception("All observations must have the same parameters!") + return + + # the reference observation is the one with the minimum mjd + obs_idx = np.argmin([obs.mjd for obs in obs_List]) + obs_ref = obs_List[obs_idx] + + # re-reference times to new mjd + # must be in UTC! + mjd_ref = obs_ref.mjd + for obs in obs_List: + mjd_offset = obs.mjd - mjd_ref + obs.data['time'] += mjd_offset * 24 + if not(obs.scans is None or obs.scans == []): + obs.scans += mjd_offset * 24 + + # merge the data + data_merge = np.hstack([obs.data for obs in obs_List]) + + # merge the scan list + scan_merge = None + for obs in obs_List: + if (obs.scans is None or obs.scans == []): + continue + if (scan_merge is None or scan_merge == []): + scan_merge = [obs.scans] + else: + scan_merge.append(obs.scans) + + if not (scan_merge is None or scan_merge == []): + scan_merge = np.vstack(scan_merge) + _idxsort = np.argsort(scan_merge[:, 0]) + scan_merge = scan_merge[_idxsort] + + # merge the list of telescopes + tarr_merge = np.unique(np.concatenate([obs.tarr for obs in obs_List])) + + arglist, argdict = obs_ref.obsdata_args() + arglist[DATPOS] = data_merge + arglist[TARRPOS] = tarr_merge + argdict['scantable'] = scan_merge + mergeobs = Obsdata(*arglist, **argdict) + + return mergeobs + + +def load_txt(fname, polrep='stokes'): + """Read an observation from a text file. + + Args: + fname (str): path to input text file + polrep (str): load data as either 'stokes' or 'circ' + + Returns: + obs (Obsdata): Obsdata object loaded from file + """ + + return ehtim.io.load.load_obs_txt(fname, polrep=polrep) + + +def load_uvfits(fname, flipbl=False, remove_nan=False, force_singlepol=None, + channel=all, IF=all, polrep='stokes', allow_singlepol=True, + ignore_pzero_date=True, + trial_speedups=False): + """Load observation data from a uvfits file. + + Args: + fname (str or HDUList): path to input text file or HDUList object + flipbl (bool): flip baseline phases if True. + remove_nan (bool): True to remove nans from missing polarizations + polrep (str): load data as either 'stokes' or 'circ' + force_singlepol (str): 'R' or 'L' to load only 1 polarization + channel (list): list of channels to average in the import. channel=all averages all + IF (list): list of IFs to average in the import. IF=all averages all IFS + remove_nan (bool): whether or not to remove entries with nan data + + ignore_pzero_date (bool): if True, ignore the offset parameters in DATE field + TODO: what is the correct behavior per AIPS memo 117? + Returns: + obs (Obsdata): Obsdata object loaded from file + """ + + return ehtim.io.load.load_obs_uvfits(fname, flipbl=flipbl, force_singlepol=force_singlepol, + channel=channel, IF=IF, polrep=polrep, + remove_nan=remove_nan, allow_singlepol=allow_singlepol, + ignore_pzero_date=ignore_pzero_date, + trial_speedups=trial_speedups) + + +def load_oifits(fname, flux=1.0): + """Load data from an oifits file. Does NOT currently support polarization. + + Args: + fname (str): path to input text file + flux (float): normalization total flux + + Returns: + obs (Obsdata): Obsdata object loaded from file + """ + + return ehtim.io.load.load_obs_oifits(fname, flux=flux) + + +def load_maps(arrfile, obsspec, ifile, qfile=0, ufile=0, vfile=0, + src=ehc.SOURCE_DEFAULT, mjd=ehc.MJD_DEFAULT, ampcal=False, phasecal=False): + """Read an observation from a maps text file and return an Obsdata object. + + Args: + arrfile (str): path to input array file + obsspec (str): path to input obs spec file + ifile (str): path to input Stokes I data file + qfile (str): path to input Stokes Q data file + ufile (str): path to input Stokes U data file + vfile (str): path to input Stokes V data file + src (str): source name + mjd (int): integer observation MJD + ampcal (bool): True if amplitude calibrated + phasecal (bool): True if phase calibrated + + Returns: + obs (Obsdata): Obsdata object loaded from file + """ + + return ehtim.io.load.load_obs_maps(arrfile, obsspec, ifile, + qfile=qfile, ufile=ufile, vfile=vfile, + src=src, mjd=mjd, ampcal=ampcal, phasecal=phasecal) + +def load_obs( + fname, + polrep='stokes', + flipbl=False, + remove_nan=False, + force_singlepol=None, + channel=all, + IF=all, + allow_singlepol=True, + flux=1.0, + obsspec=None, + ifile=None, + qfile=None, + ufile=None, + vfile=None, + src=ehc.SOURCE_DEFAULT, + mjd=ehc.MJD_DEFAULT, + ampcal=False, + phasecal=False + ): + """Smart obs read-in, detects file type and loads appropriately. + + Args: + fname (str): path to input text file + polrep (str): load data as either 'stokes' or 'circ' + flipbl (bool): flip baseline phases if True. + remove_nan (bool): True to remove nans from missing polarizations + polrep (str): load data as either 'stokes' or 'circ' + force_singlepol (str): 'R' or 'L' to load only 1 polarization + channel (list): list of channels to average in the import. channel=all averages all + IF (list): list of IFs to average in the import. IF=all averages all IFS + flux (float): normalization total flux + obsspec (str): path to input obs spec file + ifile (str): path to input Stokes I data file + qfile (str): path to input Stokes Q data file + ufile (str): path to input Stokes U data file + vfile (str): path to input Stokes V data file + src (str): source name + mjd (int): integer observation MJD + ampcal (bool): True if amplitude calibrated + phasecal (bool): True if phase calibrated + + Returns: + obs (Obsdata): Obsdata object loaded from file + """ + + + ## grab file ending ## + fname_extension = fname.split('.')[-1] + print(f"Extension is {fname_extension}.") + + ## check extension ## + if fname_extension.lower() == 'uvfits': + return load_uvfits(fname, flipbl=flipbl, remove_nan=remove_nan, force_singlepol=force_singlepol, channel=channel, IF=IF, polrep=polrep, allow_singlepol=allow_singlepol) + + elif fname_extension.lower() in ['txt', 'text']: + return load_txt(fname, polrep=polrep) + + elif fname_extension.lower() == 'oifits': + return load_oifits(fname, flux=flux) + + + else: + if obsspec is not None and ifile is None: + print("You have provided a value for but no value for ") + return + elif obsspec is None and ifile is not None: + print("You have provided a value for but no value for ") + return + + elif obsspec is not None and ifile is not None: + return load_maps(fname, obsspec, ifile, qfile=qfile, ufile=ufile, vfile=vfile, + src=src, mjd=mjd, ampcal=ampcal, phasecal=phasecal) + diff --git a/observing/__init__.py b/observing/__init__.py new file mode 100644 index 00000000..213aefd6 --- /dev/null +++ b/observing/__init__.py @@ -0,0 +1,11 @@ +""" +.. module:: ehtim.observing + :platform: Unix + :synopsis: EHT Imaging Utilities: simulated observation functions + +.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu) + +""" +from . import pulses +from . import obs_helpers +from . import obs_simulate diff --git a/observing/obs_helpers.py b/observing/obs_helpers.py new file mode 100644 index 00000000..6dc1dbaa --- /dev/null +++ b/observing/obs_helpers.py @@ -0,0 +1,1776 @@ +# obs_helpers.py +# helper functions for simulating and manipulating observations +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +try: + from sgp4.api import Satrec, WGS72 + import skyfield.api +except ImportError: + print("Warning: skyfield not installed: cannot simulate space VLBI") + +try: + from pynfft.nfft import NFFT +except ImportError: + print("Warning: No NFFT installed!") + +import astropy.time as at +import astropy.coordinates as coords +import numpy as np +import itertools as it +import scipy.ndimage as nd +import scipy.spatial.distance +import copy +import sys + +import ehtim.const_def as ehc + +import warnings +warnings.filterwarnings("ignore", message="divide by zero encountered in double_scalars") + +################################################################################################## +# Other Functions +################################################################################################## + + +def compute_uv_coordinates(array, site1, site2, time, mjd, ra, dec, rf, timetype='UTC', + elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH, no_elevcut_space=False, + fix_theta_GMST=False): + + """Compute u,v coordinates for an array at a given time for a source at a given ra,dec,rf + """ + + if not isinstance(time, np.ndarray): + time = np.array([time]).flatten() + if not isinstance(site1, np.ndarray): + site1 = np.array([site1]).flatten() + if not isinstance(site2, np.ndarray): + site2 = np.array([site2]).flatten() + + if len(site1) == len(site2) == 1: + site1 = np.array([site1[0] for i in range(len(time))]) + site2 = np.array([site2[0] for i in range(len(time))]) + elif not (len(site1) == len(site2) == len(time)): + raise Exception("site1, site2, and time not the same dimension in compute_uv_coordinates!") + + # Source vector + sourcevec = np.array([np.cos(dec*ehc.DEGREE), 0, np.sin(dec*ehc.DEGREE)]) + projU = np.cross(np.array([0, 0, 1]), sourcevec) + projU = projU/np.linalg.norm(projU) + projV = -np.cross(projU, sourcevec) + + # Wavelength + wvl = ehc.C/rf + + if timetype == 'GMST': + time_sidereal = time + # time_utc = gmst_to_utc(time, mjd) + elif timetype == 'UTC': + time_sidereal = utc_to_gmst(time, mjd) + # time_utc = time + else: + raise Exception("timetype must be UTC or GMST!") + + fracmjd = np.floor(mjd) + time/24. + dto = (at.Time(fracmjd, format='mjd')).datetime + theta = np.mod((time_sidereal - ra)*ehc.HOUR, 2*np.pi) + if type(fix_theta_GMST) != bool: + theta = np.mod((fix_theta_GMST - ra)*ehc.HOUR, 2*np.pi) + + i1 = np.array([array.tkey[site] for site in site1]) + i2 = np.array([array.tkey[site] for site in site2]) + + coord1 = np.vstack((array.tarr[i1]['x'], array.tarr[i1]['y'], array.tarr[i1]['z'])).T + coord2 = np.vstack((array.tarr[i2]['x'], array.tarr[i2]['y'], array.tarr[i2]['z'])).T + + # Satellites: new method + spacemask1 = [np.all(coord == (0., 0., 0.)) for coord in coord1] + spacemask2 = [np.all(coord == (0., 0., 0.)) for coord in coord2] + + satnames = array.ephem.keys() + satdict = {satname: sat_skyfield_from_ephementry(satname, array.ephem, mjd) for satname in satnames} + for satname in satnames: + sat = satdict[satname] + + mask1 = (site1==satname) + c1 = orbit_skyfield(sat, fracmjd[mask1], whichout='itrs') + coord1[mask1] = c1.T + + mask2 = (site2==satname) + c2 = orbit_skyfield(sat, fracmjd[mask2], whichout='itrs') + coord2[mask2] = c2.T + + # Satellites: old method + """ + if np.any(spacemask1): + if timetype == 'GMST': + raise Exception("Spacecraft ephemeris only work with UTC!") + + site1space_list = site1[spacemask1] + site1space_fracmjdlist = fracmjd[spacemask1] + site1space_dtolist = dto[spacemask1] + coord1space = [] + + for k in range(len(site1space_list)): + + + # old method with pyephem + + site1space = site1space_list[k] + dto_now = site1space_dtolist[k] + sat = ephem.readtle(array.ephem[site1space][0], + array.ephem[site1space][1], array.ephem[site1space][2]) + sat.compute(dto_now) # often complains if ephemeris out of date! + elev = sat.elevation + lat = sat.sublat / ehc.DEGREE + lon = sat.sublong / ehc.DEGREE + # pyephem doesn't use an ellipsoid earth model! + c1 = coords.EarthLocation.from_geodetic(lon, lat, elev, ellipsoid=None) + c1 = np.array((c1.x.value, c1.y.value, c1.z.value)) + coord1space.append(c1) + + coord1space = np.array(coord1space) + coord1[spacemask1] = coord1space + + # use spacecraft ephemeris to get position of site 2 + spacemask2 = [np.all(coord == (0., 0., 0.)) for coord in coord2] + if np.any(spacemask2): + if timetype == 'GMST': + raise Exception("Spacecraft ephemeris only work with UTC!") + + site2space_list = site2[spacemask2] + site2space_fracmjdlist = fracmjd[spacemask2] + site2space_dtolist = dto[spacemask2] + coord2space = [] + for k in range(len(site2space_list)): + + site2space = site2space_list[k] + dto_now = site2space_dtolist[k] + sat = ephem.readtle(array.ephem[site2space][0], + array.ephem[site2space][1], + array.ephem[site2space][2]) + sat.compute(dto_now) # often complains if ephemeris out of date! + elev = sat.elevation + lat = sat.sublat / ehc.DEGREE + lon = sat.sublong / ehc.DEGREE + # pyephem doesn't use an ellipsoid earth model! + c2 = coords.EarthLocation.from_geodetic(lon, lat, elev, ellipsoid=None) + c2 = np.array((c2.x.value, c2.y.value, c2.z.value)) + coord2space.append(c2) + + coord2space = np.array(coord2space) + coord2[spacemask2] = coord2space + """ + + # rotate the station coordinates with the earth + coord1 = earthrot(coord1, theta) + coord2 = earthrot(coord2, theta) + + # u,v coordinates + u = np.dot((coord1 - coord2)/wvl, projU) # u (lambda) + v = np.dot((coord1 - coord2)/wvl, projV) # v (lambda) + + # mask out below elevation cut + mask_elev_1 = elevcut(coord1, sourcevec, elevmin=elevmin, elevmax=elevmax) + mask_elev_2 = elevcut(coord2, sourcevec, elevmin=elevmin, elevmax=elevmax) + + # do NOT apply elevation cut for space orbiters + if no_elevcut_space: + mask_elev_1[spacemask1] = 1 + mask_elev_2[spacemask2] = 1 + + # apply elevation mask + mask = mask_elev_1 * mask_elev_2 + + time = time[mask] + u = u[mask] + v = v[mask] + + # return times and uv points where we have data + return (time, u, v) + + +def make_bispectrum(l1, l2, l3, vistype, polrep='stokes'): + """Make a list of bispectra and errors + l1,l2,l3 are full datatables of visibility entries + vtype is visibility types + """ + + if vistype == 'pvis': + vtype = 'rlvis' + else: + vtype = vistype + + # Choose the appropriate polarization and compute the bs and err + if polrep == 'stokes': + if vtype in ["vis", "qvis", "uvis", "vvis"]: + if vtype == 'vis': + sigmatype = 'sigma' + elif vtype == 'qvis': + sigmatype = 'qsigma' + elif vtype == 'uvis': + sigmatype = 'usigma' + elif vtype == 'vvis': + sigmatype = 'vsigma' + + p1 = l1[vtype] + p2 = l2[vtype] + p3 = l3[vtype] + + var1 = l1[sigmatype]**2 + var2 = l2[sigmatype]**2 + var3 = l3[sigmatype]**2 + + elif vtype == "rrvis": + p1 = l1['vis'] + l1['vvis'] + p2 = l2['vis'] + l2['vvis'] + p3 = l3['vis'] + l3['vvis'] + + var1 = l1['sigma']**2 + l1['vsigma']**2 + var2 = l2['sigma']**2 + l2['vsigma']**2 + var3 = l3['sigma']**2 + l3['vsigma']**2 + + elif vtype == "llvis": + p1 = l1['vis'] - l1['vvis'] + p2 = l2['vis'] - l2['vvis'] + p3 = l3['vis'] - l3['vvis'] + + var1 = l1['sigma']**2 + l1['vsigma']**2 + var2 = l2['sigma']**2 + l2['vsigma']**2 + var3 = l3['sigma']**2 + l3['vsigma']**2 + + elif vtype == "lrvis": + p1 = l1['qvis'] - 1j*l1['uvis'] + p2 = l2['qvis'] - 1j*l2['uvis'] + p3 = l3['qvis'] - 1j*l3['uvis'] + + var1 = l1['qsigma']**2 + l1['usigma']**2 + var2 = l2['qsigma']**2 + l2['usigma']**2 + var3 = l3['qsigma']**2 + l3['usigma']**2 + + elif vtype == "rlvis": + p1 = l1['qvis'] + 1j*l1['uvis'] + p2 = l2['qvis'] + 1j*l2['uvis'] + p3 = l3['qvis'] + 1j*l3['uvis'] + + var1 = l1['qsigma']**2 + l1['usigma']**2 + var2 = l2['qsigma']**2 + l2['usigma']**2 + var3 = l3['qsigma']**2 + l3['usigma']**2 + + elif polrep == 'circ': + if vtype in ["rrvis", "llvis", "rlvis", "lrvis"]: + + + + if vtype == 'rrvis': + sigmatype = 'rrsigma' + elif vtype == 'llvis': + sigmatype = 'llsigma' + elif vtype == 'rlvis': + sigmatype = 'rlsigma' + elif vtype == 'lrvis': + sigmatype = 'lrsigma' + + p1 = l1[vtype] + p2 = l2[vtype] + p3 = l3[vtype] + + var1 = l1[sigmatype]**2 + var2 = l2[sigmatype]**2 + var3 = l3[sigmatype]**2 + + elif vtype == "vis": + p1 = 0.5*(l1['rrvis'] + l1['llvis']) + p2 = 0.5*(l2['rrvis'] + l2['llvis']) + p3 = 0.5*(l3['rrvis'] + l3['llvis']) + + var1 = 0.25*(l1['rrsigma']**2 + l1['llsigma']**2) + var2 = 0.25*(l2['rrsigma']**2 + l2['llsigma']**2) + var3 = 0.25*(l3['rrsigma']**2 + l3['llsigma']**2) + + elif vtype == "vvis": + p1 = 0.5*(l1['rrvis'] - l1['llvis']) + p2 = 0.5*(l2['rrvis'] - l2['llvis']) + p3 = 0.5*(l3['rrvis'] - l3['llvis']) + + var1 = 0.25*(l1['rrsigma']**2 + l1['llsigma']**2) + var2 = 0.25*(l2['rrsigma']**2 + l2['llsigma']**2) + var3 = 0.25*(l3['rrsigma']**2 + l3['llsigma']**2) + + elif vtype == "qvis": + p1 = 0.5*(l1['lrvis'] + l1['rlvis']) + p2 = 0.5*(l2['lrvis'] + l2['rlvis']) + p3 = 0.5*(l3['lrvis'] + l3['rlvis']) + + var1 = 0.25*(l1['lrsigma']**2 + l1['rlsigma']**2) + var2 = 0.25*(l2['lrsigma']**2 + l2['rlsigma']**2) + var3 = 0.25*(l3['lrsigma']**2 + l3['rlsigma']**2) + + elif vtype == "uvis": + p1 = 0.5j*(l1['lrvis'] - l1['rlvis']) + p2 = 0.5j*(l2['lrvis'] - l2['rlvis']) + p3 = 0.5j*(l3['lrvis'] - l3['rlvis']) + + var1 = 0.25*(l1['lrsigma']**2 + l1['rlsigma']**2) + var2 = 0.25*(l2['lrsigma']**2 + l2['rlsigma']**2) + var3 = 0.25*(l3['lrsigma']**2 + l3['rlsigma']**2) + else: + raise Exception("only 'stokes' and 'circ' are supported polreps!") + + # Make the bispectrum and its uncertainty + bi = p1*p2*p3 + bisig = np.abs(bi) * np.sqrt(var1/np.abs(p1)**2 + + var2/np.abs(p2)**2 + + var3/np.abs(p3)**2) + + return (bi, bisig) + + +def make_closure_amplitude(blue1, blue2, red1, red2, vtype, + ctype='camp', debias=True, polrep='stokes'): + """Make a list of closure amplitudes and errors + blue1 and blue2 are full datatables numerator entries + red1 and red2 are full datatables of denominator entries + vtype is the visibility type + """ + + if not (ctype in ['camp', 'logcamp']): + raise Exception("closure amplitude type must be 'camp' or 'logcamp'!") + + if polrep == 'stokes': + + if vtype in ["vis", "qvis", "uvis", "vvis"]: + if vtype == 'vis': + sigmatype = 'sigma' + if vtype == 'qvis': + sigmatype = 'qsigma' + if vtype == 'uvis': + sigmatype = 'usigma' + if vtype == 'vvis': + sigmatype = 'vsigma' + + sig1 = blue1[sigmatype] + sig2 = blue2[sigmatype] + sig3 = red1[sigmatype] + sig4 = red2[sigmatype] + + p1 = np.abs(blue1[vtype]) + p2 = np.abs(blue2[vtype]) + p3 = np.abs(red1[vtype]) + p4 = np.abs(red2[vtype]) + + elif vtype == "rrvis": + sig1 = np.sqrt(blue1['sigma']**2 + blue1['vsigma']**2) + sig2 = np.sqrt(blue2['sigma']**2 + blue2['vsigma']**2) + sig3 = np.sqrt(red1['sigma']**2 + red1['vsigma']**2) + sig4 = np.sqrt(red2['sigma']**2 + red2['vsigma']**2) + + p1 = np.abs(blue1['vis'] + blue1['vvis']) + p2 = np.abs(blue2['vis'] + blue2['vvis']) + p3 = np.abs(red1['vis'] + red1['vvis']) + p4 = np.abs(red2['vis'] + red2['vvis']) + + elif vtype == "llvis": + sig1 = np.sqrt(blue1['sigma']**2 + blue1['vsigma']**2) + sig2 = np.sqrt(blue2['sigma']**2 + blue2['vsigma']**2) + sig3 = np.sqrt(red1['sigma']**2 + red1['vsigma']**2) + sig4 = np.sqrt(red2['sigma']**2 + red2['vsigma']**2) + + p1 = np.abs(blue1['vis'] - blue1['vvis']) + p2 = np.abs(blue2['vis'] - blue2['vvis']) + p3 = np.abs(red1['vis'] - red1['vvis']) + p4 = np.abs(red2['vis'] - red2['vvis']) + + elif vtype == "lrvis": + sig1 = np.sqrt(blue1['qsigma']**2 + blue1['usigma']**2) + sig2 = np.sqrt(blue2['qsigma']**2 + blue2['usigma']**2) + sig3 = np.sqrt(red1['qsigma']**2 + red1['usigma']**2) + sig4 = np.sqrt(red2['qsigma']**2 + red2['usigma']**2) + + p1 = np.abs(blue1['qvis'] - 1j*blue1['uvis']) + p2 = np.abs(blue2['qvis'] - 1j*blue2['uvis']) + p3 = np.abs(red1['qvis'] - 1j*red1['uvis']) + p4 = np.abs(red2['qvis'] - 1j*red2['uvis']) + + elif vtype in ["pvis", "rlvis"]: + sig1 = np.sqrt(blue1['qsigma']**2 + blue1['usigma']**2) + sig2 = np.sqrt(blue2['qsigma']**2 + blue2['usigma']**2) + sig3 = np.sqrt(red1['qsigma']**2 + red1['usigma']**2) + sig4 = np.sqrt(red2['qsigma']**2 + red2['usigma']**2) + + p1 = np.abs(blue1['qvis'] + 1j*blue1['uvis']) + p2 = np.abs(blue2['qvis'] + 1j*blue2['uvis']) + p3 = np.abs(red1['qvis'] + 1j*red1['uvis']) + p4 = np.abs(red2['qvis'] + 1j*red2['uvis']) + + elif polrep == 'circ': + if vtype in ["rrvis", "llvis", "rlvis", "lrvis", 'pvis']: + if vtype == 'pvis': + vtype = 'rlvis' # p = rl + + if vtype == 'rrvis': + sigmatype = 'rrsigma' + if vtype == 'llvis': + sigmatype = 'llsigma' + if vtype == 'rlvis': + sigmatype = 'rlsigma' + if vtype == 'lrvis': + sigmatype = 'lrsigma' + + sig1 = blue1[sigmatype] + sig2 = blue2[sigmatype] + sig3 = red1[sigmatype] + sig4 = red2[sigmatype] + + p1 = np.abs(blue1[vtype]) + p2 = np.abs(blue2[vtype]) + p3 = np.abs(red1[vtype]) + p4 = np.abs(red2[vtype]) + + elif vtype == "vis": + sig1 = 0.5*np.sqrt(blue1['rrsigma']**2 + blue1['llsigma']**2) + sig2 = 0.5*np.sqrt(blue2['rrsigma']**2 + blue2['llsigma']**2) + sig3 = 0.5*np.sqrt(red1['rrsigma']**2 + red1['llsigma']**2) + sig4 = 0.5*np.sqrt(red2['rrsigma']**2 + red2['llsigma']**2) + + p1 = 0.5*np.abs(blue1['rrvis'] + blue1['llvis']) + p2 = 0.5*np.abs(blue2['rrvis'] + blue2['llvis']) + p3 = 0.5*np.abs(red1['rrvis'] + red1['llvis']) + p4 = 0.5*np.abs(red2['rrvis'] + red2['llvis']) + + elif vtype == "vvis": + sig1 = 0.5*np.sqrt(blue1['rrsigma']**2 + blue1['llsigma']**2) + sig2 = 0.5*np.sqrt(blue2['rrsigma']**2 + blue2['llsigma']**2) + sig3 = 0.5*np.sqrt(red1['rrsigma']**2 + red1['llsigma']**2) + sig4 = 0.5*np.sqrt(red2['rrsigma']**2 + red2['llsigma']**2) + + p1 = 0.5*np.abs(blue1['rrvis'] - blue1['llvis']) + p2 = 0.5*np.abs(blue2['rrvis'] - blue2['llvis']) + p3 = 0.5*np.abs(red1['rrvis'] - red1['llvis']) + p4 = 0.5*np.abs(red2['rrvis'] - red2['llvis']) + + elif vtype == "qvis": + sig1 = 0.5*np.sqrt(blue1['lrsigma']**2 + blue1['rlsigma']**2) + sig2 = 0.5*np.sqrt(blue2['lrsigma']**2 + blue2['rlsigma']**2) + sig3 = 0.5*np.sqrt(red1['lrsigma']**2 + red1['rlsigma']**2) + sig4 = 0.5*np.sqrt(red2['lrsigma']**2 + red2['rlsigma']**2) + + p1 = 0.5*np.abs(blue1['lrvis'] + blue1['rlvis']) + p2 = 0.5*np.abs(blue2['lrvis'] + blue2['rlvis']) + p3 = 0.5*np.abs(red1['lrvis'] + red1['rlvis']) + p4 = 0.5*np.abs(red2['lrvis'] + red2['rlvis']) + + elif vtype == "uvis": + sig1 = 0.5*np.sqrt(blue1['lrsigma']**2 + blue1['rlsigma']**2) + sig2 = 0.5*np.sqrt(blue2['lrsigma']**2 + blue2['rlsigma']**2) + sig3 = 0.5*np.sqrt(red1['lrsigma']**2 + red1['rlsigma']**2) + sig4 = 0.5*np.sqrt(red2['lrsigma']**2 + red2['rlsigma']**2) + + p1 = 0.5*np.abs(blue1['lrvis'] - blue1['rlvis']) + p2 = 0.5*np.abs(blue2['lrvis'] - blue2['rlvis']) + p3 = 0.5*np.abs(red1['lrvis'] - red1['rlvis']) + p4 = 0.5*np.abs(red2['lrvis'] - red2['rlvis']) + else: + raise Exception("only 'stokes' and 'circ' are supported polreps!") + + # Debias the amplitude + if debias: + p1 = amp_debias(p1, sig1, force_nonzero=True) + p2 = amp_debias(p2, sig2, force_nonzero=True) + p3 = amp_debias(p3, sig3, force_nonzero=True) + p4 = amp_debias(p4, sig4, force_nonzero=True) + else: + p1 = np.abs(p1) + p2 = np.abs(p2) + p3 = np.abs(p3) + p4 = np.abs(p4) + + # Get snrs + snr1 = p1/sig1 + snr2 = p2/sig2 + snr3 = p3/sig3 + snr4 = p4/sig4 + + # Compute the closure amplitude and its uncertainty + if ctype == 'camp': + camp = np.abs((p1*p2)/(p3*p4)) + camperr = camp * np.sqrt(1./(snr1**2) + 1./(snr2**2) + 1./(snr3**2) + 1./(snr4**2)) + + # Debias + if debias: + camp = camp_debias(camp, snr3, snr4) + + elif ctype == 'logcamp': + camp = np.log(np.abs(p1)) + np.log(np.abs(p2)) - np.log(np.abs(p3)) - np.log(np.abs(p4)) + camperr = np.sqrt(1./(snr1**2) + 1./(snr2**2) + 1./(snr3**2) + 1./(snr4**2)) + + # Debias + if debias: + camp = logcamp_debias(camp, snr1, snr2, snr3, snr4) + + return (camp, camperr) + + +def amp_debias(amp, sigma, force_nonzero=False): + """Return debiased visibility amplitudes + """ + + deb2 = np.abs(amp)**2 - np.abs(sigma)**2 + + # puts amplitude at 0 if snr < 1 + deb2 *= (np.nan_to_num(np.abs(amp)) > np.nan_to_num(np.abs(sigma))) + + # raises amplitude to sigma to force nonzero + if force_nonzero: + deb2 += (np.nan_to_num(np.abs(amp)) < np.nan_to_num(np.abs(sigma))) * np.abs(sigma)**2 + out = np.sqrt(deb2) + + return out + + +def camp_debias(camp, snr3, snr4): + """Debias closure amplitudes + snr3 and snr4 are snr of visibility amplitudes #3 and 4. + """ + + camp_debias = camp / (1 + 1./(snr3**2) + 1./(snr4**2)) + return camp_debias + + +def logcamp_debias(log_camp, snr1, snr2, snr3, snr4): + """Debias log closure amplitudes + The snrs are the snr of the component visibility amplitudes + """ + + log_camp_debias = log_camp + 0.5*(1./(snr1**2) + 1./(snr2**2) - 1./(snr3**2) - 1./(snr4**2)) + return log_camp_debias + + +def gauss_uv(u, v, flux, beamparams, x=0., y=0.): + """Return the value of the Gaussian FT with + beamparams is [FWHMmaj, FWHMmin, theta, x, y], all in radian + x,y are the center coordinates + """ + + sigma_maj = beamparams[0]/(2*np.sqrt(2*np.log(2))) + sigma_min = beamparams[1]/(2*np.sqrt(2*np.log(2))) + theta = -beamparams[2] # theta needs to be negative in this convention! + + # Covariance matrix + a = (sigma_min * np.cos(theta))**2 + (sigma_maj*np.sin(theta))**2 + b = (sigma_maj * np.cos(theta))**2 + (sigma_min*np.sin(theta))**2 + c = (sigma_min**2 - sigma_maj**2) * np.cos(theta) * np.sin(theta) + m = np.array([[a, c], [c, b]]) + + uv = np.array([[u[i], v[i]] for i in range(len(u))]) + x2 = np.array([np.dot(uvi, np.dot(m, uvi)) for uvi in uv]) + + g = np.exp(-2 * np.pi**2 * x2) + p = np.exp(-2j * np.pi * (u*x + v*y)) + + return flux * g * p + + +def rbf_kernel_covariance(x, sigma): + """Compute a covariance matrix from an RBF kernel + + Args: + x (ndarray): 1D data points for which to compute the covariance + sigma (float): std for the covariance. Controls correlation length / time. + + Returns: + cov (ndarray): Covariance matrix + """ + x = np.expand_dims(x, 1) if x.ndim == 1 else x + norm = -0.5 * scipy.spatial.distance.cdist(x, x, 'sqeuclidean') / sigma**2 + cov = np.exp(norm) + cov *= 1.0 / cov.sum(axis=0) + return cov + + +def sgra_kernel_uv(rf, u, v): + """Return the value of the Sgr A* scattering kernel at a given u,v (in lambda) + + Args: + rf (float): The observation frequency in Hz + u (float or ndarray): an array of u coordinates + v (float or ndarray): an array of v coordinates + + Returns: + g (float ndarray): Sgr A* scattering kernel + """ + u = np.array(u) + v = np.array(v) + assert u.size == v.size, 'u and v should have the same size' + + lcm = (ehc.C / rf) * 100 # in cm + sigma_maj = ehc.FWHM_MAJ * (lcm ** 2) / (2 * np.sqrt(2 * np.log(2))) * ehc.RADPERUAS + sigma_min = ehc.FWHM_MIN * (lcm ** 2) / (2 * np.sqrt(2 * np.log(2))) * ehc.RADPERUAS + theta = -ehc.POS_ANG * ehc.DEGREE # theta needs to be negative in this convention! + + # Covariance matrix + a = (sigma_min * np.cos(theta)) ** 2 + (sigma_maj * np.sin(theta)) ** 2 + b = (sigma_maj * np.cos(theta)) ** 2 + (sigma_min * np.sin(theta)) ** 2 + c = (sigma_min ** 2 - sigma_maj ** 2) * np.cos(theta) * np.sin(theta) + m = np.array([[a, c], [c, b]]) + uv = np.array([u, v]) + + x2 = (uv * np.dot(m, uv)).sum(axis=0) + g = np.exp(-2 * np.pi ** 2 * x2) + + return g + + +def sgra_kernel_params(rf): + """Return elliptical gaussian parameters in radian for the Sgr A* scattering ellipse + at a given frequency rf + """ + + lcm = (ehc.C/rf) * 100 # in cm + fwhm_maj_rf = ehc.FWHM_MAJ * (lcm**2) * ehc.RADPERUAS + fwhm_min_rf = ehc.FWHM_MIN * (lcm**2) * ehc.RADPERUAS + theta = ehc.POS_ANG * ehc.DEGREE + + return np.array([fwhm_maj_rf, fwhm_min_rf, theta]) + + +def blnoise(sefd1, sefd2, tint, bw): + """Determine the standard deviation of Gaussian thermal noise on a baseline + This is the noise on the rr/ll/rl/lr product, not the Stokes parameter + 2-bit quantization is responsible for the 0.88 factor + """ + + noise = np.sqrt(sefd1*sefd2/(2*bw*tint))/0.88 + # noise = np.sqrt(sefd1*sefd2/(bw*tint))/0.88 + + return noise + + +def merr(sigma, qsigma, usigma, I, m): + """Return the error in mbreve real and imaginary parts given stokes input + """ + + err = np.sqrt((qsigma**2 + usigma**2 + (sigma*np.abs(m))**2)/(np.abs(I) ** 2)) + + return err + + +def merr2(rlsigma, rrsigma, llsigma, I, m): + """Return the error in mbreve real and imaginary parts given polprod input + """ + + err = np.sqrt((rlsigma**2 + (rrsigma**2 + llsigma**2)*np.abs(m)**2)/(np.abs(I) ** 2)) + + return err + + +def cerror(sigma): + """Return a complex number drawn from a circular complex Gaussian of zero mean + """ + + noise = np.random.normal(loc=0, scale=sigma) + 1j*np.random.normal(loc=0, scale=sigma) + return noise + + +def cerror_hash(sigma, *args): + """Return a complex number drawn from a circular complex Gaussian of zero mean + """ + + reargs = list(args) + reargs.append('re') + np.random.seed(hash(",".join(map(repr, reargs))) % 4294967295) + re = np.random.randn() + + imargs = list(args) + imargs.append('im') + np.random.seed(hash(",".join(map(repr, imargs))) % 4294967295) + im = np.random.randn() + + err = sigma * (re + 1j*im) + + return err + + +def hashmultivariaterandn(size, cov, *args): + """set the seed according to a collection of arguments and return random multivariate gaussian var + """ + np.random.seed(hash(",".join(map(repr, args))) % 4294967295) + mean = np.zeros(size) + noise = np.random.multivariate_normal(mean, cov, check_valid='ignore') + return noise + + +def hashrandn(*args): + """set the seed according to a collection of arguments and return random gaussian var + """ + + np.random.seed(hash(",".join(map(repr, args))) % 4294967295) + noise = np.random.randn() + return noise + + +def hashrand(*args): + """set the seed according to a collection of arguments and return random number in 0,1 + """ + + np.random.seed(hash(",".join(map(repr, args))) % 4294967295) + noise = np.random.rand() + return noise + + +def image_centroid(im): + """Return the image centroid (in radians) + """ + + xlist = np.arange(0, -im.xdim, -1)*im.psize + (im.psize*im.xdim)/2.0 - im.psize/2.0 + ylist = np.arange(0, -im.ydim, -1)*im.psize + (im.psize*im.ydim)/2.0 - im.psize/2.0 + + x0 = np.sum(np.outer(0.0*ylist+1.0, xlist).ravel()*im.imvec)/np.sum(im.imvec) + y0 = np.sum(np.outer(ylist, 0.0*xlist+1.0).ravel()*im.imvec)/np.sum(im.imvec) + + return np.array([x0, y0]) + + +def ftmatrix(pdim, xdim, ydim, uvlist, pulse=ehc.PULSE_DEFAULT, mask=[]): + """Return a DFT matrix for the xdim*ydim image with pixel width pdim + that extracts spatial frequencies of the uv points in uvlist. + """ + + xlist = np.arange(0, -xdim, -1)*pdim + (pdim*xdim)/2.0 - pdim/2.0 + ylist = np.arange(0, -ydim, -1)*pdim + (pdim*ydim)/2.0 - pdim/2.0 + + # original sign convention + # ftmatrices = [pulse(2*np.pi*uv[0], 2*np.pi*uv[1], pdim, dom="F") * + # np.outer(np.exp(-2j*np.pi*ylist*uv[1]), np.exp(-2j*np.pi*xlist*uv[0])) + # for uv in uvlist] + + # changed the sign convention to agree with BU data (Jan 2017) + # this is correct for a u,v definition from site 1-2 as (x1-x2)/lambda + ftmatrices = [pulse(2*np.pi*uv[0], 2*np.pi*uv[1], pdim, dom="F") * + np.outer(np.exp(2j*np.pi*ylist*uv[1]), np.exp(2j*np.pi*xlist*uv[0])) + for uv in uvlist] + ftmatrices = np.reshape(np.array(ftmatrices), (len(uvlist), xdim*ydim)) + + if len(mask): + ftmatrices = ftmatrices[:, mask] + + return ftmatrices + + +def ftmatrix_centered(im, pdim, xdim, ydim, uvlist, pulse=ehc.PULSE_DEFAULT): + """Return a DFT matrix for the xdim*ydim image with pixel width pdim + that extracts spatial frequencies of the uv points in uvlist. + in this version, it puts the image centroid at the origin + """ + + # TODO : there is a residual value for the center being around 0, + # maybe we should chop this off to be exactly 0? + xlist = np.arange(0, -xdim, -1)*pdim + (pdim*xdim)/2.0 - pdim/2.0 + ylist = np.arange(0, -ydim, -1)*pdim + (pdim*ydim)/2.0 - pdim/2.0 + x0 = np.sum(np.outer(0.0*ylist+1.0, xlist).ravel()*im)/np.sum(im) + y0 = np.sum(np.outer(ylist, 0.0*xlist+1.0).ravel()*im)/np.sum(im) + + # Now shift the lists + xlist = xlist - x0 + ylist = ylist - y0 + + # list of matrices at each spatial freq + ftmatrices = [pulse(2*np.pi*uv[0], 2*np.pi*uv[1], pdim, dom="F") * + np.outer(np.exp(-2j*np.pi*ylist*uv[1]), np.exp(-2j*np.pi*xlist*uv[0])) + for uv in uvlist] + ftmatrices = np.reshape(np.array(ftmatrices), (len(uvlist), xdim*ydim)) + return ftmatrices + + +def ticks(axisdim, psize, nticks=8): + """Return a list of ticklocs and ticklabels + psize should be in desired units + """ + + axisdim = int(axisdim) + nticks = int(nticks) + if not axisdim % 2: + axisdim += 1 + if nticks % 2: + nticks -= 1 + tickspacing = float((axisdim-1))/nticks + ticklocs = np.arange(0, axisdim+1, tickspacing) - 0.5 + ticklabels = np.around(psize * np.arange((axisdim-1)/2.0, - + (axisdim)/2.0, -tickspacing), decimals=1) + + return (ticklocs, ticklabels) + + +def power_of_two(target): + """Finds the next greatest power of two + """ + cur = 1 + if target > 1: + for i in range(0, int(target)): + if (cur >= target): + return cur + else: + cur *= 2 + else: + return 1 + + +def paritycompare(perm1, perm2): + """Compare the parity of two permutations. + Assume both lists are equal length and with same elements + Copied from: http://stackoverflow.com/questions/1503072/how-to-check-if-permutations-have-equal-parity + """ + + perm2 = list(perm2) + perm2_map = dict((v, i) for i, v in enumerate(perm2)) + transCount = 0 + for loc, p1 in enumerate(perm1): + p2 = perm2[loc] + if p1 != p2: + sloc = perm2_map[p1] + perm2[loc], perm2[sloc] = p1, p2 + perm2_map[p1], perm2_map[p2] = sloc, loc + transCount += 1 + + if not (transCount % 2): + return 1 + else: + return -1 + + +def sigtype(datatype): + """Return the type of noise corresponding to the data type + """ + + datatype = str(datatype) + if datatype in ['vis', 'amp']: + sigmatype = 'sigma' + elif datatype in ['qvis', 'qamp']: + sigmatype = 'qsigma' + elif datatype in ['uvis', 'uamp']: + sigmatype = 'usigma' + elif datatype in ['vvis', 'vamp']: + sigmatype = 'vsigma' + elif datatype in ['pvis', 'pamp']: + sigmatype = 'psigma' + elif datatype in ['evis', 'eamp']: + sigmatype = 'esigma' + elif datatype in ['bvis', 'bamp']: + sigmatype = 'esigma' + elif datatype in ['rrvis', 'rramp']: + sigmatype = 'rrsigma' + elif datatype in ['llvis', 'llamp']: + sigmatype = 'llsigma' + elif datatype in ['rlvis', 'rlamp']: + sigmatype = 'rlsigma' + elif datatype in ['lrvis', 'lramp']: + sigmatype = 'lrsigma' + elif datatype in ['rrllvis', 'rrllamp']: + sigmatype = 'rrllsigma' + elif datatype in ['m', 'mamp']: + sigmatype = 'msigma' + elif datatype in ['phase']: + sigmatype = 'sigma_phase' + elif datatype in ['qphase']: + sigmatype = 'qsigma_phase' + elif datatype in ['uphase']: + sigmatype = 'usigma_phase' + elif datatype in ['vphase']: + sigmatype = 'vsigma_phase' + elif datatype in ['pphase']: + sigmatype = 'psigma_phase' + elif datatype in ['ephase']: + sigmatype = 'esigma_phase' + elif datatype in ['bphase']: + sigmatype = 'bsigma_phase' + elif datatype in ['mphase']: + sigmatype = 'msigma_phase' + elif datatype in ['rrphase']: + sigmatype = 'rrsigma_phase' + elif datatype in ['llphase']: + sigmatype = 'llsigma_phase' + elif datatype in ['rlphase']: + sigmatype = 'rlsigma_phase' + elif datatype in ['lrphase']: + sigmatype = 'lrsigma_phase' + elif datatype in ['rrllphase']: + sigmatype = 'rrllsigma_phase' + + else: + sigmatype = False + + return sigmatype + + +def rastring(ra): + """Convert a ra in fractional hours to formatted string + """ + h = int(ra) + m = int((ra-h)*60.) + s = (ra-h-m/60.)*3600. + out = "%2i h %2i m %2.4f s" % (h, m, s) + + return out + + +def decstring(dec): + """Convert a dec in fractional degrees to formatted string + """ + + deg = int(dec) + m = int((abs(dec)-abs(deg))*60.) + s = (abs(dec)-abs(deg)-m/60.)*3600. + out = "%2i deg %2i m %2.4f s" % (deg, m, s) + + return out + + +def gmtstring(gmt): + """Convert a gmt in fractional hours to formatted string + """ + + if gmt > 24.0: + gmt = gmt-24.0 + h = int(gmt) + m = int((gmt-h)*60.) + s = (gmt-h-m/60.)*3600. + out = "%02i:%02i:%2.4f" % (h, m, s) + + return out + +# TODO fix this hacky way to do it!! + + +def gmst_to_utc(gmst, mjd): + """Convert gmst times in hours to utc hours using astropy + """ + + mjd = int(mjd) + time_obj_ref = at.Time(mjd, format='mjd', scale='utc') + time_sidereal_ref = time_obj_ref.sidereal_time('mean', 'greenwich').hour + time_utc = (gmst - time_sidereal_ref) * 0.9972695601848 + + return time_utc + + +def utc_to_gmst(utc, mjd): + """Convert utc times in hours to gmst using astropy + """ + + mjd = int(mjd) # MJD should always be an integer, but was float in older versions of the code + time_obj = at.Time(utc/24.0 + np.floor(mjd), format='mjd', scale='utc') + time_sidereal = time_obj.sidereal_time('mean', 'greenwich').hour + + return time_sidereal + + +def earthrot(vecs, thetas): + """Rotate a vector / array of vectors about the z-direction by theta / array of thetas (radian) + """ + + if len(vecs.shape) == 1: + vecs = np.array([vecs]) + if np.isscalar(thetas): + thetas = np.array([thetas for i in range(len(vecs))]) + + # equal numbers of sites and angles + if len(thetas) == len(vecs): + rotvec = np.array([np.dot(np.array(((np.cos(thetas[i]), -np.sin(thetas[i]), 0), + (np.sin(thetas[i]), np.cos(thetas[i]), 0), + (0, 0, 1))), + vecs[i]) + for i in range(len(vecs))]) + + # only one rotation angle, many sites + elif len(thetas) == 1: + rotvec = np.array([np.dot(np.array(((np.cos(thetas[0]), -np.sin(thetas[0]), 0), + (np.sin(thetas[0]), np.cos(thetas[0]), 0), + (0, 0, 1))), + vecs[i]) + for i in range(len(vecs))]) + + # only one site, many angles + elif len(vecs) == 1: + rotvec = np.array([np.dot(np.array(((np.cos(thetas[i]), -np.sin(thetas[i]), 0), + (np.sin(thetas[i]), np.cos(thetas[i]), 0), + (0, 0, 1))), + vecs[0]) + for i in range(len(thetas))]) + else: + raise Exception("Unequal numbers of vectors and angles in earthrot(vecs, thetas)!") + + return rotvec + + +def elev(obsvecs, sourcevec): + """Return the elevation of a source with respect to an observer/observers in radians + obsvec can be an array of vectors but sourcevec can ONLY be a single vector + """ + + if len(obsvecs.shape) == 1: + obsvecs = np.array([obsvecs]) + + anglebtw = np.array([np.dot(obsvec, sourcevec)/np.linalg.norm(obsvec) / + np.linalg.norm(sourcevec) for obsvec in obsvecs]) + el = 0.5*np.pi - np.arccos(anglebtw) + + return el + + +def elevcut(obsvecs, sourcevec, elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH): + """Return True if a source is observable by a telescope vector + """ + + angles = elev(obsvecs, sourcevec)/ehc.DEGREE + + return (angles > elevmin) * (angles < elevmax) + + +def hr_angle(gst, lon, ra): + """Computes the hour angle for a source at RA, observer at longitude long, and GMST time gst + gst in hours, ra & lon ALL in radian, longitude positive east + """ + + hr_angle = np.mod(gst + lon - ra, 2*np.pi) + + return hr_angle + + +def par_angle(hr_angle, lat, dec): + """Compute the parallactic angle for a source at hr_angle and dec + for an observer with latitude lat. + All angles in radian + """ + + num = np.sin(hr_angle)*np.cos(lat) + denom = np.sin(lat)*np.cos(dec) - np.cos(lat)*np.sin(dec)*np.cos(hr_angle) + + return np.arctan2(num, denom) + + +def xyz_2_latlong(obsvecs): + """Compute the (geocentric) latitude and longitude of a site at geocentric position x,y,z + The output is in radians + """ + + if len(obsvecs.shape) == 1: + obsvecs = np.array([obsvecs]) + out = [] + for obsvec in obsvecs: + x = obsvec[0] + y = obsvec[1] + z = obsvec[2] + lon = np.array(np.arctan2(y, x)) + lat = np.array(np.arctan2(z, np.sqrt(x**2+y**2))) + out.append([lat, lon]) + + out = np.array(out) + + return out + + +def tri_minimal_set(sites, tarr, tkey): + """returns a minimal set of triangles for bispectra and closure phase""" + + # determine ordering and reference site based on order of self.tarr + sites_ordered = [x for x in tarr['site'] if x in sites] + ref = sites_ordered[0] + sites_ordered.remove(ref) + + # Find all triangles that contain the ref + tris = list(it.combinations(sites_ordered, 2)) + tris = [(ref, t[0], t[1]) for t in tris] + + return tris + + +def quad_minimal_set(sites, tarr, tkey): + """returns a minimal set of quadrangels for closure amplitude""" + + # determine ordering and reference site based on order of self.tarr + sites_ordered = np.array([x for x in tarr['site'] if x in sites]) + ref = sites_ordered[0] + + # Loop over other sites >=3 and form minimal closure amplitude set + quads = [] + for i in range(3, len(sites_ordered)): + for j in range(1, i): + if j == i-1: + k = 1 + else: + k = j+1 + + # convetion is (12)(34)/(14)(23) + quad = (ref, sites_ordered[i], sites_ordered[j], sites_ordered[k]) + quads.append(quad) + + return quads + + +# TODO This returns A minimal set if input is maximal, but it is not necessarily the same +# minimal set as we would from calling c_phases(count='min'). This is because of sign flips. +def reduce_tri_minimal(obs, datarr): + """reduce a bispectrum or closure phase data array to a minimal set + datarr can be either a bispectrum array of type DTBIS + or a closure phase array of type DTCPHASE, or a time sorted list of either + """ + + # time sort or not + if not (type(datarr) is list): + datalist = [] + dtype = datarr.dtype + for key, group in it.groupby(datarr, lambda x: x['time']): + datalist.append(np.array([gp for gp in group], dtype=dtype)) + returnType = 'all' + else: + dtype = datarr[0].dtype + datalist = datarr + returnType = 'time' + + out = [] + + for timegroup in datalist: + if returnType == 'all': + outgroup = out + else: + outgroup = [] + + # determine a minimal set of trinagles + sites = list(set(np.hstack((timegroup['t1'], timegroup['t2'], timegroup['t3'])))) + tris = tri_minimal_set(sites, obs.tarr, obs.tkey) + tris = [set(tri) for tri in tris] + + # add data points from original array to new array if in minimal set + for dp in timegroup: + # TODO: sign flips? + if set((dp['t1'], dp['t2'], dp['t3'])) in tris: + outgroup.append(dp) + + if returnType == 'time': + out.append(np.array(outgroup, dtype=dtype)) + else: + out = outgroup + + if returnType == 'all': + out = np.array(out, dtype=dtype) + return out + +# TODO This returns A minimal set if input is maximal, but it is not necessarily the same +# minimal set as we would from calling c_amplitudes(count='min'). This is because of inverses. + + +def reduce_quad_minimal(obs, datarr, ctype='camp'): + """Reduce a closure amplitude or log closure amplitude array + FROM a maximal set TO a minimal set + """ + + if ctype not in ['camp', 'logcamp']: + raise Exception("ctype must be 'camp' or 'logcamp'") + + # time sort or not + if not (type(datarr) is list): + datalist = [] + dtype = datarr.dtype + for key, group in it.groupby(datarr, lambda x: x['time']): + datalist.append(np.array([x for x in group])) + returnType = 'all' + else: + dtype = datarr[0].dtype + datalist = datarr + returnType = 'time' + + out = [] + for timegroup in datalist: + if returnType == 'all': + outgroup = out + else: + outgroup = [] + + # determine a minimal set of quadrangles + sites = np.array(list(set(np.hstack((timegroup['t1'], + timegroup['t2'], + timegroup['t3'], + timegroup['t4']))))) + if len(sites) < 4: + continue + quads = quad_minimal_set(sites, obs.tarr, obs.tkey) + + # add data points from original camp array to new array if in minimal set + # ANDREW TODO: do we need to change the ordering ?? + for dp in timegroup: + + # this is all same closure amplitude, but the ordering of labels is different + if ((dp['t1'], dp['t2'], dp['t3'], dp['t4']) in quads or + (dp['t2'], dp['t1'], dp['t4'], dp['t3']) in quads or + (dp['t3'], dp['t4'], dp['t1'], dp['t2']) in quads or + (dp['t4'], dp['t3'], dp['t2'], dp['t1']) in quads): + + outgroup.append(np.array(dp, dtype=ehc.DTCAMP)) + + # flip the inverse closure amplitude + elif ((dp['t1'], dp['t4'], dp['t3'], dp['t2']) in quads or + (dp['t2'], dp['t3'], dp['t4'], dp['t1']) in quads or + (dp['t3'], dp['t2'], dp['t1'], dp['t4']) in quads or + (dp['t4'], dp['t1'], dp['t2'], dp['t3']) in quads): + + dp2 = copy.deepcopy(dp) + campold = dp['camp'] + sigmaold = dp['sigmaca'] + t1old = dp['t1'] + t2old = dp['t2'] + t3old = dp['t3'] + t4old = dp['t4'] + u1old = dp['u1'] + u2old = dp['u2'] + u3old = dp['u3'] + u4old = dp['u4'] + v1old = dp['v1'] + v2old = dp['v2'] + v3old = dp['v3'] + v4old = dp['v4'] + + dp2['t1'] = t1old + dp2['t2'] = t4old + dp2['t3'] = t3old + dp2['t4'] = t2old + + dp2['u1'] = u3old + dp2['v1'] = v3old + + dp2['u2'] = -u4old + dp2['v2'] = -v4old + + dp2['u3'] = u1old + dp2['v3'] = v1old + + dp2['u4'] = -u2old + dp2['v4'] = -v2old + + if ctype == 'camp': + dp2['camp'] = 1./campold + dp2['sigmaca'] = sigmaold/(campold**2) + + elif ctype == 'logcamp': + dp2['camp'] = -campold + dp2['sigmaca'] = sigmaold + + outgroup.append(dp2) + + if returnType == 'time': + out.append(np.array(outgroup, dtype=dtype)) + else: + out = outgroup + + if returnType == 'all': + out = np.array(out, dtype=dtype) + return out + + +def qimage(iimage, mimage, chiimage): + """Return the Q image from m and chi""" + return iimage * mimage * np.cos(2*chiimage) + + +def uimage(iimage, mimage, chiimage): + """Return the U image from m and chi""" + return iimage * mimage * np.sin(2*chiimage) + + +################################################################################################## +# FFT & NFFT helper functions +################################################################################################## +class NFFTInfo(object): + def __init__(self, xdim, ydim, psize, pulse, npad, p_rad, uv): + self.xdim = int(xdim) + self.ydim = int(ydim) + self.psize = psize + self.pulse = pulse + + self.npad = int(npad) + self.p_rad = int(p_rad) + self.uv = uv + self.uvdim = len(uv) + + # set nfft plan + uv_scaled = uv*psize + nfft_plan = NFFT([xdim, ydim], self.uvdim, m=p_rad, n=[npad, npad]) + nfft_plan.x = uv_scaled + nfft_plan.precompute() + self.plan = nfft_plan + + # compute phase and pulsefac + phases = np.exp(-1j*np.pi*(uv_scaled[:, 0]+uv_scaled[:, 1])) + pulses = np.fromiter((pulse(2*np.pi*uv_scaled[i, 0], 2*np.pi*uv_scaled[i, 1], 1., dom="F") + for i in range(self.uvdim)), 'c16') + self.pulsefac = (pulses*phases) + +class SamplerInfo(object): + def __init__(self, order, uv, pulsefac): + self.order = int(order) + self.uv = uv + self.pulsefac = pulsefac + + +class GridderInfo(object): + def __init__(self, npad, func, p_rad, coords, weights): + self.npad = int(npad) + self.conv_func = func + self.p_rad = int(p_rad) + self.coords = coords + self.weights = weights + + +class ImInfo(object): + def __init__(self, xdim, ydim, npad, psize, pulse): + self.xdim = int(xdim) + self.ydim = int(ydim) + self.npad = int(npad) + self.psize = psize + self.pulse = pulse + + padvalx1 = padvalx2 = int(np.floor((npad - xdim)/2.0)) + if xdim % 2: + padvalx2 += 1 + padvaly1 = padvaly2 = int(np.floor((npad - ydim)/2.0)) + if ydim % 2: + padvaly2 += 1 + + self.padvalx1 = padvalx1 + self.padvalx2 = padvalx2 + self.padvaly1 = padvaly1 + self.padvaly2 = padvaly2 + + +def conv_func_pill(x, y): + if abs(x) < 0.5 and abs(y) < 0.5: + out = 1. + else: + out = 0. + return out + + +def conv_func_gauss(x, y): + return np.exp(-(x**2 + y**2)) + + +def conv_func_cubicspline(x, y): + if abs(x) <= 1: + fx = 1.5*abs(x)**3 - 2.5*abs(x)**2 + 1 + elif abs(x) < 2: + fx = -0.5*abs(x)**3 + 2.5*abs(x)**2 - 4*abs(x) + 2 + else: + fx = 0 + + if abs(y) <= 1: + fy = 1.5*abs(y)**3 - 2.5*abs(y)**2 + 1 + elif abs(y) < 2: + fy = -0.5*abs(y)**3 + 2.5*abs(y)**2 - 4*abs(y) + 2 + else: + fy = 0 + + return fx*fy + +# There's a bug in scipy spheroidal function of order 0! - gives nans for eta<1 +# def conv_func_spheroidal(x,y,p,m): +# etax = 2.*x/float(p) +# etay = 2.*x/float(p) +# psix = abs(1-etax**2)**m * scipy.special.pro_rad1(m,0,0.5*np.pi*p,etax)[0] +# psiy = abs(1-etay**2)**m * scipy.special.pro_rad1(m,0,0.5*np.pi*p,etay)[0] +# return psix*psiy + + +def fft_imvec(imvec, im_info): + """ + Returns fft of imvec on grid + im_info = (xdim, ydim, npad, psize, pulse) + order is the order of the spline interpolation + """ + + xdim = im_info.xdim + ydim = im_info.ydim + padvalx1 = im_info.padvalx1 + padvalx2 = im_info.padvalx2 + padvaly1 = im_info.padvaly1 + padvaly2 = im_info.padvaly2 + + imarr = imvec.reshape(ydim, xdim) + imarr = np.pad(imarr, ((padvalx1, padvalx2), (padvaly1, padvaly2)), + 'constant', constant_values=0.0) + + if imarr.shape[0] != imarr.shape[1]: + raise Exception("FFT padding did not return a square image!") + + # FFT for visibilities + vis_im = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(imarr))) + + return vis_im + + +def sampler(griddata, sampler_info_list, sample_type="vis"): + """ + Samples griddata (e.g. the FFT of an image) at uv points + the griddata should already be rotated so u,v = 0,0 in the center + sampler_info_list is an appropriately ordered list of 4 sampler_info objects + order is the order of the spline interpolation + """ + if sample_type not in ["vis", "bs", "camp"]: + raise Exception("sampler sample_type should be either 'vis','bs',or 'camp'!") + if griddata.shape[0] != griddata.shape[1]: + raise Exception("griddata should be a square array!") + + dataset = [] + for sampler_info in sampler_info_list: + + vu2 = sampler_info.uv + pulsefac = sampler_info.pulsefac + + datare = nd.map_coordinates(np.real(griddata), vu2, order=sampler_info.order) + dataim = nd.map_coordinates(np.imag(griddata), vu2, order=sampler_info.order) + + data = datare + 1j*dataim + data = data * pulsefac + + dataset.append(data) + + if sample_type == "vis": + out = dataset[0] + if sample_type == "bs": + out = dataset[0]*dataset[1]*dataset[2] + if sample_type == "camp": + out = np.abs((dataset[0]*dataset[1])/(dataset[2]*dataset[3])) + return out + + +def gridder(data_list, gridder_info_list): + """ + Grid the data sampled at uv points on a square array + gridder_info_list is an list of gridder_info objects + """ + + if len(data_list) != len(gridder_info_list): + raise Exception("length of data_list in gridder() " + + "is not equal to length of gridder_info_list!") + + npad = gridder_info_list[0].npad + datagrid = np.zeros((npad, npad)).astype('c16') + + for k in range(len(gridder_info_list)): + gridder_info = gridder_info_list[k] + data = data_list[k] + + if gridder_info.npad != npad: + raise Exception("npad values not consistent in gridder_info_list!") + + p_rad = gridder_info.p_rad + coords = gridder_info.coords + weights = gridder_info.weights + + p_rad = int(p_rad) + for i in range(2*p_rad+1): + dy = i - p_rad + for j in range(2*p_rad+1): + dx = j - p_rad + weight = weights[i][j] + np.add.at(datagrid, tuple(map(tuple, (coords + [dy, dx]).transpose())), data*weight) + + return datagrid + + +def make_gridder_and_sampler_info(im_info, uv, conv_func=ehc.GRIDDER_CONV_FUNC_DEFAULT, + p_rad=ehc.GRIDDER_P_RAD_DEFAULT, order=ehc.FFT_INTERP_DEFAULT): + """ + Prep norms and weights for gridding data sampled at uv points on a square array + im_info tuple contains (xdim, ydim, npad, psize, pulse) of the grid + conv_func is the convolution function: current options are "pillbox", "gaussian" + p_rad is the pixel radius inside wich the conv_func is nonzero + """ + + if not (conv_func in ['pillbox', 'gaussian', 'cubic']): + raise Exception("conv_func must be either 'pillbox', 'gaussian', or, 'cubic'") + + npad = im_info.npad + psize = im_info.psize + pulse = im_info.pulse + + # compute grid u,v coordinates + vu2 = np.hstack((uv[:, 1].reshape(-1, 1), uv[:, 0].reshape(-1, 1))) + du = 1.0/(npad*psize) + vu2 = (vu2/du + 0.5*npad) + + coords = np.round(vu2).astype(int) + dcoords = vu2 - np.round(vu2).astype(int) + vu2 = vu2.T + + # TODO: phase rotations should be done separately for x and y if the image isn't square + # e.g., + phase = np.exp(-1j*np.pi*psize*((1+im_info.xdim % 2)*uv[:, 0] + (1+im_info.ydim % 2)*uv[:, 1])) + + pulsefac = np.fromiter( + (pulse(2*np.pi*uvpt[0], 2*np.pi*uvpt[1], psize, dom="F") for uvpt in uv), 'c16') + pulsefac = pulsefac * phase + + # compute gridder norm + weights = [] + norm = np.zeros_like(len(coords)) + for i in range(2*p_rad+1): + weights.append([]) + dy = i - p_rad + for j in range(2*p_rad+1): + dx = j - p_rad + if conv_func == 'gaussian': + norm = norm + conv_func_gauss(dy - dcoords[:, 0], dx - dcoords[:, 1]) + elif conv_func == 'pillbox': + norm = norm + conv_func_pill(dy - dcoords[:, 0], dx - dcoords[:, 1]) + elif conv_func == 'cubic': + norm = norm + conv_func_cubicspline(dy - dcoords[:, 0], dx - dcoords[:, 1]) + + weights[i].append(None) + + # compute weights for gridding + for i in range(2*p_rad+1): + dy = i - p_rad + for j in range(2*p_rad+1): + dx = j - p_rad + if conv_func == 'gaussian': + weight = conv_func_gauss(dy - dcoords[:, 0], dx - dcoords[:, 1])/norm + elif conv_func == 'pillbox': + weight = conv_func_pill(dy - dcoords[:, 0], dx - dcoords[:, 1])/norm + elif conv_func == 'cubic': + weight = conv_func_cubicspline(dy - dcoords[:, 0], dx - dcoords[:, 1])/norm + + weights[i][j] = weight + + # output the coordinates, norms, and weights + sampler_info = SamplerInfo(order, vu2, pulsefac) + gridder_info = GridderInfo(npad, conv_func, p_rad, coords, weights) + return (sampler_info, gridder_info) + + +################################################################################################## +# miscellaneous functions +################################################################################################## + +# TODO this makes a copy -- is there a faster robust way? +def recarr_to_ndarr(x, typ): + """converts a record array x to a normal ndarray with all fields converted to datatype typ + """ + + fields = x.dtype.names + shape = x.shape + (len(fields),) + dt = [(name, typ) for name in fields] + y = x.astype(dt).view(typ).reshape(shape) + return y + + +def prog_msg(nscan, totscans, msgtype='bar', nscan_last=0): + """print a progress method for calibration + """ + complete_percent_last = int(100*float(nscan_last)/float(totscans)) + complete_percent = int(100*float(nscan)/float(totscans)) + ndigit = str(len(str(totscans))) + + if msgtype == 'bar': + bar_width = 30 + progress = int(bar_width * complete_percent/float(100)) + barparams = (nscan, totscans, ("-"*progress) + + (" " * (bar_width-progress)), complete_percent) + + printstr = "\rScan %0"+ndigit+"i/%i : [%s]%i%%" + sys.stdout.write(printstr % barparams) + sys.stdout.flush() + + elif msgtype == 'bar2': + bar_width = 30 + progress = int(bar_width * complete_percent/float(100)) + barparams = (nscan, totscans, ("/"*progress) + + (" " * (bar_width-progress)), complete_percent) + + printstr = "\rScan %0"+ndigit+"i/%i : [%s]%i%%" + sys.stdout.write(printstr % barparams) + sys.stdout.flush() + + elif msgtype == 'casa': + message_list = [".", ".", ".", "10", ".", ".", ".", "20", + ".", ".", ".", "30", ".", ".", ".", "40", + ".", ".", ".", "50", ".", ".", ".", "60", + ".", ".", ".", "70", ".", ".", ".", "80", + ".", ".", ".", "90", ".", ".", ".", "DONE"] + bar_width = len(message_list) + progress = int(bar_width * complete_percent/float(100)) + message = ''.join(message_list[:progress]) + + barparams = (nscan, totscans, message) + printstr = "\rScan %0"+ndigit+"i/%i : %s" + sys.stdout.write(printstr % barparams) + sys.stdout.flush() + + elif msgtype == 'itcrowd': + message_list = ["0", "1", "1", "8", " ", "9", "9", "9", " ", "8", "8", "1", "9", "9", " ", + "9", "1", "1", "9", " ", "7", "2", "5", " ", " ", " ", "3"] + bar_width = len(message_list) + progress = int(bar_width * complete_percent/float(100)) + message = ''.join(message_list[:progress]) + if complete_percent < 100: + message += "." + message += " "*(bar_width-progress-1) + + barparams = (nscan, totscans, message) + + printstr = "\rScan %0"+ndigit+"i/%i : [%s]" + sys.stdout.write(printstr % barparams) + sys.stdout.flush() + + elif msgtype == 'bh': + message_all = ehc.BHIMAGE + bar_width = len(message_all) + progress = int(np.floor(bar_width * complete_percent/float(100)))-1 + progress_last = int(np.floor(bar_width * complete_percent_last/float(100)))-1 + if progress > progress_last: + for i in range(progress_last+1, progress+1): + message_line = ''.join(message_all[i]) + message_line = '%03i' % int(complete_percent) + message_line + print(message_line) + + elif msgtype == 'eht': + message_all = ehc.EHTIMAGE + bar_width = len(message_all) + progress = int(np.floor(bar_width * complete_percent/float(100)))-1 + progress_last = int(np.floor(bar_width * complete_percent_last/float(100)))-1 + if progress > progress_last: + for i in range(progress_last+1, progress+1): + message_line = ''.join(message_all[i]) + message_line = '%03i' % int(complete_percent) + message_line + print(message_line) + + elif msgtype == 'dots': + sys.stdout.write('.') + sys.stdout.flush() + + else: # msgtype=='default': + barparams = (nscan, totscans, complete_percent) + printstr = "\rScan %0"+ndigit+"i/%i : %i%% done . . ." + sys.stdout.write(printstr % barparams) + sys.stdout.flush() + +def sat_skyfield_from_elements(satname, epoch_mjd, perigee_mjd, + period_days, eccentricity, + inclination, arg_perigee, long_ascending): + + """skyfield EarthSatellite object from keplerian orbital elements + perfect keplerian orbit is assumed, no derivatives + epoch, pericenter given in mjd + period given in days + inclination, arg_perigee, long_ascending given in degrees + """ + + if not(0<=eccentricity<1): + raise Exception("eccentricity must be between 0 and 1") + if not(0<=inclination<=180): + raise Exception("inclination must be between 0 and 180") + if not(0<=arg_perigee<=180): + raise Exception("arg_perigee must be between 0 and 180") + if not(0<=long_ascending<=360): + raise Exception("arg_perigee must be between 0 and 360") + + satrec = Satrec() + ts = skyfield.api.load.timescale() + ref_mjd = 33281. # mjd 1949 December 31 00:00 UT + epoch_wrt_ref = epoch_mjd - ref_mjd + + inclination_rad = inclination*ehc.DEGREE + arg_perigee_rad = arg_perigee*ehc.DEGREE + long_ascending_rad = long_ascending*ehc.DEGREE + + mean_motion = 2*np.pi/(period_days*24.*60.) # radians/minute + + mean_anomaly = mean_motion*(epoch_mjd - perigee_mjd) + mean_anomaly = np.mod(mean_anomaly, 2*np.pi) + + satrec.sgp4init( + WGS72, # gravity model + 'i', # 'a' = old AFSPC mode, 'i' = improved mode + 1, # satnum: Satellite number + epoch_wrt_ref, # epoch: days since 1949 December 31 00:00 UT + 0.0, # bstar: drag coefficient (/earth radii) + 0.0, # ndot: ballistic coefficient (revs/day) + 0.0, # nddot: second derivative of mean motion (revs/day^3) + eccentricity, # ecco: eccentricity + arg_perigee_rad, # argpo: argument of perigee (radians) + inclination_rad, # inclo: inclination (radians) + mean_anomaly, # mo: mean anomaly (radians) + mean_motion, # no_kozai: mean motion (radians/minute) + long_ascending_rad, # nodeo: right ascension of ascending node (radians) + ) + sat_skyfield = skyfield.api.EarthSatellite.from_satrec(satrec, ts) + + return sat_skyfield + +def sat_skyfield_from_tle(satname, line1, line2): + ts = skyfield.api.load.timescale() + sat_skyfield = skyfield.api.EarthSatellite(line1, line2, satname, ts) + return sat_skyfield + +def sat_skyfield_from_ephementry(satname, ephem, epoch_mjd): + if len(ephem[satname])==3: # TLE + line1 = ephem[satname][1] + line2 = ephem[satname][2] + sat = sat_skyfield_from_tle(satname, line1, line2) + elif len(ephem[satname])==6: #keplerian elements + elements = ephem[satname] + sat = sat_skyfield_from_elements(satname, epoch_mjd, + elements[0],elements[1],elements[2],elements[3],elements[4],elements[5]) + else: + raise Exception("ephemeris format not recognized for %s"%satellite) + + return sat + +def orbit_skyfield(sat, fracmjds, whichout='itrs'): + + """uses skyfield to propagate a earth satellite orbit and return x,y,z coordinates in co-rotating earth frame + sat is a skyfield.sgp4lib.EarthSatellite object + times is a list of fractional mjds + whichout is 'itrs' (co-rotating) or 'gcrs' (fixed x-axis to equinox) + """ + + #fractional days of orbit in jd + mjd_to_jd = 2400000.5 + ts = skyfield.api.load.timescale() + t = ts.ut1_jd(mjd_to_jd+fracmjds) + + # propagate orbit + time_data = sat.at(t) + + if whichout=='gcrs': + # GCRS coordinates in km + positions = time_data.xyz.m + + elif whichout=='itrs': + # get coordinates in earth frame (WGS84 ellipsiod) + geographic_position = skyfield.api.wgs84.geographic_position_of(time_data) + positions = geographic_position.itrs_xyz.m + + else: + raise Excption("orbit_skyfield whichout must be 'itrs' or 'gcrs'") + + return positions + diff --git a/observing/obs_simulate.py b/observing/obs_simulate.py new file mode 100644 index 00000000..1afaf71c --- /dev/null +++ b/observing/obs_simulate.py @@ -0,0 +1,1500 @@ +# obs_simulate.py +# functions to simulate interferometric observations +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import time as ttime +import scipy.ndimage as nd +from scipy.interpolate import interp1d +import numpy as np +import copy +try: + from pynfft.nfft import NFFT +except ImportError: + pass + #print("Warning: No NFFT installed!") + +from . import obs_helpers as obsh +import ehtim.const_def as ehc + +################################################################################################## +# Generate U-V Points +################################################################################################## + + +def make_uvpoints(array, ra, dec, rf, bw, tint, tadv, tstart, tstop, + polrep='stokes', + mjd=ehc.MJD_DEFAULT, tau=ehc.TAUDEF, + elevmin=ehc.ELEV_LOW, elevmax=ehc.ELEV_HIGH, + no_elevcut_space=False, + timetype='UTC', fix_theta_GMST=False): + """Generate u,v points and baseline sigmas for a given array. + + Args: + array (Array): the array object + ra (float): The source Right Ascension in fractional hours + dec (float): The source declination in fractional degrees + rf (float): The observation frequency in Hz + bw (float): The observation bandwidth in Hz + tint (float): the scan integration time in seconds + tadv (float): the uniform cadence between scans in seconds + tstart (float): the start time of the observation in hours + tstop (float): the end time of the observation in hours + polrep (str): 'stokes' or 'circ' sets the data polarimetric representtion + mjd (int): the mjd of the observation, if different from the image mjd + tau (float): the base opacity at all sites, or a dict giving one opacity per site + elevmin (float): station minimum elevation in degrees + elevmax (float): station maximum elevation in degrees + no_elevcut_space (bool): if True, do not apply elevation cut to orbiters + timetype (str): how to interpret tstart and tstop; either 'GMST' or 'UTC' + fix_theta_GMST (bool): if True, stops earth rotation to sample fixed u,v points + Returns: + (Obsdata): an observation object with all visibilities zeroed + """ + + if polrep == 'stokes': + poltype = ehc.DTPOL_STOKES + elif polrep == 'circ': + poltype = ehc.DTPOL_CIRC + else: + raise Exception("only 'stokes' and 'circ' are supported polreps!") + + # Set up time start and steps + tstep = tadv/3600.0 + if tstop < tstart: + tstop = tstop + 24.0 + + # Observing times + times = np.arange(tstart, tstop, tstep) + if timetype not in ['UTC', 'GMST']: + print("Time Type Not Recognized! Assuming UTC!") + timetype = 'UTC' + + # Generate uv points at all times + outlist = [] + blpairs = [] + + for i1 in range(len(array.tarr)): + for i2 in range(len(array.tarr)): + if (i1 != i2 and + i1 < i2 and # This is the right condition for uvfits save order + not ((i2, i1) in blpairs)): # This cuts out the conjugate baselines + + blpairs.append((i1, i2)) + + # sites + site1 = array.tarr[i1]['site'] + site2 = array.tarr[i2]['site'] + coord1 = ((array.tarr[i1]['x'], array.tarr[i1]['y'], array.tarr[i1]['z'])) + coord2 = ((array.tarr[i2]['x'], array.tarr[i2]['y'], array.tarr[i2]['z'])) + + # Optical Depth + if type(tau) == dict: + try: + tau1 = tau[site1] + tau2 = tau[site2] + except KeyError: + tau1 = tau2 = ehc.TAUDEF + else: + tau1 = tau2 = tau + # no optical depth for space sites + if coord1 == (0., 0., 0.): + tau1 = 0. + if coord2 == (0., 0., 0.): + tau2 = 0. + + # Noise on the correlations + if np.any(array.tarr['sefdr'] <= 0) or np.any(array.tarr['sefdl'] <= 0): + print("Warning!: in make_uvpoints, some SEFDs are <= 0!") + + sig_rr = obsh.blnoise(array.tarr[i1]['sefdr'], array.tarr[i2]['sefdr'], tint, bw) + sig_ll = obsh.blnoise(array.tarr[i1]['sefdl'], array.tarr[i2]['sefdl'], tint, bw) + sig_rl = obsh.blnoise(array.tarr[i1]['sefdr'], array.tarr[i2]['sefdl'], tint, bw) + sig_lr = obsh.blnoise(array.tarr[i1]['sefdl'], array.tarr[i2]['sefdr'], tint, bw) + if polrep == 'stokes': + sig_iv = 0.5*np.sqrt(sig_rr**2 + sig_ll**2) + sig_qu = 0.5*np.sqrt(sig_rl**2 + sig_lr**2) + sig1 = sig_iv + sig2 = sig_qu + sig3 = sig_qu + sig4 = sig_iv + elif polrep == 'circ': + sig1 = sig_rr + sig2 = sig_ll + sig3 = sig_rl + sig4 = sig_lr + + uvdat = obsh.compute_uv_coordinates(array, site1, site2, times, mjd, + ra, dec, rf, timetype=timetype, + elevmin=elevmin, elevmax=elevmax, + no_elevcut_space=no_elevcut_space, + fix_theta_GMST=fix_theta_GMST) + + (timesout, uout, vout) = uvdat + for k in range(len(timesout)): + outlist.append(np.array(( + timesout[k], + tint, # Integration + site1, # Station 1 + site2, # Station 2 + tau1, # Station 1 zenith optical depth + tau2, # Station 1 zenith optical depth + uout[k], # u (lambda) + vout[k], # v (lambda) + 0.0, # 1st Visibility (Jy) + 0.0, # 2nd Visibility + 0.0, # 3rd Visibility + 0.0, # 4th Visibility + sig1, # 1st Sigma (Jy) + sig2, # 2nd Sigma + sig3, # 3rd Sigma + sig4 # 4th Sigma + ), dtype=poltype + )) + + obsarr = np.array(outlist) + + if not len(obsarr): + raise Exception("No mutual visibilities in the specified time range!") + + return obsarr + +################################################################################################## +# Observe w/o noise +################################################################################################## + + +def sample_vis(im_org, uv, sgrscat=False, polrep_obs='stokes', + ttype="nfft", cache=False, fft_pad_factor=2, zero_empty_pol=True, verbose=True): + """Observe a image on given baselines with no noise. + + Args: + im (Image): the image to be observed + uv (ndarray): an array of u,v coordinates + sgrscat (bool): if True, the visibilites are blurred by the Sgr A* scattering kernel + polrep_obs (str): 'stokes' or 'circ' sets the data polarimetric representtion + ttype (str): 'direct' or 'fast' or 'nfft' + fft_pad_factor (float): zero pad the image to fft_pad_factor * image size in FFT + zero_empty_pol (bool): if True, returns zero vec if the polarization doesn't exist. + Otherwise return None + verbose (bool): Boolean value controls output prints. + + Returns: + (Obsdata): an observation object + """ + + if polrep_obs == 'stokes': + im = im_org.switch_polrep('stokes', 'I') + pollist = ['I', 'Q', 'U', 'V'] # TODO what if we have to I image? + elif polrep_obs == 'circ': + im = im_org.switch_polrep('circ', 'RR') + pollist = ['RR', 'LL', 'RL', 'LR'] # TODO what if we have to RR image? + else: + raise Exception("only 'stokes' and 'circ' are supported polreps!") + + uv = np.array(uv) + if uv.shape[1] != 2: + raise Exception("When given as a list of uv points, " + + "the obs should be a list of pairs of u-v coordinates!") + if im.pa != 0.0: + c = np.cos(im.pa) + s = np.sin(im.pa) + u = uv[:, 0] + v = uv[:, 1] + uv = np.column_stack([c * u - s * v, + s * u + c * v]) + +# umin = np.min(np.sqrt(uv[:,0]**2 + uv[:,1]**2)) +# umax = np.max(np.sqrt(uv[:,0]**2 + uv[:,1]**2)) +# if not im.psize < 1.0/(2.0*umax): +# print(" Warning!: longest baseline > 1/2 x maximum image spatial wavelength!") +# if not im.psize*np.sqrt(im.xdim*im.ydim) > 1.0/(0.5*umin): +# print(" Warning!: shortest baseline < 2 x minimum image spatial wavelength!") + + obsdata = [] + + # Get visibilities from straightforward FFT + if ttype == "fast": + + # Padded image size + npad = fft_pad_factor * np.max((im.xdim, im.ydim)) + npad = obsh.power_of_two(npad) + + padvalx1 = padvalx2 = int(np.floor((npad - im.xdim)/2.0)) + if im.xdim % 2: + padvalx2 += 1 + padvaly1 = padvaly2 = int(np.floor((npad - im.ydim)/2.0)) + if im.ydim % 2: + padvaly2 += 1 + + imarr = im.imvec.reshape(im.ydim, im.xdim) + imarr = np.pad(imarr, ((padvalx1, padvalx2), (padvaly1, padvaly2)), + 'constant', constant_values=0.0) + npad = imarr.shape[0] + if imarr.shape[0] != imarr.shape[1]: + raise Exception("FFT padding did not return a square image!") + + # Scaled uv points + du = 1.0/(npad*im.psize) + uv2 = np.hstack((uv[:, 1].reshape(-1, 1), uv[:, 0].reshape(-1, 1))) + uv2 = (uv2/du + 0.5*npad).T + + # Extra phase to match centroid convention + phase = np.exp(-1j*np.pi*im.psize*((1+im.xdim % 2)*uv[:, 0] + (1+im.ydim % 2)*uv[:, 1])) + + # Pulse function + pulsefac = np.fromiter( + (im.pulse(2*np.pi*uvpt[0], 2*np.pi*uvpt[1], im.psize, dom="F") for uvpt in uv), 'c16') + + for i in range(4): + pol = pollist[i] + imvec = im._imdict[pol] + if imvec is None or len(imvec) == 0: + if zero_empty_pol: + obsdata.append(np.zeros(len(uv))) + else: + obsdata.append(None) + else: + # FFT for visibilities + if pol in im_org.cached_fft: + vis_im = im_org.cached_fft[pol] + else: + imarr = imvec.reshape(im.ydim, im.xdim) + imarr = np.pad(imarr, ((padvalx1, padvalx2), (padvaly1, padvaly2)), + 'constant', constant_values=0.0) + vis_im = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(imarr))) + if cache == 'auto': + im_org.cached_fft[pol] = vis_im + + # Sample the visibilities + # default is cubic spline interpolation + visre = nd.map_coordinates(np.real(vis_im), uv2) + visim = nd.map_coordinates(np.imag(vis_im), uv2) + vis = visre + 1j*visim + + # Extra phase and pulse factor + vis = vis * phase * pulsefac + + # Return visibilities + obsdata.append(vis) + + # Get visibilities from the NFFT + elif ttype == "nfft": + + uvdim = len(uv) + if (im.xdim % 2 or im.ydim % 2): + raise Exception("NFFT doesn't work with odd image dimensions!") + + npad = fft_pad_factor * np.max((im.xdim, im.ydim)) + + # TODO what is a good kernel size?? + nker = np.floor(np.min((im.xdim, im.ydim))/5) + if (nker > 50): + nker = 50 + elif (im.xdim < 50 or im.ydim < 50): + nker = np.min((im.xdim, im.ydim))/2 + + # TODO are y & x reversed? + plan = NFFT([im.xdim, im.ydim], uvdim, m=nker, n=[npad, npad]) + + # Sampled uv points + uvlist = uv*im.psize + + # Precompute + plan.x = uvlist + plan.precompute() + + # Extra phase and pulsefac + phase = np.exp(-1j*np.pi*(uvlist[:, 0] + uvlist[:, 1])) + pulsefac = np.fromiter((im.pulse(2*np.pi*uvlist[i, 0], 2*np.pi*uvlist[i, 1], 1., dom="F") + for i in range(uvdim)), 'c16') + + # Compute the uniform --> nonuniform transform for different polarizations + for i in range(4): + pol = pollist[i] + imvec = im._imdict[pol] + if imvec is None or len(imvec) == 0: + if zero_empty_pol: + obsdata.append(np.zeros(len(uv))) + else: + obsdata.append(None) + else: + plan.f_hat = imvec.copy().reshape((im.ydim, im.xdim)).T + plan.trafo() + vis = plan.f.copy()*phase*pulsefac + + obsdata.append(vis) + + elif ttype == "DFT": + xfov, yfov = im.xdim*im.psize/4.84813681109536e-12, im.ydim*im.psize/4.84813681109536e-12 + for i in range(4): + pol = pollist[i] + imvec = im._imdict[pol] + if imvec is None or len(imvec) == 0: + if zero_empty_pol: + obsdata.append(np.zeros(len(uv))) + else: + obsdata.append(None) + else: + imarr = imvec.reshape(im.ydim, im.xdim) + vis = DFT(imarr, uv, xfov=xfov, yfov=yfov) + obsdata.append(vis) + + elif ttype == "DFT_i": + xfov, yfov = im.xdim*im.psize/4.84813681109536e-12, im.ydim*im.psize/4.84813681109536e-12 + for i in range(4): + pol = pollist[i] + imvec = im._imdict[pol] + if imvec is None or len(imvec) == 0: + if zero_empty_pol: + obsdata.append(np.zeros(len(uv))) + else: + obsdata.append(None) + else: + uv = np.array([uv[:,1], uv[:,0]]).T # uv swap hack + imarr = imvec.reshape(im.ydim, im.xdim) + vis = DFT(imarr, uv, xfov=xfov, yfov=yfov) + obsdata.append(vis) + + + # Get visibilities from DTFT + else: + # Construct Fourier matrix + mat = obsh.ftmatrix(im.psize, im.xdim, im.ydim, uv, pulse=im.pulse) + + # Compute DTFT for different polarizations + for i in range(4): + pol = pollist[i] + imvec = im._imdict[pol] + if imvec is None or len(imvec) == 0: + if zero_empty_pol: + obsdata.append(np.zeros(len(uv))) + else: + obsdata.append(None) + else: + vis = np.dot(mat, imvec) + obsdata.append(vis) + + # Scatter the visibilities with the SgrA* kernel + if sgrscat: + if verbose: + print('Scattering Visibilities with Sgr A* kernel!') + ker = obsh.sgra_kernel_uv(im.rf, uv[:, 0], uv[:, 1]) + for data in obsdata: + if data is None: + continue + data *= ker + + return obsdata + + +def DFT(data, uv, xfov=225, yfov=225): + if data.ndim == 2: + data = data[np.newaxis,...] + out_shape = (uv.shape[0],) + elif data.ndim > 2: + data = data.reshape((-1,) + data.shape[-2:]) + out_shape = data.shape[:-2] + (uv.shape[0],) + ny, nx = data.shape[-2:] + dx = xfov*4.84813681109536e-12 / nx + dy = yfov*4.84813681109536e-12 / ny + angx = (np.arange(nx) - nx//2) * dx + angy = (np.arange(ny) - ny//2) * dy + lvect = np.sin(angx) + mvect = np.sin(angy) + l, m = np.meshgrid(lvect, mvect) + lm = np.concatenate([l.reshape(1,-1), m.reshape(1,-1)], axis=0) + imgvect = data.reshape((data.shape[0],-1)) + x = -2*np.pi*np.dot(uv,lm)[np.newaxis, ...] + visr = np.sum(imgvect[:, np.newaxis, :] * np.cos(x, dtype=np.float32), axis=-1) + visi = np.sum(imgvect[:, np.newaxis, :] * np.sin(x, dtype=np.float32), axis=-1) + if data.ndim == 2: + vis = visr.ravel() + 1j*visi.ravel() + else: + vis = visr.ravel() + 1j*visi.ravel() + vis = vis.reshape(out_shape) + return vis + + +################################################################################################## +# Noise + miscalibration funcitons +################################################################################################## + + +def make_jones(obs, opacitycal=True, ampcal=True, phasecal=True, dcal=True, + frcal=True, rlgaincal=True, + stabilize_scan_phase=False, stabilize_scan_amp=False, neggains=False, + taup=ehc.GAINPDEF, + gainp=ehc.GAINPDEF, gain_offset=ehc.GAINPDEF, + phase_std=-1, + dterm_offset=ehc.DTERMPDEF, + rlratio_std=0., rlphase_std=0., + sigmat=None,phasesigmat=None,rlgsigmat=None,rlpsigmat=None, + caltable_path=None, seed=False): + """Computes Jones Matrices for a list of times (non repeating), with gain and dterm errors. + + Args: + obs (Obsdata): the observation with scans for the Jones matrices to be computed + opacitycal (bool): if False, time-dependent gaussian errors are added to station opacities + ampcal (bool): if False, time-dependent gaussian errors are added to complex station gains + phasecal (bool): if False, time-dependent random phases are added to complex station gains + dcal (bool): if False, time-dependent gaussian errors are added to D-terms. + frcal (bool): if False, feed rotation angle terms are added to Jones matrices. + rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + neggains (bool): if True, force the applied gains to be <1 + + taup (float): the fractional std. dev. of the random error on the opacities + gainp (float): the fractional std. dev. of the random error on the gains + or a dict giving one std. dev. per site + + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + phase_std (float): std. dev. of LCP phase, + or a dict giving one std. dev. per site + a negative value samples from uniform + dterm_offset (float): the base std. dev. of random additive error at all sites, + or a dict giving one std. dev. per site + + rlratio_std (float): the fractional std. dev. of the R/L gain offset + or a dict giving one std. dev. per site + rlphase_std (float): std. dev. of R/L phase offset, + or a dict giving one std. dev. per site + a negative value samples from uniform + + sigmat (float): temporal std for a Gaussian Process used to generate gains. + If sigmat=None then an iid gain noise is applied. + phasesigmat (float): temporal std for a Gaussian Process used to generate phases. + If phasesigmat=None then an iid gain noise is applied. + rlgsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L gain ratios. + If rlgsigmat=None then an iid gain noise is applied. + rlpsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L phase diff. + If rlpsigmat=None then an iid gain noise is applied. + + caltable_path (string): If not None, path and prefix for saving the applied caltable + seed : a seed for the random number generators, uses system time if false + Returns: + (dict): a nested dictionary of matrices indexed by the site, then by the time + """ + + obs_tmp = obs.copy() + tlist = obs_tmp.tlist() + tarr = obs_tmp.tarr + ra = obs_tmp.ra + dec = obs_tmp.dec + sourcevec = np.array([np.cos(dec*ehc.DEGREE), 0, np.sin(dec*ehc.DEGREE)]) + + # Create a dictionary of taus and a list of unique times + nsites = len(obs_tmp.tarr['site']) + taudict = {site: np.array([]) for site in obs_tmp.tarr['site']} + times = np.array([]) + for scan in tlist: + time = scan['time'][0] + times = np.append(times, time) + sites_in = np.array([]) + for bl in scan: + + # Should we screen for conflicting same-time measurements of tau? + if len(sites_in) >= nsites: + break + + if (not len(sites_in)) or (not bl['t1'] in sites_in): + taudict[bl['t1']] = np.append(taudict[bl['t1']], bl['tau1']) + sites_in = np.append(sites_in, bl['t1']) + + if (not len(sites_in)) or (not bl['t2'] in sites_in): + taudict[bl['t2']] = np.append(taudict[bl['t2']], bl['tau2']) + sites_in = np.append(sites_in, bl['t2']) + + if len(sites_in) < nsites: + for site in obs_tmp.tarr['site']: + if site not in sites_in: + taudict[site] = np.append(taudict[site], 0.0) + + # Now define a list that accounts for periods where the phase or amplitude errors + # are stable (e.g., over scans if stabilize_scan_phase==True) + times_stable_phase = times.copy() + times_stable_amp = times.copy() + times_stable = times.copy() + if stabilize_scan_phase is True or stabilize_scan_amp is True: + scans = obs_tmp.scans + if np.all(scans) is None or len(scans) == 0: + obs_scans = obs.copy() + obs_scans.add_scans() + scans = obs_scans.scans + for j in range(len(times_stable)): + for scan in scans: + if scan[0] <= times_stable[j] and scan[1] >= times_stable[j]: + times_stable[j] = scan[0] + break + + if stabilize_scan_phase is True: + times_stable_phase = times_stable.copy() + if stabilize_scan_amp is True: + times_stable_amp = times_stable.copy() + + # Compute Sidereal Times + if obs.timetype == 'GMST': + times_sid = times + else: + times_sid = obsh.utc_to_gmst(times, obs.mjd) + + # Seed for random number generators + if seed is False: + seed = str(ttime.time()) + + # Generate Jones Matrices at each time for each telescope + out = {} + datatables = {} + for i in range(len(tarr)): + site = tarr[i]['site'] + coords = np.array([tarr[i]['x'], tarr[i]['y'], tarr[i]['z']]) + latlon = obsh.xyz_2_latlong(coords) + + # Elevation Angles + thetas = np.mod((times_sid - ra)*ehc.HOUR, 2*np.pi) + el_angles = obsh.elev(obsh.earthrot(coords, thetas), sourcevec) + + # Parallactic Angles + hr_angles = obsh.hr_angle(times_sid*ehc.HOUR, latlon[:, 1], ra*ehc.HOUR) + par_angles = obsh.par_angle(hr_angles, latlon[:, 0], dec*ehc.DEGREE) + + # gain offset: time independent part + if type(gain_offset) == dict: + goff = gain_offset[site] + else: + goff = gain_offset + + # gain_mult: time dependent part + if type(gainp) == dict: + gain_mult = gainp[site] + else: + gain_mult = gainp + + # phase mult - phase std deviation (-1 = uniform) + if type(phase_std) == dict: + phase_mult = phase_std[site] + else: + phase_mult = phase_std + + # gainratio_mult: time dependent R/L gain offset + if type(rlratio_std) == dict: + gainratio_mult = rlratio_std[site] + else: + gainratio_mult = rlratio_std + + # phasediff_mult: time dependent R-L phase offset + if type(rlphase_std) == dict: + phasediff_mult = rlphase_std[site] + else: + phasediff_mult = rlphase_std + + # correlation timescales + if type(sigmat) == dict: + sigt_g = sigmat[site] + else: + sigt_g = sigmat + + if type(phasesigmat) == dict: + sigt_p = phasesigmat[site] + else: + sigt_p = phasesigmat + + if type(rlgsigmat) == dict: + sigt_rlg = rlgsigmat[site] + else: + sigt_rlg = rlgsigmat + + if type(rlpsigmat) == dict: + sigt_rlp = rlpsigmat[site] + else: + sigt_rlp = rlpsigmat + + # Amplitude gains + gainR = gainL = np.ones(len(times)) + if not ampcal: + + # mean LCP gain + gainL_constant = goff * obsh.hashrandn(site, 'gain', str(goff), seed) + + # Enforce mean log gain < 1 + if neggains: + gainL_constant = -np.abs(gainL_constant) + + # LCP gain + if sigt_g is None: # iid sampling in time + + gainL = np.sqrt(np.abs(np.fromiter(( + (1.0 + gainL_constant) * + (1.0 + gain_mult * obsh.hashrandn(site, 'gain', str(time), str(gain_mult), seed)) + for time in times_stable_amp + ), float))) + + elif sigt_g <=0: # single sample in time + + gainL = np.sqrt(np.abs(np.fromiter(((1.0 + gainL_constant) + for time in times_stable_amp), float))) + + else: # correlated sampling in time + scan_start_times = scans[:, 0] + cov = obsh.rbf_kernel_covariance(scan_start_times, sigt_g) + randLx = obsh.hashmultivariaterandn(len(scan_start_times), cov, site, + 'gain', str(time), str(gain_mult), seed) + gainL = np.sqrt(np.abs((1.0 + gainL_constant) * (1.0 + gain_mult * randLx))) + + + gainL_interpolateor = interp1d(scan_start_times, gainL, kind='zero') + gainL = gainL_interpolateor(times_stable_amp) + + # R/L gain offset (if present) + if rlgaincal: + gain_RLratio = 1. + else: + if sigt_rlg is None: #iid sampling in time + + gain_RLratio = np.abs(np.fromiter(( + (1.0 + gainratio_mult * obsh.hashrandn(site, 'gainratio', str(time), + str(gainratio_mult), seed)) + for time in times_stable_amp), float)) + + elif sigt_rlg <=0: # single sample in time + gain_RLratio = np.abs(np.fromiter(( + (1.0 + gainratio_mult * obsh.hashrandn(site, 'gainratio', + str(gainratio_mult), seed)) + for time in times_stable_amp), float)) + + else: #correlated sampling in time + scan_start_times = scans[:, 0] + cov = obsh.rbf_kernel_covariance(scan_start_times, sigt_rlg) + randRLx = obsh.hashmultivariaterandn(len(scan_start_times), cov, site, + 'gainratio', str(time), str(gainratio_mult), + seed) + gain_RLratio = np.abs(1.0 + gainratio_mult * randRLx) + gainRLratio_interpolateor = interp1d(scan_start_times, gain_RLratio, kind='zero') + gain_RLratio = gainRLratio_interpolateor(times_stable_amp) + + # RCP gain + gainR = gain_RLratio * gainL + + # enforce gains < 1 + # TODO -- will this mess up gain offset priors? + if neggains: + gainR = np.exp(-np.abs(np.log(gainR))) + gainL = np.exp(-np.abs(np.log(gainL))) + + # Opacity attenuation of amplitude gain + if not opacitycal: + taus = np.abs(np.fromiter(( + (taudict[site][j]) * + (1.0 + taup * obsh.hashrandn(site, 'tau', times_stable_amp[j], seed)) + for j in range(len(times))), float)) + atten = np.exp(-taus/(ehc.EP + 2.0*np.sin(el_angles))) + + gainR = gainR * atten + gainL = gainL * atten + + # Atmospheric Phase + if not phasecal: + + # Gaussian distribution of LCP phase + if phase_mult >=0: + + if sigt_p is None: #iid sampling in time + phaseL = np.fromiter((phase_mult * obsh.hashrandn(site, 'phase', str(time), + str(phase_mult), seed) + for time in times_stable_phase), float) + + elif sigt_p <=0: # single sample in time + phaseL = np.fromiter((phase_mult * obsh.hashrandn(site, 'phase', + str(phase_mult), seed) + for time in times_stable_phase), float) + else: #correlated sampling in time + scan_start_times = scans[:, 0] + cov = obsh.rbf_kernel_covariance(scan_start_times, sigt_p) + phaseL = phase_mult* obsh.hashmultivariaterandn(len(scan_start_times), cov, site, + 'phase', str(time), str(phase_mult), + seed) + phaseL_interpolateor = interp1d(scan_start_times, phaseL, kind='zero') + phaseL = phaseL_interpolateor(times_stable_phase) + + # flat distribution of LCP phase, iid in time + # TODO correlated sampling with flat phases? + else: + + phaseL = np.fromiter((2 * np.pi * obsh.hashrand(site, 'phase', time, seed) + for time in times_stable_phase), float) + + # R-L phase offset + if rlgaincal: + phaseRLdiff = 0. + else: + # Gaussian distributed phase difference + if phasediff_mult >=0: + if sigt_rlp is None: #iid sampling in time + phaseRLdiff = np.fromiter((phasediff_mult * + obsh.hashrandn(site, 'phasediff', str(time), + str(phasediff_mult), seed) + for time in times_stable_phase), float) + elif sigt_rlp <=0: # single sample in time + phaseRLdiff = np.fromiter((phasediff_mult * + obsh.hashrandn(site, 'phasediff', + str(phasediff_mult), seed) + for time in times_stable_phase), float) + else: #correlated sampling in time + scan_start_times = scans[:, 0] + cov = obsh.rbf_kernel_covariance(scan_start_times, sigt_rlp) + phaseRLdiff = phasediff_mult * obsh.hashmultivariaterandn(len(scan_start_times), + cov, site, 'phasediff', + str(time), str(phasediff_mult), + seed) + phaseRL_interpolateor = interp1d(scan_start_times, phaseRLdiff, kind='zero') + phaseRLdiff = phaseRL_interpolateor(times_stable_phase) + + # flat distribution phase difference, iid in time + # TODO correlated sampling with flat phases? + else: + phaseRLdiff = np.fromiter((2 * np.pi * obsh.hashrand(site, 'phase', time, seed) + for time in times_stable_phase), float) + phaseRLdiff -= np.pi + + # Complex gains + gainL = gainL * np.exp(1j*phaseL) + gainR = gainR * np.exp(1j*(phaseL + phaseRLdiff)) + + + # D Term errors + dR = dL = 0.0 + if not dcal: + + # D terms are always time-independent + if type(dterm_offset) == dict: + doff = dterm_offset[site] + else: + doff = dterm_offset + + dR = tarr[i]['dr'] + dL = tarr[i]['dl'] + + dR += doff * (obsh.hashrandn(site, 'dRre', seed) + + 1j * obsh.hashrandn(site, 'dRim', seed)) + dL += doff * (obsh.hashrandn(site, 'dLre', seed) + + 1j * obsh.hashrandn(site, 'dLim', seed)) + + # Feed Rotation Angles + fr_angle = np.zeros(len(times)) + fr_angle_D = np.zeros(len(times)) + + # Field rotation has not been corrected + if not frcal: + fr_angle = tarr[i]['fr_elev']*el_angles + \ + tarr[i]['fr_par']*par_angles + tarr[i]['fr_off']*ehc.DEGREE + + # If field rotation has been corrected, but leakage has NOT been corrected, + # the leakage needs to rotate doubly + elif frcal and not dcal: + fr_angle_D = 2.0*(tarr[i]['fr_elev']*el_angles + tarr[i] + ['fr_par']*par_angles + tarr[i]['fr_off']*ehc.DEGREE) + + # Assemble the Jones Matrices and save to dictionary + j_matrices = {times[j]: np.array([ + [np.exp(-1j*fr_angle[j])*gainR[j], + np.exp(1j*(fr_angle[j]+fr_angle_D[j]))*dR*gainR[j]], + [np.exp(-1j*(fr_angle[j]+fr_angle_D[j]))*dL*gainL[j], + np.exp(1j*fr_angle[j])*gainL[j]] + ]) + for j in range(len(times)) + } + + out[site] = j_matrices + + if caltable_path: + obs_tmp.tarr[i]['dr'] = dR + obs_tmp.tarr[i]['dl'] = dL + datatable = [] + for j in range(len(times)): + datatable.append(np.array((times[j], gainR[j], gainL[j]), dtype=ehc.DTCAL)) + datatables[site] = np.array(datatable) + + # Save a calibration table with the synthetic gains and dterms added + if caltable_path and len(datatables) > 0: + from ehtim.caltable import Caltable # TODO blah circular imports + caltable = Caltable(obs_tmp.ra, obs_tmp.dec, obs_tmp.rf, obs_tmp.bw, + datatables, obs_tmp.tarr, source=obs_tmp.source, + mjd=obs_tmp.mjd, timetype=obs_tmp.timetype) + + caltable.save_txt(obs_tmp, datadir=caltable_path+'_simdata_caltable') + + return out + + +def make_jones_inverse(obs, opacitycal=True, dcal=True, frcal=True): + """Computes inverse Jones Matrices for a list of times (non repeating), + with NO gain and dterm errors. + + Args: + obs (Obsdata): the observation with scans for the inverse Jones matrices to be computed + opacitycal (bool): if False, estimated opacity terms are applied in the inverse gains + dcal (bool): if False, estimated d-terms are applied to the inverse Jones matrices + frcal (bool): if False, inverse feed rotation angle terms are applied to Jones matrices. + + Returns: + (dict): a nested dictionary of matrices indexed by the site, then by the time + """ + + # Get data + tlist = obs.tlist() + tarr = obs.tarr + ra = obs.ra + dec = obs.dec + sourcevec = np.array([np.cos(dec*ehc.DEGREE), 0, np.sin(dec*ehc.DEGREE)]) + + # Create a dictionary of taus and a list of unique times + nsites = len(obs.tarr['site']) + taudict = {site: np.array([]) for site in obs.tarr['site']} + times = np.array([]) + for scan in tlist: + time = scan['time'][0] + times = np.append(times, time) + sites_in = np.array([]) + for bl in scan: + + # Should we screen for conflicting same-time measurements of tau? + if len(sites_in) >= nsites: + break + + if (not len(sites_in)) or (not bl['t1'] in sites_in): + taudict[bl['t1']] = np.append(taudict[bl['t1']], bl['tau1']) + sites_in = np.append(sites_in, bl['t1']) + + if (not len(sites_in)) or (not bl['t2'] in sites_in): + taudict[bl['t2']] = np.append(taudict[bl['t2']], bl['tau2']) + sites_in = np.append(sites_in, bl['t2']) + if len(sites_in) < nsites: + for site in obs.tarr['site']: + if site not in sites_in: + taudict[site] = np.append(taudict[site], 0.0) + + # Compute Sidereal Times + if obs.timetype == 'GMST': + times_sid = times + else: + times_sid = obsh.utc_to_gmst(times, obs.mjd) + + # Make inverse Jones Matrices + out = {} + for i in range(len(tarr)): + site = tarr[i]['site'] + coords = np.array([tarr[i]['x'], tarr[i]['y'], tarr[i]['z']]) + latlon = obsh.xyz_2_latlong(coords) + + # Elevation Angles + thetas = np.mod((times_sid - ra)*ehc.HOUR, 2*np.pi) + el_angles = obsh.elev(obsh.earthrot(coords, thetas), sourcevec) + + # Parallactic Angles (positive longitude EAST) + hr_angles = obsh.hr_angle(times_sid*ehc.HOUR, latlon[:, 1], ra*ehc.HOUR) + par_angles = obsh.par_angle(hr_angles, latlon[:, 0], dec*ehc.DEGREE) + + # Amplitude gain assumed 1 + gainR = gainL = np.ones(len(times)) + + # Opacity attenuation of amplitude gain + if not opacitycal: + taus = np.abs(np.array(taudict[site])) + atten = np.exp(-taus/(ehc.EP + 2.0*np.sin(el_angles))) + + gainR = gainR * atten + gainL = gainL * atten + + # D Terms + dR = dL = 0.0 + if not dcal: + dR = tarr[i]['dr'] + dL = tarr[i]['dl'] + + # Feed Rotation Angles + fr_angle = np.zeros(len(times)) + # This is for when field rotation is corrected but not leakage + fr_angle_D = np.zeros(len(times)) + if not frcal: + # Total Angle (Radian) + fr_angle = (tarr[i]['fr_elev']*el_angles + tarr[i]['fr_par'] + * par_angles + tarr[i]['fr_off']*ehc.DEGREE) + + elif frcal and not dcal: + # If the field rotation angle has been removed but leakage hasn't, + # we still need to rotate the leakage terms appropriately + # by *twice* the field rotation angle + fr_angle_D = 2.0*(tarr[i]['fr_elev']*el_angles + tarr[i] + ['fr_par']*par_angles + tarr[i]['fr_off']*ehc.DEGREE) + + # Assemble the inverse Jones Matrices and save to dictionary + pref = 1.0/(gainL*gainR*(1.0 - dL*dR)) + j_matrices_inv = {times[j]: pref[j]*np.array([ + [np.exp(1j*fr_angle[j])*gainL[j], + -np.exp(1j*(fr_angle[j] + fr_angle_D[j]))*dR*gainR[j]], + [-np.exp(-1j*(fr_angle[j] + fr_angle_D[j]))*dL*gainL[j], + np.exp(-1j*fr_angle[j])*gainR[j]] + ]) for j in range(len(times)) + } + + out[site] = j_matrices_inv + + return out + + +def add_jones_and_noise(obs, add_th_noise=True, + opacitycal=True, ampcal=True, phasecal=True, dcal=True, + frcal=True, rlgaincal=True, + stabilize_scan_phase=False, stabilize_scan_amp=False, + neggains=False, + taup=ehc.GAINPDEF, + gainp=ehc.GAINPDEF,gain_offset=ehc.GAINPDEF, + phase_std=-1, + dterm_offset=ehc.DTERMPDEF, + rlratio_std=0., rlphase_std=0., + sigmat=None, phasesigmat=None, rlgsigmat=None,rlpsigmat=None, + caltable_path=None, seed=False, verbose=True): + """Corrupt visibilities in obs with jones matrices and add thermal noise + + Args: + obs (Obsdata): the original observation + add_th_noise (bool): if True, baseline-dependent thermal noise is added to each data point + opacitycal (bool): if False, time-dependent gaussian errors are added to station opacities + ampcal (bool): if False, time-dependent gaussian errors are added to complex station gains + phasecal (bool): if False, time-dependent random phases are added to complex station gains + dcal (bool): if False, time-dependent gaussian errors are added to D-terms. + frcal (bool): if False, feed rotation angle terms are added to Jones matrices. + rlgaincal (bool): if False, time-dependent gains are not equal for R and L pol + + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + neggains (bool): if True, force the applied gains to be <1 + + taup (float): the fractional std. dev. of the random error on the opacities + gainp (float): the fractional std. dev. of the random error on the gains + or a dict giving one std. dev. per site + + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + phase_std (float): std. dev. of LCP phase, + or a dict giving one std. dev. per site + a negative value samples from uniform + dterm_offset (float): the base std. dev. of random additive error at all sites, + or a dict giving one std. dev. per site + + rlratio_std (float): the fractional std. dev. of the R/L gain offset + or a dict giving one std. dev. per site + rlphase_std (float): std. dev. of R/L phase offset, + or a dict giving one std. dev. per site + a negative value samples from uniform + + sigmat (float): temporal std for a Gaussian Process used to generate gains. + If sigmat=None then an iid gain noise is applied. + phasesigmat (float): temporal std for a Gaussian Process used to generate phases. + If phasesigmat=None then an iid gain noise is applied. + rlgsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L gain ratios. + If rlgsigmat=None then an iid gain noise is applied. + rlpsigmat (float): temporal std deviation for a Gaussian Process used to generate R/L phase diff. + If rlpsigmat=None then an iid gain noise is applied. + + caltable_path (string): If not None, path and prefix for saving the applied caltable + seed (int): seeds the random component of the noise terms. DO NOT set to 0! + verbose (bool): print updates and warnings + Returns: + (np.array): an observation data array + """ + + if verbose: + print("Applying Jones Matrices to data . . . ") + + # Build Jones Matrices + jm_dict = make_jones(obs, + ampcal=ampcal, opacitycal=opacitycal, phasecal=phasecal, + dcal=dcal, frcal=frcal, rlgaincal=rlgaincal, + stabilize_scan_phase=stabilize_scan_phase, + stabilize_scan_amp=stabilize_scan_amp, neggains=neggains, + taup=taup, + gainp=gainp, gain_offset=gain_offset, + phase_std=phase_std, + dterm_offset=dterm_offset, + rlratio_std=rlratio_std, rlphase_std=rlphase_std, + sigmat=sigmat,phasesigmat=phasesigmat, + rlgsigmat=rlgsigmat,rlpsigmat=rlpsigmat, + caltable_path=caltable_path, seed=seed) + + # Change pol rep: + obs_circ = obs.switch_polrep('circ') + obsdata = copy.copy(obs_circ.data) + + times = obsdata['time'] + t1 = obsdata['t1'] + t2 = obsdata['t2'] + tints = obsdata['tint'] + rr = obsdata['rrvis'] + ll = obsdata['llvis'] + rl = obsdata['rlvis'] + lr = obsdata['lrvis'] + + # Recompute the noise std. deviations from the SEFDs + if np.any(obs.tarr['sefdr'] <= 0) or np.any(obs.tarr['sefdl'] <= 0): + if verbose: + print("Warning!: in add_jones_and_noise, some SEFDs are <= 0!") + print("Resorting to data point sigmas, which may add too much systematic noise!") + sig_rr = obsdata['rrsigma'] + sig_ll = obsdata['llsigma'] + sig_rl = obsdata['rlsigma'] + sig_lr = obsdata['lrsigma'] + else: + sig_rr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdr'], + obs.tarr[obs.tkey[t2[i]]]['sefdr'], + tints[i], obs.bw) + for i in range(len(rr))), float) + sig_ll = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdl'], + obs.tarr[obs.tkey[t2[i]]]['sefdl'], + tints[i], obs.bw) + for i in range(len(ll))), float) + sig_rl = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdr'], + obs.tarr[obs.tkey[t2[i]]]['sefdl'], + tints[i], obs.bw) + for i in range(len(rl))), float) + sig_lr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdl'], + obs.tarr[obs.tkey[t2[i]]]['sefdr'], + tints[i], obs.bw) + for i in range(len(lr))), float) + + if verbose and not opacitycal: + print(" Applying opacity attenuation: opacitycal-->False") + if verbose and not ampcal: + print(" Applying gain corruption: ampcal-->False") + if verbose and not phasecal: + print(" Applying atmospheric phase corruption: phasecal-->False") + if verbose and not dcal: + print(" Applying D Term mixing: dcal-->False") + if verbose and not frcal: + print(" Applying Field Rotation: frcal-->False") + if verbose and add_th_noise: + print("Adding thermal noise to data . . . ") + + # Corrupt each IQUV visibility set with the jones matrices and add noise + for i in range(len(times)): + # Form the visibility correlation matrix + corr_matrix = np.array([[rr[i], rl[i]], [lr[i], ll[i]]]) + + # Get the jones matrices and corrupt the corr_matrix + j1 = jm_dict[t1[i]][times[i]] + j2 = jm_dict[t2[i]][times[i]] + + corr_matrix_corrupt = np.dot(j1, np.dot(corr_matrix, np.conjugate(j2.T))) + + # Add noise + if add_th_noise: + noise_matrix = np.array([[obsh.cerror(sig_rr[i]), obsh.cerror(sig_rl[i])], + [obsh.cerror(sig_lr[i]), obsh.cerror(sig_ll[i])]]) + corr_matrix_corrupt += noise_matrix + + # Put the corrupted data back into the data table + obsdata['rrvis'][i] = corr_matrix_corrupt[0][0] + obsdata['llvis'][i] = corr_matrix_corrupt[1][1] + obsdata['rlvis'][i] = corr_matrix_corrupt[0][1] + obsdata['lrvis'][i] = corr_matrix_corrupt[1][0] + + # Put the recomputed sigmas back into the data table + obsdata['rrsigma'][i] = sig_rr[i] + obsdata['llsigma'][i] = sig_ll[i] + obsdata['rlsigma'][i] = sig_rl[i] + obsdata['lrsigma'][i] = sig_lr[i] + + # put back into input polvec + obs_circ.data = obsdata + obs_back = obs_circ.switch_polrep(obs.polrep) + obsdata_back = obs_back.data + + # Return observation data + return obsdata_back + + +def apply_jones_inverse(obs, opacitycal=True, dcal=True, frcal=True, verbose=True): + """Apply inverse jones matrices to an observation + + Args: + opacitycal (bool): if False, estimated opacity terms are applied in the inverse gains + dcal (bool): if False, estimated d-terms applied to the inverse Jones matrices + frcal (bool): if False, feed rotation angle terms are applied to Jones matrices. + verbose (bool): print updates and warnings + + Returns: + (np.array): an observation data array + """ + + if verbose: + print("Applying a priori calibration with estimated Jones matrices . . . ") + + # Build Inverse Jones Matrices + jm_dict = make_jones_inverse(obs, opacitycal=opacitycal, dcal=dcal, frcal=frcal) + + # Change pol rep: + obs_circ = obs.switch_polrep('circ') + + # Get data + obsdata = copy.deepcopy(obs_circ.data) + times = obsdata['time'] + t1 = obsdata['t1'] + t2 = obsdata['t2'] + tints = obsdata['tint'] + rr = obsdata['rrvis'] + ll = obsdata['llvis'] + rl = obsdata['rlvis'] + lr = obsdata['lrvis'] + + # Recompute the noise std. deviations from the SEFDs + if np.any(obs.tarr['sefdr'] <= 0) or np.any(obs.tarr['sefdl'] <= 0): + if verbose: + print("Warning!: in add_jones_and_noise, some SEFDs are <= 0!") + print("resorting to data point sigmas, which may add too much systematic noise!") + sig_rr = obsdata['rrsigma'] + sig_ll = obsdata['llsigma'] + sig_rl = obsdata['rlsigma'] + sig_lr = obsdata['lrsigma'] + else: + # TODO why are there sqrt(2)s here and not below? + sig_rr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdr'], + obs.tarr[obs.tkey[t2[i]]]['sefdr'], + tints[i], obs.bw) + for i in range(len(rr))), float) + sig_ll = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdl'], + obs.tarr[obs.tkey[t2[i]]]['sefdl'], + tints[i], obs.bw) + for i in range(len(ll))), float) + sig_rl = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdr'], + obs.tarr[obs.tkey[t2[i]]]['sefdl'], + tints[i], obs.bw) + for i in range(len(rl))), float) + sig_lr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[t1[i]]]['sefdl'], + obs.tarr[obs.tkey[t2[i]]]['sefdr'], + tints[i], obs.bw) + for i in range(len(lr))), float) + + if not opacitycal: + if verbose: + print(" Applying opacity corrections: opacitycal-->True") + opacitycal = True + if not dcal: + if verbose: + print(" Applying D Term corrections: dcal-->True") + dcal = True + if not frcal: + if verbose: + print(" Applying Field Rotation corrections: frcal-->True") + frcal = True + + # Apply the inverse Jones matrices to each visibility + for i in range(len(times)): + + # Get the inverse jones matrices + inv_j1 = jm_dict[t1[i]][times[i]] + inv_j2 = jm_dict[t2[i]][times[i]] + + # Form the visibility correlation matrix + corr_matrix = np.array([[rr[i], rl[i]], [lr[i], ll[i]]]) + + # Form the sigma matrices + sig_rr_matrix = np.array([[sig_rr[i], 0.0], [0.0, 0.0]]) + sig_ll_matrix = np.array([[0.0, 0.0], [0.0, sig_ll[i]]]) + sig_rl_matrix = np.array([[0.0, sig_rl[i]], [0.0, 0.0]]) + sig_lr_matrix = np.array([[0.0, 0.0], [sig_lr[i], 0.0]]) + + # Apply the inverse Jones Matrices to the visibility correlation matrix and sigma matrices + corr_matrix_new = np.dot(inv_j1, np.dot(corr_matrix, np.conjugate(inv_j2.T))) + + sig_rr_matrix_new = np.dot(inv_j1, np.dot(sig_rr_matrix, np.conjugate(inv_j2.T))) + sig_ll_matrix_new = np.dot(inv_j1, np.dot(sig_ll_matrix, np.conjugate(inv_j2.T))) + sig_rl_matrix_new = np.dot(inv_j1, np.dot(sig_rl_matrix, np.conjugate(inv_j2.T))) + sig_lr_matrix_new = np.dot(inv_j1, np.dot(sig_lr_matrix, np.conjugate(inv_j2.T))) + + # Get the final sigma matrix as a quadrature sum + sig_matrix_new = np.sqrt(np.abs(sig_rr_matrix_new)**2 + np.abs(sig_ll_matrix_new)**2 + + np.abs(sig_rl_matrix_new)**2 + np.abs(sig_lr_matrix_new)**2) + + # Put the corrupted data back into the data table + obsdata['rrvis'][i] = corr_matrix_new[0][0] + obsdata['llvis'][i] = corr_matrix_new[1][1] + obsdata['rlvis'][i] = corr_matrix_new[0][1] + obsdata['lrvis'][i] = corr_matrix_new[1][0] + + # Put the recomputed sigmas back into the data table + obsdata['rrsigma'][i] = sig_matrix_new[0][0] + obsdata['llsigma'][i] = sig_matrix_new[1][1] + obsdata['rlsigma'][i] = sig_matrix_new[0][1] + obsdata['lrsigma'][i] = sig_matrix_new[1][0] + + # put back into input polvec + obs_circ.data = obsdata + obs_back = obs_circ.switch_polrep(obs.polrep) + obsdata_back = obs_back.data + + # Return observation data + return obsdata_back + +# The old noise generating function. + + +def add_noise(obs, add_th_noise=True, th_noise_factor=1, opacitycal=True, ampcal=True, phasecal=True, + stabilize_scan_amp=False, stabilize_scan_phase=False, + neggains=False, + taup=ehc.GAINPDEF, gain_offset=ehc.GAINPDEF, gainp=ehc.GAINPDEF, + caltable_path=None, seed=False, sigmat=None, + verbose=True): + """Add thermal noise and gain & phase calibration errors to a dataset. + Old routine replaced by add_jones_and_noise. + + Args: + obs (Obsdata): the original observation + add_th_noise (bool): if True, baseline-dependent thermal noise is added to each data point + opacitycal (bool): if False, time-dependent gaussian errors are added to station opacities + ampcal (bool): if False, time-dependent gaussian errors are added to complex station gains + phasecal (bool): if False, time-dependent random phases are added to complex station gains + stabilize_scan_phase (bool): if True, random phase errors are constant over scans + stabilize_scan_amp (bool): if True, random amplitude errors are constant over scans + neggains (bool): if True, force the applied gains to be <1 + taup (float): the fractional std. dev. of the random error on the opacities + gain_offset (float): the base gain offset at all sites, + or a dict giving one gain offset per site + gainp (float): the fractional std. dev. of the random error on the gains + + caltable_path (string): If not None, path and prefix for saving the applied caltable + NOT SUPPORTED for add_noise. + seed (int): seeds the random component of the noise terms. DO NOT set to 0! + sigmat (float): temporal std for a Gaussian Process used to generate gains. + NOT SUPPORTED for add_noise + + verbose (bool): print updates and warnings + Returns: + (np.array): an observation data array + """ + + if caltable_path: + print("caltable saving not implemented for old add_noise function!") + if verbose: + print("Adding gain + phase errors to data and applying a priori calibration . . . ") + + if verbose and not opacitycal: + print(" Applying opacity attenuation AND estimated opacity corrections: opacitycal-->True") + if verbose and not ampcal: + print(" Applying gain corruption: ampcal-->False") + if verbose and not phasecal: + print(" Applying atmospheric phase corruption: phasecal-->False") + if verbose and add_th_noise: + print("Adding thermal noise to data . . . ") + + # Get data + obsdata = copy.deepcopy(obs.data) + + sites = obsh.recarr_to_ndarr(obsdata[['t1', 't2']], 'U32') + taus = np.abs(obsh.recarr_to_ndarr(obsdata[['tau1', 'tau2']], 'f8')) + elevs = obsh.recarr_to_ndarr(obs.unpack(['el1', 'el2'], ang_unit='deg'), 'f8') + times = obsdata['time'] + tint = obsdata['tint'] + vis1 = obsdata[obs.poldict['vis1']] + vis2 = obsdata[obs.poldict['vis2']] + vis3 = obsdata[obs.poldict['vis3']] + vis4 = obsdata[obs.poldict['vis4']] + + times_stable_phase = times.copy() + times_stable_amp = times.copy() + times_stable = times.copy() + + if stabilize_scan_phase is True or stabilize_scan_amp is True: + scans = obs.scans + if np.all(scans) is None or len(scans) == 0: + if verbose: + print("Adding scan table") + obs_scans = obs.copy() + obs_scans.add_scans() + scans = obs_scans.scans + for j in range(len(times_stable)): + for scan in scans: + if scan[0] <= times_stable[j] and scan[1] >= times_stable[j]: + times_stable[j] = scan[0] + break + + if stabilize_scan_phase is True: + times_stable_phase = times_stable.copy() + if stabilize_scan_amp is True: + times_stable_amp = times_stable.copy() + + # Recompute perfect sigmas from SEFDs + bw = obs.bw + if np.any(obs.tarr['sefdr'] <= 0): + if verbose: + print("Warning!: in add_noise, some SEFDs are <= 0!") + print("NOT recomputing sigmas, which may result in double systematic noise") + sigma_perf1 = obsdata[obs.poldict['sigma1']] + sigma_perf2 = obsdata[obs.poldict['sigma2']] + sigma_perf3 = obsdata[obs.poldict['sigma3']] + sigma_perf4 = obsdata[obs.poldict['sigma4']] + else: + sig_rr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdr'], + obs.tarr[obs.tkey[sites[i][1]]]['sefdr'], tint[i], bw) + for i in range(len(tint))), float) + sig_ll = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdl'], + obs.tarr[obs.tkey[sites[i][1]]]['sefdl'], tint[i], bw) + for i in range(len(tint))), float) + sig_rl = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdr'], + obs.tarr[obs.tkey[sites[i][1]]]['sefdl'], tint[i], bw) + for i in range(len(tint))), float) + sig_lr = np.fromiter((obsh.blnoise(obs.tarr[obs.tkey[sites[i][0]]]['sefdl'], + obs.tarr[obs.tkey[sites[i][1]]]['sefdr'], tint[i], bw) + for i in range(len(tint))), float) + if obs.polrep == 'stokes': + sig_iv = 0.5*np.sqrt(sig_rr**2 + sig_ll**2) + sig_qu = 0.5*np.sqrt(sig_rl**2 + sig_lr**2) + sigma_perf1 = sig_iv + sigma_perf2 = sig_qu + sigma_perf3 = sig_qu + sigma_perf4 = sig_iv + elif obs.polrep == 'circ': + sigma_perf1 = sig_rr + sigma_perf2 = sig_ll + sigma_perf3 = sig_rl + sigma_perf4 = sig_lr + + # Seed for random number generators + if seed is False: + seed = str(ttime.time()) + + # Add gain and opacity fluctuations to the TRUE noise + if not ampcal: + # Amplitude gain + if type(gain_offset) == dict: + goff1 = np.fromiter((gain_offset[sites[i, 0]] for i in range(len(times))), float) + goff2 = np.fromiter((gain_offset[sites[i, 1]] for i in range(len(times))), float) + else: + goff1 = np.fromiter((gain_offset for i in range(len(times))), float) + goff2 = np.fromiter((gain_offset for i in range(len(times))), float) + + if type(gainp) is dict: + gain_mult_1 = np.fromiter((gainp[sites[i, 0]] for i in range(len(times))), float) + gain_mult_2 = np.fromiter((gainp[sites[i, 1]] for i in range(len(times))), float) + else: + gain_mult_1 = np.fromiter((gainp for i in range(len(times))), float) + gain_mult_2 = np.fromiter((gainp for i in range(len(times))), float) + + gain1_constant = np.fromiter((goff1[i] * obsh.hashrandn(sites[i, 0], 'gain', str(goff1[i]), seed) + for i in range(len(times))), float) + gain2_constant = np.fromiter((goff2[i] * obsh.hashrandn(sites[i, 1], 'gain', str(goff2[i]), seed) + for i in range(len(times))), float) + + if neggains: + gain1_constant = -np.abs(gain1_constant) + gain2_constant = -np.abs(gain2_constant) + + if sigmat is None: + gain1_var = np.fromiter((gain_mult_1[i] * obsh.hashrandn(sites[i, 0], 'gain', + times_stable_amp[i], + str(gain_mult_1[i]), seed) + for i in range(len(times))), float) + gain2_var = np.fromiter((gain_mult_2[i] * obsh.hashrandn(sites[i, 1], 'gain', + times_stable_amp[i], + str(gain_mult_2[i]), seed) + for i in range(len(times))), float) + else: + raise Exception("correlated gains not supported in old add_noise! Use jones=True") + + gain1 = np.abs((1.0 + gain1_constant)*(1.0 + gain1_var)) + gain2 = np.abs((1.0 + gain2_constant)*(1.0 + gain2_var)) + if neggains: + gain1 = np.exp(-np.abs(np.log(gain1))) + gain2 = np.exp(-np.abs(np.log(gain2))) + + gain_true = np.sqrt(gain1 * gain2) + else: + gain_true = 1 + + if not opacitycal: + + # Use estimated opacity to compute the ESTIMATED noise + tau_est = np.sqrt(np.exp(taus[:, 0]/(ehc.EP+np.sin(elevs[:, 0]*ehc.DEGREE)) + + taus[:, 1]/(ehc.EP+np.sin(elevs[:, 1]*ehc.DEGREE)))) + + # Opacity Errors + tau1 = np.abs(np.fromiter((taus[i, 0] * (1.0 + taup * obsh.hashrandn(sites[i, 0], 'tau', times_stable_amp[i], seed)) + for i in range(len(times))), float)) + tau2 = np.abs(np.fromiter((taus[i, 1] * (1.0 + taup * obsh.hashrandn(sites[i, 1], 'tau', times_stable_amp[i], seed)) + for i in range(len(times))), float)) + + # Correct noise RMS for opacity + tau_true = np.sqrt(np.exp(tau1/(ehc.EP+np.sin(elevs[:, 0]*ehc.DEGREE)) + + tau2/(ehc.EP+np.sin(elevs[:, 1]*ehc.DEGREE)))) + else: + tau_true = tau_est = 1 + + # Add the noise + sigma_true1 = sigma_perf1 + sigma_true2 = sigma_perf2 + sigma_true3 = sigma_perf3 + sigma_true4 = sigma_perf4 + + sigma_est1 = sigma_perf1 * gain_true * tau_est + sigma_est2 = sigma_perf2 * gain_true * tau_est + sigma_est3 = sigma_perf3 * gain_true * tau_est + sigma_est4 = sigma_perf4 * gain_true * tau_est + + if add_th_noise: + vis1 = (vis1 + th_noise_factor*obsh.cerror(sigma_true1)) + vis2 = (vis2 + th_noise_factor*obsh.cerror(sigma_true2)) + vis3 = (vis3 + th_noise_factor*obsh.cerror(sigma_true3)) + vis4 = (vis4 + th_noise_factor*obsh.cerror(sigma_true4)) + + # Add the gain error to the true visibilities + vis1 = vis1 * gain_true * tau_est / tau_true + vis2 = vis2 * gain_true * tau_est / tau_true + vis3 = vis3 * gain_true * tau_est / tau_true + vis4 = vis4 * gain_true * tau_est / tau_true + + # Add random atmospheric phases + if not phasecal: + phase1 = np.fromiter((2 * np.pi * obsh.hashrand(sites[i, 0], 'phase', times_stable_phase[i], seed) + for i in range(len(times))), float) + phase2 = np.fromiter((2 * np.pi * obsh.hashrand(sites[i, 1], 'phase', times_stable_phase[i], seed) + for i in range(len(times))), float) + + vis1 *= np.exp(1j * (phase2-phase1)) + vis2 *= np.exp(1j * (phase2-phase1)) + vis3 *= np.exp(1j * (phase2-phase1)) + vis4 *= np.exp(1j * (phase2-phase1)) + + # Put the visibilities estimated errors back in the obsdata array + obsdata[obs.poldict['vis1']] = vis1 + obsdata[obs.poldict['vis2']] = vis2 + obsdata[obs.poldict['vis3']] = vis3 + obsdata[obs.poldict['vis4']] = vis4 + + obsdata[obs.poldict['sigma1']] = sigma_est1 + obsdata[obs.poldict['sigma2']] = sigma_est2 + obsdata[obs.poldict['sigma3']] = sigma_est3 + obsdata[obs.poldict['sigma4']] = sigma_est4 + + # Return observation data + return obsdata diff --git a/observing/pulses.py b/observing/pulses.py new file mode 100644 index 00000000..35935b16 --- /dev/null +++ b/observing/pulses.py @@ -0,0 +1,197 @@ +# pulses.py +# image restoring pulse functions +# +# Copyright (C) 2018 Katie Bouman +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# If dom="I", we are in real space, if dom="F" we are in Fourier (uv) space +# Coordinates in real space are in radian, coordinates in Fourier space are in lambda + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import math +import numpy as np +# import scipy.special as spec + +################################################################################################### +# Delta Function Pulse +################################################################################################### + + +def deltaPulse2D(x, y, pdim, dom='F'): + if dom == 'I': + if x == y == 0.0: + return 1.0 + else: + return 0.0 + elif dom == 'F': + return 1.0 + +################################################################################################### +# Square Wave Pulse +################################################################################################### + + +def rectPulse2D(x, y, pdim, dom='F'): + if dom == 'I': + return rectPulse_I(x, pdim) * rectPulse_I(y, pdim) + elif dom == 'F': + return rectPulse_F(x, pdim) * rectPulse_F(y, pdim) + + +def rectPulse_I(x, pdim): + if abs(x) >= pdim/2.0: + return 0.0 + else: + return 1.0/pdim + + +def rectPulse_F(omega, pdim): + if (omega == 0): + return 1.0 + else: + return (2.0/(pdim*omega)) * math.sin((pdim*omega)/2.0) + +################################################################################################### +# Triangle Wave Pulse +################################################################################################### + + +def trianglePulse2D(x, y, pdim, dom='F'): + if dom == 'I': + return trianglePulse_I(x, pdim) * trianglePulse_I(y, pdim) + elif dom == 'F': + return trianglePulse_F(x, pdim)*trianglePulse_F(y, pdim) + + +def trianglePulse_I(x, pdim): + if abs(x) > pdim: + return 0.0 + else: + return -(1.0/(pdim**2))*abs(x) + 1.0/pdim + + +def trianglePulse_F(omega, pdim): + if (omega == 0): + return 1.0 + else: + return (4.0/(pdim**2 * omega**2)) * (math.sin((pdim * omega)/2.0))**2 + +################################################################################################### +# Gaussian Pulse +################################################################################################### + + +def GaussPulse2D(x, y, pdim, dom='F'): + sigma = pdim/3. # Gaussian SD (sigma) vs pixelwidth (pdim) + a = 1./2./sigma/sigma + if dom == 'I': + return (a/np.pi)*np.exp(-a*(x**2 + y**2)) + + elif dom == 'F': + return np.exp(-(x**2 + y**2)/4./a) + +################################################################################################### +# Cubic Pulse +################################################################################################### + + +def cubicPulse2D(x, y, pdim, dom='F'): + if dom == 'I': + return cubicPulse_I(x, pdim) * cubicPulse_I(y, pdim) + + elif dom == 'F': + return cubicPulse_F(x, pdim)*cubicPulse_F(y, pdim) + + +def cubicPulse_I(x, pdim): + if abs(x) < pdim: + return (1.5*(abs(x)/pdim)**3 - 2.5*(x/pdim)**2 + 1.)/pdim + elif abs(abs(x)-1.5*pdim) <= 0.5*pdim: + return (-0.5*(abs(x)/pdim)**3 + 2.5*(x/pdim)**2 - 4.*(abs(x)/pdim) + 2.)/pdim + else: + return 0. + + +def cubicPulse_F(omega, pdim): + if (omega == 0): + return 1.0 + else: + p1 = ((3./omega/pdim)*math.sin(omega*pdim/2.)-math.cos(omega*pdim/2.)) + p2 = ((2./omega/pdim)*math.sin(omega*pdim/2.))**3 + return 2.*p1*p2 + + +################################################################################################### +# Sinc Pulse +################################################################################################### + +def sincPulse2D(x, y, pdim, dom='F'): + if dom == 'I': + return sincPulse_I(x, pdim) * sincPulse_I(y, pdim) + + elif dom == 'F': + return sincPulse_F(x, pdim) * sincPulse_F(y, pdim) + + +def sincPulse_I(x, pdim): + if (x == 0): + return 1./pdim + else: + return (1./pdim)*math.sin(np.pi*x/pdim)/(np.pi*x/pdim) + + +def sincPulse_F(omega, pdim): + if (abs(omega) < np.pi/pdim): + return 1.0 + else: + return 0. + +################################################################################################### +# Circular Disk Pulse +################################################################################################### + +# def circPulse2D(x, y, pdim, dom='F'): +# rm = 0.5*pdim #max radius of the disk +# if dom=='I': +# if x**2 + y**2 <= rm**2: +# return 1./np.pi/rm**2 +# else: return 0. +# elif dom=='F': +# return 2.*spec.j1(rm*np.sqrt(x**2 + y**2))/np.sqrt(x**2 + y**2)/rm**2 + +################################################################################################### +# Cubic Spline Pulse +################################################################################################### + +# def cubicsplinePulse2D_F(omegaX, omegaY, pdim): +# return cubicsplinePulse(omegaX, pdim)*cubicsplinePulse(omegaY,pdim) +# +# def cubicsplinePulse_F(omega, delta): +# if (omega == 0): +# coeff = delta +# else: +# omega_delta = omega*delta +# +# coeff = delta * ( +# (4.0/omega_delta**3)*math.sin(omega_delta)*(2.0*math.cos(omega_delta) + 1.0) + +# (24.0/omega_delta**4)*math.cos(omega_delta)*(math.cos(omega_delta) - 1.0) ) +# +# return coeff / pdim # TODO : CHECK IF YOU DIVIDE BY PDIM FOR CLUBIC SPLINE PULSE diff --git a/parloop.py b/parloop.py new file mode 100644 index 00000000..af1a638f --- /dev/null +++ b/parloop.py @@ -0,0 +1,156 @@ +# parloop.py +# Wraps up some helper functions for parallel loops +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import sys +import os + +from multiprocessing import cpu_count +from multiprocessing import Pool, Value, Lock + +TIMEOUT = 31536000 + + +class Parloop(object): + + """A simple parallel loop with a counter + """ + + def __init__(self, func): + """Initialize the loop + """ + + self.func = func + + def run_loop(self, arglist, processes=-1): + """Run the loop on the list of arguments with multiple processes + """ + + n = len(arglist) + + if not type(arglist[0]) is list: + arglist = [[arg] for arg in arglist] + + if processes > 0: + print("Set up loop with %d Processes" % processes) + elif processes == 0: # maximum number of processes -- different argument? + processes = int(cpu_count()) + print("Set up loop with all available (%d) Processes" % processes) + else: + print("Set up loop with no multiprocessing") + + out = -1 + if processes > 0: # run on multiple cores with multiprocessing + counter = Counter(initval=0, maxval=n) + pool = Pool(processes=processes, initializer=self._initcount, initargs=(counter,)) + try: + print('Running the loop') + self.prog_msg(0, n, 0) + out = pool.map_async(self, arglist).get(TIMEOUT) + pool.close() + except KeyboardInterrupt: + print('\ngot ^C while pool mapping, terminating') + pool.terminate() + print('pool terminated') + except Exception as e: + print('\ngot exception: %r, terminating' % (e,)) + pool.terminate() + print('pool terminated') + finally: + pool.join() + + else: # run on a single core + out = [] + for i in range(n): + self.prog_msg(i, n, i-1) + args = arglist[i] + out.append(self.func(*args)) + + return out + + def _initcount(self, x): + """Initialize the counter + """ + global counter + counter = x + + def __call__(self, args): + """Call the loop function + """ + + try: + outval = self.func(*args) + counter.increment() + self.prog_msg(counter.value(), counter.maxval, counter.value()-1) + return outval + except KeyboardInterrupt: + raise KeyboardInterruptError() + + def prog_msg(self, i, n, i_last=0): + """Print a progress bar + """ + + # complete_percent_last = int(100*float(i_last)/float(n)) + complete_percent = int(100*float(i)/float(n)) + ndigit = str(len(str(n))) + + bar_width = 30 + progress = int(bar_width * complete_percent/float(100)) + barparams = (i, n, ("-"*progress) + (" " * (bar_width-progress)), complete_percent) + + printstr = "\rProcessed %0"+ndigit+"i/%i : [%s]%i%%" + sys.stdout.write(printstr % barparams) + sys.stdout.flush() + + +class Counter(object): + """Counter object for sharing among multiprocessing jobs + """ + + def __init__(self, initval=0, maxval=0): + self.val = Value('i', initval) + self.maxval = maxval + self.lock = Lock() + + def increment(self): + with self.lock: + self.val.value += 1 + + def value(self): + with self.lock: + return self.val.value + + +class HiddenPrints: + """Suppresses printing from the loop function + """ + + def __enter__(self): + self._original_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + + def __exit__(self, exc_type, exc_val, exc_tb): + sys.stdout.close() + sys.stdout = self._original_stdout diff --git a/plotting/__init__.py b/plotting/__init__.py new file mode 100644 index 00000000..c63279a7 --- /dev/null +++ b/plotting/__init__.py @@ -0,0 +1,13 @@ +""" +.. module:: ehtim.plotting + :platform: Unix + :synopsis: EHT Imaging Utilities: plotting functions + +.. moduleauthor:: Andrew Chael (achael@cfa.harvard.edu) + +""" +from . import comp_plots +from . import comparisons +from . import summary_plots + +from ..const_def import * diff --git a/plotting/comp_plots.py b/plotting/comp_plots.py new file mode 100644 index 00000000..9862a44b --- /dev/null +++ b/plotting/comp_plots.py @@ -0,0 +1,914 @@ +# comp_plots.py +# Make data plots with multiple observations,images etc. +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import numpy.matlib as matlib +import matplotlib.pyplot as plt +import itertools as it +import copy + +from ehtim.obsdata import merge_obs +import ehtim.const_def as ehc + +COLORLIST = ehc.SCOLORS + +################################################################################################## +# Plotters +################################################################################################## + + +def plotall_compare(obslist, imlist, field1, field2, + conj=False, debias=False, sgrscat=False, + ang_unit='deg', timetype='UTC', ttype='nfft', + axis=False, rangex=False, rangey=False, snrcut=0., + clist=COLORLIST, legendlabels=None, markersize=ehc.MARKERSIZE, + export_pdf="", grid=False, ebar=True, + axislabels=True, legend=True, show=True): + """Plot data from observations compared to ground truth from an image on the same axes. + + Args: + obslist (list): list of observations to plot + imlist (list): list of ground truth images to compare to + field1 (str): x-axis field (from FIELDS) + field2 (str): y-axis field (from FIELDS) + + conj (bool): Plot conjuage baseline data points if True + debias (bool): If True, debias amplitudes. + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel + + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + clist (list): list of colors scatterplot points + legendlabels (list): list of labels of the same length of obslist or imlist + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + """ + prepdata = prep_plot_lists(obslist, imlist, clist=clist, legendlabels=legendlabels, + sgrscat=sgrscat, ttype=ttype) + (obslist_plot, clist_plot, legendlabels_plot, markers) = prepdata + for i in range(len(obslist_plot)): + obs = obslist_plot[i] + axis = obs.plotall(field1, field2, + conj=conj, debias=debias, + ang_unit=ang_unit, timetype=timetype, + axis=axis, rangex=rangex, rangey=rangey, + grid=grid, ebar=ebar, axislabels=axislabels, + show=False, tag_bl=False, legend=False, snrcut=snrcut, + label=legendlabels_plot[i], color=clist_plot[i % len(clist_plot)], + marker=markers[i], markersize=markersize) + + if legend: + plt.legend() + if grid: + axis.grid() + if show: + #plt.show(block=False) + ehc.show_noblock() + + if export_pdf != "": + plt.savefig(export_pdf, bbox_inches='tight', pad_inches=0) + + return axis + + +def plot_bl_compare(obslist, imlist, site1, site2, field, + debias=False, sgrscat=False, + ang_unit='deg', timetype='UTC', ttype='nfft', + axis=False, rangex=False, rangey=False, snrcut=0., + clist=COLORLIST, legendlabels=None, markersize=ehc.MARKERSIZE, + export_pdf="", grid=False, ebar=True, + axislabels=True, legend=True, show=True): + """Plot data from multiple observations vs time on a single baseline on the same axes. + + Args: + obslist (list): list of observations to plot + imlist (list): list of ground truth images to compare to + site1 (str): station 1 name + site2 (str): station 2 name + field (str): y-axis field (from FIELDS) + + debias (bool): If True, debias amplitudes. + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel + + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + snrcut (float): a snr cutoff + + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + clist (list): list of colors scatterplot points + legendlabels (list): list of labels of the same length of obslist or imlist + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + + """ + prepdata = prep_plot_lists(obslist, imlist, clist=clist, legendlabels=legendlabels, + sgrscat=sgrscat, ttype=ttype) + (obslist_plot, clist_plot, legendlabels_plot, markers) = prepdata + for i in range(len(obslist_plot)): + obs = obslist_plot[i] + axis = obs.plot_bl(site1, site2, field, + debias=debias, ang_unit=ang_unit, timetype=timetype, + axis=axis, rangex=rangex, rangey=rangey, + grid=grid, ebar=ebar, axislabels=axislabels, + show=False, legend=False, snrcut=snrcut, + label=legendlabels_plot[i], color=clist_plot[i % len(clist_plot)], + marker=markers[i], markersize=markersize) + if legend: + plt.legend() + if grid: + axis.grid() + if show: + #plt.show(block=False) + ehc.show_noblock() + + if export_pdf != "": + plt.savefig(export_pdf, bbox_inches='tight', pad_inches=0) + + return axis + + +def plot_cphase_compare(obslist, imlist, site1, site2, site3, + vtype='vis', cphases=[], force_recompute=False, + ang_unit='deg', timetype='UTC', ttype='nfft', + axis=False, rangex=False, rangey=False, snrcut=0., + clist=COLORLIST, legendlabels=None, markersize=ehc.MARKERSIZE, + export_pdf="", grid=False, ebar=True, + axislabels=True, legend=True, show=True): + """Plot closure phase on a triangle compared to ground truth from an image on the same axes. + + Args: + obslist (list): list of observations to plot + imlist (list): list of ground truth images to compare to + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + + vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis') + cphases (list): optionally pass in a list of cphases so they don't have to be recomputed + force_recompute (bool): if True, recompute closure phases instead of using stored data + + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + snrcut (float): a snr cutoff + + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + clist (list): list of colors scatterplot points + legendlabels (list): list of labels of the same length of obslist or imlist + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + + """ + try: + len(obslist) + except TypeError: + obslist = [obslist] + + if len(cphases) == 0: + cphases = matlib.repmat([], len(obslist), 1) + + if len(cphases) != len(obslist): + raise Exception("cphases list must be same length as obslist!") + + cphases_back = [] + for i in range(len(obslist)): + cphases_back.append(obslist[i].cphase) + obslist[i].cphase = cphases[i] + + prepdata = prep_plot_lists(obslist, imlist, clist=clist, legendlabels=legendlabels, + sgrscat=False, ttype=ttype) + (obslist_plot, clist_plot, legendlabels_plot, markers) = prepdata + for i in range(len(obslist_plot)): + obs = obslist_plot[i] + axis = obs.plot_cphase(site1, site2, site3, + vtype=vtype, force_recompute=force_recompute, + ang_unit=ang_unit, timetype=timetype, + axis=axis, rangex=rangex, rangey=rangey, + grid=grid, ebar=ebar, axislabels=axislabels, + show=False, legend=False, snrcut=snrcut, + label=legendlabels_plot[i], color=clist_plot[i % len(clist_plot)], + marker=markers[i], markersize=markersize) + + # return to original cphase attribute + for i in range(len(obslist)): + obslist[i].cphase = cphases_back[i] + + if legend: + plt.legend() + if grid: + axis.grid() + if show: + #plt.show(block=False) + ehc.show_noblock() + if export_pdf != "": + plt.savefig(export_pdf, bbox_inches='tight', pad_inches=0) + + return axis + + +def plot_camp_compare(obslist, imlist, site1, site2, site3, site4, + vtype='vis', ctype='camp', camps=[], force_recompute=False, + debias=False, sgrscat=False, timetype='UTC', ttype='nfft', + axis=False, rangex=False, rangey=False, snrcut=0., + clist=COLORLIST, legendlabels=None, markersize=ehc.MARKERSIZE, + export_pdf="", grid=False, ebar=True, + axislabels=True, legend=True, show=True): + """Plot closure amplitude on a triangle vs time from multiple observations on the same axes. + + Args: + obslist (list): list of observations to plot + imlist (list): list of ground truth images to compare to + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + site4 (str): station 4 name + + vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis') + ctype (str): The closure amplitude type ('camp' or 'logcamp') + camps (list): optionally pass in a list of camp so they don't have to be recomputed + force_recompute (bool): if True, recompute closure phases instead of using stored data + + debias (bool): If True, debias amplitudes. + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel + + timetype (str): 'GMST' or 'UTC' + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + snrcut (float): a snr cutoff + + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + clist (list): list of colors scatterplot points + legendlabels (list): list of labels of the same length of obslist or imlist + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + + """ + + try: + len(obslist) + except TypeError: + obslist = [obslist] + + if len(camps) == 0: + camps = matlib.repmat([], len(obslist), 1) + + if len(camps) != len(obslist): + raise Exception("camps list must be same length as obslist!") + + camps_back = [] + for i in range(len(obslist)): + if ctype == 'camp': + camps_back.append(obslist[i].camp) + obslist[i].camp = camps[i] + elif ctype == 'logcamp': + camps_back.append(obslist[i].logcamp) + obslist[i].logcamp = camps[i] + + prepdata = prep_plot_lists(obslist, imlist, clist=clist, legendlabels=legendlabels, + sgrscat=sgrscat, ttype=ttype) + (obslist_plot, clist_plot, legendlabels_plot, markers) = prepdata + + for i in range(len(obslist_plot)): + obs = obslist_plot[i] + axis = obs.plot_camp(site1, site2, site3, site4, + vtype=vtype, ctype=ctype, force_recompute=force_recompute, + debias=debias, timetype=timetype, + axis=axis, rangex=rangex, rangey=rangey, + grid=grid, ebar=ebar, axislabels=axislabels, + show=False, legend=False, snrcut=0., + label=legendlabels_plot[i], color=clist_plot[i % len(clist_plot)], + marker=markers[i], markersize=markersize) + + for i in range(len(obslist)): + if ctype == 'camp': + obslist[i].camp = camps_back[i] + elif ctype == 'logcamp': + obslist[i].logcamp = camps_back[i] + + if legend: + plt.legend() + if grid: + axis.grid() + if show: + #plt.show(block=False) + ehc.show_noblock() + if export_pdf != "": + plt.savefig(export_pdf, bbox_inches='tight', pad_inches=0) + + return axis + + +################################################################################################## +# Aliases +################################################################################################## +def plotall_obs_compare(obslist, field1, field2, **kwargs): + """Plot data from observations compared to ground truth from an image on the same axes. + + Args: + obslist (list): list of observations to plot + field1 (str): x-axis field (from FIELDS) + field2 (str): y-axis field (from FIELDS) + + conj (bool): Plot conjuage baseline data points if True + debias (bool): If True, debias amplitudes. + ang_unit (str): phase unit 'deg' or 'rad' + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + legendlabels (str): should be a list of labels of the same length of obslist or imlist + snrcut (float): a snr cutoff + + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + axis (matplotlib.axes.Axes): add plot to this axis + + clist (list): list of colors scatterplot points + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + """ + axis = plotall_compare(obslist, [], field1, field2, **kwargs) + return axis + + +def plotall_obs_im_compare(obslist, imlist, field1, field2, **kwargs): + """Plot data from observations compared to ground truth from an image on the same axes. + + Args: + obslist (list): list of observations to plot + imlist (list): list of images to plot + field1 (str): x-axis field (from FIELDS) + field2 (str): y-axis field (from FIELDS) + + conj (bool): Plot conjuage baseline data points if True + debias (bool): If True, debias amplitudes. + ang_unit (str): phase unit 'deg' or 'rad' + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + legendlabels (str): should be a list of labels of the same length of obslist + snrcut (float): a snr cutoff + + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + axis (matplotlib.axes.Axes): add plot to this axis + + clist (list): list of colors scatterplot points + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + """ + axis = plotall_compare(obslist, imlist, field1, field2, **kwargs) + return axis + + +def plot_bl_obs_compare(obslist, site1, site2, field, **kwargs): + """Plot data from multiple observations vs time on a single baseline on the same axes. + + Args: + obslist (list): list of observations to plot + site1 (str): station 1 name + site2 (str): station 2 name + field (str): y-axis field (from FIELDS) + + debias (bool): If True and plotting vis amplitudes, debias them + axislabels (bool): Show axis labels if True + legendlabels (str): should be a list of labels of the same length of obslist or imlist + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel + snrcut (float): a snr cutoff + + rangex (list): [xmin, xmax] x-axis (time) limits + rangey (list): [ymin, ymax] y-axis limits + + legend (bool): Show legend if True + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + show (bool): Display the plot if true + axis (matplotlib.axes.Axes): add plot to this axis + clist (list): list of color strings of scatterplot points + export_pdf (str): path to pdf file to save figure + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + + """ + axis = plot_bl_compare(obslist, [], site1, site2, field, **kwargs) + return axis + + +def plot_bl_obs_im_compare(obslist, imlist, site1, site2, field, **kwargs): + """Plot data from multiple observations vs time on a single baseline on the same axes. + + Args: + obslist (list): list of observations to plot + imlist (list): list of ground truth images to compare to + site1 (str): station 1 name + site2 (str): station 2 name + field (str): y-axis field (from FIELDS) + + debias (bool): If True and plotting vis amplitudes, debias them + axislabels (bool): Show axis labels if True + legendlabels (str): should be a list of labels of the same length of obslist or imlist + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel + snrcut (float): a snr cutoff + + rangex (list): [xmin, xmax] x-axis (time) limits + rangey (list): [ymin, ymax] y-axis limits + + legend (bool): Show legend if True + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + show (bool): Display the plot if true + axis (matplotlib.axes.Axes): add plot to this axis + clist (list): list of color strings of scatterplot points + export_pdf (str): path to pdf file to save figure + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + + """ + axis = plot_bl_compare(obslist, imlist, site1, site2, field, **kwargs) + return axis + + +def plot_cphase_obs_compare(obslist, site1, site2, site3, **kwargs): + """Plot closure phase on a triangle vs time from multiple observations on the same axes. + + Args: + obslist (list): list of observations to plot + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + + vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis') + cphases (list): optionally pass in a list of cphases so they don't have to be recomputed + force_recompute (bool): if True, recompute closure phases instead of using stored data + + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + snrcut (float): a snr cutoff + + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + clist (list): list of colors scatterplot points + legendlabels (list): list of labels of the same length of obslist or imlist + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + + """ + axis = plot_cphase_compare(obslist, [], site1, site2, site3, **kwargs) + return axis + + +def plot_cphase_obs_im_compare(obslist, imlist, site1, site2, site3, **kwargs): + """Plot closure phase on a triangle vs time from multiple observations on the same axes. + + Args: + obslist (list): list of observations to plot + imlist (list): list of ground truth images to compare to + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + + vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis') + cphases (list): optionally pass in a list of cphases so they don't have to be recomputed + force_recompute (bool): if True, recompute closure phases instead of using stored data + + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + snrcut (float): a snr cutoff + + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + clist (list): list of colors scatterplot points + legendlabels (list): list of labels of the same length of obslist or imlist + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + + """ + axis = plot_cphase_compare(obslist, imlist, site1, site2, site3, **kwargs) + return axis + + +def plot_camp_obs_compare(obslist, site1, site2, site3, site4, **kwargs): + """Plot closure amplitude on a triangle vs time from multiple observations on the same axes. + + Args: + obslist (list): list of observations to plot + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + site4 (str): station 4 name + + vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis') + ctype (str): The closure amplitude type ('camp' or 'logcamp') + camps (list): optionally pass in a list of camps so they don't have to be recomputed + force_recompute (bool): if True, recompute closure phases instead of using stored data + + debias (bool): If True, debias amplitudes. + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel + + timetype (str): 'GMST' or 'UTC' + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + snrcut (float): a snr cutoff + + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + clist (list): list of colors scatterplot points + legendlabels (list): list of labels of the same length of obslist or imlist + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + + """ + axis = plot_camp_compare(obslist, [], site1, site2, site3, site4, **kwargs) + return axis + + +def plot_camp_obs_im_compare(obslist, imlist, site1, site2, site3, site4, **kwargs): + """Plot closure amplitude on a triangle vs time from multiple observations on the same axes. + + Args: + obslist (list): list of observations to plot + image (Image): ground truth image to compare to + site1 (str): station 1 name + site2 (str): station 2 name + site3 (str): station 3 name + site4 (str): station 4 name + + vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis') + ctype (str): The closure amplitude type ('camp' or 'logcamp') + camps (list): optionally pass in a list of camps so they don't have to be recomputed + force_recompute (bool): if True, recompute closure phases instead of using stored data + + debias (bool): If True, debias amplitudes. + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel + + timetype (str): 'GMST' or 'UTC' + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + snrcut (float): a snr cutoff + + axis (matplotlib.axes.Axes): add plot to this axis + rangex (list): [xmin, xmax] x-axis limits + rangey (list): [ymin, ymax] y-axis limits + clist (list): list of colors scatterplot points + legendlabels (list): list of labels of the same length of obslist or imlist + markersize (int): size of plot markers + export_pdf (str): path to pdf file to save figure + grid (bool): Plot gridlines if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + legend (bool): Show legend if True + show (bool): Display the plot if true + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + + """ + axis = plot_camp_compare(obslist, imlist, site1, site2, site3, site4, **kwargs) + return axis + +################################################################################################## +# Plotters: Compare Observations to Image +################################################################################################## + + +def plotall_obs_im_cphases(obs, imlist, + vtype='vis', ang_unit='deg', timetype='UTC', + ttype='nfft', sgrscat=False, + rangex=False, rangey=[-180, 180], legend=False, legendlabels=None, + show=True, ebar=True, axislabels=False, print_chisqs=True, + display_mode='all'): + """Plot all observation closure phases on top of image ground truth values. + Works with ONE obs and MULTIPLE images. + + Args: + obs (Obsdata): observation to plot + imlist (list): list of ground truth images to compare to + + vtype (str): The visibilty type ('vis','qvis','uvis','vvis','pvis') + ang_unit (str): phase unit 'deg' or 'rad' + timetype (str): 'GMST' or 'UTC' + ttype (str): if "fast" or "nfft" use FFT to produce visibilities. Else "direct" for DTFT + sgrscat (bool): if True, the visibilites will be blurred by the Sgr A* scattering kernel + + rangex (list): [xmin, xmax] x-axis (time) limits + rangey (list): [ymin, ymax] y-axis (phase) limits + + show (bool): Display the plot if True + ebar (bool): Plot error bars if True + axislabels (bool): Show axis labels if True + print_chisqs (bool): print individual chisqs if True + + display_mode (str): 'all' or 'individual' + + Returns: + (matplotlib.axes.Axes): Axes object with data plot + + """ + + try: + len(imlist) + except TypeError: + imlist = [imlist] + + # get closure triangle combinations + sites = [] + for i in range(0, len(obs.tarr)): + sites.append(obs.tarr[i][0]) + uniqueclosure_tri = list(it.combinations(sites, 3)) + + # generate data + cphases_obs = obs.c_phases(mode='all', count='max', vtype=vtype) + obs_all = [obs] + cphases_all = [cphases_obs] + for image in imlist: + obs_model = image.observe_same(obs, sgrscat=sgrscat, add_th_noise=False, ttype=ttype) + cphases_model = obs_model.c_phases(mode='all', count='max', vtype=vtype) + obs_all.append(obs_model) + cphases_all.append(cphases_model) + + # display as individual plots or as a huge sheet + if display_mode == 'individual': + show = True + else: + nplots = len(uniqueclosure_tri) + ncols = 4 + nrows = np.ceil(nplots / ncols).astype(int) + show = False + fig = plt.figure(figsize=(nrows*20, 40)) + + # plot closure phases + print("\n") + + nplot = 0 + for c in range(0, len(uniqueclosure_tri)): + cphases_obs_tri = obs.cphase_tri(uniqueclosure_tri[c][0], + uniqueclosure_tri[cs][1], + uniqueclosure_tri[c][2], + vtype=vtype, ang_unit='deg', cphases=cphases_obs) + + if len(cphases_obs_tri) > 0: + if print_chisqs: + printstr = "%s %s %s :" % ( + uniqueclosure_tri[c][0], uniqueclosure_tri[c][1], uniqueclosure_tri[c][2]) + for i in range(1, len(obs_all)): + cphases_model_tri = obs_all[i].cphase_tri(uniqueclosure_tri[c][0], + uniqueclosure_tri[c][1], + uniqueclosure_tri[c][2], + vtype=vtype, ang_unit='deg', + cphases=cphases_all[i]) + resids = (cphases_obs_tri['cphase'] - cphases_model_tri['cphase'])*ehc.DEGREE + chisq_tri = np.sum((1.0 - np.cos(resids)) / + ((cphases_obs_tri['sigmacp']*ehc.DEGREE)**2)) + chisq_tri *= (2.0/len(cphases_obs_tri)) + printstr += " chisq%i: %0.2f" % (i, chisq_tri) + print(printstr) + + if display_mode == 'individual': + ax = False + + else: + ax = plt.subplot2grid((nrows, ncols), (nplot//ncols, nplot % ncols), fig=fig) + axislabels = False + + f = plot_cphase_obs_compare(obs_all, + uniqueclosure_tri[c][0], + uniqueclosure_tri[c][1], + uniqueclosure_tri[c][2], + vtype=vtype, rangex=rangex, rangey=rangey, ebar=ebar, + show=show, legend=legend, legendlabels=legendlabels, + cphases=cphases_all, axis=ax, axislabels=axislabels) + nplot += 1 + + if display_mode != 'individual': + plt.ion() + f = fig + f.subplots_adjust(wspace=0.1, hspace=0.5) + f.show() + + return f + +################################################################################################## +# Misc +################################################################################################## + + +def prep_plot_lists(obslist, imlist, clist=ehc.SCOLORS, legendlabels=None, + sgrscat=False, ttype='nfft'): + """Return observation, color, marker, legend lists for comp plots""" + + if imlist is None or imlist is False: + imlist = [] + + try: + len(obslist) + except TypeError: + obslist = [obslist] + + try: + len(imlist) + except TypeError: + imlist = [imlist] + + if not((len(imlist) == len(obslist)) or len(imlist) <= 1 or len(obslist) <= 1): + raise Exception("imlist and obslist must be the same length, or either must have length 1") + + if not (legendlabels is None) and (len(legendlabels) != max(len(imlist), len(obslist))): + raise Exception("legendlabels should be the same length of the longer of imlist, obslist!") + + if legendlabels is None: + legendlabels = [str(i+1) for i in range(max(len(imlist), len(obslist)))] + + obslist_plot = [] + clist_plot = copy.copy(clist) + legendlabels_plot = copy.copy(legendlabels) + + # one image, multiple observations + if len(imlist) == 0: + markers = [] + for i in range(len(obslist)): + obslist_plot.append(obslist[i]) + markers.append('o') + + elif len(imlist) == 1 and len(obslist) > 1: + obslist_true = [] + markers = ['s'] + clist_plot = ['k'] + for i in range(len(obslist)): + obslist_plot.append(obslist[i]) + obstrue = imlist[0].observe_same( + obslist[i], sgrscat=sgrscat, add_th_noise=False, ttype=ttype) + for sigma_type in obstrue.data.dtype.names[-4:]: + obstrue.data[sigma_type] *= 0 + obslist_true.append(obstrue) + markers.append('o') + clist_plot.append(clist[i]) + + obstrue = merge_obs(obslist_true) + obslist_plot.insert(0, obstrue) + legendlabels_plot.insert(0, 'Image') + + # one observation, multiple images + elif len(obslist) == 1 and len(imlist) > 1: + obslist_plot.append(obslist[0]) + markers = ['o'] + for i in range(len(imlist)): + obstrue = imlist[i].observe_same( + obslist[0], sgrscat=sgrscat, add_th_noise=False, ttype=ttype) + for sigma_type in obstrue.data.dtype.names[-4:]: + obstrue.data[sigma_type] *= 0 + obslist_plot.append(obstrue) + markers.append('s') + + clist_plot.insert(0, 'k') + legendlabels_plot.insert(0, 'Observation') + + # same number of images and observations + elif len(obslist) == 1 and len(imlist) == 1: + obslist_plot.append(obslist[0]) + + obstrue = imlist[0].observe_same( + obslist[0], sgrscat=sgrscat, add_th_noise=False, ttype=ttype) + for sigma_type in obstrue.data.dtype.names[-4:]: + obstrue.data[sigma_type] *= 0 + obslist_plot.append(obstrue) + + markers = ['o', 's'] + clist_plot = ['k', clist[0]] + legendlabels_plot = [legendlabels[0]+'_obs', legendlabels[0]+'_im'] + + else: + markers = [] + legendlabels_plot = [] + clist_plot = [] + for i in range(len(obslist)): + obstrue = imlist[i].observe_same( + obslist[i], sgrscat=sgrscat, add_th_noise=False, ttype=ttype) + for sigma_type in obstrue.data.dtype.names[-4:]: + obstrue.data[sigma_type] *= 0 + obslist_plot.append(obstrue) + clist_plot.append(clist[i]) + legendlabels_plot.append(legendlabels[i]+'_im') + markers.append('s') + + obslist_plot.append(obslist[i]) + clist_plot.append(clist[i]) + legendlabels_plot.append(legendlabels[i]+'_obs') + markers.append('o') + + if len(obslist_plot) > len(clist): + Exception("More observations than colors -- Add more colors to clist!") + + return (obslist_plot, clist_plot, legendlabels_plot, markers) diff --git a/plotting/comparisons.py b/plotting/comparisons.py new file mode 100644 index 00000000..38c055ff --- /dev/null +++ b/plotting/comparisons.py @@ -0,0 +1,238 @@ +# comparisons.py +# Image Consistency Comparisons +# +# Copyright (C) 2018 Katie Bouman +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import sys +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.offsetbox import (OffsetImage, AnnotationBbox) +from itertools import cycle + +try: + import networkx as nx +except ImportError: + print("Warning: networkx not installed! Cannot use image_agreements()") + + +import ehtim.const_def as ehc + + +def image_consistency(imarr, beamparams, metric='nxcorr', + blursmall=True, beam_max=1.0, beam_steps=5, savepath=[]): + + # get the pixel sizes and fov to compare images at + (min_psize, max_fov) = get_psize_fov(imarr) + + # initialize matrix matrix + metric_mtx = np.zeros([len(imarr), len(imarr), beam_steps]) + + # get the different fracsteps + fracsteps = np.linspace(0, beam_max, beam_steps) + + # loop over the different beam sizes + for fracidx in range(beam_steps): + # print(fracidx) + # look at every pair of images and compute their beam convolved metrics + for i in range(len(imarr)): + img1 = imarr[i] + if fracsteps[fracidx] > 0: + img1 = img1.blur_gauss(beamparams, fracsteps[fracidx]) + + for j in range(i+1, len(imarr)): + img2 = imarr[j] + if fracsteps[fracidx] > 0: + img2 = img2.blur_gauss(beamparams, fracsteps[fracidx]) + + # compute image comparision under a specified blur_frac + (error, im1_pad, im2_shift) = img1.compare_images(img2, metric=[metric], + psize=min_psize, + target_fov=max_fov, + blur_frac=0.0, + beamparams=beamparams) + + # if specified save the shifted images used for comparision + if savepath: + im1_pad.save_fits(savepath + '/' + str(i) + '_' + str(fracidx) + '.fits') + im2_shift.save_fits(savepath + '/' + str(j) + '_' + str(fracidx) + '.fits') + + # save the metric value in a matrix + metric_mtx[i, j, fracidx] = error[0] + + return (metric_mtx, fracsteps) + + +def get_psize_fov(imarr): + """Look over an array of images and determine the min pixel size and max fov + that can be used consistently across them + """ + + min_psize = 100 + for i in range(0, len(imarr)): + if i == 0: + max_fov = np.max([imarr[i].psize*imarr[i].xdim, imarr[i].psize*imarr[i].ydim]) + min_psize = imarr[i].psize + else: + max_fov = np.max([max_fov, imarr[i].psize*imarr[i].xdim, imarr[i].psize*imarr[i].ydim]) + min_psize = np.min([min_psize, imarr[i].psize]) + return (min_psize, max_fov) + + +def image_agreements(imarr, beamparams, metric_mtx, fracsteps, cutoff=0.95): + + if 'networkx' not in sys.modules: + raise Exception("networkx not installed: cannot use image_agreements()!") + + (min_psize, max_fov) = get_psize_fov(imarr) + + im_cliques_fraclevels = [] + cliques_fraclevels = [] + for fracidx in range(len(fracsteps)): + # print(fracidx) + + slice_metric_mtx = metric_mtx[:, :, fracidx] + cuttoffidx = np.where(slice_metric_mtx >= cutoff) + consistant = zip(*cuttoffidx) + + # make graph + G = nx.Graph() + for i in range(len(consistant)): + G.add_edge(consistant[i][0], consistant[i][1]) + + # find all cliques + cliques = list(nx.find_cliques(G)) + # print(cliques) + + cliques_fraclevels.append(cliques) + + im_clique = [] + for c in range(len(cliques)): + clique = cliques[c] + im_avg = imarr[clique[0]].blur_gauss(beamparams, fracsteps[fracidx]) + + for n in range(1, len(clique)): + imcomp = imarr[clique[n]].blur_gauss(beamparams, fracsteps[fracidx]) + (error, im_avg, im2_shift) = im_avg.compare_images(imcomp, metric=['xcorr'], + psize=min_psize, + target_fov=max_fov, + blur_frac=0.0, + beamparams=beamparams) + im_avg.imvec = (im_avg.imvec + im2_shift.imvec) / 2.0 + + im_clique.append(im_avg.copy()) + + im_cliques_fraclevels.append(im_clique) + + return(cliques_fraclevels, im_cliques_fraclevels) + + +def change_cut_off(metric_mtx, fracsteps, imarr, beamparams, cutoff=0.95, zoom=0.1, fov=1): + (cliques_fraclevels, im_cliques_fraclevels) = image_agreements( + imarr, beamparams, metric_mtx, fracsteps, cutoff=cutoff) + generate_consistency_plot(cliques_fraclevels, im_cliques_fraclevels, metric_mtx=metric_mtx, + fracsteps=fracsteps, beamparams=beamparams, zoom=zoom, fov=fov) + + +def generate_consistency_plot(clique_fraclevels, im_clique_fraclevels, zoom=0.1, fov=1, show=True, + framesize=(20, 10), fracsteps=None, cutoff=None, r_offset=1): + + fig, ax = plt.subplots(figsize=framesize) + cycol = cycle('bgrcmk') + + x_loc = [] + + for c, column in enumerate(clique_fraclevels): + colorc = cycol.next() + x_loc.append(((20./len(clique_fraclevels))*c)) + if len(column) == 0: + continue + + for r, row in enumerate(column): + + # adding the images + lenx = len(clique_fraclevels) + leny = 0 + for li in clique_fraclevels: + if len(li) > leny: + leny = len(li) + sample_image = im_clique_fraclevels[c][r].regrid_image( + fov*im_clique_fraclevels[c][r].fovx(), 512) + arr_img = sample_image.imvec.reshape(sample_image.xdim, sample_image.ydim) + imagebox = OffsetImage(arr_img, zoom=zoom, cmap='afmhot') + + imagebox.image.axes = ax + + ab = AnnotationBbox(imagebox, ((20./lenx)*c+r_offset, (20./leny)*r), + xycoords='data', + pad=0.0, + arrowprops=None) + + ax.add_artist(ab) + + # adding the arrows + if c+1 != len(clique_fraclevels): + for a, ro in enumerate(clique_fraclevels[c+1]): + if set(row).issubset(ro): + px = c+1 + px = ((20./lenx)*px) + r_offset + py = a + py = (20./leny)*py + break + + xx = (20./lenx)*c + (8./lenx) + r_offset + yy = (20./leny)*r + ax.arrow(xx, yy, + px - xx - (9./lenx), py - yy, + head_width=0.05, + head_length=0.1, + color=colorc + ) + row.sort() + # adding the text + txtstring = str(row) + ax.text((20./lenx)*c, (20./leny)*(r-0.5), txtstring, fontsize=10, + horizontalalignment='center', color='black', zorder=1000) + + ax.set_xlim(0, 22) + ax.set_ylim(-10, 22) + + ax.set_xticks(x_loc) + ax.set_xticklabels(fracsteps) + + ax.set_yticks([]) + ax.set_yticklabels([]) + + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.spines['left'].set_visible(False) + ax.spines['bottom'].set_visible(False) + + ax.set_title('Blurred comparison of all images; cutoff={0}, fov (uas)={1}'.format( + str(cutoff), str(im_clique_fraclevels[-1][-1].fovx()/ehc.RADPERUAS))) + +# for item in [fig, ax]: +# item.patch.set_visible(False) +# fig.patch.set_visible(False) +# ax.axis('off') + if show is True: + plt.show() diff --git a/plotting/summary_plots.py b/plotting/summary_plots.py new file mode 100644 index 00000000..c8de71c5 --- /dev/null +++ b/plotting/summary_plots.py @@ -0,0 +1,1733 @@ +# summary_plots.py +# Make data plots with multiple observations,images etc. +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +# TODO add systematic noise to individual closure quantities? + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.gridspec as gridspec +from matplotlib.backends.backend_pdf import PdfPages +import datetime + +from ehtim.plotting.comp_plots import plotall_obs_compare +from ehtim.plotting.comp_plots import plot_bl_obs_compare +from ehtim.plotting.comp_plots import plot_cphase_obs_compare +from ehtim.plotting.comp_plots import plot_camp_obs_compare +from ehtim.calibrating.self_cal import self_cal as selfcal +from ehtim.calibrating.pol_cal import leakage_cal, plot_leakage +import ehtim.const_def as ehc + +FONTSIZE = 22 +WSPACE = 0.8 +HSPACE = 0.3 +MARGINS = 0.5 +PROCESSES = 4 +MARKERSIZE = 5 + + +def imgsum(im_or_mov, obs, obs_uncal, outname, outdir='.', title='imgsum', commentstr="", + fontsize=FONTSIZE, cfun='afmhot', snrcut=0., maxset=False, ttype='nfft', + gainplots=True, ampplots=True, cphaseplots=True, campplots=True, ebar=True, + debias=True, cp_uv_min=False, force_extrapolate=True, processes=PROCESSES, + sysnoise=0, syscnoise=0): + """Produce an image summary plot for an image and uvfits file. + + Args: + im_or_mov (Image or Movie): an Image object or Movie + obs (Obsdata): the self-calibrated Obsdata object + obs_uncal (Obsdata): the original Obsdata object + outname (str): output pdf file name + + outdir (str): directory for output file + title (str): the pdf file title + commentstr (str): a comment for the top line of the pdf + fontsize (float): the font size for text in the sheet + cfun (float): matplotlib color function + + maxset (bool): True to use a maximal set of closure quantities + + gainplots (bool): include gain plots or not + ampplots (bool): include amplitude consistency plots or not + cphaseplots (bool): include closure phase consistency plots or not + campplots (bool): include closure amplitude consistency plots or not + ebar (bool): include error bars or not + debias (bool): debias visibility amplitudes before computing chisq or not + cp_uv_min (bool): minimum uv-distance cutoff for including a baseline in closure phase + + sysnoise (float): percent systematic noise added in quadrature + syscnoise (float): closure phase systematic noise in degrees added in quadrature + + snrcut (dict): a dictionary of snrcut values for each quantity + + ttype (str): "fast" or "nfft" or "direct" + force_extrapolate (bool): if True, always extrapolate movie start/stop frames + processes (int): number of cores to use in multiprocessing + Returns: + + """ + + plt.close('all') # close conflicting plots + plt.rc('font', family='serif') + plt.rc('text', usetex=True) + plt.rc('font', size=FONTSIZE) + plt.rc('axes', titlesize=FONTSIZE) + plt.rc('axes', labelsize=FONTSIZE) + plt.rc('xtick', labelsize=FONTSIZE) + plt.rc('ytick', labelsize=FONTSIZE) + plt.rc('legend', fontsize=FONTSIZE) + plt.rc('figure', titlesize=FONTSIZE) + + if fontsize == 0: + fontsize = FONTSIZE + + if maxset: + count = 'max' + else: + count = 'min' + + snrcut_dict = {key: 0. for key in ['vis', 'amp', 'cphase', 'logcamp', 'camp']} + + if type(snrcut) is dict: + for key in snrcut.keys(): + snrcut_dict[key] = snrcut[key] + else: + for key in snrcut_dict.keys(): + snrcut_dict[key] = snrcut + + with PdfPages(outname) as pdf: + titlestr = 'Summary Sheet for %s on MJD %s' % (im_or_mov.source, im_or_mov.mjd) + + # pdf metadata + d = pdf.infodict() + d['Title'] = title + d['Author'] = u'EHT Team 1' + d['Subject'] = titlestr + d['CreationDate'] = datetime.datetime.today() + d['ModDate'] = datetime.datetime.today() + + # define the figure + fig = plt.figure(1, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + + # user comments + if len(commentstr) > 1: + titlestr = titlestr+'\n'+str(commentstr) + + plt.suptitle(titlestr, y=.9, va='center', fontsize=int(1.2*fontsize)) + + ################################################################################ + print("===========================================") + print("displaying the image") + ax = plt.subplot(gs[0:2, 0:2]) + ax.set_title('Submitted Image') + + movie = hasattr(im_or_mov, 'get_image') + if movie: + im_display = im_or_mov.avg_frame() + + # TODO --- ok to always extrapolate? + if force_extrapolate: + im_or_mov.reset_interp(bounds_error=False) + elif hasattr(im_or_mov, 'make_image'): + im_display = im_or_mov.make_image(obs.res() * 10., 512) + else: + im_display = im_or_mov.copy() + + ax = _display_img(im_display, axis=ax, show=False, + has_title=False, cfun=cfun, fontsize=fontsize) + + print("===========================================") + print("displaying the blurred image") + ax = plt.subplot(gs[0:2, 2:5]) + ax.set_title('Image blurred to nominal resolution') + fwhm = obs.res() + print("blur_FWHM: ", fwhm/ehc.RADPERUAS) + beamparams = [fwhm, fwhm, 0] + + imblur = im_display.blur_gauss(beamparams, frac=1.0) + ax = _display_img(imblur, beamparams=beamparams, axis=ax, show=False, + has_title=False, cfun=cfun, fontsize=fontsize) + + ################################################################################ + print("===========================================") + print("calculating statistics") + # display the overall chi2 + ax = plt.subplot(gs[2, 0:2]) + ax.set_title('Image statistics') + # ax.axis('off') + ax.set_yticks([]) + ax.set_xticks([]) + + flux = im_display.total_flux() + + # SNR ordering + # obs.reorder_tarr_snr() + # obs_uncal.reorder_tarr_snr() + + maxset = False + # compute chi^2 + chi2vis = obs.chisq(im_or_mov, dtype='vis', ttype=ttype, + systematic_noise=sysnoise, maxset=maxset, snrcut=snrcut_dict['vis']) + chi2amp = obs.chisq(im_or_mov, dtype='amp', ttype=ttype, + systematic_noise=sysnoise, maxset=maxset, snrcut=snrcut_dict['amp']) + chi2cphase = obs.chisq(im_or_mov, dtype='cphase', ttype=ttype, systematic_noise=sysnoise, + systematic_cphase_noise=syscnoise, + maxset=maxset, cp_uv_min=cp_uv_min, snrcut=snrcut_dict['cphase']) + chi2logcamp = obs.chisq(im_or_mov, dtype='logcamp', ttype=ttype, systematic_noise=sysnoise, + maxset=maxset, snrcut=snrcut_dict['logcamp']) + chi2camp = obs.chisq(im_or_mov, dtype='camp', ttype=ttype, + systematic_noise=sysnoise, maxset=maxset, snrcut=snrcut_dict['camp']) + + chi2vis_uncal = obs_uncal.chisq(im_or_mov, dtype='vis', ttype=ttype, systematic_noise=0, + maxset=maxset, snrcut=snrcut_dict['vis']) + chi2amp_uncal = obs_uncal.chisq(im_or_mov, dtype='amp', ttype=ttype, systematic_noise=0, + maxset=maxset, snrcut=snrcut_dict['amp']) + chi2cphase_uncal = obs_uncal.chisq(im_or_mov, dtype='cphase', ttype=ttype, + systematic_noise=0, + systematic_cphase_noise=0, maxset=maxset, + cp_uv_min=cp_uv_min, snrcut=snrcut_dict['cphase']) + chi2logcamp_uncal = obs_uncal.chisq(im_or_mov, dtype='logcamp', ttype=ttype, + systematic_noise=0, maxset=maxset, + snrcut=snrcut_dict['logcamp']) + chi2camp_uncal = obs_uncal.chisq(im_or_mov, dtype='camp', ttype=ttype, systematic_noise=0, + maxset=maxset, snrcut=snrcut_dict['camp']) + + print("chi^2 vis: %0.2f %0.2f" % (chi2vis, chi2vis_uncal)) + print("chi^2 amp: %0.2f %0.2f" % (chi2amp, chi2amp_uncal)) + print("chi^2 cphase: %0.2f %0.2f" % (chi2cphase, chi2cphase_uncal)) + print("chi^2 logcamp: %0.2f %0.2f" % (chi2logcamp, chi2logcamp_uncal)) + print("chi^2 camp: %0.2f %0.2f" % (chi2logcamp, chi2logcamp_uncal)) + + fs = int(1*fontsize) + fs2 = int(.8*fontsize) + ax.text(.05, .9, "Source:", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.05, .7, "MJD:", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.05, .5, "FREQ:", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.05, .3, "FOV:", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.05, .1, "FLUX:", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + + ax.text(.23, .9, "%s" % im_or_mov.source, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.23, .7, "%i" % im_or_mov.mjd, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.23, .5, "%0.0f GHz" % (im_or_mov.rf/1.e9), fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.23, .3, "%0.1f $\mu$as" % (im_display.fovx()/ehc.RADPERUAS), fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.23, .1, "%0.2f Jy" % flux, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + + ax.text(.5, .9, "$\chi^2_{vis}$", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.5, .7, "$\chi^2_{amp}$", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.5, .5, "$\chi^2_{cphase}$", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.5, .3, "$\chi^2_{log camp}$", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.5, .1, "$\chi^2_{camp}$", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + + ax.text(.72, .9, "%0.2f" % chi2vis, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.72, .7, "%0.2f" % chi2amp, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.72, .5, "%0.2f" % chi2cphase, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.72, .3, "%0.2f" % chi2logcamp, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.72, .1, "%0.2f" % chi2camp, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + + ax.text(.85, .9, "(%0.2f)" % chi2vis_uncal, fontsize=fs2, + ha='left', va='center', transform=ax.transAxes) + ax.text(.85, .7, "(%0.2f)" % chi2amp_uncal, fontsize=fs2, + ha='left', va='center', transform=ax.transAxes) + ax.text(.85, .5, "(%0.2f)" % chi2cphase_uncal, fontsize=fs2, + ha='left', va='center', transform=ax.transAxes) + ax.text(.85, .3, "(%0.2f)" % chi2logcamp, fontsize=fs2, + ha='left', va='center', transform=ax.transAxes) + ax.text(.85, .1, "(%0.2f)" % chi2camp_uncal, fontsize=fs2, + ha='left', va='center', transform=ax.transAxes) + + ################################################################################ + print("===========================================") + print("calculating cphase statistics") + # display the closure phase chi2 + ax = plt.subplot(gs[3:6, 0:2]) + ax.set_title('Closure phase statistics') + ax.set_yticks([]) + ax.set_xticks([]) + + # get closure triangle combinations + # ANDREW -- hacky, fix this! + cp = obs.c_phases(mode="all", count=count, uv_min=cp_uv_min, snrcut=snrcut_dict['cphase']) + n_cphase = len(cp) + alltris = [(str(cpp['t1']), str(cpp['t2']), str(cpp['t3'])) for cpp in cp] + uniqueclosure_tri = [] + for tri in alltris: + if tri not in uniqueclosure_tri: + uniqueclosure_tri.append(tri) + + # generate data + obs_model = im_or_mov.observe_same(obs, add_th_noise=False, ttype=ttype) + + # TODO: check SNR cut + cphases_obs = obs.c_phases(mode='all', count='max', vtype='vis', + uv_min=cp_uv_min, snrcut=snrcut_dict['cphase']) + if snrcut_dict['cphase'] > 0: + cphases_obs_all = obs.c_phases(mode='all', count='max', + vtype='vis', uv_min=cp_uv_min, snrcut=0.) + cphases_model_all = obs_model.c_phases( + mode='all', count='max', vtype='vis', uv_min=cp_uv_min, snrcut=0.) + mask = [cphase in cphases_obs for cphase in cphases_obs_all] + cphases_model = cphases_model_all[mask] + print('cphase snr cut', snrcut_dict['cphase'], ' : kept', len( + cphases_obs), '/', len(cphases_obs_all)) + else: + cphases_model = obs_model.c_phases( + mode='all', count='max', vtype='vis', uv_min=cp_uv_min, snrcut=0.) + + # generate chi^2 -- NO SYSTEMATIC NOISES + cphase_chisq_data = [] + for c in range(0, len(uniqueclosure_tri)): + cphases_obs_tri = obs.cphase_tri(uniqueclosure_tri[c][0], + uniqueclosure_tri[c][1], + uniqueclosure_tri[c][2], + vtype='vis', ang_unit='deg', cphases=cphases_obs) + + if len(cphases_obs_tri) > 0: + cphases_model_tri = obs_model.cphase_tri(uniqueclosure_tri[c][0], + uniqueclosure_tri[c][1], + uniqueclosure_tri[c][2], + vtype='vis', ang_unit='deg', + cphases=cphases_model) + + resids = (cphases_obs_tri['cphase'] - cphases_model_tri['cphase'])*ehc.DEGREE + chisq_tri = 2*np.sum((1.0 - np.cos(resids)) / + ((cphases_obs_tri['sigmacp']*ehc.DEGREE)**2)) + + npts = len(cphases_obs_tri) + data = [uniqueclosure_tri[c][0], uniqueclosure_tri[c] + [1], uniqueclosure_tri[c][2], npts, chisq_tri] + cphase_chisq_data.append(data) + + # sort by decreasing chi^2 + idx = np.argsort([data[-1] for data in cphase_chisq_data]) + idx = list(reversed(idx)) + + chisqtab = (r"\begin{tabular}{ l|l|l|l } \hline Triangle " + + r"& $N_{tri}$ & $\chi^2_{tri}/N_{tri}$ & $\chi^2_{tri}/N_{tot}$" + + r"\\ \hline \hline") + first = True + for i in range(len(cphase_chisq_data)): + if i > 30: + break + data = cphase_chisq_data[idx[i]] + tristr = r"%s-%s-%s" % (data[0], data[1], data[2]) + nstr = r"%i" % data[3] + rchisqstr = r"%0.1f" % (float(data[4])/float(data[3])) + rrchisqstr = r"%0.3f" % (float(data[4])/float(n_cphase)) + if first: + chisqtab += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + first = False + else: + chisqtab += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + chisqtab += r" \end{tabular}" + + ax.text(0.5, .975, chisqtab, ha="center", va="top", transform=ax.transAxes, size=fontsize) + + ################################################################################ + print("===========================================") + print("calculating camp statistics") + # display the log closure amplitude chi2 + ax = plt.subplot(gs[2:6, 2::]) + ax.set_title('Log Closure amplitude statistics') + # ax.axis('off') + ax.set_yticks([]) + ax.set_xticks([]) + + # get closure amplitude combinations + # TODO -- hacky, fix this! + cp = obs.c_amplitudes(mode="all", count=count, ctype='logcamp', debias=debias) + n_camps = len(cp) + allquads = [(str(cpp['t1']), str(cpp['t2']), str(cpp['t3']), str(cpp['t4'])) for cpp in cp] + uniqueclosure_quad = [] + for quad in allquads: + if quad not in uniqueclosure_quad: + uniqueclosure_quad.append(quad) + + # generate data + # TODO: check SNR cut + camps_obs = obs.c_amplitudes(mode='all', count='max', ctype='logcamp', + debias=debias, snrcut=snrcut_dict['logcamp']) + if snrcut_dict['logcamp'] > 0: + camps_obs_all = obs.c_amplitudes( + mode='all', count='max', ctype='logcamp', debias=debias, snrcut=0.) + camps_model_all = obs_model.c_amplitudes( + mode='all', count='max', ctype='logcamp', debias=False, snrcut=0.) + mask = [camp['camp'] in camps_obs['camp'] for camp in camps_obs_all] + camps_model = camps_model_all[mask] + print('closure amp snrcut', snrcut_dict['logcamp'], + ': kept', len(camps_obs), '/', len(camps_obs_all)) + else: + camps_model = obs_model.c_amplitudes( + mode='all', count='max', ctype='logcamp', debias=False, snrcut=0.) + + # generate chi2 -- NO SYSTEMATIC NOISES + camp_chisq_data = [] + for c in range(0, len(uniqueclosure_quad)): + camps_obs_quad = obs.camp_quad(uniqueclosure_quad[c][0], uniqueclosure_quad[c][1], + uniqueclosure_quad[c][2], uniqueclosure_quad[c][3], + vtype='vis', camps=camps_obs, ctype='logcamp') + + if len(camps_obs_quad) > 0: + camps_model_quad = obs.camp_quad(uniqueclosure_quad[c][0], uniqueclosure_quad[c][1], + uniqueclosure_quad[c][2], uniqueclosure_quad[c][3], + vtype='vis', camps=camps_model, ctype='logcamp') + + resids = camps_obs_quad['camp'] - camps_model_quad['camp'] + chisq_quad = np.sum(np.abs(resids/camps_obs_quad['sigmaca'])**2) + npts = len(camps_obs_quad) + + data = (uniqueclosure_quad[c][0], uniqueclosure_quad[c][1], + uniqueclosure_quad[c][2], uniqueclosure_quad[c][3], + npts, + chisq_quad) + camp_chisq_data.append(data) + + # sort by decreasing chi^2 + idx = np.argsort([data[-1] for data in camp_chisq_data]) + idx = list(reversed(idx)) + + chisqtab = (r"\begin{tabular}{ l|l|l|l } \hline Quadrangle " + + r"& $N_{quad}$ & $\chi^2_{quad}/N_{quad}$ & $\chi^2_{quad}/N_{tot}$ " + + r"\\ \hline \hline") + for i in range(len(camp_chisq_data)): + if i > 45: + break + data = camp_chisq_data[idx[i]] + tristr = r"%s-%s-%s-%s" % (data[0], data[1], data[2], data[3]) + nstr = r"%i" % data[4] + rchisqstr = r"%0.1f" % (data[5]/float(data[4])) + rrchisqstr = r"%0.3f" % (data[5]/float(n_camps)) + if i == 0: + chisqtab += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + else: + chisqtab += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + + chisqtab += r" \end{tabular}" + + ax.text(0.5, .975, chisqtab, ha="center", va="top", transform=ax.transAxes, size=fontsize) + + # save the first page of the plot + print('saving pdf page 1') + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + ################################################################################ + # plot the vis amps + fig = plt.figure(2, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + + print("===========================================") + print("plotting vis amps") + ax = plt.subplot(gs[0:2, 0:2]) + obs_tmp = obs_model.copy() + obs_tmp.data['sigma'] *= 0. + ax = plotall_obs_compare([obs, obs_tmp], + 'uvdist', 'amp', axis=ax, legend=False, + clist=['k', ehc.SCOLORS[1]], + ttype=ttype, show=False, debias=debias, + snrcut=snrcut_dict['amp'], + ebar=ebar, markersize=MARKERSIZE) + # modify the labels + ax.set_title('Calibrated Visiblity Amplitudes') + ax.set_xlabel('u-v distance (G$\lambda$)') + ax.set_xlim([0, 1.e10]) + ax.set_xticks([0, 2.e9, 4.e9, 6.e9, 8.e9, 10.e9]) + ax.set_xticklabels(["0", "2", "4", "6", "8", "10"]) + ax.set_xticks([1.e9, 3.e9, 5.e9, 7.e9, 9.e9], minor=True) + ax.set_xticklabels([], minor=True) + + ax.set_ylabel('Amplitude (Jy)') + ax.set_ylim([0, 1.2*flux]) + yticks_maj = np.array([0, .2, .4, .6, .8, 1])*flux + ax.set_yticks(yticks_maj) + ax.set_yticklabels(["%0.2f" % fl for fl in yticks_maj]) + yticks_min = np.array([.1, .3, .5, .7, .9])*flux + ax.set_yticks(yticks_min, minor=True) + ax.set_yticklabels([], minor=True) + + # plot the caltable gains + if gainplots: + print("===========================================") + print("plotting gains") + ax2 = plt.subplot(gs[0:2, 2:6]) + obs_tmp = obs_uncal.copy() + for i in range(1): + ct = selfcal(obs_tmp, im_or_mov, + method='amp', ttype=ttype, + caltable=True, gain_tol=.2, + processes=processes) + ct = ct.pad_scans() + obs_tmp = ct.applycal(obs_tmp, interp='nearest', extrapolate=True) + if np.any(np.isnan(obs_tmp.data['vis'])): + print("Warning: NaN in applycal vis table!") + break + if i > 0: + ct_out = ct_out.merge([ct]) + else: + ct_out = ct + + ax2 = ct_out.plot_gains('all', rangey=[.1, 10], + yscale='log', axis=ax2, legend=True, show=False) + + # median gains + ax = plt.subplot(gs[3:6, 2:5]) + ax.set_title('Station gain statistics') + ax.set_yticks([]) + ax.set_xticks([]) + + gain_data = [] + for station in ct_out.tarr['site']: + try: + gain = np.median(np.abs(ct_out.data[station]['lscale'])) + except: + continue + pdiff = np.abs(gain-1)*100 + data = (station, gain, pdiff) + gain_data.append(data) + + # sort by decreasing chi^2 + idx = np.argsort([data[-1] for data in gain_data]) + idx = list(reversed(idx)) + + chisqtab = (r"\begin{tabular}{ l|l|l } \hline Site & " + + r"Median Gain & Percent diff. \\ \hline \hline") + for i in range(len(gain_data)): + if i > 45: + break + data = gain_data[idx[i]] + sitestr = r"%s" % (data[0]) + gstr = r"%0.2f" % data[1] + pstr = r"%0.0f" % data[2] + if i == 0: + chisqtab += r" " + sitestr + " & " + gstr + " & " + pstr + else: + chisqtab += r" \\" + sitestr + " & " + gstr + " & " + pstr + + chisqtab += r" \end{tabular}" + ax.text(0.5, .975, chisqtab, ha="center", va="top", + transform=ax.transAxes, size=fontsize) + + # baseline amplitude chi2 + print("===========================================") + print("baseline vis amps chisq") + ax = plt.subplot(gs[3:6, 0:2]) + ax.set_title('Visibility amplitude statistics') + ax.set_yticks([]) + ax.set_xticks([]) + + bl_unpk = obs.unpack(['t1', 't2'], debias=debias) + n_bl = len(bl_unpk) + allbl = [(str(bl['t1']), str(bl['t2'])) for bl in bl_unpk] + uniquebl = [] + for bl in allbl: + if bl not in uniquebl: + uniquebl.append(bl) + + # generate chi2 -- NO SYSTEMATIC NOISES + bl_chisq_data = [] + for ii in range(0, len(uniquebl)): + bl = uniquebl[ii] + + amps_bl = obs.unpack_bl(bl[0], bl[1], ['amp', 'sigma'], debias=debias) + if len(amps_bl) > 0: + + amps_bl_model = obs_model.unpack_bl(bl[0], bl[1], ['amp', 'sigma'], debias=False) + + if snrcut_dict['amp'] > 0: + amask = amps_bl['amp']/amps_bl['sigma'] > snrcut_dict['amp'] + amps_bl = amps_bl[amask] + amps_bl_model = amps_bl_model[amask] + + chisq_bl = np.sum( + np.abs((amps_bl['amp'] - amps_bl_model['amp'])/amps_bl['sigma'])**2) + npts = len(amps_bl_model) + + data = (bl[0], bl[1], + npts, + chisq_bl) + bl_chisq_data.append(data) + + # sort by decreasing chi^2 + idx = np.argsort([data[-1] for data in bl_chisq_data]) + idx = list(reversed(idx)) + + chisqtab = (r"\begin{tabular}{ l|l|l|l } \hline Baseline & " + + r"$N_{amp}$ & $\chi^2_{amp}/N_{amp}$ & $\chi^2_{amp}/N_{total}$ " + + r"\\ \hline \hline") + for i in range(len(bl_chisq_data)): + if i > 45: + break + data = bl_chisq_data[idx[i]] + tristr = r"%s-%s" % (data[0], data[1]) + nstr = r"%i" % data[2] + rchisqstr = r"%0.1f" % (data[3]/float(data[2])) + rrchisqstr = r"%0.3f" % (data[3]/float(n_bl)) + if i == 0: + chisqtab += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + else: + chisqtab += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + + chisqtab += r" \end{tabular}" + + ax.text(0.5, .975, chisqtab, ha="center", va="top", transform=ax.transAxes, size=fontsize) + + # save the first page of the plot + print('saving pdf page 2') + # plt.tight_layout() + # plt.subplots_adjust(wspace=1,hspace=1) + # plt.savefig(outname, pad_inches=MARGINS,bbox_inches='tight') + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + ################################################################################ + # plot the visibility amplitudes + page = 3 + if ampplots: + print("===========================================") + print("plotting amplitudes") + fig = plt.figure(3, figsize=(18, 28), dpi=200) + plt.suptitle("Amplitude Plots", y=.9, va='center', fontsize=int(1.2*fontsize)) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + switch = False + + obs_model.data['sigma'] *= 0 + amax = 1.1*np.max(np.abs(np.abs(obs_model.data['vis']))) + obs_all = [obs, obs_model] + for bl in uniquebl: + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'amp', rangey=[0, amax], + markersize=MARKERSIZE, debias=debias, + snrcut=snrcut_dict['amp'], + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype=ttype, show=False, ebar=ebar) + if ax is None: + continue + if switch: + i += 1 + j = 0 + switch = False + else: + j = 1 + switch = True + + ax.set_xlabel('') + + if i == 3: + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + fig = plt.figure(3, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + switch = False + + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + ################################################################################ + # plot the closure phases + if cphaseplots: + print("===========================================") + print("plotting closure phases") + fig = plt.figure(3, figsize=(18, 28), dpi=200) + plt.suptitle("Closure Phase Plots", y=.9, va='center', fontsize=int(1.2*fontsize)) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + switch = False + obs_all = [obs, obs_model] + cphases_model['sigmacp'] *= 0 + cphases_all = [cphases_obs, cphases_model] + for tri in uniqueclosure_tri: + + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_cphase_obs_compare(obs_all, tri[0], tri[1], tri[2], rangey=[-185, 185], + cphases=cphases_all, markersize=MARKERSIZE, + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype=ttype, show=False, ebar=ebar) + if ax is None: + continue + if switch: + i += 1 + j = 0 + switch = False + else: + j = 1 + switch = True + + ax.set_xlabel('') + + if i == 3: + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + fig = plt.figure(3, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + switch = False + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + ################################################################################ + # plot the log closure amps + if campplots: + print("===========================================") + print("plotting closure amplitudes") + fig = plt.figure(3, figsize=(18, 28), dpi=200) + plt.suptitle("Closure Amplitude Plots", y=.9, va='center', fontsize=int(1.2*fontsize)) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + switch = False + obs_all = [obs, obs_model] + camps_model['sigmaca'] *= 0 + camps_all = [camps_obs, camps_model] + cmax = 1.1*np.max(np.abs(camps_obs['camp'])) + for quad in uniqueclosure_quad: + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_camp_obs_compare(obs_all, quad[0], quad[1], quad[2], quad[3], + markersize=MARKERSIZE, + ctype='logcamp', rangey=[-cmax, cmax], camps=camps_all, + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype=ttype, show=False, ebar=ebar) + if ax is None: + continue + if switch: + i += 1 + j = 0 + switch = False + else: + j = 1 + switch = True + + ax.set_xlabel('') + + if i == 3: + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + fig = plt.figure(3, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + switch = False + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + +def imgsum_pol(im, obs, obs_uncal, outname, + leakage_arr=False, nvec=False, + outdir='.', title='imgsum_pol', commentstr="", + fontsize=FONTSIZE, cfun='afmhot', snrcut=0., + dtermplots=True, pplots=True, mplots=True, qplots=True, uplots=True, ebar=True, + sysnoise=0): + """Produce a polarimetric image summary plot for an image and uvfits file. + + Args: + im (Image): an Image object + obs (Obsdata): the calibrated Obsdata object + obs_uncal (Obsdata): the original Obsdata object + outname (str): output pdf file name + + leakage_arr (bool): array with calibrated d-terms + nvec (int): number of polarimetric vectors to plot in each direction + + outdir (str): directory for output file + title (str): the pdf file title + commentstr (str): a comment for the top line of the pdf + fontsize (float): the font size for text in the sheet + cfun (float): matplotlib color function + snrcut (dict): a dictionary of snrcut values for each quantity + + dtermplots (bool): plot the d-terms or not + mplots (bool): plot the fractional polarizations or not + pplots (bool): plot the P=RL polarization or not + mplots (bool): plot the Q data or not + pplots (bool): plot the U data or not + + ebar (bool): include error bars or not + sysnoise (float): percent systematic noise added in quadrature + + Returns: + + """ + + # switch polreps and mask nan data + im = im.switch_polrep(polrep_out='stokes') + obs = obs.switch_polrep(polrep_out='stokes') + obs_uncal = obs_uncal.switch_polrep(polrep_out='stokes') + + mask_nan = (np.isnan(obs_uncal.data['vis']) + + np.isnan(obs_uncal.data['qvis']) + + np.isnan(obs_uncal.data['uvis']) + + np.isnan(obs_uncal.data['vvis'])) + obs_uncal.data = obs_uncal.data[~mask_nan] + + mask_nan = (np.isnan(obs.data['vis']) + + np.isnan(obs.data['qvis']) + + np.isnan(obs.data['uvis']) + + np.isnan(obs.data['vvis'])) + obs.data = obs.data[~mask_nan] + + if len(im.qvec) == 0 or len(im.uvec) == 0: + raise Exception("the image isn't polarized!") + + plt.close('all') # close conflicting plots + plt.rc('font', family='serif') + plt.rc('text', usetex=True) + plt.rc('font', size=FONTSIZE) + plt.rc('axes', titlesize=FONTSIZE) + plt.rc('axes', labelsize=FONTSIZE) + plt.rc('xtick', labelsize=FONTSIZE) + plt.rc('ytick', labelsize=FONTSIZE) + plt.rc('legend', fontsize=FONTSIZE) + plt.rc('figure', titlesize=FONTSIZE) + + if fontsize == 0: + fontsize = FONTSIZE + + snrcut_dict = {key: 0. for key in ['m', 'pvis', 'qvis', 'uvis']} + + if type(snrcut) is dict: + for key in snrcut.keys(): + snrcut_dict[key] = snrcut[key] + else: + for key in snrcut_dict.keys(): + snrcut_dict[key] = snrcut + + # TODO -- ok? prevent errors in divisition + if(np.any(im.ivec == 0)): + im.ivec += 1.e-50*np.max(im.ivec) + + with PdfPages(outname) as pdf: + titlestr = 'Summary Sheet for %s on MJD %s' % (im.source, im.mjd) + + # pdf metadata + d = pdf.infodict() + d['Title'] = title + d['Author'] = u'EHT Team 1' + d['Subject'] = titlestr + d['CreationDate'] = datetime.datetime.today() + d['ModDate'] = datetime.datetime.today() + + # define the figure + fig = plt.figure(1, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + + # user comments + if len(commentstr) > 1: + titlestr = titlestr+'\n'+str(commentstr) + + plt.suptitle(titlestr, y=.9, va='center', fontsize=int(1.2*fontsize)) + + ################################################################################ + print("===========================================") + print("displaying the images") + + # unblurred image IQU + ax = plt.subplot(gs[0:2, 0:2]) + ax.set_title('I') + + ax = _display_img_pol(im, axis=ax, show=False, has_title=False, cfun=cfun, + pol='I', polticks=True, + nvec=nvec, pcut=0.1, mcut=0.01, contour=False, + fontsize=fontsize) + + ax = plt.subplot(gs[2:4, 0:2]) + ax.set_title('Q') + ax = _display_img_pol(im, axis=ax, show=False, has_title=False, cfun=plt.get_cmap('bwr'), + pol='Q', polticks=False, + nvec=nvec, pcut=0.1, mcut=0.01, contour=True, + fontsize=fontsize) + + ax = plt.subplot(gs[4:6, 0:2]) + ax.set_title('U') + ax = _display_img_pol(im, axis=ax, show=False, has_title=False, cfun=plt.get_cmap('bwr'), + pol='U', polticks=False, + nvec=nvec, pcut=0.1, mcut=0.01, contour=True, + fontsize=fontsize) + + # blurred image IQU + ax = plt.subplot(gs[0:2, 2:5]) + beamparams = obs_uncal.fit_beam() + fwhm = np.min((np.abs(beamparams[0]), np.abs(beamparams[1]))) + print("blur_FWHM: ", fwhm/ehc.RADPERUAS) + + imblur = im.blur_gauss(beamparams, frac=1.0, frac_pol=1.) + + ax = _display_img_pol(imblur, axis=ax, show=False, has_title=False, cfun=cfun, + pol='I', polticks=True, beamparams=beamparams, + nvec=nvec, pcut=0.1, mcut=0.01, contour=False, + fontsize=fontsize) + + ax = plt.subplot(gs[2:4, 2:5]) + ax = _display_img_pol(imblur, axis=ax, show=False, has_title=False, + cfun=plt.get_cmap('bwr'), + pol='Q', polticks=False, + nvec=nvec, pcut=0.1, mcut=0.01, contour=True, + fontsize=fontsize) + + ax = plt.subplot(gs[4:6, 2:5]) + ax = _display_img_pol(imblur, axis=ax, show=False, has_title=False, + cfun=plt.get_cmap('bwr'), + pol='U', polticks=False, + nvec=nvec, pcut=0.1, mcut=0.01, contour=True, + fontsize=fontsize) + + print('saving pdf page 1') + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + # unblurred image m chi + fig = plt.figure(2, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + + ax = plt.subplot(gs[0:2, 0:2]) + ax.set_title('m') + ax = _display_img_pol(im, axis=ax, show=True, has_title=False, + cfun=plt.get_cmap('jet'), + pol='m', polticks=False, + nvec=nvec, pcut=0.1, mcut=0.01, contour=False, + fontsize=fontsize) + + ax = plt.subplot(gs[2:4, 0:2]) + ax.set_title('chi') + ax = _display_img_pol(im, axis=ax, show=False, has_title=False, + cfun=plt.get_cmap('jet'), + pol='chi', polticks=False, + nvec=nvec, pcut=0.1, mcut=0.01, contour=False, + fontsize=fontsize) + + ax = plt.subplot(gs[0:2, 2:5]) + ax = _display_img_pol(imblur, axis=ax, show=False, has_title=False, + cfun=plt.get_cmap('jet'), + pol='m', polticks=False, + nvec=nvec, pcut=0.1, mcut=0.01, contour=False, + fontsize=fontsize) + + ax = plt.subplot(gs[2:4, 2:5]) + ax = _display_img_pol(imblur, axis=ax, show=False, has_title=False, + cfun=plt.get_cmap('jet'), + pol='chi', polticks=False, + nvec=nvec, pcut=0.1, mcut=0.01, contour=False, + fontsize=fontsize) + + ################################################################################ + print("===========================================") + print("calculating statistics") + # display the overall chi2 + ax = plt.subplot(gs[4, 0:2]) + ax.set_title('Image statistics') + # ax.axis('off') + ax.set_yticks([]) + ax.set_xticks([]) + + flux = im.total_flux() + + # SNR ordering + # obs.reorder_tarr_snr() + # obs_uncal.reorder_tarr_snr() + + # compute chi^2 + chi2pvis = obs.polchisq(im, dtype='m', ttype='nfft', + systematic_noise=sysnoise, pol_trans=False) + chi2m = obs.polchisq(im, dtype='m', ttype='nfft', + systematic_noise=sysnoise, pol_trans=False) + chi2qvis = obs.chisq(im, dtype='vis', ttype='nfft', + systematic_noise=sysnoise, pol='Q') + chi2uvis = obs.chisq(im, dtype='vis', ttype='nfft', + systematic_noise=sysnoise, pol='U') + + chi2pvis_uncal = obs_uncal.polchisq(im, dtype='m', ttype='nfft', + systematic_noise=sysnoise, pol_trans=False) + chi2m_uncal = obs_uncal.polchisq(im, dtype='m', ttype='nfft', + systematic_noise=sysnoise, pol_trans=False) + chi2qvis_uncal = obs_uncal.chisq(im, dtype='vis', ttype='nfft', + systematic_noise=sysnoise, pol='Q') + chi2uvis_uncal = obs_uncal.chisq(im, dtype='vis', ttype='nfft', + systematic_noise=sysnoise, pol='U') + + print("chi^2 m: %0.2f %0.2f" % (chi2m, chi2m_uncal)) + print("chi^2 pvis: %0.2f %0.2f" % (chi2pvis, chi2pvis_uncal)) + print("chi^2 qvis: %0.2f %0.2f" % (chi2qvis, chi2qvis_uncal)) + print("chi^2 uvis: %0.2f %0.2f" % (chi2uvis, chi2uvis_uncal)) + + fs = int(1*fontsize) + fs2 = int(.8*fontsize) + ax.text(.05, .9, "Source:", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.05, .7, "MJD:", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.05, .5, "FREQ:", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.05, .3, "FOV:", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.05, .1, "FLUX:", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + + ax.text(.23, .9, "%s" % im.source, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.23, .7, "%i" % im.mjd, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.23, .5, "%0.0f GHz" % (im.rf/1.e9), fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.23, .3, "%0.1f $\mu$as" % (im.fovx()/ehc.RADPERUAS), fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.23, .1, "%0.2f Jy" % flux, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + + ax.text(.5, .9, "$\chi^2_{m}$", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.5, .7, "$\chi^2_{P}$", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.5, .5, "$\chi^2_{Q}$", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.5, .3, "$\chi^2_{U}$", fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + + ax.text(.72, .9, "%0.2f" % chi2m, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.72, .7, "%0.2f" % chi2pvis, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.72, .5, "%0.2f" % chi2qvis, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + ax.text(.72, .3, "%0.2f" % chi2uvis, fontsize=fs, + ha='left', va='center', transform=ax.transAxes) + + ax.text(.85, .9, "(%0.2f)" % chi2m_uncal, fontsize=fs2, + ha='left', va='center', transform=ax.transAxes) + ax.text(.85, .7, "(%0.2f)" % chi2pvis_uncal, fontsize=fs2, + ha='left', va='center', transform=ax.transAxes) + ax.text(.85, .5, "(%0.2f)" % chi2qvis_uncal, fontsize=fs2, + ha='left', va='center', transform=ax.transAxes) + ax.text(.85, .3, "(%0.2f)" % chi2uvis_uncal, fontsize=fs2, + ha='left', va='center', transform=ax.transAxes) + + ################################################################################ + # plot the D terms + + if dtermplots: + print("===========================================") + print("plotting d terms") + ax = plt.subplot(gs[4:6, 2:5]) + + if leakage_arr: + obs_polcal = obs_uncal.copy() + obs_polcal.tarr = leakage_arr.tarr + else: + obs_polcal = leakage_cal(obs_uncal, im, leakage_tol=1e6, ttype='nfft') + + ax = plot_leakage(obs_polcal, axis=ax, show=False, + rangex=[-20, 20], rangey=[-20, 20], markersize=5) + + print('saving pdf page 2') + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + # 3 + # baseline amplitude chi2 + fig = plt.figure(2, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + + print("===========================================") + print("baseline m&p chisq") + + bl_unpk = obs.unpack(['t1', 't2']) + n_bl = len(bl_unpk) + allbl = [(str(bl['t1']), str(bl['t2'])) for bl in bl_unpk] + uniquebl = [] + for bl in allbl: + if bl not in uniquebl: + uniquebl.append(bl) + + # generate data + obs_model = im.observe_same(obs, add_th_noise=False, ttype='nfft') + + # generate chi2 -- NO SYSTEMATIC NOISES + bl_chisq_data_m = [] + bl_chisq_data_pvis = [] + bl_chisq_data_qvis = [] + bl_chisq_data_uvis = [] + + for ii in range(0, len(uniquebl)): + bl = uniquebl[ii] + + m_bl = obs.unpack_bl(bl[0], bl[1], ['m', 'msigma'], debias=False) + pvis_bl = obs.unpack_bl(bl[0], bl[1], ['pvis', 'psigma'], debias=False) + qvis_bl = obs.unpack_bl(bl[0], bl[1], ['qvis', 'qsigma'], debias=False) + uvis_bl = obs.unpack_bl(bl[0], bl[1], ['uvis', 'usigma'], debias=False) + + if len(m_bl) > 0: + + m_bl_model = obs_model.unpack_bl(bl[0], bl[1], ['m', 'msigma'], debias=False) + pvis_bl_model = obs_model.unpack_bl(bl[0], bl[1], ['pvis', 'psigma'], debias=False) + qvis_bl_model = obs_model.unpack_bl(bl[0], bl[1], ['qvis', 'qsigma'], debias=False) + uvis_bl_model = obs_model.unpack_bl(bl[0], bl[1], ['uvis', 'usigma'], debias=False) + + if snrcut_dict['m'] > 0: + amask = np.abs(m_bl['m'])/m_bl['msigma'] > snrcut_dict['m'] + m_bl = m_bl[amask] + m_bl_model = m_bl_model[amask] + if snrcut_dict['pvis'] > 0: + amask = np.abs(pvis_bl['pvis'])/pvis_bl['psigma'] > snrcut_dict['pvis'] + pvis_bl = pvis_bl[amask] + pvis_bl_model = pvis_bl_model[amask] + if snrcut_dict['qvis'] > 0: + amask = np.abs(qvis_bl['qvis'])/qvis_bl['qsigma'] > snrcut_dict['qvis'] + qvis_bl = qvis_bl[amask] + qvis_bl_model = qvis_bl_model[amask] + if snrcut_dict['uvis'] > 0: + amask = np.abs(uvis_bl['uvis'])/uvis_bl['usigma'] > snrcut_dict['uvis'] + uvis_bl = uvis_bl[amask] + uvis_bl_model = uvis_bl_model[amask] + + chisq_m_bl = np.sum(np.abs((m_bl['m'] - m_bl_model['m'])/m_bl['msigma'])**2) + npts_m = len(m_bl_model) + data_m = (bl[0], bl[1], npts_m, chisq_m_bl) + bl_chisq_data_m.append(data_m) + + chisq_pvis_bl = np.sum( + np.abs((pvis_bl['pvis'] - pvis_bl_model['pvis'])/pvis_bl['psigma'])**2) + npts_pvis = len(pvis_bl_model) + data_pvis = (bl[0], bl[1], npts_pvis, chisq_pvis_bl) + bl_chisq_data_pvis.append(data_pvis) + + chisq_qvis_bl = np.sum( + np.abs((qvis_bl['qvis'] - qvis_bl_model['qvis'])/qvis_bl['qsigma'])**2) + npts_qvis = len(qvis_bl_model) + data_qvis = (bl[0], bl[1], npts_qvis, chisq_qvis_bl) + bl_chisq_data_qvis.append(data_qvis) + + chisq_uvis_bl = np.sum( + np.abs((uvis_bl['uvis'] - uvis_bl_model['uvis'])/uvis_bl['usigma'])**2) + npts_uvis = len(uvis_bl_model) + data_uvis = (bl[0], bl[1], npts_uvis, chisq_uvis_bl) + bl_chisq_data_uvis.append(data_uvis) + + # sort by decreasing chi^2 + idx_m = np.argsort([data[-1] for data in bl_chisq_data_m]) + idx_m = list(reversed(idx_m)) + idx_p = np.argsort([data[-1] for data in bl_chisq_data_pvis]) + idx_p = list(reversed(idx_p)) + idx_q = np.argsort([data[-1] for data in bl_chisq_data_qvis]) + idx_q = list(reversed(idx_q)) + idx_u = np.argsort([data[-1] for data in bl_chisq_data_uvis]) + idx_u = list(reversed(idx_u)) + + chisqtab_m = (r"\begin{tabular}{ l|l|l|l } \hline Baseline & $N_{m}$ & " + + r"$\chi^2_{m}/N_{m}$ & $\chi^2_{m}/N_{total}$ \\ " + + r"\hline \hline") + chisqtab_p = (r"\begin{tabular}{ l|l|l|l } \hline Baseline & $N_{p}$ & " + + r"$\chi^2_{p}/N_{p}$ & $\chi^2_{p}/N_{total}$ \\ " + + r"\hline \hline") + chisqtab_q = (r"\begin{tabular}{ l|l|l|l } \hline Baseline & $N_{Q}$ & " + + r"$\chi^2_{Q}/N_{Q}$ & $\chi^2_{Q}/N_{total}$ \\ " + + r"\hline \hline") + chisqtab_u = (r"\begin{tabular}{ l|l|l|l } \hline Baseline & $N_{U}$ & " + + r"$\chi^2_{U}/N_{U}$ & $\chi^2_{U}/N_{total}$ \\ " + + r"\hline \hline") + + for i in range(len(bl_chisq_data_m)): + if i > 45: + break + data = bl_chisq_data_m[idx_m[i]] + tristr = r"%s-%s" % (data[0], data[1]) + nstr = r"%i" % data[2] + chisqstr = r"%0.1f" % data[3] + rchisqstr = r"%0.1f" % (data[3]/float(data[2])) + rrchisqstr = r"%0.3f" % (data[3]/float(n_bl)) + if i == 0: + chisqtab_m += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + else: + chisqtab_m += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + for i in range(len(bl_chisq_data_pvis)): + if i > 45: + break + data = bl_chisq_data_pvis[idx_p[i]] + tristr = r"%s-%s" % (data[0], data[1]) + nstr = r"%i" % data[2] + rchisqstr = r"%0.1f" % (data[3]/float(data[2])) + rrchisqstr = r"%0.3f" % (data[3]/float(n_bl)) + if i == 0: + chisqtab_p += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + else: + chisqtab_p += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + for i in range(len(bl_chisq_data_qvis)): + if i > 45: + break + data = bl_chisq_data_qvis[idx_q[i]] + tristr = r"%s-%s" % (data[0], data[1]) + nstr = r"%i" % data[2] + rchisqstr = r"%0.1f" % (data[3]/float(data[2])) + rrchisqstr = r"%0.3f" % (data[3]/float(n_bl)) + if i == 0: + chisqtab_q += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + else: + chisqtab_q += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + for i in range(len(bl_chisq_data_uvis)): + if i > 45: + break + data = bl_chisq_data_qvis[idx_u[i]] + tristr = r"%s-%s" % (data[0], data[1]) + nstr = r"%i" % data[2] + rchisqstr = r"%0.1f" % (data[3]/float(data[2])) + rrchisqstr = r"%0.3f" % (data[3]/float(n_bl)) + if i == 0: + chisqtab_u += r" " + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + else: + chisqtab_u += r" \\" + tristr + " & " + nstr + " & " + rchisqstr + " & " + rrchisqstr + + chisqtab_m += r" \end{tabular}" + chisqtab_p += r" \end{tabular}" + chisqtab_q += r" \end{tabular}" + chisqtab_u += r" \end{tabular}" + + ax = plt.subplot(gs[0:3, 0:2]) + ax.set_title('baseline m statistics') + ax.set_yticks([]) + ax.set_xticks([]) + ax.text(0.5, .975, chisqtab_m, ha="center", va="top", transform=ax.transAxes, size=fontsize) + + ax = plt.subplot(gs[0:3, 2:5]) + ax.set_title('baseline P statistics') + ax.set_yticks([]) + ax.set_xticks([]) + ax.text(0.5, .975, chisqtab_p, ha="center", va="top", transform=ax.transAxes, size=fontsize) + + ax = plt.subplot(gs[3:6, 0:2]) + ax.set_title('baseline Q statistics') + ax.set_yticks([]) + ax.set_xticks([]) + ax.text(0.5, .975, chisqtab_q, ha="center", va="top", transform=ax.transAxes, size=fontsize) + + ax = plt.subplot(gs[3:6, 2:5]) + ax.set_title('baseline U statistics') + ax.set_yticks([]) + ax.set_xticks([]) + ax.text(0.5, .975, chisqtab_u, ha="center", va="top", transform=ax.transAxes, size=fontsize) + + # save the first page of the plot + print('saving pdf page 3') + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + ################################################################################ + # plot the baseline pol amps and phases + page = 4 + if mplots: + print("===========================================") + print("plotting fractional polarizatons") + fig = plt.figure(3, figsize=(18, 28), dpi=200) + plt.suptitle("Fractional Polarization Plots", y=.9, + va='center', fontsize=int(1.2*fontsize)) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + + obs_model.data['sigma'] *= 0 + obs_model.data['qsigma'] *= 0 + obs_model.data['usigma'] *= 0 + obs_model.data['vsigma'] *= 0 + + amax = 1.1*np.max(np.abs(np.abs(obs_model.unpack(['mamp'])['mamp']))) + obs_all = [obs, obs_model] + for nbl, bl in enumerate(uniquebl): + j = 0 + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'mamp', rangey=[0, amax], + markersize=MARKERSIZE, debias=False, + snrcut=snrcut_dict['m'], + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype='nfft', show=False, ebar=ebar) + ax.set_xlabel('') + j = 1 + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'mphase', rangey=[-180, 180], + + markersize=MARKERSIZE, debias=False, + snrcut=snrcut_dict['m'], + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype='nfft', show=False, ebar=ebar) + i += 1 + ax.set_xlabel('') + + if ax is None: + continue + + if i == 3: + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + fig = plt.figure(3, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + + if nbl == len(uniquebl): + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + if pplots: + print("===========================================") + print("plotting total polarizaton") + fig = plt.figure(3, figsize=(18, 28), dpi=200) + plt.suptitle("Total Polarization Plots", y=.9, va='center', fontsize=int(1.2*fontsize)) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + + obs_model.data['sigma'] *= 0 + obs_model.data['qsigma'] *= 0 + obs_model.data['usigma'] *= 0 + obs_model.data['vsigma'] *= 0 + + amax = 1.1*np.max(np.abs(np.abs(obs_model.unpack(['pamp'])['pamp']))) + obs_all = [obs, obs_model] + for nbl, bl in enumerate(uniquebl): + j = 0 + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'pamp', rangey=[0, amax], + markersize=MARKERSIZE, debias=False, + snrcut=snrcut_dict['pvis'], + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype='nfft', show=False, ebar=ebar) + ax.set_xlabel('') + j = 1 + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'pphase', rangey=[-180, 180], + markersize=MARKERSIZE, debias=False, + snrcut=snrcut_dict['pvis'], + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype='nfft', show=False, ebar=ebar) + i += 1 + ax.set_xlabel('') + + if ax is None: + continue + + if i == 3: + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + fig = plt.figure(3, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + + if nbl == len(uniquebl): + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + if qplots: + print("===========================================") + print("plotting Q fit") + fig = plt.figure(3, figsize=(18, 28), dpi=200) + plt.suptitle("Q Plots", y=.9, va='center', fontsize=int(1.2*fontsize)) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + + obs_model.data['sigma'] *= 0 + obs_model.data['qsigma'] *= 0 + obs_model.data['usigma'] *= 0 + obs_model.data['vsigma'] *= 0 + + amax = 1.1*np.max(np.abs(np.abs(obs_model.unpack(['qamp'])['qamp']))) + obs_all = [obs, obs_model] + for nbl, bl in enumerate(uniquebl): + j = 0 + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'qamp', rangey=[0, amax], + markersize=MARKERSIZE, debias=False, + snrcut=snrcut_dict['qvis'], + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype='nfft', show=False, ebar=ebar) + ax.set_xlabel('') + j = 1 + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'qphase', rangey=[-180, 180], + markersize=MARKERSIZE, debias=False, + snrcut=snrcut_dict['qvis'], + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype='nfft', show=False, ebar=ebar) + i += 1 + ax.set_xlabel('') + + if ax is None: + continue + + if i == 3: + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + fig = plt.figure(3, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + + if nbl == len(uniquebl): + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + if uplots: + print("===========================================") + print("plotting U fit") + fig = plt.figure(3, figsize=(18, 28), dpi=200) + plt.suptitle("U Plots", y=.9, va='center', fontsize=int(1.2*fontsize)) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + + obs_model.data['sigma'] *= 0 + obs_model.data['qsigma'] *= 0 + obs_model.data['usigma'] *= 0 + obs_model.data['vsigma'] *= 0 + + amax = 1.1*np.max(np.abs(np.abs(obs_model.unpack(['uamp'])['uamp']))) + obs_all = [obs, obs_model] + for nbl, bl in enumerate(uniquebl): + j = 0 + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'uamp', rangey=[0, amax], + markersize=MARKERSIZE, debias=False, + snrcut=snrcut_dict['uvis'], + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype='nfft', show=False, ebar=ebar) + ax.set_xlabel('') + j = 1 + ax = plt.subplot(gs[2*i:2*(i+1), 2*j:2*(j+1)]) + ax = plot_bl_obs_compare(obs_all, bl[0], bl[1], 'uphase', rangey=[-180, 180], + markersize=MARKERSIZE, debias=False, + snrcut=snrcut_dict['uvis'], + axis=ax, legend=False, clist=['k', ehc.SCOLORS[1]], + ttype='nfft', show=False, ebar=ebar) + i += 1 + ax.set_xlabel('') + + if ax is None: + continue + + if i == 3: + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + fig = plt.figure(3, figsize=(18, 28), dpi=200) + gs = gridspec.GridSpec(6, 4, wspace=WSPACE, hspace=HSPACE) + i = 0 + j = 0 + + if nbl == len(uniquebl): + print('saving pdf page %i' % page) + page += 1 + pdf.savefig(pad_inches=MARGINS, bbox_inches='tight') + plt.close() + + +def _display_img(im, beamparams=None, scale='linear', gamma=0.5, cbar_lims=False, + has_cbar=True, has_title=True, cfun='afmhot', dynamic_range=100, + axis=False, show=False, fontsize=FONTSIZE): + """display the figure on a given axis + cannot use im.display because it makes a new figure + """ + + interp = 'gaussian' + + if axis: + ax = axis + else: + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + + imvec = np.array(im.imvec).reshape(-1) + # flux unit is mJy/uas^2 + imvec = imvec * 1.e3 + fovfactor = im.xdim*im.psize*(1/ehc.RADPERUAS) + factor = (1./fovfactor)**2 / (1./im.xdim)**2 + imvec = imvec * factor + + imarr = (imvec).reshape(im.ydim, im.xdim) + unit = 'mJy/$\mu$ as$^2$' + if scale == 'log': + if (imarr < 0.0).any(): + print('clipping values less than 0') + imarr[imarr < 0.0] = 0.0 + imarr = np.log(imarr + np.max(imarr)/dynamic_range) + unit = 'log(' + unit + ')' + + if scale == 'gamma': + if (imarr < 0.0).any(): + print('clipping values less than 0') + imarr[imarr < 0.0] = 0.0 + imarr = (imarr + np.max(imarr)/dynamic_range)**(gamma) + unit = '(' + unit + ')^gamma' + + if cbar_lims: + imarr[imarr > cbar_lims[1]] = cbar_lims[1] + imarr[imarr < cbar_lims[0]] = cbar_lims[0] + + if cbar_lims: + ax = ax.imshow(imarr, cmap=plt.get_cmap(cfun), interpolation=interp, + vmin=cbar_lims[0], vmax=cbar_lims[1]) + else: + ax = ax.imshow(imarr, cmap=plt.get_cmap(cfun), interpolation=interp) + + if has_cbar: + cbar = plt.colorbar(ax, fraction=0.046, pad=0.04, format='%1.2g') + cbar.set_label(unit, fontsize=fontsize) + cbar.ax.xaxis.set_label_position('top') + cbar.ax.tick_params(labelsize=16) + if cbar_lims: + plt.clim(cbar_lims[0], cbar_lims[1]) + + if not(beamparams is None): + beamparams = [beamparams[0], beamparams[1], beamparams[2], + -.35*im.fovx(), -.35*im.fovy()] + beamimage = im.copy() + beamimage.imvec *= 0 + beamimage = beamimage.add_gauss(1, beamparams) + halflevel = 0.5*np.max(beamimage.imvec) + beamimarr = (beamimage.imvec).reshape(beamimage.ydim, beamimage.xdim) + plt.contour(beamimarr, levels=[halflevel], colors='w', linewidths=3) + ax = plt.gca() + + plt.axis('off') + fov_uas = im.xdim * im.psize / ehc.RADPERUAS # get the fov in uas + roughfactor = 1./3. # make the bar about 1/3 the fov + fov_scale = 40 + start = im.xdim * roughfactor / 3.0 # select the start location + end = start + fov_scale/fov_uas * im.xdim # determine the end location + plt.plot([start, end], [im.ydim-start, im.ydim-start], color="white", lw=1) # plot line + plt.text(x=(start+end)/2.0, y=im.ydim-start-im.ydim/20, s=str(fov_scale) + " $\mu$as", + color="white", ha="center", va="center", + fontsize=int(1.2*fontsize), fontweight='bold') + + ax.axes.get_xaxis().set_visible(False) + ax.axes.get_yaxis().set_visible(False) + + if show: + #plt.show(block=False) + ehc.show_noblock() + + return ax + + +def _display_img_pol(im, beamparams=None, scale='linear', gamma=0.5, cbar_lims=False, + has_cbar=True, has_title=True, cfun='afmhot', pol=None, polticks=False, + nvec=False, pcut=0.1, mcut=0.01, contour=False, dynamic_range=100, + axis=False, show=False, fontsize=FONTSIZE): + """display the polarimetric figure on a given axis + cannot use im.display because it makes a new figure + """ + + interp = 'gaussian' + + if axis: + ax = axis + else: + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + + if pol == 'm': + imvec = im.mvec + unit = r'$|\breve{m}|$' + factor = 1 + cbar_lims = [0, 1] + elif pol == 'chi': + imvec = im.chivec / ehc.DEGREE + unit = r'$\chi (^\circ)$' + factor = 1 + cbar_lims = [0, 180] + else: + # flux unit is Tb + factor = 3.254e13/(im.rf**2 * im.psize**2) + unit = 'Tb (K)' + try: + imvec = np.array(im._imdict[pol]).reshape(-1) + except KeyError: + try: + if im.polrep == 'stokes': + im2 = im.switch_polrep('circ') + elif im.polrep == 'circ': + im2 = im.switch_polrep('stokes') + imvec = np.array(im2._imdict[pol]).reshape(-1) + except KeyError: + raise Exception("Cannot make pol %s image in display()!" % pol) + + # flux unit is Tb + imvec = imvec * factor + imarr = (imvec).reshape(im.ydim, im.xdim) + + if scale == 'log': + if (imarr < 0.0).any(): + print('clipping values less than 0') + imarr[imarr < 0.0] = 0.0 + imarr = np.log(imarr + np.max(imarr)/dynamic_range) + unit = 'log(' + unit + ')' + + if scale == 'gamma': + if (imarr < 0.0).any(): + print('clipping values less than 0') + imarr[imarr < 0.0] = 0.0 + imarr = (imarr + np.max(imarr)/dynamic_range)**(gamma) + unit = '(' + unit + ')^gamma' + + if cbar_lims: + imarr[imarr > cbar_lims[1]] = cbar_lims[1] + imarr[imarr < cbar_lims[0]] = cbar_lims[0] + + if cbar_lims: + ax = ax.imshow(imarr, cmap=plt.get_cmap(cfun), interpolation=interp, + vmin=cbar_lims[0], vmax=cbar_lims[1]) + else: + ax = ax.imshow(imarr, cmap=plt.get_cmap(cfun), interpolation=interp) + + if contour: + plt.contour(imarr, colors='k', linewidths=.25) + + if polticks: + im_stokes = im.switch_polrep(polrep_out='stokes') + ivec = np.array(im_stokes.imvec).reshape(-1) + qvec = np.array(im_stokes.qvec).reshape(-1) + uvec = np.array(im_stokes.uvec).reshape(-1) + vvec = np.array(im_stokes.vvec).reshape(-1) + + if len(ivec) == 0: + ivec = np.zeros(im_stokes.ydim*im_stokes.xdim) + if len(qvec) == 0: + qvec = np.zeros(im_stokes.ydim*im_stokes.xdim) + if len(uvec) == 0: + uvec = np.zeros(im_stokes.ydim*im_stokes.xdim) + if len(vvec) == 0: + vvec = np.zeros(im_stokes.ydim*im_stokes.xdim) + + if not nvec: + nvec = im.xdim // 2 + + thin = im.xdim//nvec + maska = (ivec).reshape(im.ydim, im.xdim) > pcut * np.max(ivec) + maskb = (np.abs(qvec + 1j*uvec)/ivec).reshape(im.ydim, im.xdim) > mcut + mask = maska * maskb + mask2 = mask[::thin, ::thin] + x = (np.array([[i for i in range(im.xdim)] for j in range(im.ydim)])[::thin, ::thin]) + x = x[mask2] + y = (np.array([[j for i in range(im.xdim)] for j in range(im.ydim)])[::thin, ::thin]) + y = y[mask2] + a = (-np.sin(np.angle(qvec+1j*uvec)/2).reshape(im.ydim, im.xdim)[::thin, ::thin]) + a = a[mask2] + b = (np.cos(np.angle(qvec+1j*uvec)/2).reshape(im.ydim, im.xdim)[::thin, ::thin]) + b = b[mask2] + + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.01*im.xdim, units='x', pivot='mid', color='k', angles='uv', + scale=1.0/thin) + plt.quiver(x, y, a, b, + headaxislength=20, headwidth=1, headlength=.01, minlength=0, minshaft=1, + width=.005*im.xdim, units='x', pivot='mid', color='w', angles='uv', + scale=1.1/thin) + + if has_cbar: + cbar = plt.colorbar(ax, fraction=0.046, pad=0.04, format='%1.2g') + cbar.set_label(unit, fontsize=fontsize) + cbar.ax.xaxis.set_label_position('top') + cbar.ax.tick_params(labelsize=16) + if cbar_lims: + ax.set_clim(cbar_lims[0], cbar_lims[1]) + + if not(beamparams is None): + beamparams = [beamparams[0], beamparams[1], beamparams[2], + -.35*im.fovx(), -.35*im.fovy()] + beamimage = im.copy() + beamimage.imvec *= 0 + beamimage = beamimage.add_gauss(1, beamparams) + halflevel = 0.5*np.max(beamimage.imvec) + beamimarr = (beamimage.imvec).reshape(beamimage.ydim, beamimage.xdim) + plt.contour(beamimarr, levels=[halflevel], colors='w', linewidths=3) + ax = plt.gca() + + plt.axis('off') + if has_cbar: + fov_uas = im.xdim * im.psize / ehc.RADPERUAS # get the fov in uas + roughfactor = 1./3. # make the bar about 1/3 the fov + fov_scale = 40 + # round around 1/3 the fov to nearest 10 + start = im.xdim * roughfactor / 3.0 # select the start location + end = start + fov_scale/fov_uas * im.xdim + # determine the end location based on the size of the bar + plt.plot([start, end], [im.ydim-start, im.ydim-start], color="white", lw=1) # plot line + plt.text(x=(start+end)/2.0, y=im.ydim-start-im.ydim/20, s=str(fov_scale) + " $\mu$as", + color="white", ha="center", va="center", + fontsize=int(1.2*fontsize), fontweight='bold') + + ax.axes.get_xaxis().set_visible(False) + ax.axes.get_yaxis().set_visible(False) + + if show: + #plt.show(block=False) + ehc.show_noblock() + + return ax diff --git a/scattering/__init__.py b/scattering/__init__.py new file mode 100644 index 00000000..3ad5e415 --- /dev/null +++ b/scattering/__init__.py @@ -0,0 +1,11 @@ +""" +.. module:: ehtim.stochastic_optics + :platform: Unix + :synopsis: EHT Imaging Utilities: imaging functions + +.. moduleauthor:: Michael Johnson (mjohnson@cfa.harvard.edu) + +""" + +from .stochastic_optics import * +from ..const_def import * diff --git a/scattering/stochastic_optics.py b/scattering/stochastic_optics.py new file mode 100644 index 00000000..d328d921 --- /dev/null +++ b/scattering/stochastic_optics.py @@ -0,0 +1,817 @@ +# Michael Johnson, 2/15/2017 +# See http://adsabs.harvard.edu/abs/2016ApJ...833...74J for details about this module + +from __future__ import print_function +from builtins import range +from builtins import object +import numpy as np +import scipy.signal +import scipy.special as sps +import scipy.integrate as integrate +from scipy.optimize import minimize + +import matplotlib.pyplot as plt + +import ehtim.image as image +import ehtim.movie as movie +import ehtim.obsdata as obsdata +from ehtim.observing.obs_helpers import * +from ehtim.const_def import * #Note: C is m/s rather than cm/s. + +from multiprocessing import cpu_count +from multiprocessing import Pool + +import math +import cmath + +################################################################################ +# The class ScatteringModel enscompasses a generic scattering model, determined by the power spectrum Q and phase structure function Dphi +################################################################################ + +class ScatteringModel(object): + """A scattering model based on a thin-screen approximation. + + Models include: + ('von_Mises', 'boxcar', 'dipole'): These scattering models are motivated by observations of Sgr A*. + Each gives a Gaussian at long wavelengths that matches the model defined + by {theta_maj_mas_ref, theta_min_mas_ref, POS_ANG} at the reference wavelength wavelength_reference_cm + with a lambda^2 scaling. The source sizes {theta_maj, theta_min} are the image FWHM in milliarcseconds + at the reference wavelength. Note that this may not match the ensemble-average kernel at the reference wavelength, + if the reference wavelength is short enough to be beyond the lambda^2 regime! + This model also includes an inner and outer scale and will thus transition to scattering with scatt_alpha at shorter wavelengths + Note: This model *requires* a finite inner scale + 'power-law': This scattering model gives a pure power law at all wavelengths. There is no inner scale, but there can be an outer scale. + The ensemble-average image is given by {theta_maj_mas_ref, theta_min_mas_ref, POS_ANG} at the reference wavelength wavelength_reference_cm. + The ensemble-average image size is proportional to wavelength^(1+2/scatt_alpha) = wavelength^(11/5) for Kolmogorov + + Attributes: + model (string): The type of scattering model (determined by the power spectrum of phase fluctuations). + scatt_alpha (float): The power-law index of the phase fluctuations (Kolmogorov is 5/3). + observer_screen_distance (float): The distance from the observer to the scattering screen in cm. + source_screen_distance (float): The distance from the source to the scattering screen in cm. + theta_maj_mas_ref (float): FWHM in mas of the major axis angular broadening at the specified reference wavelength. + theta_min_mas_ref (float): FWHM in mas of the minor axis angular broadening at the specified reference wavelength. + POS_ANG (float): The position angle of the major axis of the scattering. + wavelength_reference_cm (float): The reference wavelength for the scattering model in cm. + r_in (float): The inner scale of the scattering screen in cm. + r_out (float): The outer scale of the scattering screen in cm. + rF (function): The Fresnel scale of the scattering screen at the specific wavelength. + """ + + def __init__(self, model = 'dipole', scatt_alpha = 1.38, observer_screen_distance = 2.82 * 3.086e21, source_screen_distance = 5.53 * 3.086e21, theta_maj_mas_ref = 1.380, theta_min_mas_ref = 0.703, POS_ANG = 81.9, wavelength_reference_cm = 1.0, r_in = 800e5, r_out = 1e20): + """To initialize the scattering model, specify: + + Attributes: + model (string): The type of scattering model (determined by the power spectrum of phase fluctuations). Options are 'von_Mises', 'boxcar', 'dipole', and 'power-law' + scatt_alpha (float): The power-law index of the phase fluctuations (Kolmogorov is 5/3). + observer_screen_distance (float): The distance from the observer to the scattering screen in cm. + source_screen_distance (float): The distance from the source to the scattering screen in cm. + theta_maj_mas_ref (float): FWHM in mas of the major axis angular broadening at the specified reference wavelength. + theta_min_mas_ref (float): FWHM in mas of the minor axis angular broadening at the specified reference wavelength. + POS_ANG (float): The position angle of the major axis of the scattering. + wavelength_reference_cm (float): The reference wavelength for the scattering model in cm. + r_in (float): The inner scale of the scattering screen in cm. + r_out (float): The outer scale of the scattering screen in cm. + """ + + self.model = model + self.POS_ANG = POS_ANG #Major axis position angle [degrees, east of north] + self.observer_screen_distance = observer_screen_distance #cm + self.source_screen_distance = source_screen_distance #cm + M = observer_screen_distance/source_screen_distance + self.wavelength_reference = wavelength_reference_cm #Reference wavelength [cm] + self.r_in = r_in #inner scale [cm] + self.r_out = r_out #outer scale [cm] + self.scatt_alpha = scatt_alpha + + FWHM_fac = (2.0 * np.log(2.0))**0.5/np.pi + self.Qbar = 2.0/sps.gamma((2.0 - self.scatt_alpha)/2.0) * (self.r_in**2*(1.0 + M)/(FWHM_fac*(self.wavelength_reference/(2.0*np.pi))**2) )**2 * ( (theta_maj_mas_ref**2 + theta_min_mas_ref**2)*(1.0/1000.0/3600.0*np.pi/180.0)**2) + self.C_scatt_0 = (self.wavelength_reference/(2.0*np.pi))**2 * self.Qbar*sps.gamma(1.0 - self.scatt_alpha/2.0)/(8.0*np.pi**2*self.r_in**2) + A = theta_maj_mas_ref/theta_min_mas_ref # Anisotropy, >=1, as lambda->infinity + self.phi0 = (90 - self.POS_ANG) * np.pi/180.0 + + # Parameters for the approximate phase structure function + theta_maj_rad_ref = theta_maj_mas_ref/1000.0/3600.0*np.pi/180.0 + theta_min_rad_ref = theta_min_mas_ref/1000.0/3600.0*np.pi/180.0 + self.Amaj_0 = ( self.r_in*(1.0 + M) * theta_maj_rad_ref/(FWHM_fac * (self.wavelength_reference/(2.0*np.pi)) * 2.0*np.pi ))**2 + self.Amin_0 = ( self.r_in*(1.0 + M) * theta_min_rad_ref/(FWHM_fac * (self.wavelength_reference/(2.0*np.pi)) * 2.0*np.pi ))**2 + + if model == 'von_Mises': + def avM_Anisotropy(kzeta): + return np.abs( (kzeta*sps.i0(kzeta)/sps.i1(kzeta) - 1.0)**0.5 - A ) + + self.kzeta = minimize(avM_Anisotropy, A**2, method='nelder-mead', options={'xtol': 1e-8, 'disp': False}).x + self.P_phi_prefac = 1.0/(2.0*np.pi*sps.i0(self.kzeta)) + elif model == 'boxcar': + def boxcar_Anisotropy(kzeta): + return np.abs( np.sin(np.pi/(1.0 + kzeta))/(np.pi/(1.0 + kzeta)) - (theta_maj_mas_ref**2 - theta_min_mas_ref**2)/(theta_maj_mas_ref**2 + theta_min_mas_ref**2) ) + + self.kzeta = minimize(boxcar_Anisotropy, A, method='nelder-mead', options={'xtol': 1e-8, 'disp': False}).x + self.P_phi_prefac = (1.0 + self.kzeta)/(2.0*np.pi) + elif model == 'dipole': + def dipole_Anisotropy(kzeta): + return np.abs( sps.hyp2f1((self.scatt_alpha + 2.0)/2.0, 0.5, 2.0, -kzeta)/sps.hyp2f1((self.scatt_alpha + 2.0)/2.0, 1.5, 2.0, -kzeta) - A**2 ) + + self.kzeta = minimize(dipole_Anisotropy, A, method='nelder-mead', options={'xtol': 1e-8, 'disp': False}).x + self.P_phi_prefac = 1.0/(2.0*np.pi*sps.hyp2f1((self.scatt_alpha + 2.0)/2.0, 0.5, 1.0, -self.kzeta)) + else: + print("Scattering Model Not Recognized!") + + # More parameters for the approximate phase structure function + int_maj = integrate.quad(lambda phi_q: np.abs( np.cos( self.phi0 - phi_q ) )**self.scatt_alpha * self.P_phi(phi_q), 0, 2.0*np.pi, limit=250)[0] + int_min = integrate.quad(lambda phi_q: np.abs( np.sin( self.phi0 - phi_q ) )**self.scatt_alpha * self.P_phi(phi_q), 0, 2.0*np.pi, limit=250)[0] + B_prefac = self.C_scatt_0 * 2.0**(2.0 - self.scatt_alpha) * np.pi**0.5/(self.scatt_alpha * sps.gamma((self.scatt_alpha + 1.0)/2.0)) + self.Bmaj_0 = B_prefac*int_maj + self.Bmin_0 = B_prefac*int_min + + #Check normalization: + #print("Checking Normalization:",integrate.quad(lambda phi_q: self.P_phi(phi_q), 0, 2.0*np.pi)[0]) + + return + + def P_phi(self, phi): + if self.model == 'von_Mises': + return self.P_phi_prefac * np.cosh(self.kzeta*np.cos(phi - self.phi0)) + elif self.model == 'boxcar': + return self.P_phi_prefac * (1.0 - ((np.pi/(2.0*(1.0 + self.kzeta)) < (phi - self.phi0) % np.pi) & ((phi - self.phi0) % np.pi < np.pi - np.pi/(2.0*(1.0 + self.kzeta))))) + elif self.model == 'dipole': + return self.P_phi_prefac * (1.0 + self.kzeta*np.sin(phi - self.phi0)**2)**(-(self.scatt_alpha + 2.0)/2.0) + + def rF(self, wavelength): + """Returns the Fresnel scale [cm] of the scattering screen at the specified wavelength [cm]. + + Args: + wavelength (float): The desired wavelength [cm] + + Returns: + rF (float): The Fresnel scale [cm] + """ + return (self.source_screen_distance*self.observer_screen_distance/(self.source_screen_distance + self.observer_screen_distance)*wavelength/(2.0*np.pi))**0.5 + + def Mag(self): + """Returns the effective magnification the scattering screen: (observer-screen distance)/(source-screen distance). + + Returns: + M (float): The effective magnification of the scattering screen. + """ + return self.observer_screen_distance/self.source_screen_distance + + def dDphi_dz(self, r, phi, phi_q, wavelength): + """differential contribution to the phase structure function + """ + return 4.0 * (wavelength/self.wavelength_reference)**2 * self.C_scatt_0/self.scatt_alpha * (sps.hyp1f1(-self.scatt_alpha/2.0, 0.5, -r**2/(4.0*self.r_in**2)*np.cos(phi - phi_q)**2) - 1.0) + + def Dphi_exact(self, x, y, wavelength_cm): + r = (x**2 + y**2)**0.5 + phi = np.arctan2(y, x) + + return integrate.quad(lambda phi_q: self.dDphi_dz(r, phi, phi_q, wavelength_cm)*self.P_phi(phi_q), 0, 2.0*np.pi)[0] + + def Dmaj(self, r, wavelength_cm): + return (wavelength_cm/self.wavelength_reference)**2 * self.Bmaj_0 * (2.0 * self.Amaj_0/(self.scatt_alpha * self.Bmaj_0))**(-self.scatt_alpha/(2.0 - self.scatt_alpha)) * ((1.0 + (2.0*self.Amaj_0/(self.scatt_alpha * self.Bmaj_0))**(2.0/(2.0 - self.scatt_alpha)) * (r/self.r_in)**2 )**(self.scatt_alpha/2.0) - 1.0) + + def Dmin(self, r, wavelength_cm): + return (wavelength_cm/self.wavelength_reference)**2 * self.Bmin_0 * (2.0 * self.Amin_0/(self.scatt_alpha * self.Bmin_0))**(-self.scatt_alpha/(2.0 - self.scatt_alpha)) * ((1.0 + (2.0*self.Amin_0/(self.scatt_alpha * self.Bmin_0))**(2.0/(2.0 - self.scatt_alpha)) * (r/self.r_in)**2 )**(self.scatt_alpha/2.0) - 1.0) + + def Dphi_approx(self, x, y, wavelength_cm): + r = (x**2 + y**2)**0.5 + phi = np.arctan2(y, x) + + Dmaj_eval = self.Dmaj(r, wavelength_cm) + Dmin_eval = self.Dmin(r, wavelength_cm) + + return (Dmaj_eval + Dmin_eval)/2.0 + (Dmaj_eval - Dmin_eval)/2.0*np.cos(2.0*(phi - self.phi0)) + + def Q(self, qx, qy): + """Computes the power spectrum of the scattering model at a wavenumber {qx,qy} (in 1/cm). + The power spectrum is part of what defines the scattering model (along with Dphi). + Q(qx,qy) is independent of the observing wavelength. + + Args: + qx (float): x coordinate of the wavenumber in 1/cm. + qy (float): y coordinate of the wavenumber in 1/cm. + Returns: + (float): The power spectrum Q(qx,qy) + """ + + q = (qx**2 + qy**2)**0.5 + 1e-12/self.r_in #Add a small offset to avoid division by zero + phi_q = np.arctan2(qy, qx) + + return self.Qbar * (q*self.r_in)**(-(self.scatt_alpha + 2.0)) * np.exp(-(q * self.r_in)**2) * self.P_phi(phi_q) + + + def sqrtQ_Matrix(self, Reference_Image, Vx_km_per_s=50.0, Vy_km_per_s=0.0, t_hr=0.0): + """Computes the square root of the power spectrum on a discrete grid. Because translation of the screen is done most conveniently in Fourier space, a screen translation can also be included. + + Args: + Reference_Image (Image): Reference image to determine image and pixel dimensions and wavelength. + Vx_km_per_s (float): Velocity of the scattering screen in the x direction (toward East) in km/s. + Vy_km_per_s (float): Velocity of the scattering screen in the y direction (toward North) in km/s. + t_hr (float): The current time of the scattering in hours. + Returns: + sqrtQ (2D complex ndarray): The square root of the power spectrum of the screen with an additional phase for rotation of the screen. + """ + + #Derived parameters + FOV = Reference_Image.psize * Reference_Image.xdim * self.observer_screen_distance #Field of view, in cm, at the scattering screen + N = Reference_Image.xdim + dq = 2.0*np.pi/FOV #this is the spacing in wavenumber + screen_x_offset_pixels = (Vx_km_per_s * 1.e5) * (t_hr*3600.0) / (FOV/float(N)) + screen_y_offset_pixels = (Vy_km_per_s * 1.e5) * (t_hr*3600.0) / (FOV/float(N)) + + s, t = np.meshgrid(np.fft.fftfreq(N, d=1.0/N), np.fft.fftfreq(N, d=1.0/N)) + sqrtQ = np.sqrt(self.Q(dq*s, dq*t)) * np.exp(2.0*np.pi*1j*(s*screen_x_offset_pixels + + t*screen_y_offset_pixels)/float(N)) + sqrtQ[0][0] = 0.0 #A DC offset doesn't affect scattering + + return sqrtQ + + def Ensemble_Average_Kernel(self, Reference_Image, wavelength_cm = None, use_approximate_form=True): + """The ensemble-average convolution kernel for images; returns a 2D array corresponding to the image dimensions of the reference image + + Args: + Reference_Image (Image): Reference image to determine image and pixel dimensions and wavelength. + wavelength_cm (float): The observing wavelength for the scattering kernel in cm. If unspecified, this will default to the wavelength of the Reference image. + + Returns: + ker (2D ndarray): The ensemble-average scattering kernel in the image domain. + """ + + if wavelength_cm == None: + wavelength_cm = C/Reference_Image.rf*100.0 #Observing wavelength [cm] + + uvlist = np.fft.fftfreq(Reference_Image.xdim)/Reference_Image.psize # assume square kernel. FIXME: create ulist and vlist, and construct u_grid and v_grid with the correct dimension + if use_approximate_form == True: + u_grid, v_grid = np.meshgrid(uvlist, uvlist) + ker_uv = self.Ensemble_Average_Kernel_Visibility(u_grid, v_grid, wavelength_cm, use_approximate_form=use_approximate_form) + else: + ker_uv = np.array([[self.Ensemble_Average_Kernel_Visibility(u, v, wavelength_cm, use_approximate_form=use_approximate_form) for u in uvlist] for v in uvlist]) + + ker = np.real(np.fft.fftshift(np.fft.fft2(ker_uv))) + ker = ker / np.sum(ker) # normalize to 1 + return ker + + def Ensemble_Average_Kernel_Visibility(self, u, v, wavelength_cm, use_approximate_form=True): + """The ensemble-average multiplicative scattering kernel for visibilities at a particular {u,v} coordinate + + Args: + u (float): u baseline coordinate (dimensionless) + v (float): v baseline coordinate (dimensionless) + wavelength_cm (float): The observing wavelength for the scattering kernel in cm. + + Returns: + float: The ensemble-average kernel at the specified {u,v} point and observing wavelength. + """ + if use_approximate_form == True: + return np.exp(-0.5*self.Dphi_approx(u*wavelength_cm/(1.0+self.Mag()), v*wavelength_cm/(1.0+self.Mag()), wavelength_cm)) + else: + return np.exp(-0.5*self.Dphi_exact(u*wavelength_cm/(1.0+self.Mag()), v*wavelength_cm/(1.0+self.Mag()), wavelength_cm)) + + def Ensemble_Average_Blur(self, im, wavelength_cm = None, ker = None, use_approximate_form=True): + """Blurs an input Image with the ensemble-average scattering kernel. + + Args: + im (Image): The unscattered image. + wavelength_cm (float): The observing wavelength for the scattering kernel in cm. If unspecified, this will default to the wavelength of the input image. + ker (2D ndarray): The user can optionally pass a pre-computed ensemble-average blurring kernel. + + Returns: + out (Image): The ensemble-average scattered image. + """ + + # Inputs an unscattered image and an ensemble-average blurring kernel (2D array); returns the ensemble-average image + # The pre-computed kernel can optionally be specified (ker) + + if wavelength_cm == None: + wavelength_cm = C/im.rf*100.0 #Observing wavelength [cm] + + if ker is None: + ker = self.Ensemble_Average_Kernel(im, wavelength_cm, use_approximate_form) + + Iim = Wrapped_Convolve((im.imvec).reshape(im.ydim, im.xdim), ker) + out = image.Image(Iim, im.psize, im.ra, im.dec, rf=C/(wavelength_cm/100.0), source=im.source, mjd=im.mjd, pulse=im.pulse) + if len(im.qvec): + Qim = Wrapped_Convolve((im.qvec).reshape(im.ydim, im.xdim), ker) + Uim = Wrapped_Convolve((im.uvec).reshape(im.ydim, im.xdim), ker) + out.add_qu(Qim, Uim) + if len(im.vvec): + Vim = Wrapped_Convolve((im.vvec).reshape(im.ydim, im.xdim), ker) + out.add_v(Vim) + + return out + + def Deblur_obs(self, obs, use_approximate_form=True): + """Deblurs the observation obs by dividing visibilities by the ensemble-average scattering kernel. See Fish et al. (2014): arXiv:1409.4690. + + Args: + obs (Obsdata): The observervation data (including scattering). + + Returns: + obsdeblur (Obsdata): The deblurred observation. + """ + + # make a copy of observation data + datatable = (obs.copy()).data + + vis = datatable['vis'] + qvis = datatable['qvis'] + uvis = datatable['uvis'] + vvis = datatable['vvis'] + sigma = datatable['sigma'] + qsigma = datatable['qsigma'] + usigma = datatable['usigma'] + vsigma = datatable['vsigma'] + u = datatable['u'] + v = datatable['v'] + + # divide visibilities by the scattering kernel + for i in range(len(vis)): + ker = self.Ensemble_Average_Kernel_Visibility(u[i], v[i], wavelength_cm = C/obs.rf*100.0, use_approximate_form=use_approximate_form) + vis[i] = vis[i] / ker + qvis[i] = qvis[i] / ker + uvis[i] = uvis[i] / ker + vvis[i] = vvis[i] / ker + sigma[i] = sigma[i] / ker + qsigma[i] = qsigma[i] / ker + usigma[i] = usigma[i] / ker + vsigma[i] = vsigma[i] / ker + + datatable['vis'] = vis + datatable['qvis'] = qvis + datatable['uvis'] = uvis + datatable['vvis'] = vvis + datatable['sigma'] = sigma + datatable['qsigma'] = qsigma + datatable['usigma'] = usigma + datatable['vsigma'] = vsigma + + obsdeblur = obsdata.Obsdata(obs.ra, obs.dec, obs.rf, obs.bw, datatable, obs.tarr, source=obs.source, mjd=obs.mjd, + ampcal=obs.ampcal, phasecal=obs.phasecal, opacitycal=obs.opacitycal, dcal=obs.dcal, frcal=obs.frcal) + return obsdeblur + + def MakePhaseScreen(self, EpsilonScreen, Reference_Image, obs_frequency_Hz=0.0, Vx_km_per_s=50.0, Vy_km_per_s=0.0, t_hr=0.0, sqrtQ_init=None): + """Create a refractive phase screen from standardized Fourier components (the EpsilonScreen). + All lengths should be specified in centimeters + If the observing frequency (obs_frequency_Hz) is not specified, then it will be taken to be equal to the frequency of the Reference_Image + Note: an odd image dimension is required! + + Args: + EpsilonScreen (2D ndarray): Optionally, the scattering screen can be specified. If none is given, a random one will be generated. + Reference_Image (Image): The reference image. + obs_frequency_Hz (float): The observing frequency, in Hz. By default, it will be taken to be equal to the frequency of the Unscattered_Image. + Vx_km_per_s (float): Velocity of the scattering screen in the x direction (toward East) in km/s. + Vy_km_per_s (float): Velocity of the scattering screen in the y direction (toward North) in km/s. + t_hr (float): The current time of the scattering in hours. + ea_ker (2D ndarray): The used can optionally pass a precomputed array of the ensemble-average blurring kernel. + sqrtQ_init (2D ndarray): The used can optionally pass a precomputed array of the square root of the power spectrum. + + Returns: + phi_Image (Image): The phase screen. + """ + + #Observing wavelength + if obs_frequency_Hz == 0.0: + obs_frequency_Hz = Reference_Image.rf + + wavelength = C/obs_frequency_Hz*100.0 #Observing wavelength [cm] + wavelengthbar = wavelength/(2.0*np.pi) #lambda/(2pi) [cm] + + #Derived parameters + FOV = Reference_Image.psize * Reference_Image.xdim * self.observer_screen_distance #Field of view, in cm, at the scattering screen + rF = self.rF(wavelength) + Nx = EpsilonScreen.shape[1] + Ny = EpsilonScreen.shape[0] + +# if Nx%2 == 0: +# print("The image dimension should really be odd...") + + #Now we'll calculate the power spectrum for each pixel in Fourier space + screen_x_offset_pixels = (Vx_km_per_s*1.e5) * (t_hr*3600.0) / (FOV/float(Nx)) + screen_y_offset_pixels = (Vy_km_per_s*1.e5) * (t_hr*3600.0) / (FOV/float(Nx)) + + if sqrtQ_init is None: + sqrtQ = self.sqrtQ_Matrix(Reference_Image, Vx_km_per_s=Vx_km_per_s, Vy_km_per_s=Vy_km_per_s, t_hr=t_hr) + else: + #If a matrix for sqrtQ_init is passed, we still potentially need to rotate it + + if screen_x_offset_pixels != 0.0 or screen_y_offset_pixels != 0.0: + s, t = np.meshgrid(np.fft.fftfreq(Nx, d=1.0/Nx), np.fft.fftfreq(Ny, d=1.0/Ny)) + sqrtQ = sqrtQ_init * np.exp(2.0*np.pi*1j*(s*screen_x_offset_pixels + + t*screen_y_offset_pixels)/float(Nx)) + else: + sqrtQ = sqrtQ_init + + #Now calculate the phase screen + phi = np.real(wavelengthbar/FOV*EpsilonScreen.shape[0]*EpsilonScreen.shape[1]*np.fft.ifft2(sqrtQ*EpsilonScreen)) + phi_Image = image.Image(phi, Reference_Image.psize, Reference_Image.ra, Reference_Image.dec, rf=Reference_Image.rf, source=Reference_Image.source, mjd=Reference_Image.mjd) + + return phi_Image + + def Scatter2(self, args, kwargs): + """Call self.Scatter with expanded args and kwargs.""" + return self.Scatter(*args, **kwargs) + + def Scatter(self, Unscattered_Image, Epsilon_Screen=np.array([]), obs_frequency_Hz=0.0, Vx_km_per_s=50.0, Vy_km_per_s=0.0, t_hr=0.0, ea_ker=None, sqrtQ=None, Linearized_Approximation=False, DisplayImage=False, Force_Positivity=False, use_approximate_form=True): + """Scatter an image using the specified epsilon screen. + All lengths should be specified in centimeters + If the observing frequency (obs_frequency_Hz) is not specified, then it will be taken to be equal to the frequency of the Unscattered_Image + Note: an odd image dimension is required! + + Args: + Unscattered_Image (Image): The unscattered image. + Epsilon_Screen (2D ndarray): Optionally, the scattering screen can be specified. If none is given, a random one will be generated. + obs_frequency_Hz (float): The observing frequency, in Hz. By default, it will be taken to be equal to the frequency of the Unscattered_Image. + Vx_km_per_s (float): Velocity of the scattering screen in the x direction (toward East) in km/s. + Vy_km_per_s (float): Velocity of the scattering screen in the y direction (toward North) in km/s. + t_hr (float): The current time of the scattering in hours. + ea_ker (2D ndarray): The used can optionally pass a precomputed array of the ensemble-average blurring kernel. + sqrtQ (2D ndarray): The used can optionally pass a precomputed array of the square root of the power spectrum. + Linearized_Approximation (bool): If True, uses a linearized approximation for the scattering (Eq. 10 of Johnson & Narayan 2016). If False, uses Eq. 9 of that paper. + DisplayImage (bool): If True, show a plot of the unscattered, ensemble-average, and scattered images as well as the phase screen. + Force_Positivity (bool): If True, eliminates negative flux from the scattered image from the linearized approximation. + Return_Image_List (bool): If True, returns a list of the scattered frames. If False, returns a movie object. + + Returns: + AI_Image (Image): The scattered image. + """ + + #Observing wavelength + if obs_frequency_Hz == 0.0: + obs_frequency_Hz = Unscattered_Image.rf + + wavelength = C/obs_frequency_Hz*100.0 #Observing wavelength [cm] + wavelengthbar = wavelength/(2.0*np.pi) #lambda/(2pi) [cm] + + #Derived parameters + FOV = Unscattered_Image.psize * Unscattered_Image.xdim * self.observer_screen_distance #Field of view, in cm, at the scattering screen + rF = self.rF(wavelength) + Nx = Unscattered_Image.xdim + Ny = Unscattered_Image.ydim + + #First we need to calculate the ensemble-average image by blurring the unscattered image with the correct kernel + EA_Image = self.Ensemble_Average_Blur(Unscattered_Image, wavelength, ker = ea_ker, use_approximate_form=use_approximate_form) + + # If no epsilon screen is specified, then generate a random realization + if Epsilon_Screen.shape[0] == 0: + Epsilon_Screen = MakeEpsilonScreen(Nx, Ny) + + #We'll now calculate the phase screen. + phi_Image = self.MakePhaseScreen(Epsilon_Screen, Unscattered_Image, obs_frequency_Hz, Vx_km_per_s=Vx_km_per_s, Vy_km_per_s=Vy_km_per_s, t_hr=t_hr, sqrtQ_init=sqrtQ) + phi = phi_Image.imvec.reshape(Ny,Nx) + + #Next, we need the gradient of the ensemble-average image + phi_Gradient = Wrapped_Gradient(phi/(FOV/Nx)) + #The gradient signs don't actually matter, but let's make them match intuition (i.e., right to left, bottom to top) + phi_Gradient_x = -phi_Gradient[1] + phi_Gradient_y = -phi_Gradient[0] + + if Linearized_Approximation == True: #Use Equation 10 of Johnson & Narayan (2016) + #Calculate the gradient of the ensemble-average image + EA_Gradient = Wrapped_Gradient((EA_Image.imvec/(FOV/Nx)).reshape(EA_Image.ydim, EA_Image.xdim)) + #The gradient signs don't actually matter, but let's make them match intuition (i.e., right to left, bottom to top) + EA_Gradient_x = -EA_Gradient[1] + EA_Gradient_y = -EA_Gradient[0] + #Now we can patch together the average image + AI = (EA_Image.imvec).reshape(Ny,Nx) + rF**2.0 * ( EA_Gradient_x*phi_Gradient_x + EA_Gradient_y*phi_Gradient_y ) + if len(Unscattered_Image.qvec): + # Scatter the Q image + EA_Gradient = Wrapped_Gradient((EA_Image.qvec/(FOV/Nx)).reshape(EA_Image.ydim, EA_Image.xdim)) + EA_Gradient_x = -EA_Gradient[1] + EA_Gradient_y = -EA_Gradient[0] + AI_Q = (EA_Image.qvec).reshape(Ny,Nx) + rF**2.0 * ( EA_Gradient_x*phi_Gradient_x + EA_Gradient_y*phi_Gradient_y ) + # Scatter the U image + EA_Gradient = Wrapped_Gradient((EA_Image.uvec/(FOV/Nx)).reshape(EA_Image.ydim, EA_Image.xdim)) + EA_Gradient_x = -EA_Gradient[1] + EA_Gradient_y = -EA_Gradient[0] + AI_U = (EA_Image.uvec).reshape(Ny,Nx) + rF**2.0 * ( EA_Gradient_x*phi_Gradient_x + EA_Gradient_y*phi_Gradient_y ) + if len(Unscattered_Image.vvec): + # Scatter the V image + EA_Gradient = Wrapped_Gradient((EA_Image.vvec/(FOV/Nx)).reshape(EA_Image.ydim, EA_Image.xdim)) + EA_Gradient_x = -EA_Gradient[1] + EA_Gradient_y = -EA_Gradient[0] + AI_V = (EA_Image.vvec).reshape(Ny,Nx) + rF**2.0 * ( EA_Gradient_x*phi_Gradient_x + EA_Gradient_y*phi_Gradient_y ) + else: #Use Equation 9 of Johnson & Narayan (2016) + EA_im = (EA_Image.imvec).reshape(Ny,Nx) + AI = np.copy((EA_Image.imvec).reshape(Ny,Nx)) + if len(Unscattered_Image.qvec): + AI_Q = np.copy((EA_Image.imvec).reshape(Ny,Nx)) + AI_U = np.copy((EA_Image.imvec).reshape(Ny,Nx)) + EA_im_Q = (EA_Image.qvec).reshape(Ny,Nx) + EA_im_U = (EA_Image.uvec).reshape(Ny,Nx) + if len(Unscattered_Image.vvec): + AI_V = np.copy((EA_Image.imvec).reshape(Ny,Nx)) + EA_im_V = (EA_Image.vvec).reshape(Ny,Nx) + for rx in range(Nx): + for ry in range(Ny): + # Annoyingly, the signs here must be negative to match the other approximation. I'm not sure which is correct, but it really shouldn't matter anyway because -phi has the same power spectrum as phi. However, getting the *relative* sign for the x- and y-directions correct is important. + rxp = int(np.round(rx - rF**2.0 * phi_Gradient_x[ry,rx]/self.observer_screen_distance/Unscattered_Image.psize))%Nx + ryp = int(np.round(ry - rF**2.0 * phi_Gradient_y[ry,rx]/self.observer_screen_distance/Unscattered_Image.psize))%Ny + AI[ry,rx] = EA_im[ryp,rxp] + if len(Unscattered_Image.qvec): + AI_Q[ry,rx] = EA_im_Q[ryp,rxp] + AI_U[ry,rx] = EA_im_U[ryp,rxp] + if len(Unscattered_Image.vvec): + AI_V[ry,rx] = EA_im_V[ryp,rxp] + + #Optional: eliminate negative flux + if Force_Positivity == True: + AI = abs(AI) + + #Make it into a proper image format + AI_Image = image.Image(AI, EA_Image.psize, EA_Image.ra, EA_Image.dec, rf=EA_Image.rf, source=EA_Image.source, mjd=EA_Image.mjd) + if len(Unscattered_Image.qvec): + AI_Image.add_qu(AI_Q, AI_U) + if len(Unscattered_Image.vvec): + AI_Image.add_v(AI_V) + + if DisplayImage: + plot_scatt(Unscattered_Image.imvec, EA_Image.imvec, AI_Image.imvec, phi_Image.imvec, Unscattered_Image, 0, 0, ipynb=False) + + return AI_Image + + def Scatter_Movie(self, Unscattered_Movie, Epsilon_Screen=np.array([]), obs_frequency_Hz=0.0, Vx_km_per_s=50.0, Vy_km_per_s=0.0, framedur_sec=None, N_frames = None, ea_ker=None, sqrtQ=None, Linearized_Approximation=False, Force_Positivity=False, Return_Image_List=False, processes=0): + """Scatter a movie using the specified epsilon screen. The movie can either be a movie object, an image list, or a static image + If scattering a list of images or static image, the frame duration in seconds (framedur_sec) must be specified + If scattering a static image, the total number of frames must be specified (N_frames) + All lengths should be specified in centimeters + If the observing frequency (obs_frequency_Hz) is not specified, then it will be taken to be equal to the frequency of the Unscattered_Movie + Note: an odd image dimension is required! + + Args: + Unscattered_Movie: This can be a movie object, an image list, or a static image + Epsilon_Screen (2D ndarray): Optionally, the scattering screen can be specified. If none is given, a random one will be generated. + obs_frequency_Hz (float): The observing frequency, in Hz. By default, it will be taken to be equal to the frequency of the Unscattered_Movie. + Vx_km_per_s (float): Velocity of the scattering screen in the x direction (toward East) in km/s. + Vy_km_per_s (float): Velocity of the scattering screen in the y direction (toward North) in km/s. + framedur_sec (float): Duration of each frame, in seconds. Only needed if Unscattered_Movie is not a movie object. + N_frames (int): Total number of frames. Only needed if Unscattered_Movie is a static image object. + ea_ker (2D ndarray): The used can optionally pass a precomputed array of the ensemble-average blurring kernel. + sqrtQ (2D ndarray): The used can optionally pass a precomputed array of the square root of the power spectrum. + Linearized_Approximation (bool): If True, uses a linearized approximation for the scattering (Eq. 10 of Johnson & Narayan 2016). If False, uses Eq. 9 of that paper. + Force_Positivity (bool): If True, eliminates negative flux from the scattered image from the linearized approximation. + Return_Image_List (bool): If True, returns a list of the scattered frames. If False, returns a movie object. + processes (int): Number of cores to use in multiprocessing. Default value (0) means no multiprocessing. Uses all available cores if processes < 0. + + Returns: + Scattered_Movie: Either a movie object or a list of images, depending on the flag Return_Image_List. + """ + + print("Warning!! assuming a constant frame duration, but Movie objects now support unequally spaced frames!") + + if type(Unscattered_Movie) != movie.Movie and framedur_sec is None: + print("If scattering a list of images or static image, the framedur must be specified!") + return + + if type(Unscattered_Movie) == image.Image and N_frames is None: + print("If scattering a static image, the total number of frames must be specified (N_frames)!") + return + + # time list in hr + if hasattr(Unscattered_Movie, 'times'): + tlist_hr = Unscattered_Movie.times + else: + tlist_hr = [framedur_sec/3600.0*j for j in range(N_frames)] + + if type(Unscattered_Movie) == movie.Movie: + N = Unscattered_Movie.xdim + N_frames = len(Unscattered_Movie.frames) + psize = Unscattered_Movie.psize + ra = Unscattered_Movie.ra + dec = Unscattered_Movie.dec + rf = Unscattered_Movie.rf + pulse=Unscattered_Movie.pulse + source=Unscattered_Movie.source + mjd=Unscattered_Movie.mjd + start_hr=Unscattered_Movie.start_hr + has_pol = len(Unscattered_Movie.qframes) + has_circ_pol = len(Unscattered_Movie.vframes) + elif type(Unscattered_Movie) == list: + N = Unscattered_Movie[0].xdim + N_frames = len(Unscattered_Movie) + psize = Unscattered_Movie[0].psize + ra = Unscattered_Movie[0].ra + dec = Unscattered_Movie[0].dec + rf = Unscattered_Movie[0].rf + pulse=Unscattered_Movie[0].pulse + source=Unscattered_Movie[0].source + mjd=Unscattered_Movie[0].mjd + start_hr=0.0 + has_pol = len(Unscattered_Movie[0].qvec) + has_circ_pol = len(Unscattered_Movie[0].vvec) + else: + N = Unscattered_Movie.xdim + psize = Unscattered_Movie.psize + ra = Unscattered_Movie.ra + dec = Unscattered_Movie.dec + rf = Unscattered_Movie.rf + pulse=Unscattered_Movie.pulse + source=Unscattered_Movie.source + mjd=Unscattered_Movie.mjd + start_hr=0.0 + has_pol = len(Unscattered_Movie.qvec) + has_circ_pol = len(Unscattered_Movie.vvec) + + def get_frame(j): + if type(Unscattered_Movie) == movie.Movie: + im = image.Image(Unscattered_Movie.frames[j].reshape((N,N)), psize=psize, ra=ra, dec=dec, rf=rf, pulse=pulse, source=source, mjd=mjd) + if len(Unscattered_Movie.qframes) > 0: + im.add_qu(Unscattered_Movie.qframes[j].reshape((N,N)), Unscattered_Movie.uframes[j].reshape((N,N))) + if len(Unscattered_Movie.vframes) > 0: + im.add_v(Unscattered_Movie.vframes[j].reshape((N,N))) + return im + elif type(Unscattered_Movie) == list: + return Unscattered_Movie[j] + else: + return Unscattered_Movie + + #If it isn't specified, calculate the matrix sqrtQ for efficiency + if sqrtQ is None: + sqrtQ = self.sqrtQ_Matrix(get_frame(0)) + + # If no epsilon screen is specified, then generate a random realization + if Epsilon_Screen.shape[0] == 0: + Epsilon_Screen = MakeEpsilonScreen(N, N) + + # possibly parallelize + if processes < 0: + processes = cpu_count() + processes = min(processes, N_frames) + + # generate scattered images + if processes > 0: + pool = Pool(processes=processes) + args = [ + ( + [get_frame(j), Epsilon_Screen], + dict(obs_frequency_Hz = obs_frequency_Hz, Vx_km_per_s = Vx_km_per_s, Vy_km_per_s = Vy_km_per_s, t_hr=tlist_hr[j], sqrtQ=sqrtQ, Linearized_Approximation=Linearized_Approximation, Force_Positivity=Force_Positivity) + ) for j in range(N_frames) + ] + scattered_im_List = pool.starmap(self.Scatter2, args) + pool.close() + pool.join() + else: + scattered_im_List = [self.Scatter(get_frame(j), Epsilon_Screen, obs_frequency_Hz = obs_frequency_Hz, Vx_km_per_s = Vx_km_per_s, Vy_km_per_s = Vy_km_per_s, t_hr=tlist_hr[j], ea_ker=ea_ker, sqrtQ=sqrtQ, Linearized_Approximation=Linearized_Approximation, Force_Positivity=Force_Positivity) for j in range(N_frames)] + + if Return_Image_List == True: + return scattered_im_List + + Scattered_Movie = movie.Movie( [im.imvec.reshape((im.xdim,im.ydim)) for im in scattered_im_List], + times=tlist_hr, psize = psize, ra = ra, dec = dec, rf=rf, pulse=pulse, source=source, mjd=mjd) + + if has_pol: + Scattered_Movie_Q = [im.qvec.reshape((im.xdim,im.ydim)) for im in scattered_im_List] + Scattered_Movie_U = [im.uvec.reshape((im.xdim,im.ydim)) for im in scattered_im_List] + Scattered_Movie.add_qu(Scattered_Movie_Q, Scattered_Movie_U) + if has_circ_pol: + Scattered_Movie_V = [im.vvec.reshape((im.xdim,im.ydim)) for im in scattered_im_List] + Scattered_Movie.add_v(Scattered_Movie_V) + return Scattered_Movie + +################################################################################ +# These are helper functions +################################################################################ + +def Wrapped_Convolve(sig,ker): + N = sig.shape[0] + return scipy.signal.fftconvolve(np.pad(sig,((N, N), (N, N)), 'wrap'), np.pad(ker,((N, N), (N, N)), 'constant'),mode='same')[N:(2*N),N:(2*N)] + +def Wrapped_Gradient(M): + G = np.gradient(np.pad(M,((1, 1), (1, 1)), 'wrap')) + Gx = G[0][1:-1,1:-1] + Gy = G[1][1:-1,1:-1] + return (Gx, Gy) + +def MakeEpsilonScreenFromList(EpsilonList, N): + epsilon = np.zeros((N,N),dtype=np.complex) + #If N is odd: there are (N^2-1)/2 real elements followed by their corresponding (N^2-1)/2 imaginary elements + #If N is even: there are (N^2+2)/2 of each, although 3 of these must be purely real, also giving a total of N^2-1 degrees of freedom + #This is because of conjugation symmetry in Fourier space to ensure a real Fourier transform + + #The first (N-1)/2 are the top row + N_re = (N*N-1)//2 # FIXME: check logic if N is even + i = 0 + for x in range(1,(N+1)//2): # FIXME: check logic if N is even + epsilon[0][x] = EpsilonList[i] + 1j * EpsilonList[i+N_re] + epsilon[0][N-x] = np.conjugate(epsilon[0][x]) + i=i+1 + + #The next N(N-1)/2 are filling the next N rows + for y in range(1,(N+1)//2): # FIXME: check logic if N is even + for x in range(N): + epsilon[y][x] = EpsilonList[i] + 1j * EpsilonList[i+N_re] + + x2 = N - x + y2 = N - y + if x2 == N: + x2 = 0 + if y2 == N: + y2 = 0 + + epsilon[y2][x2] = np.conjugate(epsilon[y][x]) + i=i+1 + + return epsilon + +def MakeEpsilonScreen(Nx, Ny, rngseed = 0): + """Create a standardized Fourier representation of a scattering screen + + Args: + Nx (int): Number of pixels in the x direction + Ny (int): Number of pixels in the y direction + rngseed (int): Seed for the random number generator + + Returns: + epsilon: A 2D numpy ndarray. + """ + + if rngseed != 0: + np.random.seed( rngseed ) + + epsilon = np.random.normal(loc=0.0, scale=1.0/math.sqrt(2), size=(Ny,Nx)) + 1j * np.random.normal(loc=0.0, scale=1.0/math.sqrt(2), size=(Ny,Nx)) + + # The zero frequency doesn't affect scattering + epsilon[0][0] = 0.0 + + #Now let's ensure that it has the necessary conjugation symmetry + if Nx%2 == 0: + epsilon[0][Nx//2] = np.real(epsilon[0][Nx//2]) + if Ny%2 == 0: + epsilon[Ny//2][0] = np.real(epsilon[Ny//2][0]) + if Nx%2 == 0 and Ny%2 == 0: + epsilon[Ny//2][Nx//2] = np.real(epsilon[Ny//2][Nx//2]) + + for x in range(Nx): + if x > (Nx-1)//2: + epsilon[0][x] = np.conjugate(epsilon[0][Nx-x]) + for y in range((Ny-1)//2, Ny): + x2 = Nx - x + y2 = Ny - y + if x2 == Nx: + x2 = 0 + if y2 == Ny: + y2 = 0 + epsilon[y][x] = np.conjugate(epsilon[y2][x2]) + + return epsilon + +################################################################################################## +# Plotting Functions +################################################################################################## + +def plot_scatt(im_unscatt, im_ea, im_scatt, im_phase, Prior, nit, chi2, ipynb=False): + # Get vectors and ratio from current image + x = np.array([[i for i in range(Prior.xdim)] for j in range(Prior.ydim)]) + y = np.array([[j for i in range(Prior.xdim)] for j in range(Prior.ydim)]) + + # Create figure and title + plt.ion() + plt.clf() + if chi2 > 0.0: + plt.suptitle("step: %i $\chi^2$: %f " % (nit, chi2), fontsize=20) + + # Unscattered Image + plt.subplot(141) + plt.imshow(im_unscatt.reshape(Prior.ydim, Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian', vmin=0) + xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6) + yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + plt.title('Unscattered') + + # Ensemble Average + plt.subplot(142) + plt.imshow(im_ea.reshape(Prior.ydim, Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian', vmin=0) + xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6) + yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + plt.title('Ensemble Average') + + # Scattered + plt.subplot(143) + plt.imshow(im_scatt.reshape(Prior.ydim, Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian', vmin=0) + xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6) + yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + plt.title('Average Image') + + # Phase + plt.subplot(144) + plt.imshow(im_phase.reshape(Prior.ydim, Prior.xdim), cmap=plt.get_cmap('afmhot'), interpolation='gaussian') + xticks = ticks(Prior.xdim, Prior.psize/RADPERAS/1e-6) + yticks = ticks(Prior.ydim, Prior.psize/RADPERAS/1e-6) + plt.xticks(xticks[0], xticks[1]) + plt.yticks(yticks[0], yticks[1]) + plt.xlabel('Relative RA ($\mu$as)') + plt.ylabel('Relative Dec ($\mu$as)') + plt.title('Phase Screen') + + # Display + plt.draw() diff --git a/statistics/__init__.py b/statistics/__init__.py new file mode 100644 index 00000000..aa83fb1a --- /dev/null +++ b/statistics/__init__.py @@ -0,0 +1,12 @@ +""" +.. module:: ehtim.stats + :platform: Unix + :synopsis: EHT Imaging Utilities: statistics and DataFrame format + +.. moduleauthor:: Maciek Wielgus (mwielgus@cfa.harvard.edu) + +""" +from . import stats +from . import dataframes + + diff --git a/statistics/dataframes.py b/statistics/dataframes.py new file mode 100644 index 00000000..81781872 --- /dev/null +++ b/statistics/dataframes.py @@ -0,0 +1,1044 @@ +# DataFrames.py +# variety of statistical functions useful for +# +# Copyright (C) 2018 Maciek Wielgus +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function +from builtins import str +from builtins import map +from builtins import range + +import numpy as np + +try: + import pandas as pd +except ImportError: + print("Warning: pandas not installed!") + print("Please install pandas to use statistics package!") + +import datetime as datetime +from astropy.time import Time +from ehtim.statistics.stats import * + +def make_df(obs,polarization='unknown',band='unknown',round_s=0.1): + + """converts visibilities from obs.data to DataFrame format + + Args: + obs: ObsData object + round_s: accuracy of datetime object in seconds + + Returns: + df: observation visibility data in DataFrame format + """ + sour=obs.source + df = pd.DataFrame(data=obs.data) + df['fmjd'] = df['time']/24. + df['mjd'] = obs.mjd + df['fmjd'] + telescopes = list(zip(df['t1'],df['t2'])) + telescopes = [(x[0],x[1]) for x in telescopes] + df['baseline'] = [x[0]+'-'+x[1] for x in telescopes] + if obs.polrep=='stokes': + vis1='vis'; sig1='sigma' + elif obs.polrep=='circ': + vis1='rrvis'; sig1='rrsigma' + df['vis']=df[vis1] + df['sigma']=df[sig1] + df['rramp']=np.abs(df['rrvis']) + df['llamp']=np.abs(df['llvis']) + df['rlamp']=np.abs(df['rlvis']) + df['lramp']=np.abs(df['lrvis']) + df['rrsnr']=df['rramp']/df['rrsigma'] + df['llsnr']=df['llamp']/df['llsigma'] + df['rlsnr']=df['rlamp']/df['rlsigma'] + df['lrsnr']=df['lramp']/df['lrsigma'] + #df = df.dropna(subset=['rrvis', 'llvis','rrsigma','llsigma']) + df['amp'] = list(map(np.abs,df[vis1])) + df['phase'] = list(map(lambda x: (180./np.pi)*np.angle(x),df[vis1])) + df['snr'] = df['amp']/df[sig1] + df['datetime'] = Time(df['mjd'], format='mjd').datetime + df['datetime'] =list(map(lambda x: round_time(x,round_s=round_s),df['datetime'])) + df['jd'] = Time(df['mjd'], format='mjd').jd + df['polarization'] = polarization + df['band'] = band + df['source'] = sour + df['baselength'] = np.sqrt(np.asarray(df.u)**2+np.asarray(df.v)**2) + return df + + +def make_amp(obs,debias=True,polarization='unknown',band='unknown',round_s=0.1): + + """converts visibilities from obs.data to amplitudes inDataFrame format + + Args: + obs: ObsData object + debias (str): whether to debias the amplitudes + round_s: accuracy of datetime object in seconds + + Returns: + df: observation visibility data in DataFrame format + """ + sour=obs.source + df = pd.DataFrame(data=obs.data) + df['fmjd'] = df['time']/24. + df['mjd'] = obs.mjd + df['fmjd'] + telescopes = list(zip(df['t1'],df['t2'])) + telescopes = [(x[0],x[1]) for x in telescopes] + df['baseline'] = [x[0]+'-'+x[1] for x in telescopes] + df['amp'] = list(map(np.abs,df['vis'])) + if debias==True: + amp2 = np.maximum(np.asarray(df['amp'])**2-np.asarray(df['sigma'])**2,np.asarray(df['sigma'])**2) + df['amp'] = np.sqrt(amp2) + df['phase'] = list(map(lambda x: (180./np.pi)*np.angle(x),df['vis'])) + df['datetime'] = Time(df['mjd'], format='mjd').datetime + df['datetime'] =list(map(lambda x: round_time(x,round_s=round_s),df['datetime'])) + df['jd'] = Time(df['mjd'], format='mjd').jd + df['polarization'] = polarization + df['band'] = band + df['snr'] = df['amp']/df['sigma'] + df['source'] = sour + df['baselength'] = np.sqrt(np.asarray(df.u)**2+np.asarray(df.v)**2) + return df + +def coh_avg_vis(obs,dt=0,scan_avg=False,return_type='rec',err_type='predicted',num_samples=int(1e3)): + """coherently averages visibilities + Args: + obs: ObsData object + dt (float): integration time in seconds + return_type (str): 'rec' for numpy record array (as used by ehtim), 'df' for data frame + err_type (str): 'predicted' for modeled error, 'measured' for bootstrap empirical variability estimator + num_samples: 'bootstrap' resample set size for measured error + scan_avg (bool): should scan-long averaging be performed. If True, overrides dt + Returns: + vis_avg: coherently averaged visibilities + """ + if (dt<=0)&(scan_avg==False): + return obs.data + else: + vis = make_df(obs) + if scan_avg==False: + #TODO + #we don't have to work on datetime products at all + #change it to only use 'time' in mjd + t0 = datetime.datetime(1960,1,1) + vis['round_time'] = list(map(lambda x: np.floor((x- t0).total_seconds()/float(dt)),vis.datetime)) + grouping=['tau1','tau2','polarization','band','baseline','t1','t2','round_time'] + else: + bins, labs = get_bins_labels(obs.scans) + vis['scan'] = list(pd.cut(vis.time, bins,labels=labs)) + grouping=['tau1','tau2','polarization','band','baseline','t1','t2','scan'] + #column just for counting the elements + vis['number'] = 1 + aggregated = {'datetime': np.min, 'time': np.min, + 'number': lambda x: len(x), 'u':np.mean, 'v':np.mean,'tint': np.sum} + + if err_type not in ['measured', 'predicted']: + print("Error type can only be 'predicted' or 'measured'! Assuming 'predicted'.") + err_type='predicted' + + if obs.polrep=='stokes': + vis1='vis'; vis2='qvis'; vis3='uvis'; vis4='vvis' + sig1='sigma'; sig2='qsigma'; sig3='usigma'; sig4='vsigma' + elif obs.polrep=='circ': + vis1='rrvis'; vis2='llvis'; vis3='rlvis'; vis4='lrvis' + sig1='rrsigma'; sig2='llsigma'; sig3='rlsigma'; sig4='lrsigma' + + #AVERAGING------------------------------- + if err_type=='measured': + vis['dummy'] = vis[vis1] + vis['qdummy'] = vis[vis2] + vis['udummy'] = vis[vis3] + vis['vdummy'] = vis[vis4] + meanF = lambda x: np.nanmean(np.asarray(x)) + meanerrF = lambda x: bootstrap(np.abs(x), np.mean, num_samples=num_samples,wrapping_variable=False) + aggregated[vis1] = meanF + aggregated[vis2] = meanF + aggregated[vis3] = meanF + aggregated[vis4] = meanF + aggregated['dummy'] = meanerrF + aggregated['udummy'] = meanerrF + aggregated['vdummy'] = meanerrF + aggregated['qdummy'] = meanerrF + + elif err_type=='predicted': + meanF = lambda x: np.nanmean(np.asarray(x)) + #meanerrF = lambda x: bootstrap(np.abs(x), np.mean, num_samples=num_samples,wrapping_variable=False) + def meanerrF(x): + x = np.asarray(x) + x = x[x==x] + + if len(x)>0: ret = np.sqrt(np.sum(x**2)/len(x)**2) + else: ret = np.nan +1j*np.nan + return ret + + aggregated[vis1] = meanF + aggregated[vis2] = meanF + aggregated[vis3] = meanF + aggregated[vis4] = meanF + aggregated[sig1] = meanerrF + aggregated[sig2] = meanerrF + aggregated[sig3] = meanerrF + aggregated[sig4] = meanerrF + + #ACTUAL AVERAGING + vis_avg = vis.groupby(grouping).agg(aggregated).reset_index() + + if err_type=='measured': + vis_avg[sig1] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['dummy'])] + vis_avg[sig2] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['qdummy'])] + vis_avg[sig3] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['udummy'])] + vis_avg[sig4] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['vdummy'])] + + vis_avg['amp'] = list(map(np.abs,vis_avg[vis1])) + vis_avg['phase'] = list(map(lambda x: (180./np.pi)*np.angle(x),vis_avg[vis1])) + vis_avg['snr'] = vis_avg['amp']/vis_avg[sig1] + + if scan_avg==False: + #round datetime and time to the begining of the bucket and add half of a bucket time + half_bucket = dt/2. + vis_avg['datetime'] = list(map(lambda x: t0 + datetime.timedelta(seconds= int(dt*x) + half_bucket), vis_avg['round_time'])) + vis_avg['time'] = list(map(lambda x: (Time(x).mjd-obs.mjd)*24., vis_avg['datetime'])) + else: + #drop values that couldn't be matched to any scan + vis_avg.drop(list(vis_avg[vis_avg.scan<0].index.values),inplace=True) + if err_type=='measured': + vis_avg.drop(labels=['udummy','vdummy','qdummy','dummy'],axis='columns',inplace=True) + if return_type=='rec': + if obs.polrep=='stokes': + return df_to_rec(vis_avg,'vis') + elif obs.polrep=='circ': + return df_to_rec(vis_avg,'vis_circ') + elif return_type=='df': + return vis_avg + + +def coh_moving_avg_vis(obs,dt=50,return_type='rec'): + """coherently averages visibilities with moving window + Args: + obs: ObsData object + dt (float): averaging window size in seconds + return_type (str): 'rec' for numpy record array (as used by ehtim), 'df' for data frame + Returns: + vis: coherently averaged visibilities on same grid + """ + min_periods=1 + if dt <= 0: + raise Exception('Time dt must be positive!') + if obs.polrep=='stokes': + vis1='vis'; vis2='qvis'; vis3='uvis'; vis4='vvis' + sig1='sigma'; sig2='qsigma'; sig3='usigma'; sig4='vsigma' + elif obs.polrep=='circ': + vis1='rrvis'; vis2='llvis'; vis3='rlvis'; vis4='lrvis' + sig1='rrsigma'; sig2='llsigma'; sig3='rlsigma'; sig4='lrsigma' + + vis = make_df(obs) + vis = vis.sort_values(['baseline','datetime']).reset_index().copy() + #vis['total_seconds'] = list(map(lambda x: int(x.total_seconds()), vis['datetime'] - vis['datetime'].min())) + vis['total_seconds'] = [pd.Timestamp(x) for x in vis.datetime] + vis['roll_vis'] = list(zip(vis['total_seconds'],vis[vis1],vis[vis2],vis[vis3],vis[vis4],vis['datetime'])) + vis['roll_sig'] = list(zip(vis['total_seconds'],vis[sig1],vis[sig2],vis[sig3],vis[sig4],vis['datetime'])) + + roll_vis_local = lambda x: roll_vis(x,dt=str(int(dt))+'s',min_periods=min_periods) + roll_sig_local = lambda x: roll_sig(x,dt=str(int(dt))+'s',min_periods=min_periods) + vis_avg_roll_vis = vis[['baseline','roll_vis']].groupby('baseline').transform(roll_vis_local)['roll_vis'].copy() + vis_avg_roll_sig = vis[['baseline','roll_sig']].groupby('baseline').transform(roll_sig_local)['roll_sig'].copy() + + for cou,col in enumerate([vis1,vis2,vis3,vis4]): + vis[col] = [x[2*cou] + 1j*x[2*cou+1] for x in vis_avg_roll_vis] + for cou,col in enumerate([sig1,sig2,sig3,sig4]): + vis[col] = [x[cou] for x in vis_avg_roll_sig] + #shift to match with original data + vis.datetime = vis.datetime.apply(lambda x: x-datetime.timedelta(seconds=dt)) + vis.time = vis.time - dt/2.0/3600. + if return_type=='rec': + if obs.polrep=='stokes': + return df_to_rec(vis.copy(),'vis') + elif obs.polrep=='circ': + return df_to_rec(vis.copy(),'vis_circ') + elif return_type=='df': + return vis.copy() + + +def roll_vis(ser,dt='1s',min_periods=1): + """functtion helper for coh_moving_avg_vis + """ + foo = pd.DataFrame({'REvis1': [np.real(x[1]) for x in ser],'IMvis1': [np.imag(x[1]) for x in ser], + 'REvis2': [np.real(x[2]) for x in ser],'IMvis2': [np.imag(x[2]) for x in ser], + 'REvis3': [np.real(x[3]) for x in ser],'IMvis3': [np.imag(x[3]) for x in ser], + 'REvis4': [np.real(x[4]) for x in ser],'IMvis4': [np.imag(x[4]) for x in ser]}, + index=[x[0] for x in ser]) + avg = foo.rolling(dt, min_periods=min_periods).mean() + avg_list = list(zip(avg['REvis1'],avg['IMvis1'],avg['REvis2'],avg['IMvis2'],avg['REvis3'],avg['IMvis3'],avg['REvis4'],avg['IMvis4'],[x[5] for x in ser])) + return avg_list + +def roll_sig(ser,dt='1s',min_periods=1): + """functtion helper for coh_moving_avg_vis + """ + foo = pd.DataFrame({'sig1': [x[1]**2 for x in ser],'sig2': [x[2]**2 for x in ser], + 'sig3': [x[3]**2 for x in ser],'sig4': [x[4]**2 for x in ser]}, + index=[x[0] for x in ser]) + avg0 = foo.rolling(dt, min_periods=min_periods).mean() + sumSq = foo.rolling(dt, min_periods=min_periods).sum() + avg = pd.DataFrame({},index=[x[0] for x in ser]) + avg['sig1'] = (avg0['sig1']**1.0)/(sumSq['sig1']**0.5) + avg['sig2'] = (avg0['sig2']**1.0)/(sumSq['sig2']**0.5) + avg['sig3'] = (avg0['sig3']**1.0)/(sumSq['sig3']**0.5) + avg['sig4'] = (avg0['sig4']**1.0)/(sumSq['sig4']**0.5) + avg_list = list(zip(avg['sig1'],avg['sig2'],avg['sig3'],avg['sig4'],[x[5] for x in ser])) + return avg_list + +def incoh_avg_vis(obs,dt=0,debias=True,scan_avg=False,return_type='rec',rec_type='vis',err_type='predicted',num_samples=int(1e3)): + """incoherently averages visibilities + Args: + obs: ObsData object + dt (float): integration time in seconds + return_type (str): 'rec' for numpy record array (as used by ehtim), 'df' for data frame + rec_type (str): 'vis' for DTPOL and 'amp' for DTAMP + err_type (str): 'predicted' for modeled error, 'measured' for bootstrap empirical variability estimator + num_samples: 'bootstrap' resample set size for measured error + scan_avg (bool): should scan-long averaging be performed. If True, overrides dt + + Returns: + vis_avg: coherently averaged visibilities + """ + if (dt<=0)&(scan_avg==False): + print('Either averaging time must be positive, or scan_avg option should be selected!') + return obs.data + else: + vis = make_df(obs) + if scan_avg==False: + #TODO + #we don't have to work on datetime products at all + #change it to only use 'time' in mjd + t0 = datetime.datetime(1960,1,1) + vis['round_time'] = list(map(lambda x: np.floor((x- t0).total_seconds()/float(dt)),vis.datetime)) + grouping=['tau1','tau2','polarization','band','baseline','t1','t2','round_time'] + else: + bins, labs = get_bins_labels(obs.scans) + vis['scan'] = list(pd.cut(vis.time, bins,labels=labs)) + grouping=['tau1','tau2','polarization','band','baseline','t1','t2','scan'] + #column just for counting the elements + vis['number'] = 1 + aggregated = {'datetime': np.min, 'time': np.min, + 'number': lambda x: len(x), 'u':np.mean, 'v':np.mean,'tint': np.sum} + + if err_type not in ['measured', 'predicted']: + print("Error type can only be 'predicted' or 'measured'! Assuming 'predicted'.") + err_type='predicted' + + #AVERAGING------------------------------- + vis['dummy'] = list(zip(np.abs(vis['vis']),vis['sigma'])) + vis['udummy'] = list(zip(np.abs(vis['uvis']),vis['usigma'])) + vis['vdummy'] = list(zip(np.abs(vis['vvis']),vis['vsigma'])) + vis['qdummy'] = list(zip(np.abs(vis['qvis']),vis['qsigma'])) + + if err_type=='predicted': + aggregated['dummy'] = lambda x: mean_incoh_avg(x,debias=debias) + aggregated['udummy'] = lambda x: mean_incoh_avg(x,debias=debias) + aggregated['vdummy'] = lambda x: mean_incoh_avg(x,debias=debias) + aggregated['qdummy'] = lambda x: mean_incoh_avg(x,debias=debias) + + elif err_type=='measured': + aggregated['dummy'] = lambda x: bootstrap(np.abs(np.asarray([y[0] for y in x])), np.mean, num_samples=num_samples,wrapping_variable=False) + aggregated['udummy'] = lambda x: bootstrap(np.abs(np.asarray([y[0] for y in x])), np.mean, num_samples=num_samples,wrapping_variable=False) + aggregated['vdummy'] = lambda x: bootstrap(np.abs(np.asarray([y[0] for y in x])), np.mean, num_samples=num_samples,wrapping_variable=False) + aggregated['qdummy'] = lambda x: bootstrap(np.abs(np.asarray([y[0] for y in x])), np.mean, num_samples=num_samples,wrapping_variable=False) + + #ACTUAL AVERAGING + vis_avg = vis.groupby(grouping).agg(aggregated).reset_index() + + if err_type=='predicted': + vis_avg['vis'] = [x[0] for x in list(vis_avg['dummy'])] + vis_avg['uvis'] = [x[0] for x in list(vis_avg['udummy'])] + vis_avg['qvis'] = [x[0] for x in list(vis_avg['qdummy'])] + vis_avg['vvis'] = [x[0] for x in list(vis_avg['vdummy'])] + vis_avg['sigma'] = [x[1] for x in list(vis_avg['dummy'])] + vis_avg['usigma'] = [x[1] for x in list(vis_avg['udummy'])] + vis_avg['qsigma'] = [x[1] for x in list(vis_avg['qdummy'])] + vis_avg['vsigma'] = [x[1] for x in list(vis_avg['vdummy'])] + + elif err_type=='measured': + vis_avg['vis'] = [x[0] for x in list(vis_avg['dummy'])] + vis_avg['uvis'] = [x[0] for x in list(vis_avg['udummy'])] + vis_avg['qvis'] = [x[0] for x in list(vis_avg['qdummy'])] + vis_avg['vvis'] = [x[0] for x in list(vis_avg['vdummy'])] + vis_avg['sigma'] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['dummy'])] + vis_avg['usigma'] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['udummy'])] + vis_avg['qsigma'] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['qdummy'])] + vis_avg['vsigma'] = [0.5*(x[1][1]-x[1][0]) for x in list(vis_avg['vdummy'])] + + vis_avg['amp'] = list(map(np.abs,vis_avg['vis'])) + vis_avg['phase'] = 0 + vis_avg['snr'] = vis_avg['amp']/vis_avg['sigma'] + if scan_avg==False: + #round datetime and time to the begining of the bucket and add half of a bucket time + half_bucket = dt/2. + vis_avg['datetime'] = list(map(lambda x: t0 + datetime.timedelta(seconds= int(dt*x) + half_bucket), vis_avg['round_time'])) + vis_avg['time'] = list(map(lambda x: (Time(x).mjd-obs.mjd)*24., vis_avg['datetime'])) + else: + #drop values that couldn't be matched to any scan + vis_avg.drop(list(vis_avg[vis_avg.scan<0].index.values),inplace=True) + + vis_avg.drop(labels=['udummy','vdummy','qdummy','dummy'],axis='columns',inplace=True) + if return_type=='rec': + return df_to_rec(vis_avg,rec_type) + elif return_type=='df': + return vis_avg + + +def make_cphase_df(obs,band='unknown',polarization='unknown',mode='all',count='max',round_s=0.1,snrcut=0.,uv_min=False): + + """generate DataFrame of closure phases + + Args: + obs: ObsData object + round_s: accuracy of datetime object in seconds + + Returns: + df: closure phase data in DataFrame format + """ + + data=obs.c_phases(mode=mode,count=count,snrcut=snrcut,uv_min=uv_min) + sour=obs.source + df = pd.DataFrame(data=data).copy() + df['fmjd'] = df['time']/24. + df['mjd'] = obs.mjd + df['fmjd'] + df['triangle'] = list(map(lambda x: x[0]+'-'+x[1]+'-'+x[2],zip(df['t1'],df['t2'],df['t3']))) + df['datetime'] = Time(df['mjd'], format='mjd').datetime + df['datetime'] =list(map(lambda x: round_time(x,round_s=round_s),df['datetime'])) + df['jd'] = Time(df['mjd'], format='mjd').jd + df['polarization'] = polarization + df['band'] = band + df['source'] = sour + return df + +def make_cphase_diag_df(obs,vtype='vis',band='unknown',polarization='unknown',count='min',round_s=0.1,snrcut=0.,uv_min=False): + + """generate DataFrame of diagonalized closure phases + + Args: + obs: ObsData object + round_s: accuracy of datetime object in seconds + + Returns: + df: diagonalized closure phase data in DataFrame format + """ + + data=obs.c_phases_diag(vtype=vtype,count=count,snrcut=snrcut,uv_min=uv_min) + sour=obs.source + + tarr = [] + dcparr = [] + dcperrarr = [] + triangle_arr = [] + u_arr = [] + v_arr = [] + tform_arr = [] + for d in data: + tarr.append(list(d[0]['time'])) + dcparr.append(list(d[0]['cphase'])) + dcperrarr.append(list(d[0]['sigmacp'])) + + triarr = [] + u = [] + v = [] + for ant in d[1]: + triarr.append(ant[0][0]+'-'+ant[1][0]+'-'+ant[2][0]) + for iu in d[2]: + u.append((iu[0][0],iu[1][0],iu[2][0])) + for iv in d[3]: + v.append((iv[0][0],iv[1][0],iv[2][0])) + + for i in range(len(d[0])): + triangle_arr.append(triarr) + u_arr.append(u) + v_arr.append(v) + tform_arr.append(list(d[4].view('f8'))) + + df = pd.DataFrame() + df['time'] = np.concatenate(tarr) + df['cphase'] = np.concatenate(dcparr) + df['sigmacp'] = np.concatenate(dcperrarr) + df['triangles'] = triangle_arr + df['u'] = u_arr + df['v'] = v_arr + df['tform_matrix'] = tform_arr + + df['fmjd'] = df['time']/24. + df['mjd'] = obs.mjd + df['fmjd'] + df['datetime'] = Time(df['mjd'], format='mjd').datetime + df['datetime'] = list(map(lambda x: round_time(x,round_s=round_s),df['datetime'])) + df['jd'] = Time(df['mjd'], format='mjd').jd + df['polarization'] = polarization + df['band'] = band + df['source'] = sour + return df + +def make_camp_df(obs,ctype='logcamp',debias=False,band='unknown',polarization='unknown',mode='all',count='max',round_s=0.1,snrcut=0.): + + """generate DataFrame of closure amplitudes + + Args: + obs: ObsData object + round_s: accuracy of datetime object in seconds + + Returns: + df: closure amplitude data in DataFrame format + """ + + data = obs.c_amplitudes(mode=mode,count=count,debias=debias,ctype=ctype,snrcut=snrcut) + sour=obs.source + df = pd.DataFrame(data=data).copy() + df['fmjd'] = df['time']/24. + df['mjd'] = obs.mjd + df['fmjd'] + df['quadrangle'] = list(map(lambda x: x[0]+'-'+x[1]+'-'+x[2]+'-'+x[3],zip(df['t1'],df['t2'],df['t3'],df['t4']))) + df['datetime'] = Time(df['mjd'], format='mjd').datetime + df['datetime'] =list(map(lambda x: round_time(x,round_s=round_s),df['datetime'])) + df['jd'] = Time(df['mjd'], format='mjd').jd + df['polarization'] = polarization + df['band'] = band + df['source'] = sour + df['catype'] = ctype + return df + +def make_logcamp_diag_df(obs,debias=True,band='unknown',polarization='unknown',mode='all',count='min',round_s=0.1,snrcut=0.): + + """generate DataFrame of closure amplitudes + + Args: + obs: ObsData object + round_s: accuracy of datetime object in seconds + + Returns: + df: closure amplitude data in DataFrame format + """ + + data = obs.c_log_amplitudes_diag(mode=mode,count=count,debias=debias,snrcut=snrcut) + sour=obs.source + + tarr = [] + dlcaarr = [] + dlcaerrarr = [] + quadrangle_arr = [] + u_arr = [] + v_arr = [] + tform_arr = [] + for d in data: + tarr.append(list(d[0]['time'])) + dlcaarr.append(list(d[0]['camp'])) + dlcaerrarr.append(list(d[0]['sigmaca'])) + + quadarr = [] + u = [] + v = [] + for ant in d[1]: + quadarr.append(ant[0][0]+'-'+ant[1][0]+'-'+ant[2][0]+'-'+ant[3][0]) + for iu in d[2]: + u.append((iu[0][0],iu[1][0],iu[2][0],iu[3][0])) + for iv in d[3]: + v.append((iv[0][0],iv[1][0],iv[2][0],iv[3][0])) + + for i in range(len(d[0])): + quadrangle_arr.append(quadarr) + u_arr.append(u) + v_arr.append(v) + tform_arr.append(list(d[4].view('f8'))) + + df = pd.DataFrame() + df['time'] = np.concatenate(tarr) + df['camp'] = np.concatenate(dlcaarr) + df['sigmaca'] = np.concatenate(dlcaerrarr) + df['quadrangles'] = quadrangle_arr + df['u'] = u_arr + df['v'] = v_arr + df['tform_matrix'] = tform_arr + + df['fmjd'] = df['time']/24. + df['mjd'] = obs.mjd + df['fmjd'] + df['datetime'] = Time(df['mjd'], format='mjd').datetime + df['datetime'] = list(map(lambda x: round_time(x,round_s=round_s),df['datetime'])) + df['jd'] = Time(df['mjd'], format='mjd').jd + df['polarization'] = polarization + df['band'] = band + df['source'] = sour + return df + +def make_bsp_df(obs,band='unknown',polarization='unknown',mode='all',count='min',round_s=0.1,snrcut=0., uv_min=False): + + """generate DataFrame of bispectra + + Args: + obs: ObsData object + round_s: accuracy of datetime object in seconds + + Returns: + df: bispectra data in DataFrame format + """ + + data = obs.bispectra(mode=mode,count=count,snrcut=snrcut,uv_min=uv_min) + sour=obs.source + df = pd.DataFrame(data=data).copy() + df['fmjd'] = df['time']/24. + df['mjd'] = obs.mjd + df['fmjd'] + df['triangle'] = list(map(lambda x: x[0]+'-'+x[1]+'-'+x[2],zip(df['t1'],df['t2'],df['t3']))) + df['datetime'] = Time(df['mjd'], format='mjd').datetime + df['datetime'] =list(map(lambda x: round_time(x,round_s=round_s),df['datetime'])) + df['jd'] = Time(df['mjd'], format='mjd').jd + df['polarization'] = polarization + df['band'] = band + df['source'] = sour + return df + +def average_cphases(cdf,dt,return_type='rec',err_type='predicted',num_samples=1000,snrcut=0.): + + """averages DataFrame of cphases + + Args: + cdf: data frame of closure phases + dt: integration time in seconds + return_type: 'rec' for numpy record array (as used by ehtim), 'df' for data frame + err_type: 'predicted' for modeled error, 'measured' for bootstrap empirical variability estimator + + Returns: + cdf2: averaged closure phases + """ + + cdf2 = cdf.copy() + t0 = datetime.datetime(1960,1,1) + cdf2['round_time'] = list(map(lambda x: np.round((x- t0).total_seconds()/float(dt)),cdf2.datetime)) + grouping=['polarization','band','triangle','t1','t2','t3','round_time'] + #column just for counting the elements + cdf2['number'] = 1 + aggregated = {'datetime': np.min, 'time': np.mean, + 'number': lambda x: len(x), 'u1':np.mean, 'u2': np.mean, 'u3':np.mean,'v1':np.mean, 'v2': np.mean, 'v3':np.mean} + + #AVERAGING------------------------------- + if err_type=='measured': + cdf2['dummy'] = cdf2['cphase'] + aggregated['dummy'] = lambda x: bootstrap(x, circular_mean, num_samples=num_samples,wrapping_variable=True) + elif err_type=='predicted': + aggregated['cphase'] = circular_mean + aggregated['sigmacp'] = lambda x: np.sqrt(np.sum(x**2)/len(x)**2) + else: + print("Error type can only be 'predicted' or 'measured'! Assuming 'predicted'.") + aggregated['cphase'] = circular_mean + aggregated['sigmacp'] = lambda x: np.sqrt(np.sum(x**2)/len(x)**2) + + #ACTUAL AVERAGING + cdf2 = cdf2.groupby(grouping).agg(aggregated).reset_index() + + if err_type=='measured': + cdf2['cphase'] = [x[0] for x in list(cdf2['dummy'])] + cdf2['sigmacp'] = [0.5*(x[1][1]-x[1][0]) for x in list(cdf2['dummy'])] + + # snrcut + # CHECK + if snrcut==0: + snrcut=EP + cdf2 = cdf2[cdf2['sigmacp'] < 180./np.pi/snrcut].copy() # TODO CHECK + + #round datetime + cdf2['datetime'] = list(map(lambda x: t0 + datetime.timedelta(seconds= int(dt*x)), cdf2['round_time'])) + + #ANDREW TODO-- this can lead to big problems!! + #drop values averaged from less than 3 datapoints + #cdf2.drop(cdf2[cdf2.number < 3.].index, inplace=True) + if return_type=='rec': + return df_to_rec(cdf2,'cphase') + elif return_type=='df': + return cdf2 + + +def average_bispectra(cdf,dt,return_type='rec',num_samples=int(1e3), snrcut=0.): + + """averages DataFrame of bispectra + + Args: + cdf: data frame of bispectra + dt: integration time in seconds + return_type: 'rec' for numpy record array (as used by ehtim), 'df' for data frame + + Returns: + cdf2: averaged bispectra + """ + + cdf2 = cdf.copy() + t0 = datetime.datetime(1960,1,1) + cdf2['round_time'] = list(map(lambda x: np.round((x- t0).total_seconds()/float(dt)),cdf2.datetime)) + grouping=['polarization','band','triangle','t1','t2','t3','round_time'] + #column just for counting the elements + cdf2['number'] = 1 + aggregated = {'datetime': np.min, 'time': np.mean, + 'number': lambda x: len(x), 'u1':np.mean, 'u2': np.mean, 'u3':np.mean,'v1':np.mean, 'v2': np.mean, 'v3':np.mean} + + #AVERAGING------------------------------- + aggregated['bispec'] = np.mean + aggregated['sigmab'] = lambda x: np.sqrt(np.sum(x**2)/len(x)**2) + + #ACTUAL AVERAGING + cdf2 = cdf2.groupby(grouping).agg(aggregated).reset_index() + + # snrcut + cdf2 = cdf2[np.abs(cdf2['bispec']/cdf2['sigmab']) > snrcut].copy() # TODO CHECK + + #round datetime + cdf2['datetime'] = list(map(lambda x: t0 + datetime.timedelta(seconds= int(dt*x)), cdf2['round_time'])) + + #ANDREW TODO -- this can lead to big problems!! + #drop values averaged from less than 3 datapoints + #cdf2.drop(cdf2[cdf2.number < 3.].index, inplace=True) + if return_type=='rec': + return df_to_rec(cdf2,'bispec') + elif return_type=='df': + return cdf2 + + +def average_camp(cdf,dt,return_type='rec',err_type='predicted',num_samples=int(1e3)): + #TODO: SNRCUT? + """averages DataFrame of closure amplitudes + + Args: + cdf: data frame of closure amplitudes + dt: integration time in seconds + return_type: 'rec' for numpy record array (as used by ehtim), 'df' for data frame + err_type: 'predicted' for modeled error, 'measured' for bootstrap empirical variability estimator + + Returns: + cdf2: averaged closure amplitudes + """ + + cdf2 = cdf.copy() + t0 = datetime.datetime(1960,1,1) + cdf2['round_time'] = list(map(lambda x: np.round((x- t0).total_seconds()/float(dt)),cdf2.datetime)) + grouping=['polarization','band','quadrangle','t1','t2','t3','t4','round_time'] + #column just for counting the elements + cdf2['number'] = 1 + aggregated = {'datetime': np.min, 'time': np.mean, + 'number': lambda x: len(x), 'u1':np.mean, 'u2': np.mean, 'u3':np.mean, 'u4': np.mean, 'v1':np.mean, 'v2': np.mean, 'v3':np.mean,'v4':np.mean} + + #AVERAGING------------------------------- + if err_type=='measured': + cdf2['dummy'] = cdf2['camp'] + aggregated['dummy'] = lambda x: bootstrap(x, np.mean, num_samples=num_samples,wrapping_variable=False) + elif err_type=='predicted': + aggregated['camp'] = np.mean + aggregated['sigmaca'] = lambda x: np.sqrt(np.sum(x**2)/len(x)**2) + else: + print("Error type can only be 'predicted' or 'measured'! Assuming 'predicted'.") + aggregated['camp'] = np.mean + aggregated['sigmaca'] = lambda x: np.sqrt(np.sum(x**2)/len(x)**2) + + #ACTUAL AVERAGING + cdf2 = cdf2.groupby(grouping).agg(aggregated).reset_index() + + if err_type=='measured': + cdf2['camp'] = [x[0] for x in list(cdf2['dummy'])] + cdf2['sigmaca'] = [0.5*(x[1][1]-x[1][0]) for x in list(cdf2['dummy'])] + + #round datetime + cdf2['datetime'] = list(map(lambda x: t0 + datetime.timedelta(seconds= int(dt*x)), cdf2['round_time'])) + + #ANDREW TODO -- this can lead to big problems!! + #drop values averaged from less than 3 datapoints + #cdf2.drop(cdf2[cdf2.number < 3.].index, inplace=True) + if return_type=='rec': + return df_to_rec(cdf2,'camp') + elif return_type=='df': + return cdf2 + +def df_to_rec(df,product_type): + + """converts DataFrame to numpy recarray used by ehtim + + Args: + df: DataFrame to convert + product_type: vis, cphase, camp, amp, bispec, cphase_diag, logcamp_diag + """ + if product_type=='cphase': + out= df[['time','t1','t2','t3','u1','v1','u2','v2','u3','v3','cphase','sigmacp']].to_records(index=False) + return np.array(out,dtype=DTCPHASE) + elif product_type=='camp': + out= df[['time','t1','t2','t3','t4','u1','v1','u2','v2','u3','v3','u4','v4','camp','sigmaca']].to_records(index=False) + return np.array(out,dtype=DTCAMP) + elif product_type=='vis': + out= df[['time','tint','t1','t2','tau1','tau2','u','v','vis','qvis','uvis','vvis','sigma','qsigma','usigma','vsigma']].to_records(index=False) + return np.array(out,dtype=DTPOL_STOKES) + elif product_type=='vis_circ': + out= df[['time','tint','t1','t2','tau1','tau2','u','v','rrvis','llvis','rlvis','lrvis','rrsigma','llsigma','rlsigma','lrsigma']].to_records(index=False) + return np.array(out,dtype=DTPOL_CIRC) + elif product_type=='amp': + out= df[['time','tint','t1','t2','u','v','amp','sigma']].to_records(index=False) + return np.array(out,dtype=DTAMP) + elif product_type=='bispec': + out= df[['time','t1','t2','t3','u1','v1','u2','v2','u3','v3','bispec','sigmab']].to_records(index=False) + return np.array(out,dtype=DTBIS) + elif product_type=='cphase_diag': + out= df[['time','cphase','sigmacp','triangles','u','v','tform_matrix']].to_records(index=False) + return np.array(out,dtype=DTCPHASEDIAG) + elif product_type=='logcamp_diag': + out= df[['time','camp','sigmaca','quadrangles','u','v','tform_matrix']].to_records(index=False) + return np.array(out,dtype=DTLOGCAMPDIAG) + + +def round_time(t,round_s=0.1): + + """rounding time to given accuracy + + Args: + t: time + round_s: delta time to round to in seconds + + Returns: + round_t: rounded time + """ + t0 = datetime.datetime(t.year,1,1) + foo = t - t0 + foo_s = foo.days*24*3600 + foo.seconds + foo.microseconds*(1e-6) + foo_s = np.round(foo_s/round_s)*round_s + days = np.floor(foo_s/24/3600) + seconds = np.floor(foo_s - 24*3600*days) + microseconds = int(1e6*(foo_s - days*3600*24 - seconds)) + round_t = t0+datetime.timedelta(days,seconds,microseconds) + return round_t + +def get_bins_labels(intervals,dt=0.00001): + '''gets bins and labels necessary to perform averaging by scan + Args: + intervals: + dt (float): time margin to add to the scan limits + ''' + + def fix_midnight_overlap(x): + if x[1] < x[0]: + x[1]+= 24. + return x + + def is_overlapping(interval0,interval1): + if ((interval1[0]<=interval0[0])&(interval1[1]>=interval0[0]))|((interval1[0]<=interval0[1])&(interval1[1]>=interval0[1])): + return True + else: return False + + def merge_overlapping_intervals(intervals): + return (np.min([x[0] for x in intervals]),np.max([x[1] for x in intervals])) + + def replace_overlapping_intervals(intervals,element_ind): + indic_not_overlap=[not is_overlapping(x,intervals[element_ind]) for x in intervals] + indic_overlap=[is_overlapping(x,intervals[element_ind]) for x in intervals] + fooarr=np.asarray(intervals) + return sorted([tuple(x) for x in fooarr[indic_not_overlap]]+[merge_overlapping_intervals(list(fooarr[indic_overlap]))]) + + intervals = sorted(list(set(zip(intervals[:,0],intervals[:,1])))) + intervals = [fix_midnight_overlap(x) for x in intervals] + cou=0 + while cou < len(intervals): + intervals = replace_overlapping_intervals(intervals,cou) + cou+=1 + + binsT=[None]*(2*np.shape(intervals)[0]) + binsT[::2] = [x[0]-dt for x in intervals] + binsT[1::2] = [x[1]+dt for x in intervals] + labels=[None]*(2*np.shape(intervals)[0]-1) + labels[::2] = [cou for cou in range(1,len(intervals)+1)] + labels[1::2] = [-cou for cou in range(1,len(intervals))] + + return binsT, labels + +def common_set(obs1, obs2, tolerance = 0,uniquely=False, by_what='uvdist'): + ''' + Selects common subset of obs1, obs2 data + tolerance: time tolerance to accept common subsets [s] if by_what = 'ut' or in [h] if 'gmst' + or u,v tolerance in lambdas if by_what='uvdist' + uniquely: whether matching single value to single value + by_what: matching common sets by ut time 'ut' or by uvdistance 'uvdist' or by 'gmst' + ''' + if obs1.polrep!=obs2.polrep: + raise ValueError('Observations must be in the same polrep!') + #make a dataframe with visibilities + #tolerance in seconds + df1 = make_df(obs1) + df2 = make_df(obs2) + + #we need this to match baselines with possibly different station orders between the pipelines + df1['ta'] = list(map(lambda x: sorted(x)[0],zip(df1.t1,df1.t2))) + df1['tb'] = list(map(lambda x: sorted(x)[1],zip(df1.t1,df1.t2))) + df2['ta'] = list(map(lambda x: sorted(x)[0],zip(df2.t1,df2.t2))) + df2['tb'] = list(map(lambda x: sorted(x)[1],zip(df2.t1,df2.t2))) + + if by_what=='ut': + if tolerance>0: + d_mjd = tolerance/24.0/60.0/60.0 + df1['roundtime']=np.round(df1.mjd/d_mjd) + df2['roundtime']=np.round(df2.mjd/d_mjd) + else: + df1['roundtime'] = df1['time'] + df2['roundtime'] = df2['time'] + #matching data + df1,df2 = match_multiple_frames([df1.copy(),df2.copy()],['ta','tb','roundtime'],uniquely=uniquely) + + elif by_what=='gmst': + df1 = add_gmst(df1) + df2 = add_gmst(df2) + if tolerance>0: + d_gmst = tolerance + df1['roundgmst']=np.round(df1.gmst/d_gmst) + df2['roundgmst']=np.round(df2.gmst/d_gmst) + else: + df1['roundgmst'] = df1['gmst'] + df2['roundgmst'] = df2['gmst'] + #matching data + df1,df2 = match_multiple_frames([df1.copy(),df2.copy()],['ta','tb','roundgmst'],uniquely=uniquely) + + elif by_what=='uvdist': + if tolerance>0: + d_lambda = tolerance + df1['roundu'] = np.round(df1.u/d_lambda) + df1['roundv'] = np.round(df1.v/d_lambda) + df2['roundu'] = np.round(df2.u/d_lambda) + df2['roundv'] = np.round(df2.v/d_lambda) + else: + df1['roundu'] = df1['u'] + df1['roundv'] = df1['v'] + df2['roundu'] = df2['u'] + df2['roundv'] = df2['v'] + #matching data + df1,df2 = match_multiple_frames([df1.copy(),df2.copy()],['ta','tb','roundu','roundv'],uniquely=uniquely) + + #replace visibility data with common subset + obs1cut = obs1.copy() + obs2cut = obs2.copy() + if obs1.polrep=='stokes': + obs1cut.data = df_to_rec(df1,'vis') + obs2cut.data = df_to_rec(df2,'vis') + elif obs1.polrep=='circ': + obs1cut.data = df_to_rec(df1,'vis_circ') + obs2cut.data = df_to_rec(df2,'vis_circ') + + return obs1cut,obs2cut + +""" +def common_multiple_sets(obsL, tolerance = 0,uniquely=False, by_what='uvdist'): + ''' + Selects common subset of obs1, obs2 data + tolerance: time tolerance to accept common subsets [s] if by_what = 'ut' or 'gmst' + or u,v tolerance in lambdas if by_what='uvdist' + uniquely: whether matching single value to single value + by_what: matching common sets by ut time 'ut' or by uvdistance 'uvdist' or by 'gmst' + ''' + polrepL = list(set([obs.polrep for obs in obsL])) + if len(polrepL)>1: + raise ValueError('Observations must be in the same polrep!') + #make a dataframe with visibilities + #tolerance in seconds + dfL = [make_df(obs) for obs in obsL] + + #we need this to match baselines with possibly different station orders between the pipelines + for df in dfL: + df['ta'] = list(map(lambda x: sorted(x)[0],zip(df.t1,df.t2))) + df['tb'] = list(map(lambda x: sorted(x)[1],zip(df.t1,df.t2))) + + if by_what=='ut': + if tolerance>0: + d_mjd = tolerance/24.0/60.0/60.0 + for df in dfL: df['roundtime']=np.round(df.mjd/d_mjd) + else: + for df in dfL: df['roundtime']=df['time'] + #matching data + dfcout = match_multiple_frames(dfL,['ta','tb','roundtime'],uniquely=uniquely) + + elif by_what=='gmst': + dfL = [add_gmst(df) for df in dfL] + if tolerance>0: + d_gmst = tolerance + for df in dfL: df['roundgmst']=np.round(df.gmst/d_gmst) + else: + for df in dfL: df['roundgmst']= df['gmst'] + #matching data + dfcut = match_multiple_frames([df1.copy(),df2.copy()],['ta','tb','roundgmst'],uniquely=uniquely) + + + elif by_what=='uvdist': + if tolerance>0: + d_lambda = tolerance + for df in dfL: df['roundu'] = np.round(df.u/d_lambda) + for df in dfL: df['roundv'] = np.round(df.v/d_lambda) + else: + for df in dfL: df['roundu'] = df['u'] + for df in dfL: df['roundv'] = df['v'] + #matching data + dfcut = match_multiple_frames([df1.copy(),df2.copy()],['ta','tb','roundu','roundv'],uniquely=uniquely) + + #replace visibility data with common subset + obscutL = [obs.copy() for obs in obsL] + + if obs1.polrep=='stokes': + + for obscut in obscutL: obscut = df_to_rec(df1,'vis') + obs2cut.data = df_to_rec(df2,'vis') + elif obs1.polrep=='circ': + obs1cut.data = df_to_rec(df1,'vis_circ') + obs2cut.data = df_to_rec(df2,'vis_circ') + + return obscut_list +""" + +def match_multiple_frames(frames, what_is_same, dt = 0,uniquely=True): + + if dt > 0: + for frame in frames: + frame['round_time'] = list(map(lambda x: np.round((x- datetime.datetime(2017,4,4)).total_seconds()/dt),frame['datetime'])) + what_is_same += ['round_time'] + + frames_common = {} + for frame in frames: + frame['all_ind'] = list(zip(*[frame[x] for x in what_is_same])) + if frames_common != {}: + frames_common = frames_common&set(frame['all_ind']) + else: + frames_common = set(frame['all_ind']) + + frames_out = [] + for frame in frames: + frame = frame[list(map(lambda x: x in frames_common, frame.all_ind))].copy() + if uniquely: + frame.drop_duplicates(subset=['all_ind'], keep='first', inplace=True) + + frame = frame.sort_values('all_ind').reset_index(drop=True) + frame.drop('all_ind', axis=1,inplace=True) + frames_out.append(frame.copy()) + return frames_out + + +def add_gmst(df): + #Lindy Blackburn's work borrowed from eat + """add *gmst* column to data frame with *datetime* field using astropy for conversion""" + from astropy import time + g = df.groupby('datetime') + (timestamps, indices) = list(zip(*iter(g.groups.items()))) + # this broke in pandas 0.9 with API changes + if type(timestamps[0]) is np.datetime64: # pandas < 0.9 + times_unix = 1e-9*np.array( + timestamps).astype('float') # note one float64 is not [ns] precision + elif type(timestamps[0]) is pd.Timestamp: + times_unix = np.array([1e-9 * t.value for t in timestamps]) # will be int64's + else: + raise Exception("do not know how to convert timestamp of type " + repr(type(timestamps[0]))) + times_gmst = time.Time( + times_unix, format='unix').sidereal_time('mean', 'greenwich').hour # vectorized + df['gmst'] = 0. # initialize new column + for (gmst, idx) in zip(times_gmst, indices): + df.ix[idx, 'gmst'] = gmst + return df diff --git a/statistics/stats.py b/statistics/stats.py new file mode 100644 index 00000000..708b53e2 --- /dev/null +++ b/statistics/stats.py @@ -0,0 +1,336 @@ +# stats.py +# variety of statistical functions useful for +# +# Copyright (C) 2018 Maciek Wielgus +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from __future__ import division +from __future__ import print_function +from builtins import str +from builtins import map +from builtins import range + +import numpy as np +import numpy.random as npr +import sys + +from ehtim.const_def import * + +def circular_mean(theta, unit='deg'): + '''circular mean for averaging angular quantities + Args: + theta: list/vector of angles to average + unit: degrees ('deg') or radians (any other string) + + Returns: + circular mean + ''' + theta = np.asarray(theta, dtype=np.float32) + theta= theta.flatten() + theta = theta[theta==theta] + if unit=='deg': + theta *= np.pi/180. + if len(theta)==0: + return None + else: + C = np.mean(np.cos(theta)) + S = np.mean(np.sin(theta)) + circ_mean = np.arctan2(S,C) + if unit=='deg': + circ_mean *= 180./np.pi + return np.mod(circ_mean+180.,360.)-180. + else: + return np.mod(circ_mean+np.pi,2.*np.pi)-np.pi + +def circular_std(theta,unit='deg'): + '''standard deviation of a circular distribution + Args: + theta: list/vector of angles + unit: degrees ('deg') or radians (any other string) + + Returns: + circular standard deviation + ''' + theta = np.asarray(theta, dtype=np.float32) + theta= theta.flatten() + theta = theta[theta==theta] + if unit=='deg': + theta *= np.pi/180. + if len(theta)<2: + return None + else: + C = np.mean(np.cos(theta)) + S = np.mean(np.sin(theta)) + circ_std = np.sqrt(-2.*np.log(np.sqrt(C**2+S**2))) + if unit=='deg': + circ_std *= 180./np.pi + return circ_std + +def circular_std_of_mean(theta,unit='deg'): + '''standard deviation of mean for a circular distribution + Args: + theta: list/vector of angles + unit: degrees ('deg') or radians (any other string) + + Returns: + circular standard deviation of mean + ''' + theta = np.asarray(theta, dtype=np.float32) + theta= theta.flatten() + theta = theta[theta==theta] + return circular_std(theta,unit)/np.sqrt(len(theta)) + +def mean_incoh_amp(amp,sigma,debias=True,err_type='predicted',num_samples=int(1e3)): + """amplitude from ensemble of Rice-distributed measurements with debiasing + Args: + amp: vector of (biased) amplitudes + sigma: vector of errors + debias: whether debiasing is applied + Returns: + amp0: estimator of unbiased amplitude + """ + if (not hasattr(amp, "__len__")): + amp = [amp] + amp = np.asarray(amp, dtype=np.float32) + N = len(amp) + if (not hasattr(sigma, "__len__")): + sigma = sigma*np.ones(N) + elif len(sigma)==1: + sigma = sigma*np.ones(N) + sigma = np.asarray(sigma, dtype=np.float32) + if len(sigma)!=len(amp): + print('Inconsistent length of amp and sigma') + return None + else: + amp_clean=amp[(amp==amp)&(sigma==sigma)&(sigma>0)&(amp>0)] + sigma_clean=sigma[(amp==amp)&(sigma==sigma)&(sigma>0)&(amp>0)] + #eq. 9.86 from Thompson et al. + if debias==True: + amp0sq = ( np.mean(amp_clean**2 - (2. - 1./N)*sigma_clean**2) ) + else: amp0sq = np.mean(amp_clean**2) + amp0sq = np.maximum(amp0sq,0.) + amp0 = np.sqrt(amp0sq) + + #getting errors + if err_type=='predicted': + sigma0 = np.sqrt(np.sum(sigma_clean**2)/len(sigma_clean)**2) + elif err_type=='measured': + ampfoo, ci = bootstrap(amp_clean, np.mean, num_samples=num_samples,wrapping_variable=False) + sigma0 = 0.5*(ci[1]-ci[0]) + return amp0,sigma0 + +def mean_incoh_amp_from_vis(vis,sigma,debias=True,err_type='predicted',num_samples=int(1e3)): + """Amplitude from ensemble of visibility measurements with debiasing + Args: + amp: vector of (biased) amplitudes + sigma: vector of errors + debias: whether debiasing is applied + Returns: + amp0: estimator of unbiased amplitude + """ + if (not hasattr(vis, "__len__")): + vis = [vis] + vis= np.asarray(vis) + vis= vis[vis==vis] + amp=np.abs(vis) + + N = len(amp) + if (not hasattr(sigma, "__len__")): + sigma = sigma*np.ones(N) + elif len(sigma)==1: + sigma = sigma*np.ones(N) + sigma = np.asarray(sigma, dtype=np.float32) + if len(sigma)!=len(amp): + print('Inconsistent length of amp and sigma') + return None, None + else: + amp_clean=amp[(amp==amp)&(sigma==sigma)&(sigma>=0)&(amp>=0)] + sigma_clean=sigma[(amp==amp)&(sigma==sigma)&(sigma>=0)&(amp>=0)] + Nc=len(amp_clean) + if Nc<1: + return None, None + else: + #eq. 9.86 from Thompson et al. + if debias==True: + amp0sq = ( np.mean(amp_clean**2 - (2. - 1./Nc)*sigma_clean**2) ) + else: amp0sq = np.mean(amp_clean**2) + if (amp0sq!=amp0sq): amp0sq=0. + amp0sq = np.maximum(amp0sq,0.) + amp0 = np.sqrt(amp0sq) + #getting errors + if err_type=='predicted': + #sigma0 = np.sqrt(np.sum(sigma_clean**2)/Nc**2) + #Esigma = np.median(sigma_clean) + #snr0 = amp0/Esigma + #snrA = 1./(np.sqrt(1. + 2./np.sqrt(Nc)*(1./snr0)*np.sqrt(1.+1./snr0**2)) - 1.) + #sigma0=amp0/snrA + sigma0 = np.sqrt(np.sum(sigma_clean**2)/Nc**2) + + elif err_type=='measured': + ampfoo, ci = bootstrap(amp_clean, np.mean, num_samples=num_samples,wrapping_variable=False,alpha='1sig') + sigma0 = 0.5*(ci[1]-ci[0]) + return amp0,sigma0 + +def bootstrap(data, statistic, num_samples=int(1e3), alpha='1sig',wrapping_variable=False): + """bootstrap estimate of 100.0*(1-alpha) confidence interval for a given statistic + Args: + data: vector of data to estimate bootstrap statistic on + statistic: function representing the statistic to be evaluated + num_samples: number of bootstrap (re)samples + alpha: parameter of the confidence interval, '1s' gets an analog of 1 sigma confidence for a normal variable + wrapping_variable: True for circular variables, attempts to avoid problem related to estimating variability of wrapping variable + + Returns: + bootstrap_value: bootstrap-estimated value of the statistic + bootstrap_CI: bootstrap-estimated confidence interval + """ + if alpha=='1sig': + alpha=0.3173 + elif alpha=='2sig': + alpha=0.0455 + elif alpha=='3sig': + alpha=0.0027 + stat = np.zeros(num_samples) + data = np.asarray(data) + if wrapping_variable==True: + m=statistic(data) + else: + m=0 + data = data-m + n = len(data) + idx = npr.randint(0, n, (num_samples, n)) + samples = data[idx] + for cou in range(num_samples): + stat[cou] = statistic(samples[cou,:]) + stat = np.sort(stat) + bootstrap_value = np.median(stat)+m + bootstrap_CI = [stat[int((alpha/2.0)*num_samples)]+m, stat[int((1-alpha/2.0)*num_samples)]+m] + return bootstrap_value, bootstrap_CI + +def mean_incoh_avg(x,debias=True): + amp = np.abs(np.asarray([y[0] for y in x])) + sig = np.asarray([y[1] for y in x]) + ampN = amp[(amp==amp)&(amp>=0)&(sig==sig)&(sig>=0)] + sigN = sig[(amp==amp)&(amp>=0)&(sig==sig)&(sig>=0)] + amp = ampN + sig = sigN + Nc = len(sig) + if Nc==0: + amp0 = -1 + sig0 = -1 + elif Nc==1: + amp0 = amp[0] + sig0 = sig[0] + else: + if debias==True: + amp0 = deb_amp(amp,sig) + else: + amp0= np.sqrt(np.maximum(np.mean(amp**2),0.)) + sig0 = inc_sig(amp,sig) + #sig0 = coh_sig(amp,sig) + return amp0,sig0 + +def deb_amp(amp,sig): + #eq. 9.86 from Thompson et al. + amp = np.abs(np.asarray(amp)) + sig = np.asarray(sig) + Nc = len(amp) + amp0sq = ( np.mean(amp**2 - (2. - 1./Nc)*sig**2) ) + amp0sq = np.maximum(amp0sq,0.) + amp0 = np.sqrt(amp0sq) + return amp0 + +def inc_sig(amp,sig): + amp = np.abs(np.asarray(amp)) + sig = np.asarray(sig) + Nc = len(amp) + amp0 = deb_amp(amp,sig) + Esigma = np.median(sig) + snr0 = amp0/Esigma + snrA = 1./(np.sqrt(1. + 2./np.sqrt(Nc)*(1./snr0)*np.sqrt(1.+1./snr0**2)) - 1.) + if snrA>0: + sigma0=amp0/snrA + else: sigma0=coh_sig(amp,sig) + return sigma0 + +def coh_sig(amp,sig): + amp = np.abs(np.asarray(amp)) + sig = np.asarray(sig) + Nc = len(amp) + sigma0 = np.sqrt(np.sum(sig**2)/Nc**2) + return sigma0 + + +def dicts_TV_report(obs,snr_cut=2.): + """Computes mean total variation reports + Args: + obs: ObsData object + snr_cut: threshold for data snr + Returns: + amptv: dictionary of baseline mean TV + cptv: dictionary of triangle mean TV + lcatv: dictionary of quadrangle mean TV + """ + amp = obs.data + baselines = list(set([(x[0],x[1]) for x in lca[['t1','t2']]])) + amptv = {} + for cou,quad in enumerate(baselines): + amptv[quad] = np.mean(np.abs(np.diff(np.abs(amp[(amp['t1']==baselines[cou][0])&(amp['t2']==baselines[cou][1])]['vis'])))) + + cp = obs.c_phases() + obs = obs.flag_low_snr(snr_cut=snr_cut) + triangles = list(set([(x[0],x[1],x[2]) for x in cp[['t1','t2','t3']]])) + cptv = {} + for cou,tri in enumerate(triangles): + cptv[tri] = np.mean(np.abs(np.diff(cp[(cp['t1']==triangles[cou][0])&(cp['t2']==triangles[cou][1])&(cp['t3']==triangles[cou][2])]['cphase']))) + + lca = obs.c_amplitudes(ctype='logcamp') + quadrangles = list(set([(x[0],x[1],x[2],x[3]) for x in lca[['t1','t2','t3','t4']]])) + lcatv = {} + for cou,quad in enumerate(quadrangles): + lcatv[quad] = np.mean(np.abs(np.diff(lca[(lca['t1']==quadrangles[cou][0])&(lca['t2']==quadrangles[cou][1])&(lca['t3']==quadrangles[cou][2])&(lca['t4']==quadrangles[cou][3])]['camp']))) + + return amptv, cptv, lcatv + +def compare_TV(obs,obsref,snr_cut=2.,output=''): + """Computes mean total variation reports + Args: + obs: ObsData object + obref: ObsData object to use as a reference + snr_cut: threshold for data snr + output (str): if empty, returns median relative TV across all (baselines, triangles, quadrangles) + if 'Full', returns full dictionaries comparing all (baselines, triangles, quadrangles) + Returns: + amprel / ampmed: dictionary of baseline relative differences in mean TV / median of baseline relative differences in mean TV + cprel / cpmed: dictionary of triangle relative differences in mean TV / median of triangle relative differences in mean TV + lcarel / lcamed: dictionary of quadrangle relative differences in mean TV / median of quadrangle relative differences in mean TV + """ + amptv, cptv, lcatv = dicts_TV(obs,snr_cut=snr_cut) + ampref, cpref, lcaref = dicts_TV(obsref,snr_cut=snr_cut) + + cprel = {key: (cptv[key] - cpref[key])/cpref[key] for key in cptv.keys() if key in set(cpref.keys())} + amprel = {key: (amptv[key] - ampref[key])/ampref[key] for key in amptv.keys() if key in set(ampref.keys())} + lcarel = {key: (lcatv[key] - lcaref[key])/lcaref[key] for key in lcatv.keys() if key in set(lcaref.keys())} + amprel = {key:amprel[key] for key in amprel.keys() if amprel[key]==amprel[key]} + cprel = {key:cprel[key] for key in cprel.keys() if cprel[key]==cprel[key]} + lcarel = {key:lcarel[key] for key in lcarel.keys() if lcarel[key]==lcarel[key]} + + if output=='Full': + return amprel, cprel, lcarel + else: + ampmed =np.median([amprel[key] for key in amprel.keys() if amprel[key]==amprel[key]]) + cpmed =np.median([cprel[key] for key in cprel.keys() if cprel[key]==cprel[key]]) + lcamed =np.median([lcarel[key] for key in lcarel.keys() if lcarel[key]==lcarel[key]]) + return ampmed,cpmed,lcamed \ No newline at end of file diff --git a/survey.py b/survey.py new file mode 100644 index 00000000..fc7c2b0a --- /dev/null +++ b/survey.py @@ -0,0 +1,626 @@ +# survey.py +# a parameter survey class +# +# Copyright (C) 2018 Andrew Chael +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import os + +import ehtim as eh +import paramsurvey +import paramsurvey.params + +import warnings +warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) + + +################################################################################################## +# ParameterSet object +################################################################################################## + +class ParameterSet: + + """ + Attributes: + paramset (dict): dict containing single parameter set. Each key in dict becomes its own attribute + params_fixed (dict): dict containing non-varying parameters. Each key in dict becomes its own attribute + outfile (str): base of outfile with 7 digit index of pset added to it + obs (Obsdata): observation object from input uvfits file, generated by load_data() + zbl_tot (float): flux on ALMA-APEX baseline, generated in preimcal() + obs_sc_init (Obsdata): original observation with systematic noise added, reverse taper applied, and extended flux removed + obs_sc (Obsdata): obs_sc_init after all rounds of self-calibration + initimg (Image): prior/initial image + im_out (Image): final reconstructed image + im_addcmp (Image): final reconstructed image with extended flux added back in + obs_sc_addcmp (Obsdata): original observation self-calibrated with im_addcmp + caltab (Caltable): final calibration table associated with im_out + + """ + + def __init__(self, paramset, params_fixed={}): + + """ + An object for one parameter set + + Args: + paramset (dict): A dict containing single parameter set + params_fixed (dict): A dict containing non-varying parameters + + Returns: + a ParameterSet object + """ + + # set each item in paramset dict to an attribute + for param in paramset: + setattr(self, param, paramset[param]) + + if len(params_fixed) > 0: + for param in params_fixed: + setattr(self, param, params_fixed[param]) + + os.makedirs(self.outpath, exist_ok=True) + self.paramset = paramset + self.params_fixed = params_fixed + self.outfile = '%s_%0.7i' % (self.outfile_base, self.i) + + self.fov *= eh.RADPERUAS + self.reverse_taper_uas *= eh.RADPERUAS + self.prior_fwhm *= eh.RADPERUAS + + def load_data(self): + """ + Loads in uvfits file from self.infile into eht-imaging obs object and averages using self.avg_time + Creates self.obs + + Args: + + Returns: + + """ + + # load the uvfits file + self.obs = eh.obsdata.load_uvfits(self.infile) + + # identify the scans (times of continuous observation) in the data + self.obs.add_scans() + + # coherently average + if self.avg_time == 'scan': + self.obs = self.obs.avg_coherent(0., scan_avg=True) + else: + self.obs = self.obs.avg_coherent(self.avg_time) + + def preimcal(self): + + """ + Applies pre-imaging calibration to self.obs. This includes flagging sites with no measurements, rescaling + short baselines so only compact flux is being imaged, applying a u-v taper, and adding extra sys noise. Creates + self.obs_sc which will have self-cal applied in future steps and self.obs_sc_init which is a copy of the initial + preimcal observation + + Args: + + Returns: + + """ + + # handle site name change of APEX between 2017 (AP) to 2018 (AX) + sites = list(self.obs.tkey.keys()) + + alma_options = ['ALMA', 'AA'] + for opt in alma_options: + if opt in sites: + alma = opt + + apex_options = ['APEX', 'AP', 'AX'] + for opt in apex_options: + if opt in sites: + apex = opt + + self.zbl_tot = np.median(self.obs.unpack_bl(alma, apex, 'amp')['amp']) + + + # Flag out sites in the obs.tarr table with no measurements + allsites = set(self.obs.unpack(['t1'])['t1']) | set(self.obs.unpack(['t2'])['t2']) + self.obs.tarr = self.obs.tarr[[o in allsites for o in self.obs.tarr['site']]] + self.obs = eh.obsdata.Obsdata(self.obs.ra, self.obs.dec, self.obs.rf, self.obs.bw, self.obs.data, self.obs.tarr, + source=self.obs.source, mjd=self.obs.mjd, + ampcal=self.obs.ampcal, phasecal=self.obs.phasecal) + + self.obs_orig = self.obs.copy() + + # Rescale short baselines to excise contributions from extended flux. + # setting zbl < zbl_tot assumes there is an extended constant flux component of zbl_tot-zbl Jy + + if self.zbl != self.zbl_tot: + for j in range(len(self.obs.data)): + if (self.obs.data['u'][j] ** 2 + self.obs.data['v'][j] ** 2) ** 0.5 >= self.uv_zblcut: continue + for field in ['vis', 'qvis', 'uvis', 'vvis', 'sigma', 'qsigma', 'usigma', 'vsigma']: + self.obs.data[field][j] *= self.zbl / self.zbl_tot + + self.obs.reorder_tarr_snr() + + self.obs_sc = self.obs.copy() + # Reverse taper the observation: this enforces a maximum resolution on reconstructed features + if self.reverse_taper_uas > 0: + self.obs_sc = self.obs_sc.reverse_taper(self.reverse_taper_uas) + + # Add non-closing systematic noise to the observation + self.obs_sc = self.obs_sc.add_fractional_noise(self.sys_noise) + + # Make a copy of the initial data (before any self-calibration but after the taper) + self.obs_sc_init = self.obs_sc.copy() + + def init_img(self): + + """ + Creates initial/prior image. Only gaussian prior option at present, but creates prior attritubute self.initimg + using self.zbl and self.prior_fwhm + + Args: + + Returns: + + """ + # create guassian prior/inital image + emptyprior = eh.image.make_square(self.obs_sc, self.npixels, self.fov) + + gaussprior = emptyprior.add_gauss(self.zbl, (self.prior_fwhm, self.prior_fwhm, 0, 0, 0)) + # To avoid gradient singularities in the first step, add an additional small Gaussian + gaussprior = gaussprior.add_gauss(self.zbl * 1e-3, (self.prior_fwhm, self.prior_fwhm, 0, + self.prior_fwhm, self.prior_fwhm)) + self.initimg = gaussprior.copy() + + def make_img(self): + + """ + Reconstructs image with specified parameters (data weights, amount of self-cal, etc) described in paramset dict + Creates attributes self.im_out containing the final image and self.caltab containing a corresponding calibration + table object for the final image + + Args: + + Returns: + + """ + + # specify data terms + data_term = {} + if hasattr(self, 'vis') and self.vis != 0.: + data_term['vis'] = self.vis + if hasattr(self, 'amp') and self.amp != 0.: + data_term['amp'] = self.amp + if hasattr(self, 'diag_closure') and self.diag_closure is True: + if hasattr(self, 'logcamp_diag') and self.logcamp_diag != 0.: + data_term['logcamp_diag'] = self.logcamp + if hasattr(self, 'cphase_diag') and self.cphase_diag != 0.: + data_term['cphase_diag'] = self.cphase + else: + if hasattr(self, 'logcamp') and self.logcamp != 0.: + data_term['logcamp'] = self.logcamp + if hasattr(self, 'cphase') and self.cphase != 0.: + data_term['cphase'] = self.cphase + + # specify regularizer terms + reg_term = {} + if hasattr(self, 'simple') and self.simple != 0.: + reg_term['simple'] = self.simple + if hasattr(self, 'tv2') and self.tv2 != 0.: + reg_term['tv2'] = self.tv2 + if hasattr(self, 'tv') and self.tv != 0.: + reg_term['tv'] = self.tv + if hasattr(self, 'l1') and self.l1 != 0.: + reg_term['l1'] = self.l1 + if hasattr(self, 'flux') and self.flux != 0.: + reg_term['flux'] = self.flux + if hasattr(self, 'rgauss') and self.rgauss != 0.: + reg_term['rgauss'] = self.rgauss + + ### How to make this more general? ### + # Add systematic noise tolerance for amplitude a-priori calibration errors + # Start with the SEFD noise (but need sqrt) + # then rescale to ensure that final results respect the stated error budget + systematic_noise = self.SEFD_error_budget.copy() + for key in systematic_noise.keys(): + systematic_noise[key] = ((1.0 + systematic_noise[key]) ** 0.5 - 1.0) * 0.25 + + # set up imager + imgr = eh.imager.Imager(self.obs_sc, self.initimg, prior_im=self.initimg, flux=self.zbl, + data_term=data_term, maxit=self.maxit, norm_reg=True, systematic_noise=systematic_noise, + reg_term=reg_term, ttype=self.ttype, cp_uv_min=self.uv_zblcut, stop=self.stop) + + res = self.obs.res() + + imgr.make_image_I(show_updates=False, niter=self.niter_static, blur_frac=self.blurfrac) + + if self.selfcal: + # Self-calibrate to the previous model (phase-only); + # The solution_interval is 0 to align phases from high and low bands if needed + self.obs_sc = eh.selfcal(self.obs_sc, imgr.out_last(), method='phase', ttype=self.ttype, + solution_interval=0.0, processes=-1) + + sc_p_idx = 0 + dterms = data_term.keys() + while sc_p_idx < self.sc_phase: + + # Blur the previous reconstruction to the intrinsic resolution + init = imgr.out_last().blur_circ(res) + + # Increase the data weights and reinitialize imaging + if sc_p_idx == 0: + for key in dterms: + data_term[key] *= self.xdw_phase + + # set up imager + imgr = eh.imager.Imager(self.obs_sc, init, prior_im=self.initimg, flux=self.zbl, + data_term=data_term, maxit=self.maxit, norm_reg=True, + systematic_noise=systematic_noise, + reg_term=reg_term, ttype=self.ttype, cp_uv_min=self.uv_zblcut, stop=self.stop) + + # Imaging + imgr.make_image_I(show_updates=False, niter=self.niter_static, blur_frac=self.blurfrac) + + # apply self-calibration to original calibrated data + self.obs_sc = eh.selfcal(self.obs_sc_init, imgr.out_last(), method='phase', ttype=self.ttype) + + sc_p_idx += 1 + + # repeat amp+phase self-calibration + sc_ap_idx = 0 + + while sc_ap_idx < self.sc_ap: + + # Blur the previous reconstruction to the intrinsic resolution + init = imgr.out_last().blur_circ(res) + + # Increase the data weights and reinitialize imaging + if sc_p_idx == 0: + for key in dterms: + data_term[key] *= self.xdw_ap + + # set up imager + imgr = eh.imager.Imager(self.obs_sc, init, prior_im=self.initimg, flux=self.zbl, + data_term=data_term, maxit=self.maxit, norm_reg=True, + systematic_noise=systematic_noise, + reg_term=reg_term, ttype=self.ttype, cp_uv_min=self.uv_zblcut, stop=self.stop) + + # Imaging + imgr.make_image_I(show_updates=False, niter=self.niter_static, blur_frac=self.blurfrac) + + caltab = eh.selfcal(self.obs_sc_init, imgr.out_last(), method='both', + ttype=self.ttype, gain_tol=self.gaintol, caltable=True, processes=-1) + self.obs_sc = caltab.applycal(self.obs_sc_init, interp='nearest', extrapolate=True) + + sc_ap_idx += 1 + + self.im_out = imgr.out_last().copy() + + # if no self-cal, no caltabs will be saved + if self.sc_phase == 0 and self.sc_ap == 0: + self.save_caltab = False + + else: + self.caltab = caltab + + def output_results(self): + + """ + Outputs all requested files pertaining to final image + + Args: + + Returns: + + """ + + # Add a large gaussian component to account for the missing flux + # so the final image can be compared with the original data + self.im_addcmp = self.im_out.add_zblterm(self.obs_orig, self.uv_zblcut, debias=True) + self.obs_sc_addcmp = eh.selfcal(self.obs_orig, self.im_addcmp, method='both', ttype=self.ttype) + + # If an inverse taper was used, restore the final image + # to be consistent with the original data + if self.reverse_taper_uas > 0.0: + self.im_out = self.im_out.blur_circ(self.reverse_taper_uas) + + # Save the final image + outfits = os.path.join(self.outpath, '%s.fits' % (self.outfile)) + self.im_out.save_fits(outfits) + + # Save caltab + if hasattr(self, 'save_caltab') and self.save_caltab == True: + outcal = os.path.join(self.outpath, '%s/' % (self.outfile)) + eh.caltable.save_caltable(self.caltab, self.obs_sc_init, outcal) + + # Save self-calibrated uvfits + if self.save_uvfits: + outuvfits = os.path.join(self.outpath, '%s.uvfits' % (self.outfile)) + self.obs_sc_addcmp.save_uvfits(outuvfits) + + # Save pdf of final image + if self.save_pdf: + outpdf = os.path.join(self.outpath, '%s.pdf' % (self.outfile)) + self.im_out.display(cbar_unit=['Tb'], label_type='scale', export_pdf=outpdf) + + # Save pdf of image summary + if self.save_imgsums: + # Save an image summary sheet + plt.close('all') + outimgsum = os.path.join(self.outpath, '%s_imgsum.pdf' % (self.outfile)) + eh.imgsum(self.im_addcmp, self.obs_sc_addcmp, self.obs_orig, outimgsum, cp_uv_min=self.uv_zblcut, + processes=-1) + + def save_statistics(self): + + """ + Saves a csv file with the following statistics: + chi^2 closure phase, logcamp, vis wrt the original observation + chi^2 vis wrt to original observation with self-cal to final image applied + chi^2 closure phase, logcamp, vis wrt the original observation with sys noise and self-cal applied + + Args: + + Returns: + + """ + stats_dict = {} + stats_dict['i'] = [self.i] + + outstats = os.path.join(self.outpath, '%s_stats.csv' % (self.outfile)) + + # if ground truth image available, compute nxcorr + if self.ground_truth_img != 'None': + gt_im = eh.image.load_fits(self.ground_truth_img) + + fov_ = 200 * eh.RADPERUAS + psize_ = fov_ / 256 + nxcorr_, _, _ = gt_im.compare_images(self.im_addcmp, metric='nxcorr', target_fov=fov_, psize=psize_) + nxcorr = nxcorr_[0] + + stats_dict['nxcorr'] = [nxcorr] + + # chi^2 for closure phase (cp) and log camp (lc) + # original uv data + obs_ref = self.obs_orig + chi2_cp_ref = obs_ref.chisq(self.im_addcmp, dtype='cphase', + ttype=self.ttype, systematic_noise=0., + systematic_cphase_noise=0, maxset=False, + cp_uv_min=self.uv_zblcut) + chi2_lc_ref = obs_ref.chisq(self.im_addcmp, dtype='logcamp', + ttype=self.ttype, systematic_noise=0., + snrcut=1.0, maxset=False, + cp_uv_min=self.uv_zblcut) # snrcut to remove large chi2 point + chi2_vis_ref = obs_ref.chisq(self.im_addcmp, dtype='vis', + ttype=self.ttype, systematic_noise=0., + snrcut=1.0, maxset=False, + cp_uv_min=self.uv_zblcut) # snrcut to remove large chi2 point + + stats_dict['chi2_cp_ref'] = [chi2_cp_ref] + stats_dict['chi2_lc_ref'] = [chi2_lc_ref] + stats_dict['chi2_vis_ref'] = [chi2_vis_ref] + + # orig data self-cal to final image + obs_sub = self.obs_sc_addcmp + chi2_vis_sub = obs_sub.chisq(self.im_addcmp, dtype='vis', + ttype=self.ttype, systematic_noise=0., + snrcut=1.0, maxset=False, + cp_uv_min=self.uv_zblcut) # snrcut to remove large chi2 point + + stats_dict['chi2_vis_sub'] = [chi2_vis_sub] + + # orig data with sys noise added + self-cal to final image + self.obs_sc_addcmp_sys = eh.selfcal(self.obs_sc_init, self.im_addcmp, method='both', ttype=self.ttype) + obs_sys = self.obs_sc_addcmp_sys + chi2_cp_sys = obs_sys.chisq(self.im_addcmp, dtype='cphase', + ttype=self.ttype, systematic_noise=0., + systematic_cphase_noise=0, maxset=False, + cp_uv_min=self.uv_zblcut) + chi2_lc_sys = obs_sys.chisq(self.im_addcmp, dtype='logcamp', + ttype=self.ttype, systematic_noise=0., + snrcut=1.0, maxset=False, + cp_uv_min=self.uv_zblcut) # snrcut to remove large chi2 point + chi2_vis_sys = obs_sys.chisq(self.im_addcmp, dtype='vis', + ttype=self.ttype, systematic_noise=0., + snrcut=1.0, maxset=False, + cp_uv_min=self.uv_zblcut) # snrcut to remove large chi2 point + + stats_dict['chi2_cp_sys'] = [chi2_cp_sys] + stats_dict['chi2_lc_sys'] = [chi2_lc_sys] + stats_dict['chi2_vis_sys'] = [chi2_vis_sys] + + df = pd.DataFrame.from_dict(stats_dict) + df.to_csv(outstats) + + def save_params(self): + """ + Saves a csv file with parameter set details + + Args: + + Returns: + + """ + self.paramset['fov'] = self.fov + self.paramset['zbl_tot'] = self.zbl_tot + + df = pd.DataFrame.from_dict([self.paramset]) + + outparams = os.path.join(self.outpath, '%s_params.csv' % (self.outfile)) + df.to_csv(outparams) + + def run(self): + + """ + Run imaging pipeline for one parameter set. + + Args: + + Returns: + + """ + + # if a *_params.csv file exists, it means this parameter has already been run and can be skipped + # useful in case survey with multiple parameter sets gets interrupted + outcsv = os.path.join(self.outpath, '%s_params.csv' % (self.outfile)) + + if not self.overwrite: + if os.path.exists(outcsv): + pass + + else: + + # load in data + self.load_data() + + # do pre-imaging calibration + self.preimcal() + + # create initial/prior image + self.init_img() + + # run imaging step + self.make_img() + + # output the results + self.output_results() + + # save params to text + self.save_params() + + if self.save_stats: + self.save_statistics() + + +def run_pset(pset, system_kwargs, params_fixed): + """ + Run imaging for one parameter set. Not to be used individually, but called in map function for run_survey + + Args: + pset (dict): A dict containing single parameter set + params_fixed (dict): A dict containing non-varying parameters + + Returns: + + """ + + PSet = ParameterSet(pset, params_fixed) + PSet.run() + +def run_survey(psets, params_fixed): + """Run survey for all parameter sets using paramsurvey + + Args: + psets (DataFrame): A pandas DataFrame containing all parameter sets + params_fixed (dict): A dict containing non-varying parameters + + Returns: + + """ + # run whole survey using map function + paramsurvey.init(backend=params_fixed['backend'], ncores=params_fixed['nproc'], + verbose=0,vstats=0) + paramsurvey.map(run_pset, psets, user_kwargs=params_fixed, verbose=0) + +def create_params_fixed(infile, outfile_base, outpath, ground_truth_img='None', + save_imgsums=False, save_uvfits=True, save_pdf=False, save_stats=True, save_caltab=True, + nproc=1, backend='multiprocessing', ttype='nfft', overwrite=False, + selfcal=True, gaintol=[0.02,0.2], niter_static=3, blurfrac=1, + maxit=100, stop=1e-4, fov=128, npixels=64, reverse_taper_uas=5, uv_zblcut=0.1e9, + SEFD_error_budget={'AA':0.1,'AX':0.1,'GL':0.1,'LM':0.1,'MG':0.1,'MM':0.1,'PV':0.1,'SW':0.1}): + """Create a dict for all non-varying survery parameters + + Args: + infile (str): path to input uvfits observation file + outfile_base (str): name of base filename for all outputs + outpath (str): path to directory where all outputs should be stored + ground_truth_img (str): if applicable, path to ground truth fits file + save_imgsums (bool): save summary pdf for each image + save_uvfits (bool): save final self-cal observation to uvfits file + save_pdf (bool): save pdf of each image + save_stats (bool): save csv file containing statistics for each image + save_caltab (bool): save a calibration table for each image + nproc (int): number of parallel processes + backend (str): either 'multiprocessing' or 'ray' + ttype (str): “fast” or “nfft” or “direct” + overwrite (bool): if True, write over existing files with same names - else, skip parameter set + selfcal (bool): perform self-calibration steps during imaging + gaintol (array): tolerance for gains to be under/over unity respectively during self-cal + niter_static (int): number of iterations for each imaging step + blurfrac (int): factor to blur initial image between iterations + maxit (int): maximum number of iterations if image does not converge + stop (float): convergence criterion for imaging + fov (int): image field of view in uas + npixels (int): number of image pixels + reverse_taper_uas (int): fwhm of gassuain in uas to reverse taper observation + uv_zblcut (float): maximum uv-distance to which is considered short baseline flux + SEFD_error_budget (dict): SEFD percentage error for each station + + + Returns: + (dict): dict containing all non-varying survery parameters + """ + + # take all arguments and put them in a dict + args = list(locals().keys()) + params_fixed = {} + for arg in args: + params_fixed[arg] = locals().get(arg) + + return params_fixed + +def create_survey_psets(zbl=[0.6], sys_noise=[0.02], avg_time=['scan'], prior_fwhm=[40], + sc_phase=[0], xdw_phase=[10], sc_ap=[0], xdw_ap=[1], amp=[0.2], cphase=[1], logcamp=[1], + simple=[1], l1=[1], tv=[1], tv2=[1], flux=[1], epsilon_tv=[1e-10]): + """Create a dataframe given all survey parameters. Default values will create an example dataframe but these values should be adjusted for each specific observation + + Args: + zbl (array): compact flux value (Jy) + sys_noise (array): percent addition of systematic noise + avg_time (array): in seconds or 'scan' for scan averaging + prior_fwhm (array): fwhm of gaussian prior in uas + sc_phase (array): number of iterations to perform phase-only self-cal + xdw_phase (array): multiplicative factor for data weights after one round of phase-only self-cal + sc_ap (array): number of iterations to perform amp+phase self-cal + xdw_ap (array): multiplicative factor for data weights after one round of amp+phase self-cal + amp (array): data weight to be placed on amplitudes + cphase (array): data weight to be placed on closure phases + logcamp (array): data weight to be placed on log closure amplitudes + simple (array): regularizer weight for relative entropy, favoring similarity to prior image + l1 (array): regularizer weight for l1 norm, favoring image sparsity + tv (array): regularizer weight for total variation, favoring sharp edges + tv2 (array): regularizer weight for total squared variation, favoring smooth edges + flux (array): regularizer weight for total flux density, favoring final images with flux close to zbl + epsilon_tv (array): epsilon value used in definition of total variation - rarely need to change this + + Returns: + (DataFrame): pandas DataFrame containing all combination of parameter sets along with index values + """ + + # take all arguments and put them in a dict + args = list(locals().keys()) + params = {} + for arg in args: + params[arg] = locals().get(arg) + + # pandas dataframe containing all combinations of parameters to survey + psets = paramsurvey.params.product(params) + + # add pset index number to each row of dataframe + psets['i'] = np.array(range(len(psets))) + + return psets diff --git a/vex.py b/vex.py new file mode 100644 index 00000000..38a42e43 --- /dev/null +++ b/vex.py @@ -0,0 +1,332 @@ +# vex.py +# a interferometric array vex schedule class +# +# Copyright (C) 2018 Hotaka Shiokawa +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + + +from __future__ import division +from __future__ import print_function + +from builtins import str +from builtins import range +from builtins import object + +import numpy as np +import re + +from astropy.time import Time +import os +import ehtim.array +import ehtim.const_def as ehc + +################################################################################################### +# Vex object +################################################################################################### + + +class Vex(object): + """Read in observing schedule data from .vex files. + Assumes there is only 1 MODE in vex file + + Attributes: + filename (str): The .vex filename. + source (str): The source name. + metalist (list): The observation information. + sched (list): The schedule information. + array (Array): an Array object of sites. + """ + + def __init__(self, filename): + + f = open(filename) + raw = f.readlines() + f.close() + + self.filename = filename + + # Divide 'raw' data into sectors of '$' marks + # ASSUMING '$' is the very first character in a line (no space in front) + metalist = [] # meaning list of metadata + + for i in range(len(raw)): + if raw[i][0] == '$': + temp = [raw[i]] + break + + for j in range(i + 1, len(raw)): + if raw[j][0] != '$': + temp.append(raw[j]) + elif raw[j][0] == '$': + metalist.append(temp) + temp = [raw[j]] + else: + print('Something is wrong.') + metalist.append(temp) # don't forget to add the final one + self.metalist = metalist + + # Extract desired information + # SOURCE ======================================================== + SOURCE = self.get_sector('SOURCE') + source = [] + indef = False + + for i in range(len(SOURCE)): + + line = SOURCE[i] + if line[0:3] == "def": + indef = True + + if indef: + ret = self.get_variable("source_name", line) + if len(ret) > 0: + source_name = ret + ret = self.get_variable("ra", line) + if len(ret) > 0: + ra = ret + ret = self.get_variable("dec", line) + if len(ret) > 0: + dec = ret + ret = self.get_variable("ref_coord_frame", line) + if len(ret) > 0: + ref_coord_frame = ret + + if line[0:6] == "enddef": + source.append({'source': source_name, 'ra': ra, 'dec': dec, + 'ref_coord_frame': ref_coord_frame}) + indef = False + + self.source = source + + # FREQ ========================================================== + FREQ = self.get_sector('FREQ') + indef = False + nfreq = 0 + for i in range(len(FREQ)): + + line = FREQ[i] + if line[0:3] == "def": + if nfreq > 0: + print("Not implemented yet.") + nfreq += 1 + indef = True + + if indef: + idx = line.find('chan_def') + if idx >= 0 and line[0] != '*': + chan_def = re.findall(r"[-+]?\d+[\.]?\d*", line) + self.freq = float(chan_def[0]) * 1.e6 + self.bw_hz = float(chan_def[1]) * 1.e6 + + if line[0:6] == "enddef": + indef = False + + # SITE ========================================================== + SITE = self.get_sector('SITE') + sites = [] + site_ID_dict = {} + indef = False + + for i in range(len(SITE)): + + line = SITE[i] + if line[0:3] == "def": + indef = True + + if indef: + # get site_name and SEFD + ret = self.get_variable("site_name", line) + if len(ret) > 0: + site_name = ret + SEFD = self.get_SEFD(site_name) + + # making dictionary of site_ID:site_name + ret = self.get_variable("site_ID", line) + if len(ret) > 0: + site_ID_dict[ret] = site_name + + # get site_position + ret = self.get_variable("site_position", line) + if len(ret) > 0: + site_position = re.findall(r"[-+]?\d+[\.]?\d*", line) + + # same format as Andrew's array tables + if line[0:6] == "enddef": + sites.append([site_name, site_position[0], + site_position[1], site_position[2], SEFD]) + indef = False + + # Construct Array() object of Andrew's format + # mimic the function "load_array(filename)" + # TODO this does not store d-term and pol cal. information! + tdataout = [np.array((x[0], float(x[1]), float(x[2]), float(x[3]), float(x[4]), float(x[4]), + 0.0, 0.0, 0.0, 0.0, 0.0), + dtype=ehc.DTARR) for x in sites] + + tdataout = np.array(tdataout) + self.array = ehtim.array.Array(tdataout) + + # SCHED ========================================================= + SCHED = self.get_sector('SCHED') + sched = [] + inscan = False + + for i in range(len(SCHED)): + + line = SCHED[i] + if line[0:4] == "scan": + inscan = True + temp = {} + temp['scan'] = {} + cnt = 0 + + if inscan: + ret = self.get_variable("start", line) + if len(ret) > 0: + mjd, hr = vexdate_to_MJD_hr(ret) # convert vex time format to mjd and hour + temp['mjd_floor'] = mjd + temp['start_hr'] = hr + + ret = self.get_variable("mode", line) + if len(ret) > 0: + temp['mode'] = ret + + ret = self.get_variable("source", line) + if len(ret) > 0: + temp['source'] = ret + + ret = self.get_variable("station", line) + if len(ret) > 0: + site_ID = ret + site_name = site_ID_dict[site_ID] # convert to more familier site name + sdur = re.findall(r"[-+]?\d+[\.]?\d*", line) + s_st = float(sdur[0]) # start time in sec + s_en = float(sdur[1]) # end time in sec + d_size = float(sdur[2]) # data size(?) in GB + temp['scan'][cnt] = {'site': site_name, 'scan_sec_start': s_st, + 'scan_sec': s_en, 'data_size': d_size} + cnt += 1 + + if line[0:7] == "endscan": + sched.append(temp) + inscan = False + + self.sched = sched + + # Function to obtain a desired sector from 'metalist' + + def get_sector(self, sname): + """Obtain a desired sector from 'metalist'. + """ + + for i in range(len(self.metalist)): + if sname in self.metalist[i][0]: + return self.metalist[i] + print('No sector named %s' % sname) + return False + + # Function to get a value of 'vname' in a line which has format of + # 'vname' = value ;(or :) + def get_variable(self, vname, line): + """Function to get a value of 'vname' in a line. + """ + + idx = self.find_variable(vname, line) + name = '' + if idx >= 0: + start = False + for i in range(idx + len(vname), len(line)): + if start is True: + if line[i] == ';' or line[i] == ':': + break + elif line[i] != ' ': + name += line[i] + if start is False and line[i] != ' ' and line[i] != '=': + break + if line[i] == '=': + start = True + return name + + # check if a variable 'vname' exists by itself in a line. + # returns index of vname[0] in a line, or -1 + def find_variable(self, vname, line): + """Function to find a variable 'vname' in a line. + """ + idx = line.find(vname) + if ((idx > 0 and line[idx - 1] == ' ') or idx == 0) and line[0] != '*': + if idx + len(vname) == len(line): + return idx + if (line[idx + len(vname)] == '=' or + line[idx + len(vname)] == ' ' or + line[idx + len(vname)] == ':' or + line[idx + len(vname)] == ';'): + return idx + return -1 + + # Find SEFD for a given station name. + # For now look for it in Andrew's tables + # Vex files could have SEFD sector. + def get_SEFD(self, station): + """Find SEFD for a given station. + """ + f = open(os.path.dirname(os.path.abspath(__file__)) + "/../arrays/SITES.txt") + sites = f.readlines() + f.close() + for i in range(len(sites)): + if sites[i].split()[0] == station: + return float(re.findall(r"[-+]?\d+[\.]?\d*", sites[i])[3]) + print('No station named %s' % station) + return 10000. # some arbitrary value + + # Find the time that any station starts observing the source in MJD. + # Find the time that the last station stops observing the source in MHD. + def get_obs_timerange(self, source): + """Find the time that any station starts observing the source in MJD, + and the time that the last station stops observing the source. + """ + + sched = self.sched + first = True + for i_scan in range(len(sched)): + if sched[i_scan]['source'] == source and first is True: + Tstart_hr = sched[i_scan]['start_hr'] + mjd_s = sched[i_scan]['mjd_floor'] + Tstart_hr / 24. + first = False + if sched[i_scan]['source'] == source and first is False: + Tstop_hr = sched[i_scan]['start_hr'] + sched[i_scan]['scan'][0]['scan_sec'] / 3600. + mjd_e = sched[i_scan]['mjd_floor'] + Tstop_hr / 24. + + return mjd_s, mjd_e + +# ================================================================= +# ================================================================= + +# Function to find MJD (int!) and hour in UT from vex format, +# e.g, 2016y099d05h00m00s + + +def vexdate_to_MJD_hr(vexdate): + """Find the integer MJD and UT hour from vex format date. + """ + + time = re.findall(r"[-+]?\d+[\.]?\d*", vexdate) + year = int(time[0]) + date = int(time[1]) + yeardatetime = ("%04i" % year) + ':' + ("%03i" % date) + ":00:00:00.000" + t = Time(yeardatetime, format='yday') + mjd = t.mjd + hour = int(time[2]) + float(time[3]) / 60. + float(time[4]) / 60. / 60. + + return mjd, hour