From d618a1ac6dd770aee39ef639b7a7a889d614c577 Mon Sep 17 00:00:00 2001 From: Darren Edge Date: Sun, 10 Nov 2024 17:56:30 +0000 Subject: [PATCH] Fix data generation batch sizes --- app/util/schema_ui.py | 2 - app/workflows/generate_mock_data/workflow.py | 1 - .../generate_mock_data/data_generator.py | 38 ++++++++++++------- 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/app/util/schema_ui.py b/app/util/schema_ui.py index 301c2b8c..4c61c280 100644 --- a/app/util/schema_ui.py +++ b/app/util/schema_ui.py @@ -19,7 +19,6 @@ def build_schema_ui(global_schema, last_filename): jsn = loads(file.read()) for k, v in jsn.items(): global_schema[k] = v - print(f'Loaded schema: {global_schema}') st.markdown('### Edit data schema') generate_form_from_json_schema( global_schema=global_schema, @@ -255,7 +254,6 @@ def create_enum_ui(field_location, key, key_with_prefix, value): value['enum'].pop(i) st.rerun() new_enum_value = st.text_input(f'New value', key=f'{key_with_prefix}_new_enum_{"_".join([str(x) for x in value["enum"]])}', value="") - print(new_enum_value) if new_enum_value != "" and new_enum_value not in value['enum']: if value['type'] == 'string': value['enum'].append(new_enum_value) diff --git a/app/workflows/generate_mock_data/workflow.py b/app/workflows/generate_mock_data/workflow.py index fabf5d01..8bff47f1 100644 --- a/app/workflows/generate_mock_data/workflow.py +++ b/app/workflows/generate_mock_data/workflow.py @@ -79,7 +79,6 @@ async def create(sv: bds_variables.SessionVariables, workflow: None): dl_placeholders.append(dl_placeholder) def on_dfs_update(path_to_df): - print(path_to_df) for ix, record_array in enumerate(sv.record_arrays.value): with df_placeholders[ix]: df = path_to_df[record_array] diff --git a/intelligence_toolkit/generate_mock_data/data_generator.py b/intelligence_toolkit/generate_mock_data/data_generator.py index e9572418..c35a6282 100644 --- a/intelligence_toolkit/generate_mock_data/data_generator.py +++ b/intelligence_toolkit/generate_mock_data/data_generator.py @@ -23,7 +23,6 @@ async def generate_data( callback_batch, parallel_batches=5, ): - num_iterations = num_records_overall // (records_per_batch * parallel_batches) record_arrays = extract_array_fields(data_schema) primary_record_array = record_arrays[0] generated_objects = [] @@ -31,22 +30,36 @@ async def generate_data( ai_configuration=ai_configuration, generation_guidance=generation_guidance, primary_record_array=primary_record_array, - total_records=parallel_batches, + total_records=records_per_batch, data_schema=data_schema, temperature=temperature, ) first_object_json = loads(first_object) - current_object_json = {} + try: + first_object_json = loads(first_object) + except Exception as e: + msg = f"AI did not return a valid JSON response. Please try again. {e}" + raise ValueError(msg) from e + generated_objects.append(first_object_json) + current_object_json = first_object_json.copy() dfs = {} - for i in range(num_iterations): - if i == 0: - sample_records = sample_from_record_array( - first_object_json, primary_record_array, records_per_batch - ) - else: - sample_records = sample_from_record_array( - current_object_json, primary_record_array, parallel_batches - ) + for record_array in record_arrays: + df = extract_df(current_object_json, record_array) + dfs[".".join(record_array)] = df + if df_update_callback is not None: + df_update_callback(dfs) + + num_records = records_per_batch + while num_records < num_records_overall: + remainder = num_records_overall - num_records + required = remainder / records_per_batch + if not required.is_integer(): + required += 1 + batches = min(parallel_batches, required) + sample_records = sample_from_record_array( + current_object_json, primary_record_array, batches + ) + num_records += records_per_batch * parallel_batches # Use each as seed for parallel gen new_objects = await generate_seeded_data( ai_configuration=ai_configuration, @@ -62,7 +75,6 @@ async def generate_data( ) for new_object in new_objects: - print(new_object) try: new_object_json = loads(new_object) except Exception as e: