-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add Custom Dataset and Implement * Clean up Branch * Clean up Branch * Resolve Some of Logan's Changes * Resolve Testing Issues? * Resolve Testing Issues? * Resolve Testing Issues? * Resolve Testing Issues? * Resolve Testing Issues? * Reflect Logan's Requests * Fix Import Issues * Simplify Imports * Fix Imports * Apply Logan's Changes * Comments * Refactor Common Logic Into New Function * Add Documentation * Add Documentation * Add Custom Dataset and Implement * Clean up Branch * Replace keep_hdf5 with as_hdf5 * Resolve Some of Logan's Changes * Resolve Testing Issues? * Resolve Testing Issues? * Resolve Testing Issues? * Resolve Testing Issues? * Resolve Testing Issues? * Reflect Logan's Requests * Fix Import Issues * Simplify Imports * Fix Imports * Apply Logan's Changes * Comments * Refactor Common Logic Into New Function * Add Documentation * Add Documentation * fix reference to _get_inputs_to_targets(); also, whitespace * remove unused * import * fix test_foundry.py to have the proper tests from the dev branch * remove outdated test_to_pytorch() test * fix passing of self for _get_inputs_targets() Co-authored-by: Aristana Scourtas <[email protected]>
- Loading branch information
1 parent
7055fbc
commit de2f7e0
Showing
7 changed files
with
124 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ | |
*.DS_STORE | ||
*.pyc | ||
*.idea | ||
*/foundry_ml.egg-info/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import numpy as np | ||
from torch.utils.data import Dataset | ||
|
||
class TorchDataset(Dataset): | ||
"""Foundry Dataset Converted to Pytorch Format""" | ||
|
||
def __init__(self, inputs, targets): | ||
self.inputs=inputs | ||
self.targets=targets | ||
|
||
def __len__(self): | ||
return len(self.inputs[0]) | ||
|
||
def __getitem__(self, idx): | ||
item = {"input": [], "target": []} | ||
|
||
# adds the correct item at index idx from each input from self.inputs to the item dictionary | ||
for input in self.inputs: | ||
item["input"].append(np.array(input[idx])) | ||
item["input"] = np.array(item["input"]) | ||
|
||
# adds the correct item at index idx from each target from self.targets to the item dictionary | ||
for target in self.targets: | ||
item["target"].append(np.array(target[idx])) | ||
item["target"] = np.array(item["target"]) | ||
|
||
return item | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,3 +14,4 @@ mdf-connect-client>=0.4.0 | |
json2table>=1.1.5 | ||
joblib>=1.1.0 | ||
torch>=1.8.0 | ||
tensorflow>=2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters