diff --git a/galaxy2galaxy/bin/g2g-exporter b/galaxy2galaxy/bin/g2g-exporter index e332d06..53d851f 100755 --- a/galaxy2galaxy/bin/g2g-exporter +++ b/galaxy2galaxy/bin/g2g-exporter @@ -16,6 +16,7 @@ from tensor2tensor.utils import decoding from tensor2tensor.utils import t2t_model from tensor2tensor.utils import trainer_lib from tensor2tensor.utils import usr_dir +from tensor2tensor.utils import registry from galaxy2galaxy import models from galaxy2galaxy import problems @@ -101,7 +102,12 @@ def main(_): estimator = create_estimator(run_config, hparams) hparams.img_len = problem.get_hparams().img_len - hparams.attributes = problem.get_hparams().attributes + try: + if len(hparams.attributes) == 0: + hparams.attributes = hparams.problem.get_hparams().attributes + except: + pass + # Use tf hub to export any module that has been registered exporter = hub.LatestModuleExporter("tf_hub", @@ -115,4 +121,4 @@ def main(_): if __name__ == "__main__": tf.logging.set_verbosity(tf.logging.INFO) - tf.app.run() + tf.app.run() \ No newline at end of file diff --git a/galaxy2galaxy/data_generators/all_problems.py b/galaxy2galaxy/data_generators/all_problems.py index fd8064c..38a4eda 100644 --- a/galaxy2galaxy/data_generators/all_problems.py +++ b/galaxy2galaxy/data_generators/all_problems.py @@ -14,6 +14,7 @@ try: import galsim MODULES += ["galaxy2galaxy.data_generators.cosmos"] + MODULES += ["galaxy2galaxy.data_generators.candels"] except: print("Could not import GalSim, excluding some data generators") diff --git a/galaxy2galaxy/data_generators/candels.py b/galaxy2galaxy/data_generators/candels.py new file mode 100644 index 0000000..d8bcb9f --- /dev/null +++ b/galaxy2galaxy/data_generators/candels.py @@ -0,0 +1,534 @@ +""" HSC Datasets """ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from . import hsc_utils +from . import astroimage_utils + +from tensor2tensor.data_generators import generator_utils +from tensor2tensor.data_generators import image_utils +from tensor2tensor.data_generators import problem +from tensor2tensor.layers import modalities, common_layers +from tensor2tensor.utils import metrics + +from galaxy2galaxy.utils import registry + +from scipy.ndimage import gaussian_filter + +import tensorflow as tf +import numpy as np +import fits2hdf.pyhdfits as fits +from astropy.table import Table +from astropy.visualization import make_lupton_rgb +import h5py +import glob +import os +import sys +import galsim +from skimage.transform import resize,rescale +from scipy.ndimage import binary_dilation # type: ignore +from astropy.table import Table +from scipy.ndimage import rotate +from scipy.spatial import KDTree +import sep + +def _resize_image(im, size): + centh = im.shape[0]/2 + centw = im.shape[1]/2 + lh, rh = int(centh-size/2), int(centh+size/2) + lw, rw = int(centw-size/2), int(centw+size/2) + cropped = im[lh:rh, lw:rw, :] + assert cropped.shape[0]==size and cropped.shape[1]==size, f"Wrong size! Still {cropped.shape}" + return cropped + + +@registry.register_problem +class Img2imgCandelsMultires(astroimage_utils.AstroImageProblem): + """ Base class for image problems with the CANDELS catalog, with multiresolution images. + """ + + @property + def dataset_splits(self): + """Splits of data to produce and number of output shards for each. + . + """ + return [{ + "split": problem.DatasetSplit.TRAIN, + "shards": 20, + }, { + "split": problem.DatasetSplit.EVAL, + "shards": 2, + }] + + @property + def multiprocess_generate(self): + """Whether to generate the data in multiple parallel processes.""" + return True + + # START: Subclass interface + def hparams(self, defaults, model_hparams): + p = defaults + p.img_len = 128 + p.sigmas = {"high" : [1e-4], "low" : [4.0e-3]} + p.filters = {"high" : ['acs_f814w'], "low" : ['wfc3_f160w']} + p.resolutions = ["high","low"] + p.modality = {"inputs": modalities.ModalityType.IDENTITY, + "targets": modalities.ModalityType.IDENTITY} + p.vocab_size = {"inputs": None, + "targets": None} + p.add_hparam("psf", None) + + @property + def num_bands(self): + """Number of bands.""" + p = self.get_hparams() + return np.sum([len(p.filters[res]) for res in p.resolutions]) + + def generator(self, data_dir, tmp_dir, dataset_split, task_id=-1): + """ + Generator yielding individual postage stamps. + """ + print(task_id) + + p = self.get_hparams() + band_num = np.sum([len(p.filters[res]) for res in p.resolutions]) + scalings = {} + for res in p.resolutions: + scalings[res] = p.pixel_scale[res]/p.base_pixel_scale[res] + target_pixel_scale = p.pixel_scale[p.resolutions[0]] + target_scaling = target_pixel_scale/p.base_pixel_scale[p.resolutions[0]] + target_size = p.img_len + + '''Load the catalogue containing every fields and every filter''' + all_cat = Table.read(os.path.join(data_dir, 'CANDELS_morphology_v8_3dhst_galfit_ALLFIELDS.fit')) + all_cat['FIELD_1'][np.where(all_cat['FIELD_1']=='gdn ')] = 'GDN' + all_cat['FIELD_1'][np.where(all_cat['FIELD_1']=='egs ')] = 'EGS' + all_cat['FIELD_1'][np.where(all_cat['FIELD_1']=='GDS ')] = 'GDS' + all_cat['FIELD_1'][np.where(all_cat['FIELD_1']=='UDS ')] = 'UDS' + all_cat['FIELD_1'][np.where(all_cat['FIELD_1']=='COSMOS ')] = 'COSMOS' + + ''' Load the psfs for each filter and resize''' + cube_psf = np.zeros((2*p.img_len, 2*p.img_len // 2 + 1, band_num)) + interp_factor=2 + padding_factor=1 + Nk = p.img_len*interp_factor*padding_factor + bounds = galsim.bounds._BoundsI(0, Nk//2, -Nk//2, Nk//2-1) + k = 0 + for res in p.resolutions: + cube_psf_tmp = np.zeros((2*p.img_len, 2*p.img_len // 2 + 1, len(p.filters[res]))) + for i, filt in enumerate(p.filters[res]): + psf = galsim.InterpolatedImage(data_dir + '/psfs/psf_' + filt +'.fits',scale=0.06) + + imCp = psf.drawKImage(bounds=bounds, + scale=2.*np.pi/(Nk * p.pixel_scale[res] / interp_factor), + recenter=False) + + # Transform the psf array into proper format, remove the phase + im_psf = np.abs(np.fft.fftshift(imCp.array, axes=0)).astype('float32') + cube_psf_tmp[:, :, i] = im_psf + cube_psf_tmp = resize(cube_psf_tmp, (2*p.img_len, 2*p.img_len // 2 + 1,len(p.filters[res]))) + cube_psf[:,:,k:k+len(p.filters[res])] = cube_psf_tmp + k += len(p.filters[res]) + + psf = cube_psf + + sigmas = p.sigmas + + # Step 2: Extract postage stamps, resize them to requested size + n_gal_creat = 0 + index = 0 + + ''' Create a subcat containing only the galaxies (in every filters) of the current field''' + sub_cat = all_cat[np.where(np.isin(list(all_cat["FIELD_1"]),["GDS","GDN","EGS","COSMOS","UDS"]))] + sub_cat = sub_cat[np.where(sub_cat['mag'] <= 25.3)] + assert(task_id > -1) + n_shards = self.dataset_splits[0]["shards"] + self.dataset_splits[1]["shards"] + indexes = list(range(task_id*len(sub_cat)//n_shards, + min((task_id+1)*len(sub_cat)//n_shards, len(sub_cat)))) + sub_cat = sub_cat[indexes] + + ''' Loop on all the galaxies of the field ''' + for m,gal in enumerate(sub_cat['RB_ID']): + if gal == index or gal == 15431 or sub_cat["mag"][m] < 0: # To take care of the redudency inside the cat + continue + index = gal + target_flux_main_band = 10**(-0.4*(sub_cat['mag'][m]-p.zeropoint)) + + try: + ''' Loop on the filters ''' + im = np.zeros((target_size, target_size, band_num)) + + k = 0 + for res in p.resolutions: + im_tmp = np.zeros((128, 128, len(p.filters[res]))) + for n_filter, filt in enumerate(p.filters[res]): + ''' Open the image corresponding to the index of the current galaxy''' + + tmp_file = glob.glob(os.path.join(data_dir, sub_cat["FIELD_1"][m], filt)+'/galaxy_'+str(index)+'_*')[0] + im_import = fits.open(tmp_file)[0].data + cleaned_image = clean_rotate_stamp(im_import,sigma_sex=1.5)#,noise_level=p.sigmas[res][n_filter]) + + if res == p.resolutions[0] and n_filter == 0: + flux_ratio = 1/np.max(cleaned_image) if np.max(cleaned_image) != 0 else 1 + + im_tmp[:, :, n_filter] = cleaned_image * flux_ratio + if np.max(cleaned_image) <= 5*10**(-3): + raise ValueError("Very weak image") + + ''' Resize the image to the low resolution''' + new_size = np.ceil(128/scalings[res])+1 + im_tmp = resize(im_tmp, (new_size, new_size, len(p.filters[res]))) + ''' Resize the image to the highest resolution to get consistent array sizes''' + im_tmp = rescale(im_tmp,p.pixel_scale[res]/target_pixel_scale,multichannel=True,preserve_range=True) + im_tmp = _resize_image(im_tmp,target_size) + + im[:,:,k:k+len(p.filters[res])] = im_tmp + k += len(p.filters[res]) + + im = _resize_image(im, p.img_len) + + # Check that there is still a galaxy + img_s = im[:,:,0] + img_s = img_s = img_s.copy(order='C') + bkg = sep.Background(img_s) + cat_s = sep.extract(img_s-bkg,2,err=bkg.globalrms) + if len(cat_s) == 0: + raise ValueError('No galaxy detected in the field') + + ''' Load the wanted physical parameters of the galaxy ''' + if hasattr(p, 'attributes'): + attributes = {k: float(sub_cat[k][m]) for k in p.attributes} + + else: + attributes=None + + ''' Create the power spectrum ''' + k = 0 + noise_im = np.zeros((p.img_len, p.img_len, band_num)) + for res in p.resolutions: + for n_filter in range(len(p.filters[res])): + if False: + noise_im[:, :, n_filter+k] = np.random.normal(0, bkg.globalrms, (p.img_len, p.img_len)) + else: + noise_im[:, :, n_filter+k] = np.random.normal(0, p.sigmas[res][n_filter], (p.img_len, p.img_len)) + k+=1 + noise_im = np.transpose(noise_im,[2,0,1]) + ps = np.abs(np.fft.rfft2(noise_im)) + ps = np.transpose(ps,[1,2,0]) + + ''' Add a flag corresponding to the field ''' + field_info = np.asarray(1 if sub_cat["FIELD_1"][m] == "GDS" else 0) + + sigmas_array = [] + for res in p.resolutions: + sigmas_array += sigmas[res] + sigmas_array = np.array(sigmas_array) + + ''' Create the output to match T2T format ''' + serialized_output = {"image/encoded": [im.astype('float32').tostring()], + "image/format": ["raw"], + "psf/encoded": [psf.astype('float32').tostring()], + "psf/format": ["raw"], + "ps/encoded": [ps.astype('float32').tostring()], + "ps/format": ["raw"], + "sigma_noise/encoded": [sigmas_array.astype('float32').tostring()], + "sigma_noise/format": ["raw"], + "field/encoded": [field_info.astype('float32').tostring()], + "field/format": ["raw"]} + + if attributes is not None: + for k in attributes: + serialized_output['attrs/'+k] = [attributes[k]] + + ''' Increment the number of galaxy created on the shard ''' + n_gal_creat += 1 + + if n_gal_creat > p.example_per_shard: + print('out ',n_gal_creat) + break + yield serialized_output + except Exception: + print(sys.exc_info()[0], sys.exc_info()[1]) + continue + + def preprocess_example(self, example, unused_mode, unused_hparams): + """ Preprocess the examples, can be used for further augmentation or + image standardization. + """ + p = self.get_hparams() + image = example["inputs"] + + # Clip to 1 the values of the image + # image = tf.clip_by_value(image, -1, 1) + + # Aggregate the conditions + if hasattr(p, 'attributes'): + example['attributes'] = tf.stack([example[k] for k in p.attributes]) + + + example["inputs"] = image + example["targets"] = image + return example + + def example_reading_spec(self): + """ + Define how data is serialized to file and read back. + + Returns: + data_fields: A dictionary mapping data names to its feature type. + data_items_to_decoders: A dictionary mapping data names to TF Example + decoders, to be used when reading back TF examples + from disk. + """ + p = self.get_hparams() + + data_fields = { + "image/encoded": tf.FixedLenFeature((), tf.string), + "image/format": tf.FixedLenFeature((), tf.string), + + "psf/encoded": tf.FixedLenFeature((), tf.string), + "psf/format": tf.FixedLenFeature((), tf.string), + + "ps/encoded": tf.FixedLenFeature((), tf.string), + "ps/format": tf.FixedLenFeature((), tf.string), + + "sigma_noise/encoded": tf.FixedLenFeature((), tf.string), + "sigma_noise/format": tf.FixedLenFeature((), tf.string), + + "field/encoded": tf.FixedLenFeature((), tf.string), + "field/format": tf.FixedLenFeature((), tf.string), + } + + # Adds additional attributes to be decoded as specified in the configuration + if hasattr(p, 'attributes'): + for k in p.attributes: + data_fields['attrs/'+k] = tf.FixedLenFeature([], tf.float32, -1) + data_items_to_decoders = { + "inputs": tf.contrib.slim.tfexample_decoder.Image( + image_key="image/encoded", + format_key="image/format", + shape=[p.img_len, p.img_len, self.num_bands], + dtype=tf.float32), + + "psf": tf.contrib.slim.tfexample_decoder.Image( + image_key="psf/encoded", + format_key="psf/format", + shape=[2*p.img_len, 2*p.img_len // 2 + 1, self.num_bands], + dtype=tf.float32), + + "ps": tf.contrib.slim.tfexample_decoder.Image( + image_key="ps/encoded", + format_key="ps/format", + shape=[p.img_len, p.img_len//2+1, self.num_bands], + dtype=tf.float32), + + "sigma_noise": tf.contrib.slim.tfexample_decoder.Image( + image_key="sigma_noise/encoded", + format_key="sigma_noise/format", + shape=[self.num_bands], + dtype=tf.float32), + + "field": tf.contrib.slim.tfexample_decoder.Image( + image_key="field/encoded", + format_key="field/format", + shape=[1], + dtype=tf.float32), + } + if hasattr(p, 'attributes'): + for k in p.attributes: + data_items_to_decoders[k] = tf.contrib.slim.tfexample_decoder.Tensor('attrs/'+k) + + return data_fields, data_items_to_decoders +# END: Subclass interface + + + @property + def is_generate_per_split(self): + return False + + + +@registry.register_problem +class Attrs2imgCandelsEuclid64(Img2imgCandelsMultires): + """For generating images with the Euclid bands + """ + + def eval_metrics(self): + eval_metrics = [ ] + return eval_metrics + + def hparams(self, defaults, model_hparams): + p = defaults + p.pixel_scale = {'high' : 0.1, 'low' : 0.3} + p.base_pixel_scale = {'high' : 0.06,'low' : 0.06} + p.img_len = 64 + p.sigmas = {"high" : [1e-4], "low" : [0.003954237367399534, 0.003849901319445, 0.004017507500562]} + p.filters = {"high" : ['acs_f814w'], "low" : ['f105w', 'f125w', 'wfc3_f160w']} + p.resolutions = ["high","low"] + p.example_per_shard = 2000 + p.modality = {"inputs": modalities.ModalityType.IDENTITY, + "attributes": modalities.ModalityType.IDENTITY, + "targets": modalities.ModalityType.IDENTITY} + p.vocab_size = {"inputs": None, + "attributes": None, + "targets": None} + p.attributes = ['mag', 're', 'q'] + + + +@registry.register_problem +class Attrs2imgCandelsEuclid64TwoBands(Img2imgCandelsMultires): + """ For generating two-band images (visible and infrared) + """ + + def eval_metrics(self): + eval_metrics = [ ] + return eval_metrics + + def hparams(self, defaults, model_hparams): + p = defaults + p.pixel_scale = {'high' : 0.1, 'low' : 0.1} + p.base_pixel_scale = {'high' : 0.06,'low' : 0.06} + p.img_len = 64 + p.sigmas = {"high" : [0.004094741966557142], "low" : [0.004017507500562]} + p.filters = {"high" : ['acs_f606w'], "low" : ['wfc3_f160w']} + p.zeropoint = 26.49 + p.resolutions = ["high","low"] + p.example_per_shard = 1000 + p.modality = {"inputs": modalities.ModalityType.IDENTITY, + "attributes": modalities.ModalityType.IDENTITY, + "targets": modalities.ModalityType.IDENTITY} + p.vocab_size = {"inputs": None, + "attributes": None, + "targets": None} + p.attributes = ['mag','re', 'q','ZPHOT','F_IRR','F_SPHEROID','F_DISK'] + + +def find_central(sex_cat,center_coords=64): + """Find the central galaxy in a catalog provided by SExtractor + """ + n_detect = len(sex_cat) + + ''' Match the pred and true cat''' + pred_pos = np.zeros((n_detect, 2)) + pred_pos[:, 0] = sex_cat['x'] + pred_pos[:, 1] = sex_cat['y'] + + true_pos = np.zeros((1, 2)) + true_pos[:, 0] = center_coords + true_pos[:, 1] = center_coords + + _, match_index = KDTree(pred_pos).query(true_pos) + + return match_index + + +import re + +def sort_nicely( l ): + """ Sort the given list in the way that humans expect. + """ + convert = lambda text: int(text) if text.isdigit() else text + alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] + l.sort( key=alphanum_key ) + return l + + +def mask_out_pixels(img, segmap, segval, + n_iter: int = 5, shuffle: bool = False, + noise_factor: int = 1, noise_level=None): + """ + Replace central galaxy neighbours with background noise + + Basic recipe to replace the detected sources around the central galaxy + with either randomly selected pixels from the background, or a random + realisation of the background noise. + + """ + masked_img = img.copy() + # Create binary masks of all segmented sources + sources = binary_dilation(segmap, iterations=n_iter) + + background_mask = np.logical_and(np.logical_not(sources),np.array(img,dtype=bool)) + # Create binary mask of the central galaxy + central_source = binary_dilation(np.where(segmap == segval, 1, 0), + iterations=n_iter) + # Compute the binary mask of all sources BUT the central galaxy + sources_except_central = np.logical_xor(sources, central_source) + + if shuffle: + # Select random pixels from the noise in the image + n_pixels_to_fill_in = sources_except_central.sum() + random_background_pixels = np.random.choice( + img[background_mask], + size=n_pixels_to_fill_in + ) + # Fill in the voids with these pixels + masked_img[sources_except_central] = random_background_pixels + else: + # Create a realisation of the background for the std value + if noise_level == None: + background_std = np.std(img[background_mask]) + else: + background_std = noise_level + random_background = np.random.normal(scale=background_std, size=img.shape) + masked_img[sources_except_central] = random_background[sources_except_central] + masked_img[np.where(masked_img==0.0)] = random_background[np.where(masked_img==0.0)] + + return masked_img.astype(img.dtype), sources, background_mask, central_source, sources_except_central + +def clean_rotate_stamp(img, eps=5, sigma_sex=2, noise_level=None, rotate_b=False, blend_threshold=0.1): + """Clean images by removing galaxies other than the central one. + """ + + # Detect galaxies with SExtractor + img = img.byteswap().newbyteorder() + im_size = img.shape[0] + bkg = sep.Background(img) + + cat,sex_seg = sep.extract(img-bkg,sigma_sex,err=bkg.globalrms,segmentation_map=True) + + if len(cat) == 0: + raise ValueError('No galaxy detected in the field') + + middle_pos = [cat[find_central(cat,im_size//2)[0]]['x'],cat[find_central(cat,im_size//2)[0]]['y']] + + distance = np.sqrt((middle_pos[0]-im_size//2)**2 + (middle_pos[1]-im_size//2)**2) + if distance > 10 : + raise ValueError('No galaxy detected in the center') + + middle = np.max(sex_seg[int(round(middle_pos[0]))-eps:int(round(middle_pos[0]))+eps, int(round(middle_pos[1]))-eps:int(round(middle_pos[1]))+eps]) + if middle == 0: + raise ValueError('No galaxy detected in the center') + + cleaned, _, _, central, _ = mask_out_pixels(img, sex_seg, middle,n_iter=5,noise_level=noise_level) + + blended_pixels = np.logical_and(np.not_equal(sex_seg,0),np.not_equal(sex_seg,middle))*central + blend_flux = np.sum(img[np.nonzero(blended_pixels)]) + if np.any(blended_pixels): + loc = np.argwhere(blended_pixels==True) + blended_galaxies = np.unique(sex_seg[loc]) + for blended_galaxy in blended_galaxies: + blended_galaxy_flux = np.sum(img[np.where(sex_seg==blended_galaxy)]) + if blend_flux/blended_galaxy_flux > blend_threshold: + raise ValueError('Blending suspected') + + # Rotate + if rotate_b: + PA = cat[find_central(cat)[0]][4] + img_rotate = rotate(cleaned, PA, reshape=False) + else: + img_rotate = cleaned + + # Add noise + background_mask = np.logical_and(np.logical_not(sex_seg==0),np.array(img,dtype=bool)) + if noise_level == None: + background_std = np.std(img * background_mask) + else: + background_std = noise_level + random_background = np.random.normal(scale=background_std, size=img_rotate.shape) + rotated = np.where(img_rotate == 0, random_background, img_rotate) + + return rotated diff --git a/galaxy2galaxy/data_generators/cosmos.py b/galaxy2galaxy/data_generators/cosmos.py index 732ede6..d6d4d68 100644 --- a/galaxy2galaxy/data_generators/cosmos.py +++ b/galaxy2galaxy/data_generators/cosmos.py @@ -52,6 +52,7 @@ def hparams(self, defaults, model_hparams): p.pixel_scale = 0.03 p.img_len = 64 p.example_per_shard = 1000 + p.modality = {"inputs": modalities.ModalityType.IDENTITY, "targets": modalities.ModalityType.IDENTITY} p.vocab_size = {"inputs": None, @@ -69,6 +70,7 @@ def generator(self, data_dir, tmp_dir, dataset_split, task_id=-1): Generates and yields postage stamps obtained with GalSim. """ p = self.get_hparams() + try: # try to use default galsim path to the data catalog = galsim.COSMOSCatalog() @@ -106,6 +108,7 @@ def generator(self, data_dir, tmp_dir, dataset_split, task_id=-1): cat_param = append_fields(cat_param, 'sersic_n', sparams[:,2]) cat_param = append_fields(cat_param, 'sersic_beta', sparams[:,7]) + for ind in index: # Draw a galaxy using GalSim, any kind of operation can be done here gal = catalog.makeGalaxy(ind, noise_pad_size=p.img_len * p.pixel_scale*2) @@ -247,6 +250,46 @@ def hparams(self, defaults, model_hparams): p.vocab_size = {"inputs": None, "targets": None} + +@registry.register_problem +class Img2imgCosmos64(Img2imgCosmos): + """ Smaller version of the Img2imgCosmos problem, at half the pixel + resolution + """ + + def hparams(self, defaults, model_hparams): + p = defaults + p.pixel_scale = 0.1 + p.img_len = 64 + p.example_per_shard = 1000 + p.modality = {"inputs": modalities.ModalityType.IDENTITY, + "targets": modalities.ModalityType.IDENTITY} + p.vocab_size = {"inputs": None, + "targets": None} + +@registry.register_problem +class Attrs2imgCosmos64(Img2imgCosmos64): + """ Smaller version of the Img2imgCosmos problem, at half the pixel + resolution + """ + + def eval_metrics(self): + eval_metrics = [ ] + return eval_metrics + + def hparams(self, defaults, model_hparams): + p = defaults + p.pixel_scale = 0.1 + p.img_len = 64 + p.example_per_shard = 1000 + p.modality = {"inputs": modalities.ModalityType.IDENTITY, + "attributes": modalities.ModalityType.IDENTITY, + "targets": modalities.ModalityType.IDENTITY} + p.vocab_size = {"inputs": None, + "attributes": None, + "targets": None} + p.attributes = ['mag_auto', 'flux_radius', 'zphot'] + @registry.register_problem class Attrs2imgCosmos128(Img2imgCosmos128): """ Smaller version of the Img2imgCosmos problem, at half the pixel @@ -281,7 +324,7 @@ def eval_metrics(self): def hparams(self, defaults, model_hparams): p = defaults - p.pixel_scale = 0.03 + p.pixel_scale = 0.1 p.img_len = 128 p.example_per_shard = 1000 p.modality = {"inputs": modalities.ModalityType.IDENTITY, @@ -291,6 +334,52 @@ def hparams(self, defaults, model_hparams): "attributes": None, "targets": None} p.attributes = ['mag_auto', 'flux_radius', 'sersic_n', 'sersic_q'] + +@registry.register_problem +class Attrs2imgCosmos64Euclid(Img2imgCosmos128): + """ + """ + + def eval_metrics(self): + eval_metrics = [ ] + return eval_metrics + + def hparams(self, defaults, model_hparams): + p = defaults + p.pixel_scale = 0.1 + p.img_len = 64 + p.example_per_shard = 1000 + p.modality = {"inputs": modalities.ModalityType.IDENTITY, + "attributes": modalities.ModalityType.IDENTITY, + "targets": modalities.ModalityType.IDENTITY} + p.vocab_size = {"inputs": None, + "attributes": None, + "targets": None} +# p.attributes = ['mag_auto', 'flux_radius', 'sersic_n', 'sersic_q'] + p.attributes = ['flux_radius', 'sersic_n', 'sersic_q'] + +@registry.register_problem +class Attrs2imgCosmos64EuclidWithMorpho(Img2imgCosmos128): + """ + """ + + def eval_metrics(self): + eval_metrics = [ ] + return eval_metrics + + def hparams(self, defaults, model_hparams): + p = defaults + p.pixel_scale = 0.1 + p.img_len = 64 + p.example_per_shard = 1000 + p.modality = {"inputs": modalities.ModalityType.IDENTITY, + "attributes": modalities.ModalityType.IDENTITY, + "targets": modalities.ModalityType.IDENTITY} + p.vocab_size = {"inputs": None, + "attributes": None, + "targets": None} + p.attributes = ['mag_auto', 'flux_radius', 'sersic_n', 'sersic_q', 'Mph3'] + @registry.register_problem class Attrs2imgCosmos32(Attrs2imgCosmos): @@ -310,3 +399,194 @@ def hparams(self, defaults, model_hparams): "targets": None} p.attributes = ['mag_auto', 'flux_radius', 'zphot', 'bulge_q', 'bulge_beta' , 'disk_q', 'disk_beta', 'bulge_hlr', 'disk_hlr'] + + + +@registry.register_problem +class Img2imgCosmosMultiband(galsim_utils.GalsimProblem): + """ + Img2img problem on GalSim's COSMOS 25.2 sample, at native pixel resolution, + on 64px postage stamps. + """ + + @property + def dataset_splits(self): + """Splits of data to produce and number of output shards for each. + Note that each shard will be produced in parallel. + We are going to split the GalSim data into shards of 1000 galaxies each, + with 80 shards for training, 2 shards for validation. + """ + return [{ + "split": problem.DatasetSplit.TRAIN, + "shards": 80, + }, { + "split": problem.DatasetSplit.EVAL, + "shards": 2, + }] + + def hparams(self, defaults, model_hparams): + p = defaults + p.pixel_scale = 0.03 + p.img_len = 64 + p.flux_ratio = [1.0] + + p.modality = {"inputs": modalities.ModalityType.IDENTITY, + "targets": modalities.ModalityType.IDENTITY} + p.vocab_size = {"inputs": None, + "targets": None} + p.add_hparam("psf", None) + p.add_hparam("rotation", False) + + @property + def num_bands(self): + """Number of bands.""" + return 1 + + def generator(self, data_dir, tmp_dir, dataset_split, task_id=-1): + """ + Generates and yields poattrs2img_cosmos64_euclid_pixscale_1_wmorpho_2stage stamps obtained with GalSim. + """ + p = self.get_hparams() + + try: + # try to use default galsim path to the data + catalog = galsim.COSMOSCatalog() + except: + # If that fails, tries to use the specified tmp_dir + catalog = galsim.COSMOSCatalog(dir=tmp_dir+'/COSMOS_25.2_training_sample') + + # Create a list of galaxy indices for this task, remember, there is a task + # per shard, each shard is 1000 galaxies. + assert(task_id > -1) + index = range(task_id*p.example_per_shard, + min((task_id+1)*p.example_per_shard, catalog.getNObjects())) + + # Extracts additional information about the galaxies + cat_param = catalog.param_cat[catalog.orig_index] + from numpy.lib.recfunctions import append_fields + import numpy as np + + bparams = cat_param['bulgefit'] + sparams = cat_param['sersicfit'] + # Parameters for a 2 component fit + cat_param = append_fields(cat_param, 'bulge_q', bparams[:,11]) + cat_param = append_fields(cat_param, 'bulge_beta', bparams[:,15]) + cat_param = append_fields(cat_param, 'disk_q', bparams[:,3]) + cat_param = append_fields(cat_param, 'disk_beta', bparams[:,7]) + cat_param = append_fields(cat_param, 'bulge_hlr', cat_param['hlr'][:,1]) + cat_param = append_fields(cat_param, 'bulge_flux_log10', np.where(cat_param['use_bulgefit'] ==1, np.log10(cat_param['flux'][:,1]), np.zeros(len(cat_param) ))) + cat_param = append_fields(cat_param, 'disk_hlr', cat_param['hlr'][:,2]) + cat_param = append_fields(cat_param, 'disk_flux_log10', np.where(cat_param['use_bulgefit'] ==1, np.log10(cat_param['flux'][:,2]), np.log10(cat_param['flux'][:,0]))) + + # Parameters for a single component fit + cat_param = append_fields(cat_param, 'sersic_flux_log10', np.log10(sparams[:,0])) + cat_param = append_fields(cat_param, 'sersic_q', sparams[:,3]) + cat_param = append_fields(cat_param, 'sersic_hlr', sparams[:,1]) + cat_param = append_fields(cat_param, 'sersic_n', sparams[:,2]) + cat_param = append_fields(cat_param, 'sersic_beta', sparams[:,7]) + + passed = 0 + + late = 0 + irr = 0 + for ind in index: + # Draw a galaxy using GalSim, any kind of operation can be done here +# if cat_param['Mph3'][ind] not in [1,2,3] : +# passed += 1 +# continue + +# if cat_param['Mph3'][ind] == 2: +# if late >= 85 : +# if ind % 200 == 0: +# print("done all the late of the id") +# continue +# late += 1 + +# if cat_param['Mph3'][ind] == 3: +# if irr >= 85 : +# if ind % 200 == 0: +# print("done all the irr of the id") +# continue +# irr += 1 + + gal = catalog.makeGalaxy(ind, noise_pad_size=p.img_len * p.pixel_scale) + + # We apply the orginal psf if a different PSF is not requested + if ~hasattr(p, "psf") or p.psf is None: + psf = gal.original_psf + else: + psf = p.psf + + # Apply rotation so that the galaxy is at 0 PA +# if hasattr(p, "rotation") and p.rotation: + # rotation_angle = galsim.Angle(-cat_param[ind]['sersic_beta'], + # galsim.radians) + # gal = gal.rotate(rotation_angle) + # psf = psf.rotate(rotation_angle) + + # We save the corresponding attributes for this galaxy + if hasattr(p, 'attributes'): + params = cat_param[ind] + attributes = {k: params[k] for k in p.attributes} + else: + attributes = None + + flux_r = [1.0] + for i in range(1,self.num_bands): + flux_r.append(max(np.random.normal(p.flux_ratio_mean[i],p.flux_ratio_std[i]),0)) + # Utility function encodes the postage stamp for serialized features + yield galsim_utils.draw_and_encode_stamp(gal, psf, + stamp_size=p.img_len, + pixel_scale=p.pixel_scale, + attributes=attributes, + flux_r=flux_r, + num_bands=self.num_bands) + + def preprocess_example(self, example, unused_mode, unused_hparams): + """ Preprocess the examples, can be used for further augmentation or + image standardization. + """ + p = self.get_hparams() + image = example["inputs"] + + image = galsim_utils.tf_rotate(image) + # Clip to 1 the values of the image + # image = tf.clip_by_value(image, -1, 1) + + # Aggregate the conditions + if hasattr(p, 'attributes'): + example['attributes'] = tf.stack([example[k] for k in p.attributes]) + + example["inputs"] = image + example["targets"] = image + return example + + +@registry.register_problem +class Attrs2imgCosmosMultiband64(Img2imgCosmosMultiband): + """ Smaller version of the Img2imgCosmos problem, at half the pixel + resolution + """ + @property + def num_bands(self): + """Number of bands.""" + return 2 + + def eval_metrics(self): + eval_metrics = [ ] + return eval_metrics + + def hparams(self, defaults, model_hparams): + p = defaults + p.pixel_scale = 0.1 + p.img_len = 64 + p.example_per_shard = 1000 + p.flux_ratio_mean = [1,0.299] + p.flux_ratio_std = [0,1.038] + p.modality = {"inputs": modalities.ModalityType.IDENTITY, + "attributes": modalities.ModalityType.IDENTITY, + "targets": modalities.ModalityType.IDENTITY} + p.vocab_size = {"inputs": None, + "attributes": None, + "targets": None} + p.attributes = ['mag_auto', 'flux_radius', 'zphot'] diff --git a/galaxy2galaxy/data_generators/galsim_utils.py b/galaxy2galaxy/data_generators/galsim_utils.py index 221d587..6643aae 100644 --- a/galaxy2galaxy/data_generators/galsim_utils.py +++ b/galaxy2galaxy/data_generators/galsim_utils.py @@ -92,14 +92,14 @@ def example_reading_spec(self): "inputs": tf.contrib.slim.tfexample_decoder.Image( image_key="image/encoded", format_key="image/format", - channels=self.num_bands, + channels=None, shape=[p.img_len, p.img_len, self.num_bands], dtype=tf.float32), "psf": tf.contrib.slim.tfexample_decoder.Image( image_key="psf/encoded", format_key="psf/format", - channels=self.num_bands, + channels=None, # The factor 2 here is to account for x2 interpolation shape=[2*p.img_len, 2*p.img_len // 2 + 1, self.num_bands], dtype=tf.float32), @@ -107,8 +107,8 @@ def example_reading_spec(self): "ps": tf.contrib.slim.tfexample_decoder.Image( image_key="ps/encoded", format_key="ps/format", - channels=self.num_bands, - shape=[p.img_len, p.img_len // 2 + 1], + channels=None, + shape=[p.img_len, p.img_len // 2 + 1,self.num_bands], dtype=tf.float32), } @@ -143,7 +143,7 @@ def _float_feature(value): def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) -def draw_and_encode_stamp(gal, psf, stamp_size, pixel_scale, attributes=None): +def draw_and_encode_stamp(gal, psf, stamp_size, pixel_scale, num_bands = 1, flux_r = [1.0], attributes=None): """ Draws the galaxy, psf and noise power spectrum on a postage stamp and encodes it to be exported in a TFRecord. @@ -152,57 +152,73 @@ def draw_and_encode_stamp(gal, psf, stamp_size, pixel_scale, attributes=None): # Apply the PSF gal = galsim.Convolve(gal, psf) - # Draw a kimage of the galaxy, just to figure out what maxk is, there might - # be more efficient ways to do this though... - bounds = _BoundsI(0, stamp_size//2, -stamp_size//2, stamp_size//2-1) - imG = gal.drawKImage(bounds=bounds, - scale=2.*np.pi/(stamp_size * pixel_scale), - recenter=False) - mask = ~(np.fft.fftshift(imG.array, axes=0) == 0) - - # We draw the pixel image of the convolved image - im = gal.drawImage(nx=stamp_size, ny=stamp_size, scale=pixel_scale, - method='no_pixel', use_true_center=False).array.astype('float32') - - # Draw the Fourier domain image of the galaxy, using x1 zero padding, - # and x2 subsampling - interp_factor=2 - padding_factor=1 - Nk = stamp_size*interp_factor*padding_factor - bounds = _BoundsI(0, Nk//2, -Nk//2, Nk//2-1) - imCp = psf.drawKImage(bounds=bounds, - scale=2.*np.pi/(Nk * pixel_scale / interp_factor), - recenter=False) - - # Transform the psf array into proper format, remove the phase - im_psf = np.abs(np.fft.fftshift(imCp.array, axes=0)).astype('float32') - - # Compute noise power spectrum, at the resolution and stamp size of target - # image - ps = gal.noise._get_update_rootps((stamp_size, stamp_size), - wcs=galsim.PixelScale(pixel_scale)) - - # The following comes from correlatednoise.py - rt2 = np.sqrt(2.) - shape = (stamp_size, stamp_size) - ps[0, 0] = rt2 * ps[0, 0] - # Then make the changes necessary for even sized arrays - if shape[1] % 2 == 0: # x dimension even - ps[0, shape[1] // 2] = rt2 * ps[0, shape[1] // 2] - if shape[0] % 2 == 0: # y dimension even - ps[shape[0] // 2, 0] = rt2 * ps[shape[0] // 2, 0] - # Both dimensions even - if shape[1] % 2 == 0: - ps[shape[0] // 2, shape[1] // 2] = rt2 * \ - ps[shape[0] // 2, shape[1] // 2] - - # Apply mask to power spectrum so that it is very large outside maxk - ps = np.where(mask, np.log(ps**2), 10).astype('float32') - serialized_output = {"image/encoded": [im.tostring()], + im_multi = np.zeros((stamp_size,stamp_size,num_bands)) + psf_multi = np.zeros((2*stamp_size,2*stamp_size//2+1,num_bands)) + ps_multi = np.zeros((stamp_size,stamp_size//2+1,num_bands)) + # Draw the Fourier domain image of the galaxy + for i in range(num_bands): + # Draw a kimage of the galaxy, just to figure out what maxk is, there might + # be more efficient ways to do this though... + bounds = _BoundsI(0, stamp_size//2, -stamp_size//2, stamp_size//2-1) + imG = gal.drawKImage(bounds=bounds, + scale=2.*np.pi/(stamp_size * pixel_scale), + recenter=False) + mask = ~(np.fft.fftshift(imG.array, axes=0) == 0) + + # We draw the pixel image of the convolved image + im = gal.drawImage(nx=stamp_size, ny=stamp_size, scale=pixel_scale, + method='no_pixel', use_true_center=False).array.astype('float32') + + im = im/np.max(im) * flux_r[i] + + # Draw the Fourier domain image of the galaxy, using x1 zero padding, + # and x2 subsampling + interp_factor=2 + padding_factor=1 + Nk = stamp_size*interp_factor*padding_factor + bounds = _BoundsI(0, Nk//2, -Nk//2, Nk//2-1) + imCp = psf.drawKImage(bounds=bounds, + scale=2.*np.pi/(Nk * pixel_scale / interp_factor), + recenter=False) + + # Transform the psf array into proper format, remove the phase + im_psf = np.abs(np.fft.fftshift(imCp.array, axes=0)).astype('float32') + + im_multi[:,:,i] = im + psf_multi[:,:,i] = im_psf + # Compute noise power spectrum, at the resolution and stamp size of target + # image + ps = gal.noise._get_update_rootps((stamp_size, stamp_size), + wcs=galsim.PixelScale(pixel_scale)) + + # The following comes from correlatednoise.py + rt2 = np.sqrt(2.) + shape = (stamp_size, stamp_size) + ps[0, 0] = rt2 * ps[0, 0] + # Then make the changes necessary for even sized arrays + if shape[1] % 2 == 0: # x dimension even + ps[0, shape[1] // 2] = rt2 * ps[0, shape[1] // 2] + if shape[0] % 2 == 0: # y dimension even + ps[shape[0] // 2, 0] = rt2 * ps[shape[0] // 2, 0] + # Both dimensions even + if shape[1] % 2 == 0: + ps[shape[0] // 2, shape[1] // 2] = rt2 * \ + ps[shape[0] // 2, shape[1] // 2] + + # Apply mask to power spectrum so that it is very large outside maxk + ps = np.where(mask, np.log(ps**2), 10).astype('float32') + ps_multi[:,:,i] = ps + + if num_bands == 1: + im_multi = np.squeeze(im_multi) + psf_multi = np.squeeze(psf_multi) + ps_multi = np.squeeze(ps_multi) + + serialized_output = {"image/encoded": [im_multi.astype('float32').tostring()], "image/format": ["raw"], - "psf/encoded": [im_psf.tostring()], + "psf/encoded": [psf_multi.astype('float32').tostring()], "psf/format": ["raw"], - "ps/encoded": [ps.tostring()], + "ps/encoded": [ps_multi.astype('float32').tostring()], "ps/format": ["raw"]} # Adding the parameters provided @@ -271,3 +287,21 @@ def maybe_download_cosmos(target_dir, sample="25.2"): if do_remove: logger.info("Removing the tarball to save space") os.remove(target) + +def tf_rotate(input_image, min_angle = -np.pi/2, max_angle = np.pi/2): + ''' + Tensorflow rotates the image randomly + : param input_image: image input + : param min_angle: minimum rotation angle + : param max? Angle: maximum rotation angle + : Return: rotated image + ''' + distorted_image = tf.expand_dims(input_image, 0) + random_angles = tf.random.uniform(shape=(tf.shape(distorted_image)[0],), minval = min_angle , maxval = max_angle) + distorted_image = tf.contrib.image.transform( + distorted_image, + tf.contrib.image.angles_to_projective_transforms( + random_angles, tf.cast(tf.shape(distorted_image)[1], tf.float32), tf.cast(tf.shape(distorted_image)[2], tf.float32) + )) + rotate_image = tf.squeeze(distorted_image, [0]) + return rotate_image \ No newline at end of file diff --git a/galaxy2galaxy/models/autoencoders.py b/galaxy2galaxy/models/autoencoders.py index 0540341..333ba94 100644 --- a/galaxy2galaxy/models/autoencoders.py +++ b/galaxy2galaxy/models/autoencoders.py @@ -263,7 +263,7 @@ def continuous_autoencoder_residual_128(): hparams.batch_size = 32 hparams.bottleneck_bits = 64 - hparams.bottleneck_warmup_steps = 5000 + hparams.bottleneck_warmup_steps = 2000 hparams.add_hparam("autoregressive_decode_steps", 0) hparams.add_hparam("num_residual_layers", 2) diff --git a/galaxy2galaxy/models/autoencoders_utils.py b/galaxy2galaxy/models/autoencoders_utils.py index e7b35f2..77e3457 100644 --- a/galaxy2galaxy/models/autoencoders_utils.py +++ b/galaxy2galaxy/models/autoencoders_utils.py @@ -38,26 +38,32 @@ def loglikelihood_fn(xin, yin, features, hparams): size = xin.get_shape().as_list()[1] if hparams.likelihood_type == 'Fourier': # Compute FFT normalization factor - x = tf.spectral.rfft2d(xin[...,0]) / tf.complex(tf.sqrt(tf.exp(features['ps'])),0.) / size**2 * (2*np.pi)**2 - y = tf.spectral.rfft2d(yin[...,0]) / tf.complex(tf.sqrt(tf.exp(features['ps'])),0.) / size**2 * (2*np.pi)**2 - - pz = 0.5 * tf.reduce_sum(tf.abs(x - y)**2, axis=[-1, -2]) #/ size**2 + x = tf.transpose(xin,[0,3,1,2]) + y = tf.transpose(yin,[0,3,1,2]) + ps = tf.reshape(features['ps'],tf.shape(tf.transpose(tf.spectral.rfft2d(x),[0,2,3,1]))) + x = tf.transpose(tf.spectral.rfft2d(x),[0,2,3,1]) / tf.complex(tf.sqrt(tf.exp(ps)),0.) / size**2 * (2*np.pi)**2 + y = tf.transpose(tf.spectral.rfft2d(y),[0,2,3,1]) / tf.complex(tf.sqrt(tf.exp(ps)),0.) / size**2 * (2*np.pi)**2 + + pz = 0.5 * tf.reduce_sum(tf.abs(x - y)**2, axis=[-1, -2, -3]) #/ size**2 + tf.print(-pz) return -pz elif hparams.likelihood_type == 'Pixel': # TODO: include per example noise std - pz = 0.5 * tf.reduce_sum(tf.abs(xin[:,:,:,0] - yin[...,0])**2, axis=[-1, -2]) / hparams.noise_rms**2 #/ size**2 + pz = 0.5 * tf.reduce_sum(tf.abs(xin - yin)**2, axis=[-1, -2, -3]) / hparams.noise_rms**2 #/ size**2 return -pz else: raise NotImplementedError def image_summary(name, image_logits, max_outputs=1, rows=4, cols=4): """Helper for image summaries that are safe on TPU.""" - if len(image_logits.get_shape()) != 4: + shape = image_logits.get_shape() + if len(shape) != 4: tf.logging.info("Not generating image summary, maybe not an image.") return - return tf.summary.image(name, pack_images(image_logits, rows, cols), + for i in range(shape[3]): + tf.summary.image(name+str(i), pack_images(tf.expand_dims(image_logits[...,i],-1), rows, cols), max_outputs=max_outputs) - + return 0 def autoencoder_body(self, features): """ Customized body function for autoencoders acting on continuous images. This is based on tensor2tensor.models.research.AutoencoderBasic.body @@ -125,11 +131,14 @@ def make_model_spec(): input_layer = tf.placeholder(tf.float32, shape=b_shape) x = self.unbottleneck(input_layer, res_size) x = self.decoder(x, None) - reconstr = tf.layers.dense(x, self.num_channels, name="autoencoder_final", + reconstr = tf.layers.dense(x, input_shape[-1], name="autoencoder_final", activation=output_activation) hub.add_signature(inputs=input_layer, outputs=reconstr) hub.attach_message("stamp_size", tf.train.Int64List(value=[hparams.problem_hparams.img_len])) - hub.attach_message("pixel_size", tf.train.FloatList(value=[hparams.problem_hparams.pixel_scale])) + try: + hub.attach_message("pixel_size", tf.train.FloatList(value=[hparams.problem_hparams.pixel_scale[res] for res in hparams.problem_hparams.resolutions])) + except AttributeError: + hub.attach_message("pixel_size", tf.train.FloatList(value=[hparams.problem_hparams.pixel_scale])) spec = hub.create_module_spec(make_model_spec, drop_collections=['checkpoints']) decoder = hub.Module(spec, name="decoder_module") hub.register_module_for_export(decoder, "decoder") @@ -222,7 +231,7 @@ def make_model_spec(): res = x[:, :shape[1], :shape[2], :] with tf.variable_scope('decoder_module'): - reconstr = tf.layers.dense(res, self.num_channels, name="autoencoder_final", + reconstr = tf.layers.dense(res, shape[-1], name="autoencoder_final", activation=output_activation) # We apply an optional apodization of the output before taking the @@ -250,14 +259,15 @@ def make_model_spec(): #tv = tf.reduce_sum(tf.sqrt(im_dx**2 + im_dy**2 + 1e-6), axis=[1,2,3]) #tv = tf.reduce_mean(tv) + image_summary("without_psf",tf.reshape(reconstr, labels_shape)) # Apply channel-wise convolution with the PSF if requested - # TODO: Handle multiple bands if hparams.apply_psf and 'psf' in features: - if self.num_channels > 1: - raise NotImplementedError - - reconstr = convolve(reconstr, tf.cast(features['psf'][...,0], tf.complex64), - zero_padding_factor=1) + output_list = [] + for i in range(shape[3]): + output_list.append(tf.squeeze(convolve(tf.expand_dims(reconstr[...,i],-1), tf.cast(features['psf'][...,i], tf.complex64), + zero_padding_factor=1))) + reconstr = tf.stack(output_list,axis=-1) + reconstr = tf.reshape(reconstr,shape) # Losses. losses = { diff --git a/galaxy2galaxy/models/latent_flow.py b/galaxy2galaxy/models/latent_flow.py index 320c92a..4d27032 100644 --- a/galaxy2galaxy/models/latent_flow.py +++ b/galaxy2galaxy/models/latent_flow.py @@ -17,6 +17,7 @@ from galaxy2galaxy.layers.flows import masked_autoregressive_conditional_template, ConditionalNeuralSpline, conditional_neural_spline_template, autoregressive_conditional_neural_spline_template from galaxy2galaxy.layers.tfp_utils import RealNVP, MaskedAutoregressiveFlow +from galaxy2galaxy.layers.image_utils import pack_images import tensorflow as tf import tensorflow_hub as hub @@ -24,9 +25,19 @@ tfb = tfp.bijectors tfd = tfp.distributions +def image_summary(name, image_logits, max_outputs=1, rows=4, cols=4): + """Helper for image summaries that are safe on TPU.""" + shape = image_logits.get_shape() + if len(shape) != 4: + tf.logging.info("Not generating image summary, maybe not an image.") + return + for i in range(shape[3]): + tf.summary.image(name+str(i), pack_images(tf.expand_dims(image_logits[...,i],-1), rows, cols), + max_outputs=max_outputs) + return 0 + class LatentFlow(t2t_model.T2TModel): """ Base class for latent flows - This assumes that an already exported tensorflow hub autoencoder is provided in hparams. """ @@ -50,10 +61,15 @@ def infer(self, def body(self, features): hparams = self.hparams - hparamsp = hparams.problem.get_hparams() - + attributes = hparams.attributes + if len(attributes[0]) == 0: + hparamsp = hparams.problem.get_hparams() + attributes = hparamsp.attributes + x = features['inputs'] - cond = {k: features[k] for k in hparamsp.attributes} + cond = {k: features[k] for k in attributes} + + image_summary("input",x) # Load the encoder and decoder modules encoder = hub.Module(hparams.encoder_module, trainable=False) @@ -64,7 +80,7 @@ def body(self, features): code_shape = [-1, code_shape[1].value, code_shape[2].value, code_shape[3].value] def get_flow(inputs, is_training=True): - y = tf.concat([tf.expand_dims(inputs[k], axis=1) for k in hparamsp.attributes] ,axis=1) + y = tf.concat([tf.expand_dims(inputs[k], axis=1) for k in attributes] ,axis=1) y = tf.layers.batch_normalization(y, name="y_norm", training=is_training) flow = self.normalizing_flow(y, latent_size) return flow @@ -72,7 +88,7 @@ def get_flow(inputs, is_training=True): if hparams.mode == tf.estimator.ModeKeys.PREDICT: # Export the latent flow alone def flow_module_spec(): - inputs_params = {k: tf.placeholder(tf.float32, shape=[None]) for k in hparamsp.attributes} + inputs_params = {k: tf.placeholder(tf.float32, shape=[None]) for k in attributes} random_normal = tf.placeholder(tf.float32, shape=[None, latent_size]) flow = get_flow(inputs_params, is_training=False) samples = flow._bijector.forward(random_normal) @@ -82,7 +98,7 @@ def flow_module_spec(): flow_spec = hub.create_module_spec(flow_module_spec) flow = hub.Module(flow_spec, name='flow_module') hub.register_module_for_export(flow, "code_sampler") - cond['random_normal'] = tf.random_normal(shape=[tf.shape(cond[hparamsp.attributes[0]])[0] , latent_size]) + cond['random_normal'] = tf.random_normal(shape=[tf.shape(cond[attributes[0]])[0] , latent_size]) samples = flow(cond) return samples, {'loglikelihood': 0} @@ -226,6 +242,8 @@ def latent_flow(): # hparams related to the PSF hparams.add_hparam("encode_psf", True) # Should we use the PSF at the encoder + hparams.add_hparam("attributes",[""]) + return hparams @@ -254,6 +272,8 @@ def latent_flow_larger(): # hparams related to the PSF hparams.add_hparam("encode_psf", True) # Should we use the PSF at the encoder + hparams.add_hparam("attributes",[""]) + return hparams @registry.register_hparams @@ -283,4 +303,6 @@ def latent_flow_nsf(): # hparams related to the PSF hparams.add_hparam("encode_psf", True) # Should we use the PSF at the encoder + hparams.add_hparam("attributes",[""]) + return hparams