Skip to content

Commit

Permalink
rewrite ds model evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
TPLin22 committed Dec 13, 2024
1 parent 00526ff commit 6980e1d
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 174 deletions.
6 changes: 4 additions & 2 deletions rdagent/components/coder/data_science/model/es.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
CoSTEERQueriedKnowledge,
CoSTEERQueriedKnowledgeV2,
)
from rdagent.components.coder.data_science.model.exp import ModelFBWorkspace, ModelTask
from rdagent.components.coder.data_science.model.exp import ModelTask

from rdagent.core.experiment import FBWorkspace
from rdagent.core.prompts import Prompts
from rdagent.oai.llm_conf import LLM_SETTINGS
from rdagent.oai.llm_utils import APIBackend
Expand Down Expand Up @@ -111,7 +113,7 @@ def assign_code_list_to_evo(self, code_list, evo):
if code_list[index] is None:
continue
if evo.sub_workspace_list[index] is None:
evo.sub_workspace_list[index] = ModelFBWorkspace(target_task=evo.sub_tasks[index])
evo.sub_workspace_list[index] = FBWorkspace(target_task=evo.sub_tasks[index])
# TODO: avoid hardcode of file name
evo.sub_workspace_list[index].inject_code(**{"model01.py": code_list[index]})
return evo
84 changes: 33 additions & 51 deletions rdagent/components/coder/data_science/model/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,29 @@
Beyond previous tests
-
"""

import json
from rdagent.components.coder.CoSTEER.evaluators import (
CoSTEEREvaluator,
CoSTEERMultiFeedback,
CoSTEERSingleFeedback,
CoSTEERSingleFeedbackDeprecated,
)
from rdagent.components.coder.data_science.model.eva_utils import (
ModelCodeEvaluator,
ModelFinalEvaluator,
expected_shape_evaluate,
)
from rdagent.components.coder.data_science.model.exp import ModelFBWorkspace
from rdagent.core.evolving_framework import QueriedKnowledge
from rdagent.core.experiment import Task, Workspace
from rdagent.core.experiment import Task, Workspace, FBWorkspace
from rdagent.utils.env import DSDockerConf, DockerEnv
from rdagent.oai.llm_utils import APIBackend
from pathlib import Path
from rdagent.utils.agent.tpl import T

DIRNAME = Path(__file__).absolute().resolve().parent

ModelSingleFeedback = CoSTEERSingleFeedbackDeprecated

ModelSingleFeedback = CoSTEERSingleFeedback
ModelMultiFeedback = CoSTEERMultiFeedback


Expand All @@ -35,11 +42,11 @@ class ModelGeneralCaseSpecEvaluator(CoSTEEREvaluator):
def evaluate(
self,
target_task: Task,
implementation: Workspace,
gt_implementation: Workspace,
implementation: FBWorkspace,
gt_implementation: FBWorkspace,
queried_knowledge: QueriedKnowledge = None,
**kwargs,
) -> ModelSingleFeedback:
) -> CoSTEERSingleFeedbackDeprecated:
target_task_information = target_task.get_task_information()
if (
queried_knowledge is not None
Expand All @@ -58,54 +65,29 @@ def evaluate(
# assert isinstance(target_task, ModelTask)

batch_size = 8
assert isinstance(implementation, ModelFBWorkspace)
model_execution_feedback, pred_list = implementation.execute(
assert isinstance(implementation, FBWorkspace)
"""model_execution_feedback, pred_list= implementation.execute(
batch_size=batch_size,
)"""
de = DockerEnv(conf=DSDockerConf())
fname = "model_execute.py"
with (DIRNAME / "eval_tests" / "model_execute.py").open("r") as f:
test_code = f.read()
implementation.inject_code(**{fname: test_code})
stdout = implementation.execute(env=de, entry=f"python {fname}")
system_prompt = T(".prompts:model_eval.system").r(
test_code=test_code,
scenario="No scenario information yet.",
spec=target_task.spec,
)
shape_feedback = ""
if pred_list is None:
shape_feedback += "No output generated from the model. No shape evaluation conducted."
else:
val_pred_array, test_pred_array, hypers = pred_list
# spec_message = implementation.code_dict["spec/model.md"]
spec_message = target_task.spec
val_shape_feedback = expected_shape_evaluate(
val_pred_array,
spec_message,
model_execution_feedback=model_execution_feedback,
)
test_shape_feedback = expected_shape_evaluate(
test_pred_array,
spec_message,
model_execution_feedback=model_execution_feedback,
)

shape_feedback += f"Validation Output: {val_shape_feedback}\n"
shape_feedback += f"Test Output: {test_shape_feedback}\n"
value_feedback = "The value feedback is ignored, and the value decision is automatically set as true."
code_feedback, _ = ModelCodeEvaluator(scen=self.scen).evaluate(
target_task=target_task,
implementation=implementation,
model_execution_feedback=model_execution_feedback,
)
final_feedback, final_decision = ModelFinalEvaluator(scen=self.scen).evaluate(
target_task=target_task,
implementation=implementation,
model_execution_feedback=model_execution_feedback,
model_shape_feedback=shape_feedback,
model_code_feedback=code_feedback,
user_prompt = T(".prompts:model_eval.user").r(
stdout=stdout,
code=implementation.code_dict["model01.py"],
)
resp = APIBackend().build_messages_and_create_chat_completion(user_prompt, system_prompt, json_mode=True)
return ModelSingleFeedback(**json.loads(resp))

return ModelSingleFeedback(
execution_feedback=model_execution_feedback,
shape_feedback=shape_feedback,
value_feedback=value_feedback,
code_feedback=code_feedback,
final_feedback=final_feedback,
final_decision=final_decision,
value_generated_flag=(pred_list is not None),
final_decision_based_on_gt=False,
)
"""feedback"""


class XXX2SpecEval:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
test_X=test_X,
hyper_params={}
)
#val_pred = np.random.rand(8, 1)
#test_pred = np.random.rand(8, 1)

execution_feedback_str = "Execution successful.\n"
if val_pred is not None:
Expand All @@ -39,10 +41,6 @@
if test_pred is not None:
execution_feedback_str += f"Test predictions shape: {test_pred.shape}\n"
else:
execution_feedback_str += "Test predictions are None.\n"

# Save the outputs
pred_list = [val_pred, test_pred, hypers]
pickle.dump(pred_list, open("pred_list.pkl", "wb"))
pickle.dump(execution_feedback_str, open("execution_feedback_str.pkl", "wb"))

execution_feedback_str += "Test predictions are None.\n" ''

print(execution_feedback_str)
137 changes: 26 additions & 111 deletions rdagent/components/coder/data_science/model/prompts.yaml
Original file line number Diff line number Diff line change
@@ -1,44 +1,3 @@
extract_model_formulation_system: |-
offer description of the proposed model in this paper, write a latex formula with variable as well as the architecture of the model. the format should be like
{
"model_name (The name of the model)": {
"description": "A detailed description of the model",
"formulation": "A LaTeX formula representing the model's formulation",
"architecture": "A detailed description of the model's architecture, e.g., neural network layers or tree structures",
"variables": {
"\\hat{y}_u": "The predicted output for node u",
"variable_name_2": "Description of variable 2",
"variable_name_3": "Description of variable 3"
},
"hyperparameters": {
"hyperparameter_name_1": "value of hyperparameter 1",
"hyperparameter_name_2": "value of hyperparameter 2",
"hyperparameter_name_3": "value of hyperparameter 3"
},
"model_type": "Tabular or TimeSeries or Graph or XGBoost" # Should be one of "Tabular", "TimeSeries", "Graph", or "XGBoost"
}
}
Eg.
{
"ABC Model": {
"description": "A detailed description of the model",
"formulation": "A LaTeX formula representing the model's formulation",
"architecture": "A detailed description of the model's architecture, e.g., neural network layers or tree structures",
"variables": {
"\\hat{y}_u": "The predicted output for node u",
"variable_name_2": "Description of variable 2",
"variable_name_3": "Description of variable 3"
},
"hyperparameters": {
"hyperparameter_name_1": "value of hyperparameter 1",
"hyperparameter_name_2": "value of hyperparameter 2",
"hyperparameter_name_3": "value of hyperparameter 3"
},
"model_type": "Tabular or TimeSeries or Graph or RandomForest or XGBoost" # If torch & Neural network models are required, the choice should be one of "Tabular", "TimeSeries", or "Graph"
}
}
such format content should be begin with ```json and end with ``` and the content should be in json format.
evolving_strategy_model_coder:
system: |-
User is trying to implement some pytorch models in the following scenario:
Expand Down Expand Up @@ -102,80 +61,36 @@ evolving_strategy_model_coder:
{% endfor %}
{% endif %}
evaluator_shape_feedback:
system: |-
User is trying to evaluate whether a model output shape is correct or not. The correct message about the ground truth shape is given in spec.md as below:
{{ spec }}
The user will provide you the actual output of the model. The model is a part for solving a task in an given scenario. This model takes train dataset as input. Valid and test dataset are optional. The model workflow will generate prediction output of valid and test dataset.
The user will provide the execution result message.
Your job is to compare the output user provide and the message from spec.md to evaluate whether the user's model output is correct.
In your response you should give a clear judgement and also point out the expected shape and actual shape of the model output.
Here is an example structure for the output:
Expected prediction shape: (8, 1). The actual output shape: (8, 1). The shape of the output is correct.
user: |-
--------------Actual Output Shape:---------------
{{ pre_shape }}
--------------Execution feedback:---------------
{{ model_execution_feedback }}
evaluator_code_feedback:
model_eval:
system: |-
You are data scientist.
User is trying to implement some models in the following scenario:
{{ scenario }}
User will provide you the information of the model.
Your job is to check whether user's code is align with the model information and the scenario.
The user will provide the source python code and the execution error message if execution failed.
The user might provide you the ground truth code for you to provide the critic. You should not leak the ground truth code to the user in any form but you can use it to provide the critic.
User has also compared the output generated by the user's code and the ground truth code. The user will provide you some analysis results comparing two output. You may find some error in the code which caused the difference between the two output.
If the ground truth code is provided, your critic should only consider checking whether the user's code is align with the ground truth code since the ground truth is definitely correct.
If the ground truth code is not provided, your critic should consider checking whether the user's code is reasonable and correct to the description and to the scenario.
Notice that your critics are not for user to debug the code. They are sent to the coding agent to correct the code. So don't give any following items for the user to check like "Please check the code line XXX".
You suggestion should not include any code, just some clear and short suggestions. Please point out very critical issues in your response, ignore non-important issues to avoid confusion. If no big issue found in the code, you can response "No critics found".
You should provide the suggestion to each of your critic to help the user improve the code. Please response the critic in the following format. Here is an example structure for the output:
critic 1: The critic message to critic 1
critic 2: The critic message to critic 2
user: |-
--------------Model information:---------------
{{ model_information }}
--------------Python code:---------------
{{ code }}
--------------Execution feedback:---------------
{{ model_execution_feedback }}
evaluator_final_feedback:
system: |-
User is trying to implement a model in the following scenario:
{{ scenario }}
User has finished evaluation and got some feedback from the evaluator.
The evaluator run the code and get the output and provide several feedback regarding user's code and code output. You should analyze the feedback and considering the scenario and model description to give a final decision about the evaluation result. The final decision concludes whether the model is implemented correctly and if not, detail feedback containing reason and suggestion if the final decision is False.
The implementation final decision is considered in the following logic:
1. If the value and the ground truth value are exactly the same under a small tolerance, the implementation is considered correct.
2. If no ground truth value is not provided, the implementation is considered correct if the code execution is successful and the code feedback is align with the scenario and model description.
Please response the critic in the json format. Here is an example structure for the JSON output, please strictly follow the format:
The information about how to implement the model is given in spec.md as below:
{{ spec }}
You are testing the model with the following code:
```python
{{test_code}}
```
You should evaluate the code given by user. You should concern about whether the user implement it correctly, including whether the shape of model's output is aligned with request, the equality of code, and any other thing you think necessary.
You will be given the code generated by user and the stdout of the testing process.
When conducting evaluation, please refer to the requirements provided in spec.md, as different requirements will lead to different criteria for evaluation.
For example, in some cases, the model's output may be required to have predictions for both the valid and test sets, while in other cases, only one of them may be required. Some cases may also require the model's hyperparameters to be preserved and outputted.
Please respond with your feedback in the following JSON format and order
```json
{
"final_decision": True,
"final_feedback": "The final feedback message",
"execution": "Describe whether the model execute successfully, including any errors or issues encountered.",
"return_checking": "Checks about the generated value, including whether the value generated. Especially compare the shape of model output and the requirement in spec.md.",
"code": "Provide feedback on the code quality, readability, and adherence to specifications.",
"final_decision": <true/false>
}
```
user: |-
--------------Model information:---------------
{{ model_information }}
--------------Model Execution feedback:---------------
{{ model_execution_feedback }}
--------------Model shape feedback:---------------
{{ model_shape_feedback }}
--------------Model Code feedback:---------------
{{ model_code_feedback }}
--------------Code generated by user:---------------
{{ code }}
--------------stdoutput:---------------
'''
{{ stdout }}
'''
6 changes: 5 additions & 1 deletion rdagent/components/coder/data_science/model/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from rdagent.components.coder.data_science.model.exp import ModelFBWorkspace, ModelTask
from rdagent.scenarios.data_science.experiment.experiment import ModelExperiment
from rdagent.scenarios.data_science.scen import DataScienceScen
from rdagent.components.coder.CoSTEER.config import CoSTEER_SETTINGS
from rdagent.components.coder.data_science.model.es import ModelMultiProcessEvolvingStrategy
from rdagent.core.experiment import FBWorkspace



# Take tasks, spec.md and feat as input, generate a feedback as output
Expand All @@ -36,7 +40,7 @@ def develop_one_competition(competition: str):
tpl_ex_path = Path(__file__).resolve() / Path("rdagent/scenarios/kaggle/tpl_ex").resolve() / competition
injected_file_names = ["spec/model.md", "load_data.py", "feat01.py", "model01.py"]

modelexp = ModelFBWorkspace()
modelexp = FBWorkspace()
for file_name in injected_file_names:
file_path = tpl_ex_path / file_name
modelexp.inject_code(**{file_name: file_path.read_text()})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def model_workflow(
val_y: np.ndarray = None,
test_X: np.ndarray = None,
**hyper_params,
) -> tuple[np.ndarray | None, np.ndarray | None]:
) -> tuple[np.ndarray | None, np.ndarray | None, dict]:
"""
Manages the workflow of a machine learning model, including training, validation, and testing.
Expand Down Expand Up @@ -150,4 +150,4 @@ def model_workflow(
else:
test_pred = None

return val_pred, test_pred
return val_pred, test_pred, hyper_params

0 comments on commit 6980e1d

Please sign in to comment.