@@ -101,10 +101,7 @@ def load(self, name, version="1.1", provider="MDF", download=True, globus=True,
101
101
Args:
102
102
name (str): Name of the foundry dataset
103
103
download (bool): If True, download the data associated with the package (default is True)
104
- globus (bool): If True, download using Globus, otherwise https
105
- verbose (bool): If True print additional debug information
106
- metadata (dict): **For debug purposes.** A search result analog to prepopulate metadata.
107
-
104
+
108
105
Keyword Args:
109
106
interval (int): How often to poll Globus to check if transfers are complete
110
107
@@ -123,6 +120,7 @@ def load(self, name, version="1.1", provider="MDF", download=True, globus=True,
123
120
.match_field ("mdf.organizations" , "foundry" )
124
121
.search ()
125
122
)
123
+
126
124
# Handle MDF source_ids
127
125
else :
128
126
print ("Loading by source_id" )
@@ -152,6 +150,7 @@ def load(self, name, version="1.1", provider="MDF", download=True, globus=True,
152
150
self .download (
153
151
interval = kwargs .get ("interval" , 10 ), globus = globus , verbose = verbose
154
152
)
153
+
155
154
return self
156
155
157
156
def list (self ):
@@ -252,22 +251,44 @@ def load_data(self, source_id=None, globus=True):
252
251
253
252
Args:
254
253
inputs (list): List of strings for input columns
255
- targets (list): List of strings for output columns
254
+ outputs (list): List of strings for output columns
256
255
257
256
Returns
258
- -------s
257
+ -------
259
258
(tuple): Tuple of X, y values
260
259
"""
261
- data = {}
262
-
263
- # Handle splits if they exist. Return as a labeled dictionary of tuples
264
- if self .dataset .splits :
265
- for split in self .dataset .splits :
266
- data [split .label ] = self ._load_data (file = split .path ,
267
- source_id = source_id , globus = globus )
268
- return data
260
+
261
+ if source_id :
262
+ path = os .path .join (self .config .local_cache_dir , source_id )
263
+ print ("Here" )
269
264
else :
270
- return {"data" : self ._load_data (source_id = source_id , globus = globus )}
265
+ path = os .path .join (self .config .local_cache_dir , self .mdf ["source_id" ])
266
+ # Handle Foundry-defined types.
267
+ if self .dataset .type .value == "tabular" :
268
+ # If the file is not local, fetch the contents with Globus
269
+ # Check if the contents are local
270
+ # TODO: Add hashes and versioning to metadata and checking to the file
271
+ try :
272
+ self .dataset .dataframe = pd .read_json (
273
+ os .path .join (path , self .config .dataframe_file )
274
+ )
275
+ except :
276
+ # Try to read individual lines instead
277
+ self .dataset .dataframe = pd .read_json (
278
+ os .path .join (path , self .config .dataframe_file ), lines = True
279
+ )
280
+
281
+ return (
282
+ self .dataset .dataframe [self .dataset .inputs ],
283
+ self .dataset .dataframe [self .dataset .outputs ],
284
+ )
285
+ elif self .dataset .type .value == "hdf5" :
286
+ f = h5py .File (os .path .join (path , self .config .data_file ), "r" )
287
+ inputs = [f [i [0 :]] for i in self .dataset .inputs ]
288
+ outputs = [f [i [0 :]] for i in self .dataset .outputs ]
289
+ return (inputs , outputs )
290
+ else :
291
+ raise NotImplementedError
271
292
272
293
def describe (self ):
273
294
print ("DC:{}" .format (self .dc ))
@@ -607,15 +628,14 @@ def download(self, globus=True, verbose=False, **kwargs):
607
628
num_cores = multiprocessing .cpu_count ()
608
629
609
630
def download_file (file ):
610
- requests .packages .urllib3 .disable_warnings (
611
- InsecureRequestWarning )
631
+ requests .packages .urllib3 .disable_warnings (InsecureRequestWarning )
612
632
613
633
url = "https://data.materialsdatafacility.org" + file ["path" ]
614
634
destination = (
615
635
"data/"
616
636
+ source_id
617
637
+ "/"
618
- + file ["path" ][file ["path" ].rindex ("/" ) + 1 :]
638
+ + file ["path" ][file ["path" ].rindex ("/" ) + 1 :]
619
639
)
620
640
response = requests .get (url , verify = False )
621
641
@@ -633,7 +653,6 @@ def download_file(file):
633
653
634
654
return self
635
655
636
-
637
656
def build (self , spec , globus = False , interval = 3 , file = False ):
638
657
"""Build a Foundry Data Package
639
658
Args:
@@ -670,62 +689,3 @@ def start_download(ds, interval=interval, globus=False):
670
689
)
671
690
672
691
return self
673
-
674
- def get_keys (self , type , as_object = False ):
675
- """Get keys for a Foundry dataset
676
-
677
- Arguments:
678
- type (str): The type of key to be returned e.g., "input", "target"
679
- as_object (bool): When ``False``, will return a list of keys in as strings
680
- When ``True``, will return the full key objects
681
- **Default:** ``False``
682
- Returns: (list) String representations of keys or if ``as_object``
683
- is False otherwise returns the full key objects.
684
-
685
- """
686
- if as_object :
687
- return [key for key in self .dataset .keys if key .type == type ]
688
- else :
689
- return [key .key for key in self .dataset .keys if key .type == type ]
690
-
691
- def _load_data (self , file = None , source_id = None , globus = True ):
692
-
693
- # Build the path to access the cached data
694
- if source_id :
695
- path = os .path .join (self .config .local_cache_dir , source_id )
696
- else :
697
- path = os .path .join (self .config .local_cache_dir ,
698
- self .mdf ["source_id" ])
699
-
700
- # Handle Foundry-defined types.
701
- if self .dataset .type .value == "tabular" :
702
- # Determine which file to load, defaults to config.dataframe_file
703
- if not file :
704
- file = self .config .dataframe_file
705
-
706
- # If the file is not local, fetch the contents with Globus
707
- # Check if the contents are local
708
- # TODO: Add hashes and versioning to metadata and checking to the file
709
- try :
710
- self .dataset .dataframe = pd .read_json (
711
- os .path .join (path , file )
712
- )
713
- except :
714
- # Try to read individual lines instead
715
- self .dataset .dataframe = pd .read_json (
716
- os .path .join (path , file ), lines = True
717
- )
718
-
719
- return (
720
- self .dataset .dataframe [self .get_keys ("input" )],
721
- self .dataset .dataframe [self .get_keys ("target" )],
722
- )
723
- elif self .dataset .type .value == "hdf5" :
724
- if not file :
725
- file = self .config .data_file
726
- f = h5py .File (os .path .join (path , file ), "r" )
727
- inputs = [f [i [0 :]] for i in self .get_keys ("input" )]
728
- targets = [f [i [0 :]] for i in self .get_keys ("target" )]
729
- return (inputs , targets )
730
- else :
731
- raise NotImplementedError
0 commit comments