Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 38 additions & 42 deletions baler/modules/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,14 @@ def get_arguments():
" 2. If workspace exists but project does not, create project in workspace.\n"
" 3. If workspace does not exist, create workspace directory and project.",
)
parser.add_argument(
"--verbose", dest="verbose", action="store_true", help="Verbose mode"
)
parser.add_argument("--verbose", dest="verbose", action="store_true", help="Verbose mode")
parser.set_defaults(verbose=False)

args = parser.parse_args()

workspace_name = args.project[0]
project_name = args.project[1]
config_path = (
f"workspaces.{workspace_name}.{project_name}.config.{project_name}_config"
)
config_path = f"workspaces.{workspace_name}.{project_name}.config.{project_name}_config"

if args.mode == "newProject":
config = None
Expand Down Expand Up @@ -141,9 +137,7 @@ def create_new_project(
os.makedirs(directory, exist_ok=True)

# Populate default config
with open(
os.path.join(project_path, "config", f"{project_name}_config.py"), "w"
) as f:
with open(os.path.join(project_path, "config", f"{project_name}_config.py"), "w") as f:
f.write(create_default_config(workspace_name, project_name))


Expand Down Expand Up @@ -177,6 +171,8 @@ class Config:
emd: bool
l1: bool
deterministic_algorithm: bool
dtype: str
plot_negative: bool


def create_default_config(workspace_name: str, project_name: str) -> str:
Expand Down Expand Up @@ -268,9 +264,7 @@ def normalize(data, custom_norm):
Returns:
ndarray: Normalized data
"""
data = np.apply_along_axis(
data_processing.normalize, axis=0, arr=data, custom_norm=custom_norm
)
data = np.apply_along_axis(data_processing.normalize, axis=0, arr=data, custom_norm=custom_norm)
return data


Expand Down Expand Up @@ -312,9 +306,7 @@ def process(
train_set = data
test_set = train_set
else:
train_set, test_set = train_test_split(
data, test_size=test_size, random_state=1
)
train_set, test_set = train_test_split(data, test_size=test_size, random_state=1)

return (train_set, test_set, normalization_features, original_shape)

Expand Down Expand Up @@ -347,9 +339,7 @@ def train(model, number_of_columns, train_set, test_set, project_path, config):
Returns:
_type_: _description_
"""
return training.train(
model, number_of_columns, train_set, test_set, project_path, config
)
return training.train(model, number_of_columns, train_set, test_set, project_path, config)


def plotter(output_path, config):
Expand Down Expand Up @@ -405,9 +395,9 @@ def encoder_decoder_saver(model, encoder_path, decoder_path):
.pt file: `.pt` File containing the decoder state dictionary

"""
return data_processing.encoder_saver(
model, encoder_path
), data_processing.decoder_saver(model, decoder_path)
return data_processing.encoder_saver(model, encoder_path), data_processing.decoder_saver(
model, decoder_path
)


def detacher(tensor):
Expand Down Expand Up @@ -450,13 +440,9 @@ def save_error_bounded_requirement(config, decoded_output, data_batch):

# Ignoring RMS Undefind Values because Ground Truth is Zero
rms_pred_error[
(rms_pred_error == np.inf)
| (rms_pred_error == -np.inf)
| (rms_pred_error == np.nan)
(rms_pred_error == np.inf) | (rms_pred_error == -np.inf) | (rms_pred_error == np.nan)
] = 0.0
rms_pred_error_index = np.where(
abs(rms_pred_error) > config.error_bounded_requirement
)
rms_pred_error_index = np.where(abs(rms_pred_error) > config.error_bounded_requirement)
rows_idx, col_idx = rms_pred_error_index
if len(rows_idx) > 0 and len(col_idx) > 0:
rms_pred_error_exceeding_error_bound = np.subtract(
Expand Down Expand Up @@ -493,9 +479,7 @@ def compress(model_path, config):
original_shape = data_before.shape

if hasattr(config, "convert_to_blocks") and config.convert_to_blocks:
data_before = data_processing.convert_to_blocks_util(
config.convert_to_blocks, data_before
)
data_before = data_processing.convert_to_blocks_util(config.convert_to_blocks, data_before)

if config.apply_normalization:
print("Normalizing...")
Expand All @@ -508,9 +492,7 @@ def compress(model_path, config):
if config.data_dimension == 1:
column_names = np.load(config.input_path)["names"]
number_of_columns = len(column_names)
config.latent_space_size = ceil(
number_of_columns / config.compression_ratio
)
config.latent_space_size = ceil(number_of_columns / config.compression_ratio)
config.number_of_columns = number_of_columns
n_features = number_of_columns
elif config.data_dimension == 2:
Expand Down Expand Up @@ -653,12 +635,8 @@ def decompress(
)

if config.save_error_bounded_deltas:
loaded_deltas = np.load(
gzip.GzipFile(input_path_deltas, "r"), allow_pickle=True
)
loaded_batch_indexes = np.load(
gzip.GzipFile(input_batch_index, "r"), allow_pickle=True
)
loaded_deltas = np.load(gzip.GzipFile(input_path_deltas, "r"), allow_pickle=True)
loaded_batch_indexes = np.load(gzip.GzipFile(input_batch_index, "r"), allow_pickle=True)
error_bound_batch = loaded_batch_indexes[0]
error_bound_deltas = loaded_deltas
error_bound_index = loaded_batch_indexes[1]
Expand Down Expand Up @@ -730,6 +708,26 @@ def decompress(
(len(decompressed), original_shape[1], original_shape[2])
)

# Changing the decompressed dtype to configured precision
try:
if config.dtype:
try:
dtype = np.dtype(config.dtype)
except TypeError as _:
raise TypeError(f'invalid dtype "{config.dtype}" found in config file') from None
if "int" in config.dtype:
dtype = np.dtype(config.dtype)
info = np.iinfo(dtype)
decompressed = np.clip(decompressed, a_min=info.min, a_max=info.max).astype(dtype)
elif "float" in config.dtype:
dtype = np.dtype(config.dtype)
info = np.finfo(dtype)
decompressed = np.clip(decompressed, a_min=info.min, a_max=info.max).astype(dtype)
else:
raise TypeError(f'invalid dtype "{config.dtype}" found in config file')
except AttributeError as _:
pass

return decompressed, names, normalization_features


Expand Down Expand Up @@ -846,6 +844,4 @@ def perform_hls4ml_conversion(output_path, config):

hls_model.compile()

hls_model.build(
csim=config.csim, synth=config.synth, cosim=config.cosim, export=config.export
)
hls_model.build(csim=config.csim, synth=config.synth, cosim=config.cosim, export=config.export)