Skip to content

Commit

Permalink
Fix existing xgboost examples (#2830)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Aug 23, 2024
1 parent f251be4 commit da54ac1
Show file tree
Hide file tree
Showing 35 changed files with 173 additions and 512 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
{
"format_version": 2,
"server": {
"heart_beat_timeout": 600,
"task_request_interval": 0.05
},
"task_data_filters": [],
"task_result_filters": [],
"components": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,25 @@ def __init__(self, data_split_filename):
"""
self.data_split_filename = data_split_filename

def load_data(self, client_id: str):
def load_data(self):
with open(self.data_split_filename, "r") as file:
data_split = json.load(file)

data_path = data_split["data_path"]
data_index = data_split["data_index"]

# check if site_id and "valid" in the mapping dict
if client_id not in data_index.keys():
if self.client_id not in data_index.keys():
raise ValueError(
f"Data does not contain Client {client_id} split",
f"Data does not contain Client {self.client_id} split",
)

if "valid" not in data_index.keys():
raise ValueError(
"Data does not contain Validation split",
)

site_index = data_index[client_id]
site_index = data_index[self.client_id]
valid_index = data_index["valid"]

# training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,25 @@ def __init__(self, data_split_filename):
"""
self.data_split_filename = data_split_filename

def load_data(self, client_id: str):
def load_data(self):
with open(self.data_split_filename, "r") as file:
data_split = json.load(file)

data_path = data_split["data_path"]
data_index = data_split["data_index"]

# check if site_id and "valid" in the mapping dict
if client_id not in data_index.keys():
if self.client_id not in data_index.keys():
raise ValueError(
f"Data does not contain Client {client_id} split",
f"Data does not contain Client {self.client_id} split",
)

if "valid" not in data_index.keys():
raise ValueError(
"Data does not contain Validation split",
)

site_index = data_index[client_id]
site_index = data_index[self.client_id]
valid_index = data_index["valid"]

# training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,25 @@ def __init__(self, data_split_filename):
"""
self.data_split_filename = data_split_filename

def load_data(self, client_id: str):
def load_data(self):
with open(self.data_split_filename, "r") as file:
data_split = json.load(file)

data_path = data_split["data_path"]
data_index = data_split["data_index"]

# check if site_id and "valid" in the mapping dict
if client_id not in data_index.keys():
if self.client_id not in data_index.keys():
raise ValueError(
f"Data does not contain Client {client_id} split",
f"Data does not contain Client {self.client_id} split",
)

if "valid" not in data_index.keys():
raise ValueError(
"Data does not contain Validation split",
)

site_index = data_index[client_id]
site_index = data_index[self.client_id]
valid_index = data_index["valid"]

# training
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,25 @@ def __init__(self, data_split_filename):
"""
self.data_split_filename = data_split_filename

def load_data(self, client_id: str):
def load_data(self):
with open(self.data_split_filename, "r") as file:
data_split = json.load(file)

data_path = data_split["data_path"]
data_index = data_split["data_index"]

# check if site_id and "valid" in the mapping dict
if client_id not in data_index.keys():
if self.client_id not in data_index.keys():
raise ValueError(
f"Data does not contain Client {client_id} split",
f"Data does not contain Client {self.client_id} split",
)

if "valid" not in data_index.keys():
raise ValueError(
"Data does not contain Validation split",
)

site_index = data_index[client_id]
site_index = data_index[self.client_id]
valid_index = data_index["valid"]

# training
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
{
"format_version": 2,
"executors": [
{
"tasks": [
"config",
"start"
],
"executor": {
"id": "Executor",
"path": "nvflare.app_opt.xgboost.histogram_based_v2.fed_executor.FedXGBHistogramExecutor",
"args": {
"data_loader_id": "dataloader",
"model_file_name": "test.model.json",
"metrics_writer_id": "metrics_writer"
}
}
}
],
"task_result_filters": [],
"task_data_filters": [],
"components": [
{
"id": "dataloader",
"path": "data_loader.DataLoader",
"args": {
"data_split_filename": "/tmp/dataset/horizontal_xgb_data/data_{SITE_NAME}.json"
}
},
{
"id": "metrics_writer",
"path": "nvflare.app_opt.tracking.tb.tb_writer.TBWriter",
"args": {
"event_type": "analytix_log_stats"
}
},
{
"id": "event_to_fed",
"path": "nvflare.app_common.widgets.convert_to_fed_event.ConvertToFedEvent",
"args": {
"events_to_convert": [
"analytix_log_stats"
],
"fed_event_prefix": "fed."
}
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"format_version": 2,
"num_rounds": 100,
"task_data_filters": [],
"task_result_filters": [],
"components": [
{
"id": "tb_receiver",
"path": "nvflare.app_opt.tracking.tb.tb_receiver.TBAnalyticsReceiver",
"args": {
"tb_folder": "tb_events"
}
}
],
"workflows": [
{
"id": "xgb_controller",
"path": "nvflare.app_opt.xgboost.histogram_based_v2.fed_controller.XGBFedController",
"args": {
"num_rounds": "{num_rounds}",
"data_split_mode": 0,
"secure_training": false,
"xgb_params": {
"max_depth": 8,
"eta": 0.1,
"objective": "binary:logistic",
"eval_metric": "auc",
"tree_method": "hist",
"nthread": 16
},
"xgb_options": {
"early_stopping_rounds": 2
}
}
}
]
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,25 @@ def __init__(self, data_split_filename):
"""
self.data_split_filename = data_split_filename

def load_data(self, client_id: str):
def load_data(self):
with open(self.data_split_filename, "r") as file:
data_split = json.load(file)

data_path = data_split["data_path"]
data_index = data_split["data_index"]

# check if site_id and "valid" in the mapping dict
if client_id not in data_index.keys():
if self.client_id not in data_index.keys():
raise ValueError(
f"Data does not contain Client {client_id} split",
f"Data does not contain Client {self.client_id} split",
)

if "valid" not in data_index.keys():
raise ValueError(
"Data does not contain Validation split",
)

site_index = data_index[client_id]
site_index = data_index[self.client_id]
valid_index = data_index["valid"]

# training
Expand Down

This file was deleted.

This file was deleted.

This file was deleted.

Loading

0 comments on commit da54ac1

Please sign in to comment.