-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Yuan Tian <[email protected]> Co-authored-by: Anirudh Dagar <[email protected]>
- Loading branch information
1 parent
49cb475
commit 9df3ed7
Showing
19 changed files
with
1,717 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
[server] | ||
maxUploadSize = 4096 | ||
enableStaticServing = true | ||
[theme] | ||
primaryColor="#4C7DE7" | ||
backgroundColor="#FFFFFF" | ||
secondaryBackgroundColor="#fbfcfc" | ||
textColor="#404040" | ||
[client] | ||
showSidebarNavigation = false |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from copy import deepcopy | ||
|
||
import streamlit as st | ||
import streamlit.components.v1 as components | ||
from constants import DEFAULT_SESSION_VALUES, LOGO_PATH | ||
from pages.demo import main as demo | ||
from pages.feature import main as feature | ||
from pages.nav_bar import nav_bar | ||
from pages.preview import main as preview | ||
from pages.task import main as run | ||
from pages.tutorial import main as tutorial | ||
|
||
st.set_page_config( | ||
page_title="AutoGluon Assistant", | ||
page_icon=LOGO_PATH, | ||
layout="wide", | ||
initial_sidebar_state="collapsed", | ||
) | ||
|
||
# fontawesome | ||
st.markdown( | ||
""" | ||
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css"> | ||
""", | ||
unsafe_allow_html=True, | ||
) | ||
|
||
# Bootstrap 4.1.3 | ||
st.markdown( | ||
""" | ||
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css" integrity="sha384-MCw98/SFnGE8fJT3GXwEOngsV7Zt27NXFoaoApmYm81iuXoPkFOJwJ8ERdknLPMO" crossorigin="anonymous"> | ||
""", | ||
unsafe_allow_html=True, | ||
) | ||
with open("style.css") as f: | ||
st.markdown(f"<style>{f.read()}</style>", unsafe_allow_html=True) | ||
|
||
|
||
reload_warning = """ | ||
<script> | ||
window.onbeforeunload = function () { | ||
return "Are you sure want to LOGOUT the session ?"; | ||
}; | ||
</script> | ||
""" | ||
|
||
components.html(reload_warning, height=0) | ||
|
||
|
||
def initial_session_state(): | ||
""" | ||
Initial Session State | ||
""" | ||
for key, default_value in DEFAULT_SESSION_VALUES.items(): | ||
if key not in st.session_state: | ||
st.session_state[key] = ( | ||
deepcopy(default_value) if isinstance(default_value, (dict, list)) else default_value | ||
) | ||
|
||
|
||
def main(): | ||
initial_session_state() | ||
nav_bar() | ||
tutorial() | ||
demo() | ||
feature() | ||
run() | ||
preview() | ||
|
||
st.markdown( | ||
""" | ||
<script src="https://code.jquery.com/jquery-3.3.1.slim.min.js" integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo" crossorigin="anonymous"></script> | ||
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/umd/popper.min.js" integrity="sha384-ZMP7rVo3mIykV+2+9J3UJ46jBk0WLaUAdn689aCwoqbBJiSnjAK/l8WvCWPIPm49" crossorigin="anonymous"></script> | ||
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.min.js" integrity="sha384-ChfqqxuZUCnJSK3+MXmPNIyE6ZbWh2IMqE241rYiqJxyMiZ6OW/JmZQ5stwEULTy" crossorigin="anonymous"></script> | ||
""", | ||
unsafe_allow_html=True, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
from copy import deepcopy | ||
|
||
BASE_DATA_DIR = "./user_data" | ||
|
||
|
||
# Preset configurations | ||
PRESET_DEFAULT_CONFIG = { | ||
"Best Quality": {"time_limit": "4 hrs", "feature_generation": True}, | ||
"High Quality": {"time_limit": "1 hr", "feature_generation": False}, | ||
"Medium Quality": {"time_limit": "10 mins", "feature_generation": False}, | ||
} | ||
DEFAULT_PRESET = "Medium Quality" | ||
|
||
PRESET_MAPPING = { | ||
"Best Quality": "best_quality", | ||
"High Quality": "high_quality", | ||
"Medium Quality": "medium_quality", | ||
} | ||
PRESET_OPTIONS = ["Best Quality", "High Quality", "Medium Quality"] | ||
|
||
# Time limit configurations (in seconds) | ||
TIME_LIMIT_MAPPING = { | ||
"1 min": 60, | ||
"10 mins": 600, | ||
"30 mins": 1800, | ||
"1 hr": 3600, | ||
"2 hrs": 7200, | ||
"4 hrs": 14400, | ||
} | ||
|
||
DEFAULT_TIME_LIMIT = "10 mins" | ||
|
||
TIME_LIMIT_OPTIONS = ["1 min", "10 mins", "30 mins", "1 hr", "2 hrs", "4 hrs"] | ||
|
||
# LLM configurations | ||
LLM_MAPPING = { | ||
"Claude 3.5 with Amazon Bedrock": "anthropic.claude-3-5-sonnet-20241022-v2:0", | ||
"GPT 4o": "gpt-4o-mini-2024-07-18", | ||
} | ||
|
||
LLM_OPTIONS = ["Claude 3.5 with Amazon Bedrock"] | ||
|
||
# Provider configuration | ||
PROVIDER_MAPPING = {"Claude 3.5 with Amazon Bedrock": "bedrock", "GPT 4o": "openai"} | ||
|
||
|
||
API_KEY_LOCATION = {"Claude 3.5 with Amazon Bedrock": "BEDROCK_API_KEY", "GPT 4o": "OPENAI_API_KEY"} | ||
|
||
INITIAL_STAGE = { | ||
"Task Understanding": [], | ||
"Feature Generation": [], | ||
"Model Training": [], | ||
"Prediction": [], | ||
} | ||
# Initial Session state | ||
DEFAULT_SESSION_VALUES = { | ||
"config_overrides": [], | ||
"preset": DEFAULT_PRESET, | ||
"time_limit": DEFAULT_TIME_LIMIT, | ||
"llm": None, | ||
"pid": None, | ||
"logs": "", | ||
"process": None, | ||
"clicked": False, | ||
"task_running": False, | ||
"output_file": None, | ||
"output_filename": None, | ||
"task_description": "", | ||
"sample_description": "", | ||
"return_code": None, | ||
"task_canceled": False, | ||
"uploaded_files": {}, | ||
"sample_files": {}, | ||
"selected_dataset": None, | ||
"sample_dataset_dir": None, | ||
"description_uploader_key": 0, | ||
"sample_dataset_selector": None, | ||
"current_stage": None, | ||
"feature_generation": False, | ||
"stage_status": {}, | ||
"show_remaining_time": False, | ||
"model_path": None, | ||
"increment_time": 0, | ||
"progress_bar": None, | ||
"increment": 2, | ||
"zip_path": None, | ||
"stage_container": deepcopy(INITIAL_STAGE), | ||
} | ||
|
||
# Message to display different logging stage | ||
STATUS_BAR_STAGE = { | ||
"Task loaded!": 10, | ||
"Model training starts": 25, | ||
"Fitting model": 50, | ||
"AutoGluon training complete": 80, | ||
"Prediction starts": 90, | ||
} | ||
|
||
STAGE_COMPLETE_SIGNAL = [ | ||
"Task understanding complete", | ||
"Automatic feature generation complete", | ||
"Model training complete", | ||
"Prediction complete", | ||
] | ||
|
||
# Stage Names | ||
STAGE_TASK_UNDERSTANDING = "Task Understanding" | ||
STAGE_FEATURE_GENERATION = "Feature Generation" | ||
STAGE_MODEL_TRAINING = "Model Training" | ||
STAGE_PREDICTION = "Prediction" | ||
|
||
# Log Messages | ||
MSG_TASK_UNDERSTANDING = "Task understanding starts" | ||
MSG_FEATURE_GENERATION = "Automatic feature generation starts" | ||
MSG_MODEL_TRAINING = "Model training starts" | ||
MSG_PREDICTION = "Prediction starts" | ||
|
||
# Mapping | ||
STAGE_MESSAGES = { | ||
MSG_TASK_UNDERSTANDING: STAGE_TASK_UNDERSTANDING, | ||
MSG_FEATURE_GENERATION: STAGE_FEATURE_GENERATION, | ||
MSG_MODEL_TRAINING: STAGE_MODEL_TRAINING, | ||
MSG_PREDICTION: STAGE_PREDICTION, | ||
} | ||
# DataSet Options | ||
DATASET_OPTIONS = ["Sample Dataset", "Upload Dataset"] | ||
|
||
# Captions under DataSet Options | ||
CAPTIONS = ["Run with sample dataset", "Upload Train, Test and Output (Optional) Dataset"] | ||
|
||
DEMO_URL = "https://automl-mm-bench.s3.amazonaws.com/autogluon-assistant/aga-kaggle-demo.mp4" | ||
|
||
SAMPLE_DATASET_DESCRIPTION = """You are solving this data science tasks:The dataset presented here (knot theory) comprises a lot of numerical features. Some of the features may be missing, with nan value. Your task is to predict the 'signature', which has 18 unique integers. The evaluation metric is the classification accuracy.""" | ||
LOGO_PATH = "static/page_icon.png" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import os | ||
|
||
import pandas as pd | ||
import streamlit as st | ||
from utils import get_user_data_dir | ||
|
||
|
||
def save_description_file(description): | ||
""" | ||
Save the task description to a file in the user's data directory. | ||
Args: | ||
description (str): The task description to save. | ||
""" | ||
try: | ||
user_data_dir = get_user_data_dir() | ||
description_file = os.path.join(user_data_dir, "description.txt") | ||
with open(description_file, "w") as f: | ||
f.write(description) | ||
except Exception as e: | ||
print(f"Error saving file: {str(e)}") | ||
|
||
|
||
def description_file_uploader(): | ||
""" | ||
Handle Description file uploads | ||
""" | ||
uploaded_file = st.file_uploader( | ||
"Upload task description file", | ||
type="txt", | ||
key=st.session_state.description_uploader_key, | ||
help="Accepted file format: .txt", | ||
label_visibility="collapsed", | ||
) | ||
if uploaded_file: | ||
task_description = uploaded_file.read().decode("utf-8") | ||
st.session_state.task_description = task_description | ||
save_description_file(st.session_state.task_description) | ||
st.session_state.description_uploader_key += 1 | ||
st.rerun() | ||
|
||
|
||
def file_uploader(): | ||
""" | ||
Handle file uploads | ||
""" | ||
st.markdown("#### Upload Dataset") | ||
uploaded_files = st.file_uploader( | ||
"Select the dataset", accept_multiple_files=True, label_visibility="collapsed", type=["csv", "xlsx"] | ||
) | ||
st.session_state.uploaded_files = {} | ||
for file in uploaded_files: | ||
if file.name.endswith(".csv"): | ||
df = pd.read_csv(file) | ||
elif file.name.endswith(".xlsx"): | ||
df = pd.read_excel(file) | ||
st.session_state.uploaded_files[file.name] = {"file": file, "df": df} |
Oops, something went wrong.