Skip to content

Commit

Permalink
To tf dataset (#201) with rebase
Browse files Browse the repository at this point in the history
* Add Custom Dataset and Implement

* Clean up Branch

* Clean up Branch

* Resolve Some of Logan's Changes

* Resolve Testing Issues?

* Resolve Testing Issues?

* Resolve Testing Issues?

* Resolve Testing Issues?

* Resolve Testing Issues?

* Reflect Logan's Requests

* Fix Import Issues

* Simplify Imports

* Fix Imports

* Apply Logan's Changes

* Comments

* Refactor Common Logic Into New Function

* Add Documentation

* Add Documentation

* Add Custom Dataset and Implement

* Clean up Branch

* Replace keep_hdf5 with as_hdf5

* Resolve Some of Logan's Changes

* Resolve Testing Issues?

* Resolve Testing Issues?

* Resolve Testing Issues?

* Resolve Testing Issues?

* Resolve Testing Issues?

* Reflect Logan's Requests

* Fix Import Issues

* Simplify Imports

* Fix Imports

* Apply Logan's Changes

* Comments

* Refactor Common Logic Into New Function

* Add Documentation

* Add Documentation

* fix reference to _get_inputs_to_targets(); also, whitespace

* remove unused * import

* fix test_foundry.py to have the proper tests from the dev branch

* remove outdated test_to_pytorch() test

* fix passing of self for _get_inputs_targets()

Co-authored-by: Aristana Scourtas <[email protected]>
  • Loading branch information
Aadit-Ambadkar and ascourtas authored Aug 16, 2022
1 parent 7055fbc commit de2f7e0
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 46 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
*.DS_STORE
*.pyc
*.idea
*/foundry_ml.egg-info/*
67 changes: 47 additions & 20 deletions foundry/foundry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
FoundrySpecification,
FoundryDataset
)
from foundry.external_data_architectures import (
FoundryDataset_Torch
)

import logging
import warnings
Expand Down Expand Up @@ -737,7 +734,6 @@ def get_keys(self, type=None, as_object=False):
key_list = key_list + k
return key_list


def _load_data(self, file=None, source_id=None, globus=True, as_hdf5=False):
# Build the path to access the cached data
if source_id:
Expand All @@ -752,7 +748,6 @@ def _load_data(self, file=None, source_id=None, globus=True, as_hdf5=False):
if not file:
file = self.config.dataframe_file


# Check to make sure the path can be created
try:
path_to_file = os.path.join(path, file)
Expand Down Expand Up @@ -817,22 +812,16 @@ def _load_data(self, file=None, source_id=None, globus=True, as_hdf5=False):
else:
raise NotImplementedError


def toTorch(self, raw=None, split=None):
"""Convert Foundry Dataset to a PyTorch Dataset
def _get_inputs_targets(self, split: str = None):
"""Get Inputs and Outputs from a Foundry Dataset
Arguments:
raw (dict): The output of running ``f.load_data(as_hdf5=False)``
Recommended that this is left as ``None``
split (string): Split to get inputs and outputs from.
**Default:** ``None``
split (string): Split to create PyTorch Dataset on.
**Default:** ``None``
Returns: (FoundryDataset_Torch) PyTorch Dataset of all the data from the specified split
Returns: (Tuple) Tuple of the inputs and outputs
"""
if not raw:
raw = self.load_data(as_hdf5=False)
raw = self.load_data(as_hdf5=False)

if not split:
split = self.dataset.splits[0].type
Expand All @@ -841,16 +830,23 @@ def toTorch(self, raw=None, split=None):
inputs = []
targets = []
for key in self.dataset.keys:
# raw[split][key.type][key.key[0]] gets the data values for the given key.
#
# For example, if the key was coordinates and had type target, then
# raw[split][key.type][key.key[0]] would return all the coordinates for each item
# and raw[split][key.type][key.key[0]].keys() are the indexes of the item.
if len(raw[split][key.type][key.key[0]].keys()) != self.dataset.n_items:
continue

# Get a numpy array of all the values for each item for that key
val = np.array([raw[split][key.type][key.key[0]][k] for k in raw[split][key.type][key.key[0]].keys()])
if key.type == 'input':
inputs.append(val)
else:
targets.append(val)

return (inputs, targets)

return FoundryDataset_Torch(inputs, targets)
elif self.dataset.data_type.value == "tabular":
inputs = []
targets = []
Expand All @@ -859,11 +855,42 @@ def toTorch(self, raw=None, split=None):
df = raw[split][index]
for key in df.keys():
arr.append(df[key].values)

return FoundryDataset_Torch(inputs, targets)

return (inputs, targets)

else:
raise NotImplementedError

def to_torch(self, split: str = None):
"""Convert Foundry Dataset to a PyTorch Dataset
Arguments:
split (string): Split to create PyTorch Dataset on.
**Default:** ``None``
Returns: (TorchDataset) PyTorch Dataset of all the data from the specified split
"""
from foundry.loaders.torch_wrapper import TorchDataset

inputs, targets = self._get_inputs_targets(split)
return TorchDataset(inputs, targets)

def to_tensorflow(self, split: str = None):
"""Convert Foundry Dataset to a Tensorflow Sequence
Arguments:
split (string): Split to create Tensorflow Sequence on.
**Default:** ``None``
Returns: (TensorflowSequence) Tensorflow Sequence of all the data from the specified split
"""
from foundry.loaders.tf_wrapper import TensorflowSequence

inputs, targets = self._get_inputs_targets(split)
return TensorflowSequence(inputs, targets)


def is_pandas_pytable(group):
if 'axis0' in group.keys() and 'axis1' in group.keys():
Expand Down
Empty file added foundry/loaders/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import numpy as np
import torch
from torch.utils.data import Dataset
from tensorflow.keras.utils import Sequence

class FoundryDataset_Torch(Dataset):
"""Foundry Dataset Converted to Pytorch Format"""
class TensorflowSequence(Sequence):
"""Foundry Dataset Converted to Tensorflow Format"""

def __init__(self, inputs, targets):
self.inputs=inputs
Expand All @@ -24,4 +23,4 @@ def __getitem__(self, idx):
item["target"] = np.array(item["target"])

return item

28 changes: 28 additions & 0 deletions foundry/loaders/torch_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import numpy as np
from torch.utils.data import Dataset

class TorchDataset(Dataset):
"""Foundry Dataset Converted to Pytorch Format"""

def __init__(self, inputs, targets):
self.inputs=inputs
self.targets=targets

def __len__(self):
return len(self.inputs[0])

def __getitem__(self, idx):
item = {"input": [], "target": []}

# adds the correct item at index idx from each input from self.inputs to the item dictionary
for input in self.inputs:
item["input"].append(np.array(input[idx]))
item["input"] = np.array(item["input"])

# adds the correct item at index idx from each target from self.targets to the item dictionary
for target in self.targets:
item["target"].append(np.array(target[idx]))
item["target"] = np.array(item["target"])

return item

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ mdf-connect-client>=0.4.0
json2table>=1.1.5
joblib>=1.1.0
torch>=1.8.0
tensorflow>=2
64 changes: 43 additions & 21 deletions tests/test_foundry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os, shutil
import re
import types
import pytest
from datetime import datetime
import mdf_toolbox
Expand Down Expand Up @@ -44,6 +46,7 @@
test_dataset = "foundry_experimental_band_gaps_v1.1"
expected_title = "Graph Network Based Deep Learning of Band Gaps - Experimental Band Gaps"


# Kept the Old metadata format in case we ever want to refer back
old_test_metadata = {
"inputs": ["sepal length (cm)", "sepal width (cm)", "petal length (cm)", "petal width (cm)"],
Expand All @@ -58,7 +61,7 @@
"package_type": "tabular"
}

test_metadata = {
pub_test_metadata = {
"keys":[
{
"key": ["sepal length (cm)"],
Expand Down Expand Up @@ -116,11 +119,11 @@
'n_items': 1000
}

# Globus endpoint for '_iris_dev'
test_data_source = "https://app.globus.org/file-manager?origin_id=e38ee745-6d04-11e5-ba46-22000b92c6ec&origin_path=%2Ffoundry-test%2Firis-dev%2F"
# Globus endpoint for '_iris_dev' for test publication
pub_test_data_source = "https://app.globus.org/file-manager?origin_id=e38ee745-6d04-11e5-ba46-22000b92c6ec&origin_path=%2Ffoundry-test%2Firis-dev%2F"


#Quick function to delete any downloaded test data
# Quick function to delete any downloaded test data
def _delete_test_data(foundry_obj):
path = os.path.join(foundry_obj.config.local_cache_dir, test_dataset)
if os.path.isdir(path):
Expand Down Expand Up @@ -184,19 +187,6 @@ def test_dataframe_load():
_delete_test_data(f)


def test_to_pytorch():
f = Foundry(authorizers=auths, no_browser=True, no_local_server=True)
_delete_test_data(f)

f = f.load(test_dataset, download=True, globus=False, authorizers=auths)
raw = f.load_data()
ds = f.toTorch(raw=raw, split='train')

assert raw['train'][0].iloc[0][0] == ds[0]['input'][0]
assert len(raw['train'][0]) == len(ds)
_delete_test_data(f)


@pytest.mark.skipif(bool(is_gha), reason="Test does not succeed online") # PLEASE CONFIRM THIS BEHAVIOR IS INTENDED
def test_download_globus():
f = Foundry(authorizers=auths, no_browser=True, no_local_server=True)
Expand Down Expand Up @@ -234,7 +224,7 @@ def test_publish():
short_name = "example_AS_iris_test_{:.0f}".format(timestamp)
authors = ["A Scourtas"]

res = f.publish(test_metadata, test_data_source, title, authors, short_name=short_name)
res = f.publish(pub_test_metadata, pub_test_data_source, title, authors, short_name=short_name)

# publish with short name
assert res['success']
Expand All @@ -247,19 +237,51 @@ def test_publish():
# assert res['source_id'] == "_test_scourtas_example_iris_publish_{:.0f}_v1.1".format(timestamp)

# check that pushing same dataset without update flag fails
res = f.publish(test_metadata, test_data_source, title, authors, short_name=short_name)
res = f.publish(pub_test_metadata, pub_test_data_source, title, authors, short_name=short_name)
assert not res['success']

# check that using update flag allows us to update dataset
res = f.publish(test_metadata, test_data_source, title, authors, short_name=short_name, update=True)
res = f.publish(pub_test_metadata, pub_test_data_source, title, authors, short_name=short_name, update=True)
assert res['success']

# check that using update flag for new dataset fails
new_short_name = short_name + "_update"
res = f.publish(test_metadata, test_data_source, title, authors, short_name=new_short_name, update=True)
res = f.publish(pub_test_metadata, pub_test_data_source, title, authors, short_name=new_short_name, update=True)
assert not res['success']


def test_check_status():
# TODO: the 'active messages' in MDF CC's check_status() don't appear to do anything? need to determine how to test
pass


def test_to_pytorch():
f = Foundry(authorizers=auths, no_browser=True, no_local_server=True)

_delete_test_data(f)

f = f.load(test_dataset, download=True, globus=False, authorizers=auths)
raw = f.load_data()

ds = f.to_torch(split='train')

assert raw['train'][0].iloc[0][0] == ds[0]['input'][0]
assert len(raw['train'][0]) == len(ds)

_delete_test_data(f)


def test_to_tensorflow():
f = Foundry(authorizers=auths, no_browser=True, no_local_server=True)

_delete_test_data(f)

f = f.load(test_dataset, download=True, globus=False, authorizers=auths)
raw = f.load_data()

ds = f.to_tensorflow(split='train')

assert raw['train'][0].iloc[0][0] == ds[0]['input'][0]
assert len(raw['train'][0]) == len(ds)

_delete_test_data(f)

0 comments on commit de2f7e0

Please sign in to comment.