diff --git a/app.py b/app.py index 972d2eb..8485040 100644 --- a/app.py +++ b/app.py @@ -6,324 +6,582 @@ import streamlit.components.v1 as components import maidr +# Import plot modules +from plots.utils import set_theme, color_palettes +from plots.histogram import create_histogram, create_custom_histogram +from plots.boxplot import create_boxplot, create_custom_boxplot +from plots.scatterplot import create_scatterplot, create_custom_scatterplot +from plots.barplot import create_barplot, create_custom_barplot +from plots.lineplot import create_lineplot, create_custom_lineplot +from plots.heatmap import create_heatmap, create_custom_heatmap +from plots.multilineplot import generate_multiline_data, create_multiline_plot, create_custom_multiline_plot +from plots.multilayerplot import create_multilayer_plot, create_custom_multilayer_plot +from plots.multipanelplot import create_multipanel_plot, create_custom_multipanel_plot + # Set random seed np.random.seed(1000) -# Define color palettes -color_palettes = { - "Default": "#007bc2", - "Red": "#FF0000", - "Green": "#00FF00", - "Blue": "#0000FF", - "Purple": "#800080", - "Orange": "#FFA500" -} - -# Helper function to set theme -def set_theme(fig, ax): - theme = st.session_state.get('theme', 'Light') - if theme == "Dark": - plt.style.use('dark_background') - fig.patch.set_facecolor('#2E2E2E') - ax.set_facecolor('#2E2E2E') - else: - plt.style.use('default') - fig.patch.set_facecolor('white') - ax.set_facecolor('white') - -# Sidebar for user input -st.sidebar.title("Settings") -theme = st.sidebar.selectbox("Theme:", ["Light", "Dark"]) -st.session_state['theme'] = theme +# Set page config +st.set_page_config( + page_title="Learning Data Visualization with MAIDR", + page_icon="📊", + layout="wide" +) + +# Define functions to render MAIDR plots +def render_maidr_plot(ax): + """Renders a matplotlib plot with MAIDR accessibility features""" + # Apply figure size from sliders + fig_width = st.session_state.get('fig_width', 10) + fig_height = st.session_state.get('fig_height', 6) + + # Resize the figure + ax.figure.set_size_inches(fig_width, fig_height) + + # Only display the MAIDR accessible output (which includes the plot) + try: + components.html( + maidr.render(ax).get_html_string(), + scrolling=False, # Disable scrolling + height=fig_height * 110, # Slightly larger height to prevent scrolling + width=fig_width * 110, # Slightly larger width to prevent scrolling + ) + except Exception as e: + # If MAIDR rendering fails, fall back to standard matplotlib rendering + st.error(f"Error rendering MAIDR accessibility features: {str(e)}") + st.warning("Falling back to standard matplotlib plot without accessibility features.") + st.pyplot(ax.figure) + +# Function to render custom plot based on uploaded data +def render_custom_plot(df, plot_type, color, theme, **kwargs): + """Render a custom plot based on user data and selections""" + if plot_type == "Histogram": + var = kwargs.get('var') + if var: + ax = create_custom_histogram(df, var, color, theme) + if ax: + render_maidr_plot(ax) + + elif plot_type == "Box Plot": + var_x = kwargs.get('var_x') + var_y = kwargs.get('var_y') + if var_x: + ax = create_custom_boxplot(df, var_x, var_y, color, theme) + if ax: + render_maidr_plot(ax) + + elif plot_type == "Scatter Plot": + var_x = kwargs.get('var_x') + var_y = kwargs.get('var_y') + if var_x and var_y: + ax = create_custom_scatterplot(df, var_x, var_y, color, theme) + if ax: + render_maidr_plot(ax) + + elif plot_type == "Bar Plot": + var = kwargs.get('var') + if var: + ax = create_custom_barplot(df, var, color, theme) + if ax: + render_maidr_plot(ax) + + elif plot_type == "Line Plot": + var_x = kwargs.get('var_x') + var_y = kwargs.get('var_y') + if var_x and var_y: + ax = create_custom_lineplot(df, var_x, var_y, color, theme) + if ax: + render_maidr_plot(ax) + + elif plot_type == "Heatmap": + var_x = kwargs.get('var_x') + var_y = kwargs.get('var_y') + var_value = kwargs.get('var_value') + colorscale = kwargs.get('colorscale', 'YlGnBu') + if var_x and var_y: + ax = create_custom_heatmap(df, var_x, var_y, var_value, colorscale, theme) + if ax: + render_maidr_plot(ax) + + elif plot_type == "Multiline Plot": + var_x = kwargs.get('var_x') + var_y = kwargs.get('var_y') + var_group = kwargs.get('var_group') + palette = kwargs.get('palette', 'Default') + if var_x and var_y and var_group: + ax = create_custom_multiline_plot(df, var_x, var_y, var_group, palette, theme) + if ax: + render_maidr_plot(ax) + + elif plot_type == "Multilayer Plot": + var_x = kwargs.get('var_x') + var_background = kwargs.get('var_background') + var_line = kwargs.get('var_line') + background_type = kwargs.get('background_type', 'Bar Plot') + background_color = kwargs.get('background_color', 'Default') + line_color = kwargs.get('line_color', 'Default') + if var_x and var_background and var_line: + ax = create_custom_multilayer_plot(df, var_x, var_background, var_line, + background_type, background_color, line_color, theme) + if ax: + render_maidr_plot(ax) + + elif plot_type == "Multipanel Plot": + vars_config = kwargs.get('vars_config', {}) + layout_type = kwargs.get('layout_type', 'Grid 2x2') + palette = kwargs.get('palette', 'Default') + if vars_config: + ax = create_custom_multipanel_plot(df, vars_config, layout_type, palette, theme) + if ax: + render_maidr_plot(ax) -# Sliders to adjust figure size (now in the sidebar) -fig_width = st.sidebar.slider("Figure width", min_value=5, max_value=15, value=10) -fig_height = st.sidebar.slider("Figure height", min_value=3, max_value=10, value=6) +# Sidebar for theme and figure settings +with st.sidebar: + st.title("Settings") + theme = st.radio("Select Theme:", ["Light", "Dark"]) + st.session_state['theme'] = theme + + # Store the slider values directly in session state with keys + fig_width = st.slider("Figure Width", min_value=6, max_value=15, value=10, key="fig_width") + fig_height = st.slider("Figure Height", min_value=4, max_value=10, value=6, key="fig_height") # Main content st.title("Learning Data Visualization with MAIDR") # Tabs for different plots -tab1, tab2, tab3, tab4, tab5, tab6, tab7 = st.tabs([ - "Practice", "Histogram", "Box Plot", "Scatter Plot", "Bar Plot", "Line Plot", "Heatmap" +tab1, tab2, tab3, tab4, tab5, tab6, tab7, tab8, tab9, tab10 = st.tabs([ + "Practice", "Histogram", "Box Plot", "Scatter Plot", "Bar Plot", "Line Plot", "Heatmap", + "Multilayer Plot", "Multipanel Plot", "Multiline Plot" ]) -# Function to render plots using Maidr -def render_maidr_plot(plot): - # Render the plot using maidr and display it in Streamlit - components.html( - maidr.render( - plot - ).get_html_string(), - scrolling=True, - height=fig_height * 100, - width=fig_width * 100, - ) - -# Practice tab: Allows users to upload a CSV and generate plots -# Practice tab: Allows users to upload a CSV and generate plots +# Practice tab with tab1: - st.header("Practice with your own data") - - # Upload CSV file - uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"]) + st.header("Create your own Custom Plot") + uploaded_file = st.file_uploader("Choose a CSV file", type="csv") if uploaded_file is not None: df = pd.read_csv(uploaded_file) - st.write("Data preview:", df.head()) - # Select the plot type - plot_type = st.selectbox("Select plot type:", [ - "Histogram", "Box Plot", "Scatter Plot", "Bar Plot", "Line Plot", "Heatmap" - ]) + # Show data preview + st.subheader("Data Preview") + st.dataframe(df.head()) - # Color palette selection - plot_color = st.selectbox("Select plot color:", list(color_palettes.keys())) + # Show data types + st.subheader("Data Types") + st.dataframe(pd.DataFrame({ + 'Column': df.columns, + 'Type': df.dtypes, + 'Unique Values': [df[col].nunique() for col in df.columns] + })) - # Select columns from uploaded data for plots - numeric_columns = df.select_dtypes(include=['float64', 'int64']).columns.tolist() - categorical_columns = df.select_dtypes(include=['object']).columns.tolist() - - if plot_type == "Histogram": - var = st.selectbox("Select numeric variable for histogram:", numeric_columns) - if var: - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - sns.histplot(data=df, x=var, kde=True, color=color_palettes[plot_color], ax=ax) - ax.set_title(f"{var}") - ax.set_xlabel(var) - render_maidr_plot(ax) - - elif plot_type == "Box Plot": - var_x = st.selectbox("Select numerical variable for X-axis:", numeric_columns) - var_y = st.selectbox("Select categorical variable for Y-axis (optional):", [""] + categorical_columns) - if var_x: - fig, ax = plt.subplots(figsize=(10, 6)) - set_theme(fig, ax) - if var_y: - sns.boxplot(x=var_y, y=var_x, data=df, palette=[color_palettes[plot_color]], ax=ax) - ax.set_title(f"{var_x} grouped by {var_y}") - ax.set_xlabel(var_y.replace("_", " ").title()) - ax.set_ylabel(var_x.replace("_", " ").title()) - else: - sns.boxplot(y=df[var_x], color=color_palettes[plot_color], ax=ax) - ax.set_title(f"{var_x}") - ax.set_ylabel(var_x.replace("_", " ").title()) - render_maidr_plot(ax) - - elif plot_type == "Scatter Plot": - x_var = st.selectbox("Select X variable:", numeric_columns) - y_var = st.selectbox("Select Y variable:", [col for col in numeric_columns if col != x_var]) - if x_var and y_var: - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - sns.scatterplot(data=df, x=x_var, y=y_var, color=color_palettes[plot_color], ax=ax) - ax.set_title(f"{x_var} vs {y_var}") - render_maidr_plot(ax) - - elif plot_type == "Bar Plot": - var = st.selectbox("Select categorical variable for bar plot:", categorical_columns) - if var: - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - sns.countplot(x=var, data=df, color=color_palettes[plot_color], ax=ax) - ax.set_title(f"{var}") - render_maidr_plot(ax) - - elif plot_type == "Line Plot": - x_var = st.selectbox("Select X variable:", numeric_columns) - y_var = st.selectbox("Select Y variable:", [col for col in numeric_columns if col != x_var]) - if x_var and y_var: - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - sns.lineplot(data=df, x=x_var, y=y_var, color=color_palettes[plot_color], ax=ax) - ax.set_title(f"{x_var} vs {y_var}") - render_maidr_plot(ax) - - elif plot_type == "Heatmap": - x_var = st.selectbox("Select X variable:", numeric_columns) - y_var = st.selectbox("Select Y variable:", [col for col in numeric_columns if col != x_var]) - if x_var and y_var: - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - sns.heatmap(pd.crosstab(df[x_var], df[y_var]), ax=ax, cmap="YlGnBu", annot=True) - ax.set_title(f"{x_var} vs {y_var}") - render_maidr_plot(ax) + # Plot selection + st.subheader("Create Plot") + col1, col2 = st.columns(2) + + with col1: + plot_type = st.selectbox( + "Select Plot Type:", + ["", "Histogram", "Box Plot", "Scatter Plot", "Bar Plot", "Line Plot", + "Heatmap", "Multiline Plot", "Multilayer Plot", "Multipanel Plot"] + ) + + # Get numeric and categorical columns + numeric_cols = df.select_dtypes(include=np.number).columns.tolist() + categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist() + + # Plot specific inputs + if plot_type == "Histogram": + var = st.selectbox("Select Variable for Histogram:", numeric_cols) + plot_color = st.selectbox("Select Color:", list(color_palettes.keys()), key="histogram_color") + + if st.button("Generate Histogram"): + with col2: + render_custom_plot(df, plot_type, color_palettes[plot_color], theme, var=var) + + elif plot_type == "Box Plot": + var_x = st.selectbox("Select Numerical Variable:", numeric_cols) + var_y = st.selectbox("Select Categorical Variable (optional):", [""] + categorical_cols) + plot_color = st.selectbox("Select Color:", list(color_palettes.keys()), key="boxplot_color") + + if st.button("Generate Box Plot"): + with col2: + render_custom_plot(df, plot_type, color_palettes[plot_color], theme, + var_x=var_x, var_y=var_y if var_y else None) + + elif plot_type == "Scatter Plot": + var_x = st.selectbox("Select X Variable:", numeric_cols, key="scatter_x") + var_y = st.selectbox("Select Y Variable:", [col for col in numeric_cols if col != var_x], key="scatter_y") + plot_color = st.selectbox("Select Color:", list(color_palettes.keys()), key="scatter_color") + + if st.button("Generate Scatter Plot"): + with col2: + render_custom_plot(df, plot_type, color_palettes[plot_color], theme, + var_x=var_x, var_y=var_y) + + elif plot_type == "Bar Plot": + var = st.selectbox("Select Categorical Variable:", categorical_cols) + plot_color = st.selectbox("Select Color:", list(color_palettes.keys()), key="barplot_color") + + if st.button("Generate Bar Plot"): + with col2: + render_custom_plot(df, plot_type, color_palettes[plot_color], theme, var=var) + + elif plot_type == "Line Plot": + var_x = st.selectbox("Select X Variable:", numeric_cols, key="line_x") + var_y = st.selectbox("Select Y Variable:", [col for col in numeric_cols if col != var_x], key="line_y") + plot_color = st.selectbox("Select Color:", list(color_palettes.keys()), key="line_color") + + if st.button("Generate Line Plot"): + with col2: + render_custom_plot(df, plot_type, color_palettes[plot_color], theme, + var_x=var_x, var_y=var_y) + + elif plot_type == "Heatmap": + var_x = st.selectbox("Select X Variable (categorical):", categorical_cols, key="heatmap_x") + var_y = st.selectbox("Select Y Variable (categorical):", + [col for col in categorical_cols if col != var_x], key="heatmap_y") + var_value = st.selectbox("Select Value Variable (numeric, optional):", + [""] + numeric_cols, key="heatmap_value") + colorscale = st.selectbox("Select Color Scale:", + ["YlGnBu", "viridis", "plasma", "inferno", "RdBu_r", "coolwarm"]) + + if st.button("Generate Heatmap"): + with col2: + render_custom_plot(df, plot_type, None, theme, + var_x=var_x, var_y=var_y, + var_value=var_value if var_value else None, + colorscale=colorscale) + + elif plot_type == "Multiline Plot": + var_x = st.selectbox("Select X Variable:", numeric_cols, key="multiline_x") + var_y = st.selectbox("Select Y Variable:", + [col for col in numeric_cols if col != var_x], key="multiline_y") + var_group = st.selectbox("Select Group Variable (categorical):", + categorical_cols, key="multiline_group") + palette = st.selectbox("Select Color Palette:", + ["Default", "Colorful", "Pastel", "Dark Tones", "Paired Colors", "Rainbow"]) + + if st.button("Generate Multiline Plot"): + with col2: + render_custom_plot(df, plot_type, None, theme, + var_x=var_x, var_y=var_y, var_group=var_group, + palette=palette) + + elif plot_type == "Multilayer Plot": + var_x = st.selectbox("Select X Variable:", df.columns.tolist(), key="multilayer_x") + var_background = st.selectbox("Select Background Variable (numeric):", + numeric_cols, key="multilayer_bg") + var_line = st.selectbox("Select Line Variable (numeric):", + [col for col in numeric_cols if col != var_background], key="multilayer_line") + background_type = st.selectbox("Select Background Plot Type:", + ["Bar Plot", "Histogram", "Scatter Plot"]) + background_color = st.selectbox("Select Background Color:", + list(color_palettes.keys()), key="multilayer_bg_color") + line_color = st.selectbox("Select Line Color:", + list(color_palettes.keys()), key="multilayer_line_color") + + if st.button("Generate Multilayer Plot"): + with col2: + render_custom_plot(df, plot_type, None, theme, + var_x=var_x, var_background=var_background, var_line=var_line, + background_type=background_type, + background_color=background_color, line_color=line_color) + + elif plot_type == "Multipanel Plot": + st.subheader("Panel 1") + plot1_type = st.selectbox("Plot Type:", ["line", "bar", "scatter", "hist"], key="panel1_type") + plot1_x = st.selectbox("X Variable:", df.columns.tolist(), key="panel1_x") + plot1_y = st.selectbox("Y Variable (if applicable):", + [""] + numeric_cols, key="panel1_y") + + st.subheader("Panel 2") + plot2_type = st.selectbox("Plot Type:", ["line", "bar", "scatter", "hist"], key="panel2_type") + plot2_x = st.selectbox("X Variable:", df.columns.tolist(), key="panel2_x") + plot2_y = st.selectbox("Y Variable (if applicable):", + [""] + numeric_cols, key="panel2_y") + + st.subheader("Panel 3") + plot3_type = st.selectbox("Plot Type:", ["line", "bar", "scatter", "hist"], key="panel3_type") + plot3_x = st.selectbox("X Variable:", df.columns.tolist(), key="panel3_x") + plot3_y = st.selectbox("Y Variable (if applicable):", + [""] + numeric_cols, key="panel3_y") + + layout_type = st.selectbox("Select Layout Type:", + ["Grid 2x2", "Column", "Row"], key="multi_layout") + palette = st.selectbox("Select Color Palette:", + ["Default", "Colorful", "Pastel", "Dark Tones", "Paired Colors", "Rainbow"], + key="multi_palette") + + # Create vars_config dictionary + vars_config = { + 'plot1': {'type': plot1_type, 'x': plot1_x, 'y': plot1_y if plot1_y else None}, + 'plot2': {'type': plot2_type, 'x': plot2_x, 'y': plot2_y if plot2_y else None}, + 'plot3': {'type': plot3_type, 'x': plot3_x, 'y': plot3_y if plot3_y else None} + } + + if st.button("Generate Multipanel Plot"): + with col2: + render_custom_plot(df, plot_type, None, theme, + vars_config=vars_config, layout_type=layout_type, palette=palette) + + else: + st.info("Please upload a CSV file to practice creating visualizations with your own data.") + + # Sample data option + if st.button("Use Sample Data"): + # Load the sample data included with the app + try: + df = pd.read_csv("dummy_data_for_practice.csv") + st.session_state['sample_data'] = df + st.experimental_rerun() + except Exception as e: + st.error(f"Error loading sample data: {e}") # Histogram tab with tab2: st.header("Histogram") - hist_type = st.selectbox("Select histogram distribution type:", [ - "Normal Distribution", "Positively Skewed", "Negatively Skewed", - "Unimodal Distribution", "Bimodal Distribution", "Multimodal Distribution" - ]) - hist_color = st.selectbox("Select histogram color:", list(color_palettes.keys()), key="hist_color") - - # Generate data based on user selection - def hist_data(): - if hist_type == "Normal Distribution": - return np.random.normal(size=1000) - elif hist_type == "Positively Skewed": - return np.random.exponential(scale=3, size=1000) - elif hist_type == "Negatively Skewed": - return -np.random.exponential(scale=1.5, size=1000) - elif hist_type == "Unimodal Distribution": - return np.random.normal(loc=0, scale=2.5, size=1000) - elif hist_type == "Bimodal Distribution": - return np.concatenate([np.random.normal(-2, 0.5, size=500), np.random.normal(2, 0.5, size=500)]) - elif hist_type == "Multimodal Distribution": - return np.concatenate([np.random.normal(-2, 0.5, size=300), np.random.normal(2, 0.5, size=300), np.random.normal(5, 0.5, size=400)]) - - # Plot the histogram using Matplotlib - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - sns.histplot(hist_data(), kde=True, bins=20, color=color_palettes[hist_color], edgecolor="white", ax=ax) - ax.set_title(f"{hist_type}") - ax.set_xlabel(hist_type) - ax.set_ylabel("Count") - - # Render using Maidr - render_maidr_plot(ax) + col1, col2 = st.columns([1, 3]) + with col1: + distribution_type = st.selectbox( + "Select histogram distribution type:", + [ + "Normal Distribution", + "Positively Skewed", + "Negatively Skewed", + "Unimodal Distribution", + "Bimodal Distribution", + "Multimodal Distribution", + ], + key='hist_dist' + ) + hist_color = st.selectbox( + "Select histogram color:", + list(color_palettes.keys()), + key='hist_color' + ) + + with col2: + # Create and render the histogram + ax = create_histogram(distribution_type, hist_color, theme) + render_maidr_plot(ax) # Box Plot tab with tab3: st.header("Box Plot") - - box_type = st.selectbox("Select box plot type:", [ - "Positively Skewed with Outliers", "Negatively Skewed with Outliers", - "Symmetric with Outliers", "Symmetric without Outliers" - ]) - box_color = st.selectbox("Select box plot color:", list(color_palettes.keys()), key="box_color") - - def box_data(): - if box_type == "Positively Skewed with Outliers": - return np.random.lognormal(mean=0, sigma=0.5, size=1000) - elif box_type == "Negatively Skewed with Outliers": - return -np.random.lognormal(mean=0, sigma=0.5, size=1000) - elif box_type == "Symmetric with Outliers": - return np.random.normal(loc=0, scale=1, size=1000) - elif box_type == "Symmetric without Outliers": - data = np.random.normal(loc=0, scale=1, size=1000) - return data[(data > -1.5) & (data < 1.5)] - - # Plot the box plot using Matplotlib - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - sns.boxplot(x=box_data(), ax=ax, color=color_palettes[box_color]) - ax.set_title(f"{box_type}") - - # Render using Maidr - render_maidr_plot(ax) + + col1, col2 = st.columns([1, 3]) + with col1: + boxplot_type = st.selectbox( + "Select box plot type:", + [ + "Positively Skewed with Outliers", + "Negatively Skewed with Outliers", + "Symmetric with Outliers", + "Symmetric without Outliers", + ], + key='box_type' + ) + boxplot_color = st.selectbox( + "Select box plot color:", + list(color_palettes.keys()), + key='box_color' + ) + + with col2: + # Create and render the box plot + ax = create_boxplot(boxplot_type, boxplot_color, theme) + render_maidr_plot(ax) # Scatter Plot tab with tab4: st.header("Scatter Plot") - - scatter_type = st.selectbox("Select scatter plot type:", [ - "No Correlation", "Weak Positive Correlation", "Strong Positive Correlation", - "Weak Negative Correlation", "Strong Negative Correlation" - ]) - scatter_color = st.selectbox("Select scatter plot color:", list(color_palettes.keys()), key="scatter_color") - - def scatter_data(): - num_points = np.random.randint(20, 31) - x = np.random.uniform(size=num_points) - if scatter_type == "No Correlation": - y = np.random.uniform(size=num_points) - elif scatter_type == "Weak Positive Correlation": - y = 0.3 * x + np.random.uniform(size=num_points) - elif scatter_type == "Strong Positive Correlation": - y = 0.9 * x + np.random.uniform(size=num_points) * 0.1 - elif scatter_type == "Weak Negative Correlation": - y = -0.3 * x + np.random.uniform(size=num_points) - elif scatter_type == "Strong Negative Correlation": - y = -0.9 * x + np.random.uniform(size=num_points) * 0.1 - return x, y - - # Plot the scatter plot using Matplotlib - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - data_x, data_y = scatter_data() - sns.scatterplot(x=data_x, y=data_y, ax=ax, color=color_palettes[scatter_color]) - ax.set_title(f"{scatter_type}") - - # Render using Maidr - render_maidr_plot(ax) + + col1, col2 = st.columns([1, 3]) + with col1: + scatterplot_type = st.selectbox( + "Select scatter plot type:", + [ + "No Correlation", + "Weak Positive Correlation", + "Strong Positive Correlation", + "Weak Negative Correlation", + "Strong Negative Correlation", + ], + key='scatter_type' + ) + scatter_color = st.selectbox( + "Select scatter plot color:", + list(color_palettes.keys()), + key='scatter_color_main' + ) + + with col2: + # Create and render the scatter plot + ax = create_scatterplot(scatterplot_type, scatter_color, theme) + render_maidr_plot(ax) # Bar Plot tab with tab5: st.header("Bar Plot") - - bar_color = st.selectbox("Select bar plot color:", list(color_palettes.keys()), key="bar_color") - - def bar_data(): - categories = ["Category A", "Category B", "Category C", "Category D", "Category E"] - values = np.random.randint(10, 100, size=5) - return categories, values - - # Plot the bar plot using Matplotlib - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - categories, values = bar_data() - sns.barplot(x=categories, y=values, ax=ax, color=color_palettes[bar_color]) - ax.set_title("Plot of Categories") - - # Render using Maidr - render_maidr_plot(ax) + + col1, col2 = st.columns([1, 3]) + with col1: + barplot_color = st.selectbox( + "Select bar plot color:", + list(color_palettes.keys()), + key='bar_color' + ) + + with col2: + # Create and render the bar plot + ax = create_barplot(barplot_color, theme) + render_maidr_plot(ax) # Line Plot tab with tab6: st.header("Line Plot") - - line_type = st.selectbox("Select line plot type:", [ - "Linear Trend", "Exponential Growth", "Sinusoidal Pattern", "Random Walk" - ]) - line_color = st.selectbox("Select line plot color:", list(color_palettes.keys()), key="line_color") - - def line_data(): - x = np.linspace(0, 10, 20) - if line_type == "Linear Trend": - y = 2 * x + 1 + np.random.normal(0, 1, 20) - elif line_type == "Exponential Growth": - y = np.exp(0.5 * x) + np.random.normal(0, 1, 20) - elif line_type == "Sinusoidal Pattern": - y = 5 * np.sin(x) + np.random.normal(0, 0.5, 20) - elif line_type == "Random Walk": - y = np.cumsum(np.random.normal(0, 1, 20)) - return x, y - - # Plot the line plot using Matplotlib - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - data_x, data_y = line_data() - sns.lineplot(x=data_x, y=data_y, ax=ax, color=color_palettes[line_color]) - ax.set_title(f"{line_type}") - - # Render using Maidr - render_maidr_plot(ax) + + col1, col2 = st.columns([1, 3]) + with col1: + lineplot_type = st.selectbox( + "Select line plot type:", + [ + "Linear Trend", + "Exponential Growth", + "Sinusoidal Pattern", + "Random Walk", + ], + key='line_type' + ) + lineplot_color = st.selectbox( + "Select line plot color:", + list(color_palettes.keys()), + key='line_color_main' + ) + + with col2: + # Create and render the line plot + ax = create_lineplot(lineplot_type, lineplot_color, theme) + render_maidr_plot(ax) # Heatmap tab with tab7: st.header("Heatmap") - - heatmap_type = st.selectbox("Select heatmap type:", [ - "Random", "Correlated", "Checkerboard" - ]) - - def heatmap_data(): - if heatmap_type == "Random": - return np.random.rand(5, 5) - elif heatmap_type == "Correlated": - return np.random.multivariate_normal([0] * 5, np.eye(5), size=5) - elif heatmap_type == "Checkerboard": - return np.indices((5, 5)).sum(axis=0) % 2 - - # Plot the heatmap using Matplotlib - fig, ax = plt.subplots(figsize=(fig_width, fig_height)) - set_theme(fig, ax) - sns.heatmap(heatmap_data(), ax=ax, cmap="YlGnBu", annot=True, fmt=".2f") - ax.set_title(f"{heatmap_type}") - - # Render using Maidr - render_maidr_plot(ax) + + col1, col2 = st.columns([1, 3]) + with col1: + heatmap_type = st.selectbox( + "Select heatmap type:", + [ + "Random", + "Correlated", + "Checkerboard", + ], + key='heatmap_type' + ) + + with col2: + # Create and render the heatmap + ax = create_heatmap(heatmap_type, theme) + render_maidr_plot(ax) + +# Multilayer Plot tab +with tab8: + st.header("Multilayer Plot") + + col1, col2 = st.columns([1, 3]) + with col1: + multilayer_background_type = st.selectbox( + "Select background plot type:", + [ + "Bar Plot", + "Histogram", + "Scatter Plot" + ], + key='multilayer_bg_type' + ) + multilayer_background_color = st.selectbox( + "Select background color:", + list(color_palettes.keys()), + key='multilayer_bg_color_main' + ) + multilayer_line_color = st.selectbox( + "Select line color:", + list(color_palettes.keys()), + key='multilayer_line_color_main' + ) + + with col2: + # Create and render the multilayer plot + ax = create_multilayer_plot( + multilayer_background_type, + multilayer_background_color, + multilayer_line_color, + theme + ) + render_maidr_plot(ax) + +# Multipanel Plot tab +with tab9: + st.header("Multipanel Plot") + + col1, col2 = st.columns([1, 3]) + with col1: + # Removed the layout dropdown + multipanel_color = st.selectbox( + "Select color palette:", + [ + "Default", + "Colorful", + "Pastel", + "Dark Tones", + "Paired Colors", + "Rainbow" + ], + key='multipanel_color' + ) + + with col2: + # Create and render the multipanel plot with fixed layout (removed layout parameter) + ax = create_multipanel_plot("Grid 2x2", multipanel_color, theme) + render_maidr_plot(ax) + +# Multiline Plot tab +with tab10: + st.header("Multiline Plot") + + col1, col2 = st.columns([1, 3]) + with col1: + multiline_type = st.selectbox( + "Select multiline plot type:", + [ + "Simple Trends", + "Seasonal Patterns", + "Growth Comparison", + "Random Series", + ], + key='multiline_type' + ) + multiline_color = st.selectbox( + "Select color palette:", + [ + "Default", + "Colorful", + "Pastel", + "Dark Tones", + "Paired Colors", + "Rainbow" + ], + key='multiline_color' + ) + + with col2: + # Generate data and create the multiline plot + data = generate_multiline_data(multiline_type) + ax = create_multiline_plot(data, multiline_type, multiline_color, theme) + render_maidr_plot(ax) + +# Footer +st.markdown("---") +st.markdown("Learning Data Visualization with MAIDR - Explore different visualization types and make them accessible") diff --git a/plots/barplot.py b/plots/barplot.py new file mode 100644 index 0000000..c62a794 --- /dev/null +++ b/plots/barplot.py @@ -0,0 +1,34 @@ +# filepath: c:\Users\kamat\OneDrive\Desktop\Work\UIUC\Research\maidr\maidr_streamlit\plots\barplot.py +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +from plots.utils import set_theme, color_palettes + +def create_barplot(input_barplot_color, theme): + """Create a bar plot based on input parameters""" + color = color_palettes[input_barplot_color] + categories = ["Category A", "Category B", "Category C", "Category D", "Category E"] + values = np.random.randint(10, 100, size=5) + + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + sns.barplot(x=categories, y=values, ax=ax, color=color) + ax.set_title("Plot of Categories") + ax.set_xlabel("Categories") + ax.set_ylabel("Values") + + return ax + +def create_custom_barplot(df, var, color, theme): + """Create a bar plot from user data""" + if not var or df is None: + return None + + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + sns.countplot(data=df, x=var, color=color, ax=ax) + ax.set_title(f"{var}") + ax.set_xlabel(var.replace("_", " ").title()) + ax.set_ylabel("Count") + + return ax \ No newline at end of file diff --git a/plots/boxplot.py b/plots/boxplot.py new file mode 100644 index 0000000..007f121 --- /dev/null +++ b/plots/boxplot.py @@ -0,0 +1,54 @@ +# filepath: c:\Users\kamat\OneDrive\Desktop\Work\UIUC\Research\maidr\maidr_streamlit\plots\boxplot.py +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +from plots.utils import set_theme, color_palettes + +def create_boxplot(input_boxplot_type, input_boxplot_color, theme): + """Create a box plot based on input parameters""" + boxplot_type = input_boxplot_type + color = color_palettes[input_boxplot_color] + + # Generate data based on the selected box plot type + if boxplot_type == "Positively Skewed with Outliers": + data = np.random.lognormal(mean=0, sigma=0.5, size=1000) + elif boxplot_type == "Negatively Skewed with Outliers": + data = -np.random.lognormal(mean=0, sigma=0.5, size=1000) + elif boxplot_type == "Symmetric with Outliers": + data = np.random.normal(loc=0, scale=1, size=1000) + elif boxplot_type == "Symmetric without Outliers": + data = np.random.normal(loc=0, scale=1, size=1000) + data = data[(data > -1.5) & (data < 1.5)] # Strict range to avoid outliers + else: + data = np.random.normal(loc=0, scale=1, size=1000) + + # Create the plot using matplotlib + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + sns.boxplot(x=data, ax=ax, color=color) # Horizontal box plot + ax.set_title(f"{boxplot_type}") + ax.set_xlabel("Value") + + return ax + +def create_custom_boxplot(df, var_x, var_y, color, theme): + """Create a box plot from user data""" + if df is None: + return None + + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + + if var_x and var_y: + sns.boxplot(x=var_y, y=var_x, data=df, palette=[color], ax=ax) + ax.set_title(f"{var_x} grouped by {var_y}") + ax.set_xlabel(var_y.replace("_", " ").title()) + ax.set_ylabel(var_x.replace("_", " ").title()) + elif var_x: + sns.boxplot(y=df[var_x], color=color, ax=ax) + ax.set_title(f"{var_x}") + ax.set_ylabel(var_x.replace("_", " ").title()) + else: + return None + + return ax \ No newline at end of file diff --git a/plots/heatmap.py b/plots/heatmap.py new file mode 100644 index 0000000..49e3864 --- /dev/null +++ b/plots/heatmap.py @@ -0,0 +1,51 @@ +# filepath: c:\Users\kamat\OneDrive\Desktop\Work\UIUC\Research\maidr\maidr_streamlit\plots\heatmap.py +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +import pandas as pd +from plots.utils import set_theme + +def create_heatmap(input_heatmap_type, theme): + """Create a heatmap based on input parameters""" + heatmap_type = input_heatmap_type + + if heatmap_type == "Random": + data = np.random.rand(5, 5) # Reduced size + elif heatmap_type == "Correlated": + data = np.random.multivariate_normal( + [0] * 5, np.eye(5), size=5 + ) # Reduced size + elif heatmap_type == "Checkerboard": + data = np.indices((5, 5)).sum(axis=0) % 2 # Reduced size + else: + data = np.random.rand(5, 5) + + fig, ax = plt.subplots(figsize=(10, 8)) + set_theme(fig, ax, theme) + sns.heatmap(data, ax=ax, cmap="YlGnBu", annot=True, fmt=".2f") + ax.set_title(f"{heatmap_type}") + + return ax + +def create_custom_heatmap(df, var_x, var_y, var_value, colorscale, theme): + """Create a heatmap from user data""" + if not var_x or not var_y or df is None: + return None + + fig, ax = plt.subplots(figsize=(10, 8)) + set_theme(fig, ax, theme) + + # Check if both variables are categorical and a value variable is provided + if var_value: + # Create a pivot table + pivot_table = pd.pivot_table(df, values=var_value, index=var_y, columns=var_x, aggfunc='mean') + else: + # If no value variable, use crosstab for frequency counts + pivot_table = pd.crosstab(df[var_y], df[var_x], normalize='all') + + sns.heatmap(pivot_table, ax=ax, cmap=colorscale, annot=True, fmt=".2f") + ax.set_title(f"Heatmap of {var_y} vs {var_x}") + ax.set_xlabel(var_x.replace("_", " ").title()) + ax.set_ylabel(var_y.replace("_", " ").title()) + + return ax \ No newline at end of file diff --git a/plots/histogram.py b/plots/histogram.py new file mode 100644 index 0000000..2852728 --- /dev/null +++ b/plots/histogram.py @@ -0,0 +1,61 @@ +# filepath: c:\Users\kamat\OneDrive\Desktop\Work\UIUC\Research\maidr\maidr_streamlit\plots\histogram.py +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +from plots.utils import set_theme, color_palettes + +def create_histogram(input_distribution_type, input_hist_color, theme): + """Create a histogram based on input parameters""" + distribution_type = input_distribution_type + color = color_palettes[input_hist_color] + + # Generate data based on the selected distribution + if distribution_type == "Normal Distribution": + data = np.random.normal(size=1000) + elif distribution_type == "Positively Skewed": + data = np.random.exponential(scale=3, size=1000) + elif distribution_type == "Negatively Skewed": + data = -np.random.exponential(scale=1.5, size=1000) + elif distribution_type == "Unimodal Distribution": + data = np.random.normal(loc=0, scale=2.5, size=1000) + elif distribution_type == "Bimodal Distribution": + data = np.concatenate( + [ + np.random.normal(-2, 0.5, size=500), + np.random.normal(2, 0.5, size=500), + ] + ) + elif distribution_type == "Multimodal Distribution": + data = np.concatenate( + [ + np.random.normal(-2, 0.5, size=300), + np.random.normal(2, 0.5, size=300), + np.random.normal(5, 0.5, size=400), + ] + ) + else: + data = np.random.normal(size=1000) + + # Create the plot using matplotlib + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + sns.histplot(data, kde=True, bins=20, color=color, edgecolor="white", ax=ax) + ax.set_title(f"{distribution_type}") + ax.set_xlabel("Value") + ax.set_ylabel("Frequency") + + return ax + +def create_custom_histogram(df, var, color, theme): + """Create a histogram from user data""" + if not var or df is None: + return None + + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + sns.histplot(data=df, x=var, kde=True, color=color, ax=ax) + ax.set_title(f"{var}") + ax.set_xlabel(var.replace("_", " ").title()) + ax.set_ylabel("Count") + + return ax \ No newline at end of file diff --git a/plots/lineplot.py b/plots/lineplot.py new file mode 100644 index 0000000..2115252 --- /dev/null +++ b/plots/lineplot.py @@ -0,0 +1,45 @@ +# filepath: c:\Users\kamat\OneDrive\Desktop\Work\UIUC\Research\maidr\maidr_streamlit\plots\lineplot.py +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +from plots.utils import set_theme, color_palettes + +def create_lineplot(input_lineplot_type, input_lineplot_color, theme): + """Create a line plot based on input parameters""" + lineplot_type = input_lineplot_type + color = color_palettes[input_lineplot_color] + + x = np.linspace(0, 10, 20) # Reduced number of points + if lineplot_type == "Linear Trend": + y = 2 * x + 1 + np.random.normal(0, 1, 20) + elif lineplot_type == "Exponential Growth": + y = np.exp(0.5 * x) + np.random.normal(0, 1, 20) + elif lineplot_type == "Sinusoidal Pattern": + y = 5 * np.sin(x) + np.random.normal(0, 0.5, 20) + elif lineplot_type == "Random Walk": + y = np.cumsum(np.random.normal(0, 1, 20)) + else: + y = x + np.random.normal(0, 1, 20) + + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + sns.lineplot(x=x, y=y, ax=ax, color=color) + ax.set_title(f"{lineplot_type}") + ax.set_xlabel("X") + ax.set_ylabel("Y") + + return ax + +def create_custom_lineplot(df, var_x, var_y, color, theme): + """Create a line plot from user data""" + if not var_x or not var_y or df is None: + return None + + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + sns.lineplot(data=df, x=var_x, y=var_y, color=color, ax=ax) + ax.set_title(f"{var_y} vs {var_x}") + ax.set_xlabel(var_x.replace("_", " ").title()) + ax.set_ylabel(var_y.replace("_", " ").title()) + + return ax \ No newline at end of file diff --git a/plots/multilayerplot.py b/plots/multilayerplot.py new file mode 100644 index 0000000..a91d2d0 --- /dev/null +++ b/plots/multilayerplot.py @@ -0,0 +1,197 @@ +# filepath: c:\Users\kamat\OneDrive\Desktop\Work\UIUC\Research\maidr\maidr_streamlit\plots\multilayerplot.py +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from plots.utils import set_theme, color_palettes + +def create_multilayer_plot(input_background_type, background_color, line_color, theme): + """ + Create a multilayer plot with a selected background plot type and a line chart in the foreground. + + Parameters + ---------- + input_background_type : str + The type of background plot ('Bar Plot', 'Histogram', or 'Scatter Plot') + background_color : str + The color to use for the background plot + line_color : str + The color to use for the line plot + theme : str + The theme to apply to the plot ('Light' or 'Dark') + + Returns + ------- + plt.Axes + The axes object of the created plot. + """ + # Generate sample data + x = np.arange(8) + bar_data = np.array([3, 5, 2, 7, 3, 6, 4, 5]) + hist_data = np.concatenate([np.random.normal(loc=i, scale=0.5, size=20) for i in x]) + scatter_data = np.array([4, 6, 3, 8, 2, 7, 5, 6]) + line_data = np.array([10, 8, 12, 14, 9, 11, 13, 10]) + + # Create a figure and a set of subplots + fig, ax1 = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax1, theme) + + # Get colors from the color_palettes dictionary or use default if not found + bg_color = color_palettes.get(background_color, "#007bc2") + ln_color = color_palettes.get(line_color, "#FF0000") + + # Create the background plot based on the selected type + if input_background_type == "Bar Plot": + ax1.bar(x, bar_data, color=bg_color, label="Bar Data", alpha=0.7) + ax1.set_ylabel("Bar Values", color=bg_color) + y_min, y_max = 0, max(bar_data) * 1.2 + + elif input_background_type == "Histogram": + # For histogram, we need to adjust the scale to fit with line plot + bins = np.linspace(min(hist_data), max(hist_data), 20) + ax1.hist(hist_data, bins=bins, color=bg_color, label="Histogram Data", alpha=0.7) + ax1.set_ylabel("Frequency", color=bg_color) + y_min, y_max = 0, ax1.get_ylim()[1] * 1.2 + + elif input_background_type == "Scatter Plot": + ax1.scatter(x, scatter_data, color=bg_color, label="Scatter Data", alpha=0.7, s=100) + ax1.set_ylabel("Y Values", color=bg_color) + y_min, y_max = min(scatter_data) * 0.8, max(scatter_data) * 1.2 + + ax1.tick_params(axis="y", labelcolor=bg_color) + ax1.set_xlabel("X Values") + + # Set y-axis limits for the background plot + ax1.set_ylim(y_min, y_max) + + # Create a second y-axis sharing the same x-axis + ax2 = ax1.twinx() + + # Create the line chart on the second y-axis + ax2.plot(x, line_data, color=ln_color, marker="o", linestyle="-", linewidth=2, label="Line Data") + ax2.set_ylabel("Line Values", color=ln_color) + ax2.tick_params(axis="y", labelcolor=ln_color) + + # Add title + ax1.set_title(f"Multilayer Plot: {input_background_type} with Line Plot") + + # Add legends for both axes + lines1, labels1 = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper left") + + # Adjust layout + fig.tight_layout() + + return ax1 + +def create_custom_multilayer_plot(df, var_x, var_background, var_line, background_type, background_color, line_color, theme): + """ + Create a custom multilayer plot from user data. + + Parameters + ---------- + df : pandas.DataFrame + The dataframe containing the data to plot + var_x : str + The name of the column to use for the x-axis + var_background : str + The name of the column to use for the background plot + var_line : str + The name of the column to use for the line chart + background_type : str + The type of background plot ('Bar Plot', 'Histogram', or 'Scatter Plot') + background_color : str + The color to use for the background plot + line_color : str + The color to use for the line plot + theme : str + The theme to apply to the plot ('Light' or 'Dark') + + Returns + ------- + plt.Axes + The axes object of the created plot. + """ + if not var_x or not var_background or not var_line or df is None: + return None + + # Create a figure and a set of subplots + fig, ax1 = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax1, theme) + + # Get data from dataframe + x = df[var_x].values + background_data = df[var_background].values + line_data = df[var_line].values + + # Get colors from the color_palettes dictionary or use default if not found + bg_color = color_palettes.get(background_color, "#007bc2") + ln_color = color_palettes.get(line_color, "#FF0000") + + # Create the background plot based on the selected type + if background_type == "Bar Plot": + # For categorical x, we may need to handle the x-axis differently + if df[var_x].dtype == 'object': + x_positions = np.arange(len(x)) + ax1.bar(x_positions, background_data, color=bg_color, label=var_background, alpha=0.7) + ax1.set_xticks(x_positions) + ax1.set_xticklabels(x) + else: + ax1.bar(x, background_data, color=bg_color, label=var_background, alpha=0.7) + + ax1.set_ylabel(var_background.replace("_", " ").title(), color=bg_color) + y_min, y_max = 0, max(background_data) * 1.2 + + elif background_type == "Histogram": + # For histogram, we just use the background data column + bins = np.linspace(min(background_data), max(background_data), 20) + ax1.hist(background_data, bins=bins, color=bg_color, label=var_background, alpha=0.7) + ax1.set_ylabel("Frequency", color=bg_color) + y_min, y_max = 0, ax1.get_ylim()[1] * 1.2 + + elif background_type == "Scatter Plot": + # For categorical x, we may need to handle the x-axis differently + if df[var_x].dtype == 'object': + x_positions = np.arange(len(x)) + ax1.scatter(x_positions, background_data, color=bg_color, label=var_background, alpha=0.7, s=100) + ax1.set_xticks(x_positions) + ax1.set_xticklabels(x) + else: + ax1.scatter(x, background_data, color=bg_color, label=var_background, alpha=0.7, s=100) + + ax1.set_ylabel(var_background.replace("_", " ").title(), color=bg_color) + y_min, y_max = min(background_data) * 0.8, max(background_data) * 1.2 + + ax1.tick_params(axis="y", labelcolor=bg_color) + ax1.set_xlabel(var_x.replace("_", " ").title()) + + # Set y-axis limits for the background plot + ax1.set_ylim(y_min, y_max) + + # Create a second y-axis sharing the same x-axis + ax2 = ax1.twinx() + + # Create the line chart on the second y-axis + # For categorical x, we may need to handle the x-axis differently + if df[var_x].dtype == 'object' and background_type != "Histogram": + x_positions = np.arange(len(x)) + ax2.plot(x_positions, line_data, color=ln_color, marker="o", linestyle="-", linewidth=2, label=var_line) + else: + ax2.plot(x, line_data, color=ln_color, marker="o", linestyle="-", linewidth=2, label=var_line) + + ax2.set_ylabel(var_line.replace("_", " ").title(), color=ln_color) + ax2.tick_params(axis="y", labelcolor=ln_color) + + # Add title + ax1.set_title(f"{var_background} ({background_type}) and {var_line} vs {var_x}") + + # Add legends for both axes + lines1, labels1 = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper left") + + # Adjust layout + fig.tight_layout() + + return ax1 \ No newline at end of file diff --git a/plots/multilineplot.py b/plots/multilineplot.py new file mode 100644 index 0000000..6ce73aa --- /dev/null +++ b/plots/multilineplot.py @@ -0,0 +1,121 @@ +# filepath: c:\Users\kamat\OneDrive\Desktop\Work\UIUC\Research\maidr\maidr_streamlit\plots\multilineplot.py +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +import pandas as pd +from plots.utils import set_theme + +def generate_multiline_data(multiline_type): + """Generate data for multiline plots based on the selected type""" + x = np.linspace(0, 10, 30) # 30 points for x-axis + series_names = ["Series 1", "Series 2", "Series 3"] + + if multiline_type == "Simple Trends": + # Linear trends with different slopes + y1 = 1.5 * x + np.random.normal(0, 1, 30) + y2 = 0.5 * x + 5 + np.random.normal(0, 1, 30) + y3 = -x + 15 + np.random.normal(0, 1, 30) + elif multiline_type == "Seasonal Patterns": + # Sinusoidal patterns with different phases + y1 = 5 * np.sin(x) + 10 + np.random.normal(0, 0.5, 30) + y2 = 5 * np.sin(x + np.pi/2) + 10 + np.random.normal(0, 0.5, 30) + y3 = 5 * np.sin(x + np.pi) + 10 + np.random.normal(0, 0.5, 30) + elif multiline_type == "Growth Comparison": + # Different growth patterns + y1 = np.exp(0.2 * x) + np.random.normal(0, 0.5, 30) + y2 = x**2 / 10 + np.random.normal(0, 1, 30) + y3 = np.log(x + 1) * 5 + np.random.normal(0, 0.5, 30) + else: # Random Series + # Random walks with different volatilities + y1 = np.cumsum(np.random.normal(0, 0.5, 30)) + y2 = np.cumsum(np.random.normal(0.1, 0.7, 30)) + y3 = np.cumsum(np.random.normal(-0.05, 0.9, 30)) + + # Create a dataframe with the generated data + return pd.DataFrame({ + "x": np.tile(x, 3), + "y": np.concatenate([y1, y2, y3]), + "series": np.repeat(series_names, len(x)) + }) + +def create_multiline_plot(data, input_multiline_type, input_multiline_color, theme): + """Create a multiline plot based on input parameters and data""" + multiline_type = input_multiline_type + palette = input_multiline_color + + # Map friendly palette names to seaborn palette names + palette_mapping = { + "Default": None, # Use default seaborn palette + "Colorful": "Set1", + "Pastel": "Set2", + "Dark Tones": "Dark2", + "Paired Colors": "Paired", + "Rainbow": "Spectral" + } + + # Create the plot + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + + # Use seaborn lineplot for multiple lines + if palette == "Default": + # Use default seaborn color palette + sns.lineplot( + x="x", y="y", hue="series", style="series", + markers=True, dashes=False, data=data, ax=ax + ) + else: + # Use selected color palette + sns.lineplot( + x="x", y="y", hue="series", style="series", + markers=True, dashes=False, data=data, ax=ax, + palette=palette_mapping[palette] + ) + + # Customize the plot + ax.set_title(f"Multiline Plot: {multiline_type}") + ax.set_xlabel("X values") + ax.set_ylabel("Y values") + + return ax + +def create_custom_multiline_plot(df, var_x, var_y, var_group, palette, theme): + """Create a multiline plot from user data""" + if not var_x or not var_y or not var_group or df is None: + return None + + # Map friendly palette names to seaborn palette names + palette_mapping = { + "Default": None, # Use default seaborn palette + "Colorful": "Set1", + "Pastel": "Set2", + "Dark Tones": "Dark2", + "Paired Colors": "Paired", + "Rainbow": "Spectral" + } + + # Create the plot + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + + # Use seaborn lineplot for multiple lines + if palette == "Default": + # Use default seaborn color palette + sns.lineplot( + x=var_x, y=var_y, hue=var_group, style=var_group, + markers=True, dashes=False, data=df, ax=ax + ) + else: + # Use selected color palette + sns.lineplot( + x=var_x, y=var_y, hue=var_group, style=var_group, + markers=True, dashes=False, data=df, ax=ax, + palette=palette_mapping[palette] + ) + + # Customize the plot + ax.set_title(f"{var_y} vs {var_x} by {var_group}") + ax.set_xlabel(var_x.replace("_", " ").title()) + ax.set_ylabel(var_y.replace("_", " ").title()) + + return ax \ No newline at end of file diff --git a/plots/multipanelplot.py b/plots/multipanelplot.py new file mode 100644 index 0000000..a01c595 --- /dev/null +++ b/plots/multipanelplot.py @@ -0,0 +1,210 @@ +# filepath: c:\Users\kamat\OneDrive\Desktop\Work\UIUC\Research\maidr\maidr_streamlit\plots\multipanelplot.py +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from plots.utils import set_theme + +def create_multipanel_plot(layout_type, color_palette, theme): + """ + Create a multipanel plot with different subplot types arranged in a specified layout. + + Parameters + ---------- + layout_type : str + The type of layout ('Grid 2x2', 'Row', 'Column', 'Mixed') + color_palette : str + The color palette to use for the plots + theme : str + The theme to apply to the plot ('Light' or 'Dark') + + Returns + ------- + plt.Figure + The figure containing the created plot. + """ + # Generate sample data + # Data for line plot + x_line = np.array([1, 2, 3, 4, 5, 6, 7, 8]) + y_line = np.array([2, 4, 1, 5, 3, 7, 6, 8]) + + # Data for first bar plot + categories = ["A", "B", "C", "D", "E"] + values = np.random.rand(5) * 10 + + # Data for scatter plot + x_scatter = np.random.randn(50) + y_scatter = np.random.randn(50) + + # Create a figure with subplots arranged according to the layout + if layout_type == "Grid 2x2": + fig, axs = plt.subplots(2, 2, figsize=(10, 8)) + axs = axs.flatten() # Flatten to make it easier to index + else: # Default to a vertical layout with 3 plots + fig, axs = plt.subplots(3, 1, figsize=(10, 12)) + + # Apply theme to all subplots + for ax in axs: + set_theme(fig, ax, theme) + + # First panel: Line plot + axs[0].plot(x_line, y_line, color="blue", linewidth=2) + axs[0].set_title("Line Plot: Random Data") + axs[0].set_xlabel("X-axis") + axs[0].set_ylabel("Values") + axs[0].grid(True, linestyle="--", alpha=0.7) + + # Second panel: Bar plot + axs[1].bar(categories, values, color="green", alpha=0.7) + axs[1].set_title("Bar Plot: Random Values") + axs[1].set_xlabel("Categories") + axs[1].set_ylabel("Values") + + # Third panel: Scatter plot + if len(axs) > 2: # Check if we have a third subplot (for Grid 2x2 and Column layouts) + axs[2].scatter(x_scatter, y_scatter, color="red", alpha=0.7) + axs[2].set_title("Scatter Plot: Random Points") + axs[2].set_xlabel("X-axis") + axs[2].set_ylabel("Y-axis") + + # Fourth panel (if Grid 2x2): Histogram + if len(axs) > 3: # Only for Grid 2x2 layout + data = np.random.normal(0, 1, 1000) + axs[3].hist(data, bins=20, color="purple", alpha=0.7) + axs[3].set_title("Histogram: Normal Distribution") + axs[3].set_xlabel("Values") + axs[3].set_ylabel("Frequency") + + # Adjust layout to prevent overlap + plt.tight_layout() + + return axs[0] + +def create_custom_multipanel_plot(df, vars_config, layout_type, color_palette, theme): + """ + Create a custom multipanel plot from user data. + + Parameters + ---------- + df : pandas.DataFrame + The dataframe containing the data to plot + vars_config : dict + Dictionary containing variables for each subplot + Example: { + 'plot1': {'type': 'line', 'x': 'col1', 'y': 'col2'}, + 'plot2': {'type': 'bar', 'x': 'col3', 'y': 'col4'}, + 'plot3': {'type': 'scatter', 'x': 'col5', 'y': 'col6'}, + 'plot4': {'type': 'hist', 'x': 'col7', 'y': None}, + } + layout_type : str + The type of layout ('Grid 2x2', 'Row', 'Column', 'Mixed') + color_palette : str + The color palette to use for the plots + theme : str + The theme to apply to the plot ('Light' or 'Dark') + + Returns + ------- + plt.Figure + The figure containing the created plot. + """ + if df is None or not vars_config: + return None + + # Extract configuration for each plot + plot1_config = vars_config.get('plot1', {}) + plot2_config = vars_config.get('plot2', {}) + plot3_config = vars_config.get('plot3', {}) + plot4_config = vars_config.get('plot4', {}) + + # Map palette names to colors + palette_mapping = { + "Default": ["blue", "green", "red", "purple"], + "Colorful": ["#FF5733", "#33FF57", "#3357FF", "#F033FF"], + "Pastel": ["#FFB6C1", "#B6FFB6", "#B6C1FF", "#FFB6FF"], + "Dark Tones": ["#8B0000", "#006400", "#00008B", "#8B008B"], + "Paired Colors": ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728"], + "Rainbow": ["#FF0000", "#00FF00", "#0000FF", "#FF00FF"] + } + + colors = palette_mapping.get(color_palette, palette_mapping["Default"]) + + # Create figure with appropriate layout + if layout_type == "Grid 2x2": + fig, axs = plt.subplots(2, 2, figsize=(10, 8)) + axs = axs.flatten() # Flatten to make it easier to index + else: # Default to a vertical layout with 3 plots + fig, axs = plt.subplots(3, 1, figsize=(10, 12)) + if len(axs) < 4: # If we have fewer than 4 axes but 4 plot configs + axs = list(axs) + [None] # Add None to handle the 4th plot gracefully + + # Apply theme to all subplots + for ax in axs: + if ax is not None: + set_theme(fig, ax, theme) + + # Function to create plot based on type + def create_plot(ax, plot_config, color_idx): + if ax is None or not plot_config: + return + + plot_type = plot_config.get('type', 'line') + x_var = plot_config.get('x', None) + y_var = plot_config.get('y', None) + + if not x_var or (plot_type != 'hist' and not y_var): + return + + color = colors[color_idx % len(colors)] + + if plot_type == 'line': + sns.lineplot(data=df, x=x_var, y=y_var, ax=ax, color=color) + ax.set_title(f"Line Plot: {y_var} vs {x_var}") + elif plot_type == 'bar': + if df[x_var].dtype == 'object' or df[x_var].nunique() < 15: + # For categorical x, use count or mean + if y_var: + value_counts = df.groupby(x_var)[y_var].mean() + value_counts.plot(kind='bar', ax=ax, color=color, alpha=0.7) + ax.set_title(f"Bar Plot: Mean {y_var} by {x_var}") + else: + value_counts = df[x_var].value_counts() + value_counts.plot(kind='bar', ax=ax, color=color, alpha=0.7) + ax.set_title(f"Bar Plot: Counts of {x_var}") + else: + # For numeric x with many values, create bins + ax.bar(df[x_var], df[y_var], color=color, alpha=0.7) + ax.set_title(f"Bar Plot: {y_var} by {x_var}") + elif plot_type == 'scatter': + sns.scatterplot(data=df, x=x_var, y=y_var, ax=ax, color=color, alpha=0.7) + ax.set_title(f"Scatter Plot: {y_var} vs {x_var}") + elif plot_type == 'hist': + sns.histplot(data=df, x=x_var, ax=ax, color=color, alpha=0.7, kde=True) + ax.set_title(f"Histogram: {x_var}") + elif plot_type == 'multiline': + group_var = plot_config.get('group', None) + if group_var and group_var in df.columns: + sns.lineplot(data=df, x=x_var, y=y_var, hue=group_var, style=group_var, + markers=True, dashes=False, ax=ax) + ax.set_title(f"Multiline Plot: {y_var} vs {x_var} by {group_var}") + else: + sns.lineplot(data=df, x=x_var, y=y_var, ax=ax, color=color) + ax.set_title(f"Line Plot: {y_var} vs {x_var}") + + # Common labels + ax.set_xlabel(x_var.replace("_", " ").title()) + if y_var and plot_type != 'hist': + ax.set_ylabel(y_var.replace("_", " ").title()) + + # Create each plot + create_plot(axs[0], plot1_config, 0) + create_plot(axs[1], plot2_config, 1) + if len(axs) > 2 and axs[2] is not None: + create_plot(axs[2], plot3_config, 2) + if len(axs) > 3 and axs[3] is not None: + create_plot(axs[3], plot4_config, 3) + + # Adjust layout to prevent overlap + plt.tight_layout() + + return axs[0] \ No newline at end of file diff --git a/plots/scatterplot.py b/plots/scatterplot.py new file mode 100644 index 0000000..e0bef4e --- /dev/null +++ b/plots/scatterplot.py @@ -0,0 +1,49 @@ +# filepath: c:\Users\kamat\OneDrive\Desktop\Work\UIUC\Research\maidr\maidr_streamlit\plots\scatterplot.py +import matplotlib.pyplot as plt +import numpy as np +import seaborn as sns +from plots.utils import set_theme, color_palettes + +def create_scatterplot(input_scatterplot_type, input_scatter_color, theme): + """Create a scatter plot based on input parameters""" + scatterplot_type = input_scatterplot_type + color = color_palettes[input_scatter_color] + + num_points = np.random.randint(20, 31) # Randomly select between 20 and 30 points + x = np.random.uniform(size=num_points) + if scatterplot_type == "No Correlation": + y = np.random.uniform(size=num_points) + elif scatterplot_type == "Weak Positive Correlation": + y = 0.3 * x + np.random.uniform(size=num_points) + elif scatterplot_type == "Strong Positive Correlation": + y = 0.9 * x + np.random.uniform(size=num_points) * 0.1 + elif scatterplot_type == "Weak Negative Correlation": + y = -0.3 * x + np.random.uniform(size=num_points) + elif scatterplot_type == "Strong Negative Correlation": + y = -0.9 * x + np.random.uniform(size=num_points) * 0.1 + else: + y = np.random.uniform(size=num_points) + + # Create the plot using matplotlib + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + sns.scatterplot(x=x, y=y, ax=ax, color=color) + ax.set_title(f"{scatterplot_type}") + ax.set_xlabel("X") + ax.set_ylabel("Y") + + return ax + +def create_custom_scatterplot(df, var_x, var_y, color, theme): + """Create a scatter plot from user data""" + if not var_x or not var_y or df is None: + return None + + fig, ax = plt.subplots(figsize=(10, 6)) + set_theme(fig, ax, theme) + sns.scatterplot(data=df, x=var_x, y=var_y, color=color, ax=ax) + ax.set_title(f"{var_y} vs {var_x}") + ax.set_xlabel(var_x.replace("_", " ").title()) + ax.set_ylabel(var_y.replace("_", " ").title()) + + return ax \ No newline at end of file diff --git a/plots/utils.py b/plots/utils.py new file mode 100644 index 0000000..d71b3dd --- /dev/null +++ b/plots/utils.py @@ -0,0 +1,23 @@ +# filepath: c:\Users\kamat\OneDrive\Desktop\Work\UIUC\Research\maidr\maidr_streamlit\plots\utils.py +import matplotlib.pyplot as plt + +def set_theme(fig, ax, theme="Light"): + """Apply the appropriate theme to a plot""" + if theme == "Dark": + plt.style.use('dark_background') + fig.patch.set_facecolor('#2E2E2E') + ax.set_facecolor('#2E2E2E') + else: + plt.style.use('default') + fig.patch.set_facecolor('white') + ax.set_facecolor('white') + +# Dictionary of color palettes +color_palettes = { + "Default": "#007bc2", + "Red": "#FF0000", + "Green": "#00FF00", + "Blue": "#0000FF", + "Purple": "#800080", + "Orange": "#FFA500" +} \ No newline at end of file