Skip to content

Commit

Permalink
Fix header parameter handling in sklearn's data_loader, update README…
Browse files Browse the repository at this point in the history
… with new AUC (#2363)

* Fix header parameter of data_loader, update README with new AUC

* Fix header parameter in load_data and load_data_for_range functions

* Update header handling in load_data function by removing redundancy

* Fix formatting issues

* Remove read_json

---------

Co-authored-by: Ziyue Xu <[email protected]>
Co-authored-by: Chester Chen <[email protected]>
Co-authored-by: Yuan-Ting Hsieh (謝沅廷) <[email protected]>
  • Loading branch information
4 people authored Mar 8, 2024
1 parent 6fbcef5 commit 2d75cd1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 12 deletions.
2 changes: 1 addition & 1 deletion examples/advanced/sklearn-svm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,4 @@ We can run the FL simulator with three clients under the uniform data split with
bash run_experiment_simulator.sh
```
Running with default [SVC](https://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html) classifier, the
resulting global model's AUC is 0.806 which can be seen in the clients' logs.
resulting global model's AUC is 0.8088 which can be seen in the clients' logs.
21 changes: 10 additions & 11 deletions nvflare/app_opt/sklearn/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
"csv": pd.read_csv,
"xls": pd.read_excel,
"xlsx": pd.read_excel,
"json": pd.read_json,
}


Expand All @@ -42,13 +41,10 @@ def get_pandas_reader(data_path: str):

def load_data(data_path: str, require_header: bool = False):
reader = get_pandas_reader(data_path)
if hasattr(reader, "header"):
if require_header:
data = reader(data_path)
else:
data = reader(data_path, header=None)
else:
if hasattr(reader, "header") and require_header:
data = reader(data_path)
else:
data = reader(data_path, header=None)

return _to_data_tuple(data)

Expand All @@ -58,11 +54,14 @@ def load_data_for_range(data_path: str, start: int, end: int, require_header: bo

if hasattr(reader, "skiprows"):
data_size = end - start
if hasattr(reader, "header") and not require_header:
data = reader(data_path, header=None, skiprows=start, nrows=data_size)
else:
if hasattr(reader, "header") and require_header:
data = reader(data_path, skiprows=start, nrows=data_size)
else:
data = reader(data_path, header=None, skiprows=start, nrows=data_size)
else:
data = reader(data_path).iloc[start:end]
if hasattr(reader, "header") and require_header:
data = reader(data_path).iloc[start:end]
else:
data = reader(data_path, header=None).iloc[start:end]

return _to_data_tuple(data)

0 comments on commit 2d75cd1

Please sign in to comment.