diff --git a/protzilla/data_analysis/classification.py b/protzilla/data_analysis/classification.py index 15733193..fdca2fe4 100644 --- a/protzilla/data_analysis/classification.py +++ b/protzilla/data_analysis/classification.py @@ -26,12 +26,15 @@ def perform_classification( clf_parameters, scoring, model_selection_scoring="accuracy", - test_validate_split=None, - **parameters, + train_validate_split=None, + n_splits: int = 5, + n_repeats: int = 10, + random_state_cv: int = 42, + p_samples = None, ): if validation_strategy == "Manual" and grid_search_method == "Manual": X_train, X_val, y_train, y_val = perform_train_test_split( - input_df, labels_df, test_size=test_validate_split + input_df, labels_df, test_size=train_validate_split ) model = clf.set_params(**clf_parameters) model.fit(X_train, y_train) @@ -54,7 +57,7 @@ def perform_classification( return "Please select a cross validation strategy" elif validation_strategy != "Manual" and grid_search_method == "Manual": model = clf.set_params(**clf_parameters) - cv = perform_cross_validation(validation_strategy, **parameters) + cv = perform_cross_validation(validation_strategy, n_splits,n_repeats,random_state_cv=random_state_cv, p_samples=p_samples) scores = cross_validate( model, input_df, labels_df, scoring=scoring, cv=cv, return_train_score=True ) @@ -66,7 +69,7 @@ def perform_classification( return model, model_evaluation_df elif validation_strategy != "Manual" and grid_search_method != "Manual": clf_parameters = create_dict_with_lists_as_values(clf_parameters) - cv = perform_cross_validation(validation_strategy, **parameters) + cv = perform_cross_validation(validation_strategy, n_splits, n_repeats, random_state_cv=random_state_cv, p_samples=p_samples) model = perform_grid_search_cv( grid_search_method, clf, @@ -83,21 +86,34 @@ def perform_classification( ) return model.best_estimator_, model_evaluation_df - def random_forest( input_df: pd.DataFrame, metadata_df: pd.DataFrame, labels_column: str, positive_label: str = None, - n_estimators=100, - criterion="gini", - max_depth=None, - bootstrap=True, - random_state=42, + n_estimators: int = 100, + criterion: str = "gini", + max_depth: int = None, + bootstrap: bool = True, + + #test_split_parameters + test_size: float = 0.2, + split_stratify: str = "yes", + shuffle: bool = True, + random_state: int = 42, + + #classification_parameters model_selection: str = "Grid search", - validation_strategy: str = "Cross Validation", scoring: list[str] = ["accuracy"], - **kwargs, + model_selection_scoring: str = "accuracy", + train_val_split: float = 0.25, + validation_strategy: str = "Cross Validation", + + #cross_validation_parameters + n_splits: int = 5, + n_repeats: int = 10, + random_state_cv: int = 42, + p_samples = None, ): """ Perform classification using a random forest classifier from sklearn. @@ -109,9 +125,8 @@ def random_forest( :param labels_column: The column name in the `metadata_df` dataframe that contains the target variable (labels) for classification. :type labels_column: str - :param train_test_split: The proportion of data to be used for testing. Default is - 0.2 (80-20 train-test split). - :type train_test_split: int, optional + :param positive_label: The label that should be considered as the positive class. + :type positive_label: str, optional :param n_estimators: The number of decision trees to be used in the random forest. :type n_estimators: int, optional :param criterion: The impurity measure used for tree construction. @@ -121,16 +136,35 @@ def random_forest( :type max_depth: int or None, optional :param bootstrap: Whether bootstrap samples should be used when building trees. :type bootstrap: bool, optional + :param test_size: The proportion of data to be used for testing. Default is + 0.2 (80-20 train-test split). + :type test_size: float, optional + :param split_stratify: If not None, data is split in a stratified fashion, using this as + the class labels. + :type split_stratify: str, optional + :param shuffle: Whether to shuffle the data before splitting. + :type shuffle: bool, optional :param random_state: The random seed for reproducibility. - :type random_state: int + :type random_state: int, optional :param model_selection: The model selection method for hyperparameter tuning. :type model_selection: str - :param validation_strategy: The strategy for model validation. - :type validation_strategy: str :param scoring: The scoring metric(s) used to evaluate the model's performance during validation. :type scoring: list[str] - :param **kwargs: Additional keyword arguments to be passed to the function. + :param model_selection_scoring: The scoring metric used to select the best model. + :type model_selection_scoring: str, optional + :param train_val_split: The proportion of data to be used for validation from the train part of the train-test-split. Default is 0.25. + :type train_val_split: float, optional + :param validation_strategy: The strategy for model validation. + :type validation_strategy: str + :param n_splits: The number of folds in a KFold. + :type n_splits: int, optional + :param n_repeats: The number of times cross-validator needs to be repeated. + :type n_repeats: int, optional + :param random_state_cv: The random seed for reproducibility. + :type random_state_cv: int, optional + :param p_samples: The number of samples to be used in the cross-validation. + :type p_samples: float, optional :return: A RandomForestClassifier instance, a dataframe consisting of the model's training parameters and the validation score, along with four dataframes containing the respective test and training samples and labels. @@ -155,7 +189,9 @@ def random_forest( X_train, X_test, y_train, y_test = perform_train_test_split( input_df_wide, labels_df["Encoded Label"], - **kwargs, + test_size, + shuffle=shuffle, + split_stratify=split_stratify, ) clf = RandomForestClassifier() @@ -179,7 +215,12 @@ def random_forest( clf, clf_parameters, scoring, - **kwargs, + model_selection_scoring, + train_val_split, + n_splits, + n_repeats, + random_state_cv, + p_samples, ) X_test.reset_index(inplace=True) @@ -206,14 +247,28 @@ def svm( gamma="scale", # only relevant ‘rbf’, ‘poly’ and ‘sigmoid’. coef0=0.0, # relevant for "poly" and "sigmoid" probability=True, - tol=0.001, + tolerance=0.001, class_weight=None, max_iter=-1, random_state=42, + + #test_split_parameters + test_size: float = 0.2, + split_stratify: str = "yes", + shuffle: bool = True, + + #classification_parameters model_selection: str = "Grid search", - validation_strategy: str = "Cross Validation", scoring: list[str] = ["accuracy"], - **kwargs, + model_selection_scoring = "accuracy", + train_val_split: float | None = None, + validation_strategy: str = "Cross Validation", + + #cross_validation_parameters + n_splits: int = 5, + n_repeats: int = 10, + random_state_cv: int = 42, + p_samples = None, ): """ Perform classification using the support vector machine classifier from sklearn. @@ -225,6 +280,8 @@ def svm( :param labels_column: The column name in the `metadata_df` dataframe that contains the target variable (labels) for classification. :type labels_column: str + :param positive_label: The label that should be considered as the positive class. + :type positive_label: str, optional :param C: Regularization parameter :type C: float :param kernel: Specifies the kernel type. @@ -245,14 +302,34 @@ def svm( :type max_iter: int :param random_state: The random seed for reproducibility. :type random_state: int + :param test_size: The proportion of data to be used for testing. Default is + 0.2 (80-20 train-test split). + :type test_size: float, optional + :param split_stratify: If not None, data is split in a stratified fashion, using this as + the class labels. + :type split_stratify: str, optional + :param shuffle: Whether to shuffle the data before splitting. + :type shuffle: bool, optional + :param model_selection: The model selection method for hyperparameter tuning. :type model_selection: str - :param validation_strategy: The strategy for model validation. - :type validation_strategy: str :param scoring: The scoring metric(s) used to evaluate the model's performance during validation. :type scoring: list[str] - :param **kwargs: Additional keyword arguments to be passed to the function. + :param model_selection_scoring: The scoring metric used to select the best model. + :type model_selection_scoring: str, optional + :param train_val_split: The proportion of data to be used for validation from the train part of the train-test-split. Default is 0.25. + :type train_val_split: float, optional + :param validation_strategy: The strategy for model validation. + :type validation_strategy: str + :param n_splits: The number of folds in a KFold. + :type n_splits: int, optional + :param n_repeats: The number of times cross-validator needs to be repeated. + :type n_repeats: int, optional + :param random_state_cv: The random seed for reproducibility. + :type random_state_cv: int, optional + :param p_samples: The number of samples to be used in the cross-validation. + :type p_samples: float, optional :return: A dict containing: a SVC instance, a dataframe consisting of the model's training parameters and the validation score, along with four dataframes containing the respective test and training samples and labels. @@ -276,7 +353,9 @@ def svm( X_train, X_test, y_train, y_test = perform_train_test_split( input_df_wide, labels_df["Encoded Label"], - **kwargs, + test_size, + shuffle=shuffle, + split_stratify=split_stratify ) clf = SVC() @@ -287,7 +366,7 @@ def svm( gamma=gamma, coef0=coef0, probability=probability, - tol=tol, + tol=tolerance, class_weight=class_weight, max_iter=max_iter, random_state=random_state, @@ -303,7 +382,12 @@ def svm( clf, clf_parameters, scoring, - **kwargs, + model_selection_scoring, + train_val_split, + n_splits, + n_repeats, + random_state_cv, + p_samples, ) X_test.reset_index(inplace=True) diff --git a/protzilla/data_analysis/classification_helper.py b/protzilla/data_analysis/classification_helper.py index 83e56689..bffc1bc1 100644 --- a/protzilla/data_analysis/classification_helper.py +++ b/protzilla/data_analysis/classification_helper.py @@ -95,7 +95,6 @@ def perform_cross_validation( shuffle="yes", random_state_cv=42, p_samples=None, - **parameters, ): shuffle = shuffle == "yes" random_state_cv = None if not shuffle else random_state_cv @@ -213,7 +212,6 @@ def perform_train_test_split( random_state=42, shuffle=True, split_stratify="yes", - **kwargs, ): # by default this contains already filtered samples from metadata, we need to remove those labels_df = labels_df[labels_df.index.isin(input_df.index)] diff --git a/protzilla/data_analysis/clustering.py b/protzilla/data_analysis/clustering.py index ea52aaa6..40addaa0 100644 --- a/protzilla/data_analysis/clustering.py +++ b/protzilla/data_analysis/clustering.py @@ -28,7 +28,6 @@ def k_means( n_init: int = 10, max_iter: int = 300, tolerance: float = 1e-4, - **kwargs, ): """ A method that uses k-means to partition a number of samples in k clusters. The @@ -106,8 +105,7 @@ def k_means( clf, clf_parameters, scoring, - labels_df=labels_df["Encoded Label"], - **kwargs, + labels_df=labels_df["Encoded Label"] ) # create dataframes for ouput dict @@ -159,7 +157,7 @@ def expectation_maximisation( init_params: str = "kmeans", max_iter: int = 100, random_state=42, - **kwargs, + model_selection_scoring=None, ): """ Performs expectation maximization clustering with a Gaussian Mixture Model, using @@ -236,7 +234,7 @@ def expectation_maximisation( clf_parameters, scoring, labels_df=labels_df["Encoded Label"], - **kwargs, + model_selection_scoring = model_selection_scoring, ) cluster_labels_df = pd.DataFrame( @@ -264,7 +262,7 @@ def hierarchical_agglomerative_clustering( n_clusters: int = 2, metric: str = "euclidean", linkage: str = "ward", - **kwargs, + model_selection_scoring=None, ): """ Performs Agglomerative Clustering by recursively merging a pair of clusters of @@ -327,7 +325,7 @@ def hierarchical_agglomerative_clustering( clf_parameters, scoring, labels_df=labels_df["Encoded Label"], - **kwargs, + model_selection_scoring = model_selection_scoring, ) cluster_labels_df = pd.DataFrame( @@ -348,7 +346,6 @@ def perform_clustering( scoring, labels_df=None, model_selection_scoring=None, - **parameters, ): if model_selection == "Manual": model = clf.set_params(**clf_parameters) diff --git a/protzilla/data_analysis/plots.py b/protzilla/data_analysis/plots.py index a01967a7..7d62e0cf 100644 --- a/protzilla/data_analysis/plots.py +++ b/protzilla/data_analysis/plots.py @@ -305,9 +305,7 @@ def prot_quant_plot( :param similarity_measure: method to compare the chosen proteingroup with all others. The two methods are "cosine similarity" and "euclidean distance". :param similarity: similarity score of the chosen similarity measurement method. - - - :return: returns a dictionary containing a list with a plotly figure and/or a list of messages + :return: returns a dictionary containing a list with a plotly figure """ wide_df = long_to_wide(input_df) if is_long_format(input_df) else input_df diff --git a/protzilla/data_integration/enrichment_analysis.py b/protzilla/data_integration/enrichment_analysis.py index 0f46b890..5fb3af94 100644 --- a/protzilla/data_integration/enrichment_analysis.py +++ b/protzilla/data_integration/enrichment_analysis.py @@ -456,7 +456,6 @@ def GO_analysis_with_Enrichr( background_path=None, background_number=None, background_biomart=None, - **kwargs, ): """ A method that performs online over-representation analysis for a given set of proteins @@ -680,7 +679,6 @@ def GO_analysis_offline( direction="both", background_path=None, background_number=None, - **kwargs, ): """ A method that performs offline over-representation analysis for a given set of proteins diff --git a/protzilla/data_integration/enrichment_analysis_gsea.py b/protzilla/data_integration/enrichment_analysis_gsea.py index 483a24ad..3147ca03 100644 --- a/protzilla/data_integration/enrichment_analysis_gsea.py +++ b/protzilla/data_integration/enrichment_analysis_gsea.py @@ -86,7 +86,6 @@ def gsea_preranked( weighted_score=1.0, seed=123, threads=4, - **kwargs, ): """ Ranks proteins by a provided value column according to ranking_direction and @@ -294,7 +293,6 @@ def gsea( weighted_score=1.0, seed=123, threads=4, - **kwargs, ): """ Performs Gene Set Enrichment Analysis (GSEA) on a dataframe with protein IDs, samples and intensities. diff --git a/protzilla/data_preprocessing/filter_proteins.py b/protzilla/data_preprocessing/filter_proteins.py index 5e0bb7b8..3ee6420a 100644 --- a/protzilla/data_preprocessing/filter_proteins.py +++ b/protzilla/data_preprocessing/filter_proteins.py @@ -1,6 +1,7 @@ import pandas as pd from protzilla.data_preprocessing.plots import create_bar_plot, create_pie_plot + from ..utilities.transform_dfs import long_to_wide @@ -30,9 +31,7 @@ def by_samples_missing( filtered_proteins_list = ( transformed_df.drop(remaining_proteins_list, axis=1).columns.unique().tolist() ) - filtered_df = protein_df[ - (protein_df["Protein ID"].isin(remaining_proteins_list)) - ] + filtered_df = protein_df[(protein_df["Protein ID"].isin(remaining_proteins_list))] filtered_peptide_df = None if peptide_df is not None: filtered_peptide_df = peptide_df[ @@ -46,12 +45,14 @@ def by_samples_missing( ) -def _build_pie_bar_plot(remaining_proteins, filtered_proteins, graph_type): +def by_samples_missing_plot( + output_remaining_proteins, output_filtered_proteins, graph_type +): if graph_type == "Pie chart": fig = create_pie_plot( values_of_sectors=[ - len(remaining_proteins), - len(filtered_proteins), + len(output_remaining_proteins), + len(output_filtered_proteins), ], names_of_sectors=["Proteins kept", "Proteins filtered"], heading="Number of Filtered Proteins", @@ -59,19 +60,11 @@ def _build_pie_bar_plot(remaining_proteins, filtered_proteins, graph_type): elif graph_type == "Bar chart": fig = create_bar_plot( values_of_sectors=[ - len(remaining_proteins), - len(filtered_proteins), + len(output_remaining_proteins), + len(output_filtered_proteins), ], names_of_sectors=["Proteins kept", "Proteins filtered"], heading="Number of Filtered Proteins", y_title="Number of Proteins", ) return [fig] - - -def by_samples_missing_plot(method_inputs, method_outputs, graph_type): - return _build_pie_bar_plot( - method_outputs["remaining_proteins"], - method_outputs["filtered_proteins"], - graph_type, - ) diff --git a/protzilla/data_preprocessing/filter_samples.py b/protzilla/data_preprocessing/filter_samples.py index 626aabb4..a8fc7e45 100644 --- a/protzilla/data_preprocessing/filter_samples.py +++ b/protzilla/data_preprocessing/filter_samples.py @@ -133,22 +133,18 @@ def by_proteins_missing( ) -def by_protein_intensity_sum_plot(method_inputs, method_outputs, graph_type): - return _build_pie_bar_plot( - method_outputs["protein_df"], method_outputs["filtered_samples"], graph_type - ) +def by_protein_intensity_sum_plot( + output_protein_df, output_filtered_samples, graph_type +): + return _build_pie_bar_plot(output_protein_df, output_filtered_samples, graph_type) -def by_proteins_missing_plot(method_inputs, method_outputs, graph_type): - return _build_pie_bar_plot( - method_outputs["protein_df"], method_outputs["filtered_samples"], graph_type - ) +def by_proteins_missing_plot(output_protein_df, output_filtered_samples, graph_type): + return _build_pie_bar_plot(output_protein_df, output_filtered_samples, graph_type) -def by_protein_count_plot(method_inputs, method_outputs, graph_type): - return _build_pie_bar_plot( - method_outputs["protein_df"], method_outputs["filtered_samples"], graph_type - ) +def by_protein_count_plot(output_protein_df, output_filtered_samples, graph_type): + return _build_pie_bar_plot(output_protein_df, output_filtered_samples, graph_type) def _build_pie_bar_plot(result_df, filtered_sampels, graph_type): diff --git a/protzilla/data_preprocessing/imputation.py b/protzilla/data_preprocessing/imputation.py index 48af773b..9bd5d2ea 100644 --- a/protzilla/data_preprocessing/imputation.py +++ b/protzilla/data_preprocessing/imputation.py @@ -58,7 +58,7 @@ def flag_invalid_values(df: pd.DataFrame, messages: list) -> dict: def by_knn( protein_df: pd.DataFrame, number_of_neighbours: int = 5, - **kwargs, # quantile, default is median + fit_params = {} ) -> dict: """ A function to perform value imputation based on KNN @@ -79,9 +79,6 @@ def by_knn( :param number_of_neighbours: number of neighbouring samples used for imputation. Default: 5 :type number_of_neighbours: int - :param **kwargs: additional keyword arguments passed to - KNNImputer.fit_transform - :type kwargs: dict :return: returns an imputed dataframe in typical protzilla long format and a list of messages :rtype: pd.DataFrame @@ -93,7 +90,7 @@ def by_knn( columns = transformed_df.columns imputer = KNNImputer(n_neighbors=number_of_neighbours) - transformed_df = imputer.fit_transform(transformed_df, **kwargs) + transformed_df = imputer.fit_transform(transformed_df, **fit_params) transformed_df = pd.DataFrame(transformed_df, columns=columns, index=index) # Turn the wide format into the long format @@ -353,16 +350,16 @@ def by_normal_distribution_sampling( def by_knn_plot( - method_inputs, - method_outputs, + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, visual_transformation, ): return _build_box_hist_plot( - method_inputs["protein_df"], - method_outputs["protein_df"], + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, @@ -371,16 +368,16 @@ def by_knn_plot( def by_normal_distribution_sampling_plot( - method_inputs, - method_outputs, + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, visual_transformation, ): return _build_box_hist_plot( - method_inputs["protein_df"], - method_outputs["protein_df"], + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, @@ -389,16 +386,16 @@ def by_normal_distribution_sampling_plot( def by_simple_imputer_plot( - method_inputs, - method_outputs, + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, visual_transformation, ): return _build_box_hist_plot( - method_inputs["protein_df"], - method_outputs["protein_df"], + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, @@ -407,16 +404,16 @@ def by_simple_imputer_plot( def by_min_per_sample_plot( - method_inputs, - method_outputs, + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, visual_transformation, ): return _build_box_hist_plot( - method_inputs["protein_df"], - method_outputs["protein_df"], + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, @@ -425,16 +422,16 @@ def by_min_per_sample_plot( def by_min_per_protein_plot( - method_inputs, - method_outputs, + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, visual_transformation, ): return _build_box_hist_plot( - method_inputs["protein_df"], - method_outputs["protein_df"], + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, @@ -443,16 +440,16 @@ def by_min_per_protein_plot( def by_min_per_dataset_plot( - method_inputs, - method_outputs, + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, visual_transformation, ): return _build_box_hist_plot( - method_inputs["protein_df"], - method_outputs["protein_df"], + protein_df, + output_protein_df, graph_type, graph_type_quantities, group_by, diff --git a/protzilla/data_preprocessing/normalisation.py b/protzilla/data_preprocessing/normalisation.py index e2be755b..0625e1e5 100644 --- a/protzilla/data_preprocessing/normalisation.py +++ b/protzilla/data_preprocessing/normalisation.py @@ -229,64 +229,34 @@ def by_reference_protein( def by_z_score_plot( - method_inputs, - method_outputs, - graph_type, - group_by, - visual_transformation + protein_df, output_protein_df, graph_type, group_by, visual_transformation ): return _build_box_hist_plot( - method_inputs["protein_df"], - method_outputs["protein_df"], - graph_type, - group_by, - visual_transformation + protein_df, output_protein_df, graph_type, group_by, visual_transformation ) def by_median_plot( - method_inputs, - method_outputs, - graph_type, - group_by, - visual_transformation + protein_df, output_protein_df, graph_type, group_by, visual_transformation ): return _build_box_hist_plot( - method_inputs["protein_df"], - method_outputs["protein_df"], - graph_type, group_by, - visual_transformation + protein_df, output_protein_df, graph_type, group_by, visual_transformation ) def by_totalsum_plot( - method_inputs, - method_outputs, - graph_type, - group_by, - visual_transformation + protein_df, output_protein_df, graph_type, group_by, visual_transformation ): return _build_box_hist_plot( - method_inputs["protein_df"], - method_outputs["protein_df"], - graph_type, group_by, - visual_transformation + protein_df, output_protein_df, graph_type, group_by, visual_transformation ) def by_reference_protein_plot( - method_inputs, - method_outputs, - graph_type, - group_by, - visual_transformation + protein_df, output_protein_df, graph_type, group_by, visual_transformation ): return _build_box_hist_plot( - method_inputs["protein_df"], - method_outputs["protein_df"], - graph_type, - group_by, - visual_transformation + protein_df, output_protein_df, graph_type, group_by, visual_transformation ) diff --git a/protzilla/data_preprocessing/outlier_detection.py b/protzilla/data_preprocessing/outlier_detection.py index 93220c69..8c834e5d 100644 --- a/protzilla/data_preprocessing/outlier_detection.py +++ b/protzilla/data_preprocessing/outlier_detection.py @@ -10,14 +10,15 @@ create_pca_2d_scatter_plot, create_pca_3d_scatter_plot, ) + from ..utilities.transform_dfs import long_to_wide def by_isolation_forest( - protein_df: pd.DataFrame, - peptide_df: pd.DataFrame | None, - n_estimators: int = 100, - n_jobs: int = -1, + protein_df: pd.DataFrame, + peptide_df: pd.DataFrame | None, + n_estimators: int = 100, + n_jobs: int = -1, ) -> dict: """ This function filters out outliers using a clustering @@ -62,8 +63,11 @@ def by_isolation_forest( ].index.tolist() protein_df = protein_df[~(protein_df["Sample"].isin(outlier_list))] - peptide_df = (None if peptide_df is None - else peptide_df[~(peptide_df["Sample"].isin(outlier_list))]) + peptide_df = ( + None + if peptide_df is None + else peptide_df[~(peptide_df["Sample"].isin(outlier_list))] + ) return dict( protein_df=protein_df, @@ -125,8 +129,11 @@ def by_local_outlier_factor( outlier_list = df_lof_data[df_lof_data["Outlier"]].index.tolist() protein_df = protein_df[~(protein_df["Sample"].isin(outlier_list))] - peptide_df = (None if peptide_df is None - else peptide_df[~(peptide_df["Sample"].isin(outlier_list))]) + peptide_df = ( + None + if peptide_df is None + else peptide_df[~(peptide_df["Sample"].isin(outlier_list))] + ) return dict( protein_df=protein_df, @@ -232,8 +239,11 @@ def by_pca( df_transformed_pca_data["Outlier"] ].index.tolist() protein_df = protein_df[~(protein_df["Sample"].isin(outlier_list))] - peptide_df = (None if peptide_df is None - else peptide_df[~(peptide_df["Sample"].isin(outlier_list))]) + peptide_df = ( + None + if peptide_df is None + else peptide_df[~(peptide_df["Sample"].isin(outlier_list))] + ) return dict( protein_df=protein_df, @@ -255,19 +265,22 @@ def by_pca( ) -def by_isolation_forest_plot(method_inputs, method_outputs): - return [create_anomaly_score_bar_plot(method_outputs["anomaly_df"])] +def by_isolation_forest_plot(output_anomaly_df): + return [create_anomaly_score_bar_plot(output_anomaly_df)] -def by_local_outlier_factor_plot(method_inputs, method_outputs): - return [create_anomaly_score_bar_plot(method_outputs["anomaly_df"])] +def by_local_outlier_factor_plot(output_anomaly_df): + return [create_anomaly_score_bar_plot(output_anomaly_df)] -def by_pca_plot(method_inputs, method_outputs): - pca_df = method_outputs["pca_df"] - number_of_components = method_outputs["number_of_components"] - explained_variance_ratio = method_outputs["explained_variance_ratio"] - if number_of_components == 2: - return [create_pca_2d_scatter_plot(pca_df, explained_variance_ratio)] - if number_of_components == 3: - return [create_pca_3d_scatter_plot(pca_df, explained_variance_ratio)] +def by_pca_plot( + output_pca_df, output_number_of_components, output_explained_variance_ratio +): + if output_number_of_components == 2: + return [ + create_pca_2d_scatter_plot(output_pca_df, output_explained_variance_ratio) + ] + if output_number_of_components == 3: + return [ + create_pca_3d_scatter_plot(output_pca_df, output_explained_variance_ratio) + ] diff --git a/protzilla/data_preprocessing/peptide_filter.py b/protzilla/data_preprocessing/peptide_filter.py index 3b1caee9..0781a789 100644 --- a/protzilla/data_preprocessing/peptide_filter.py +++ b/protzilla/data_preprocessing/peptide_filter.py @@ -3,9 +3,7 @@ from protzilla.data_preprocessing.plots import create_bar_plot, create_pie_plot -def by_pep_value( - peptide_df: pd.DataFrame, threshold: float -) -> dict: +def by_pep_value(peptide_df: pd.DataFrame, threshold: float) -> dict: """ This function filters out all peptides with a PEP value (assigned to all samples together for each peptide) below a certain threshold. @@ -35,11 +33,11 @@ def by_pep_value( ) -def by_pep_value_plot(method_inputs, method_outputs, graph_type): +def by_pep_value_plot(output_peptide_df, output_filtered_peptides, graph_type): value_dict = dict( values_of_sectors=[ - len(method_outputs["peptide_df"]), - len(method_outputs["filtered_peptides"]), + len(output_peptide_df), + len(output_filtered_peptides), ], names_of_sectors=["Samples kept", "Samples filtered"], heading="Number of Filtered Samples", diff --git a/protzilla/data_preprocessing/transformation.py b/protzilla/data_preprocessing/transformation.py index 221b01ab..fa3a955c 100644 --- a/protzilla/data_preprocessing/transformation.py +++ b/protzilla/data_preprocessing/transformation.py @@ -44,17 +44,11 @@ def by_log(protein_df: pd.DataFrame, peptide_df: pd.DataFrame | None, log_base=" return dict(protein_df=transformed_df, peptide_df=transformed_peptide_df) -def by_log_plot(method_inputs, method_outputs, graph_type, group_by): - return _build_box_hist_plot( - method_inputs["protein_df"], method_outputs["protein_df"], graph_type, group_by - ) - - -def _build_box_hist_plot(df, result_df, graph_type, group_by): +def by_log_plot(protein_df, output_protein_df, graph_type, group_by): if graph_type == "Boxplot": fig = create_box_plots( - dataframe_a=df, - dataframe_b=result_df, + dataframe_a=protein_df, + dataframe_b=output_protein_df, name_a="Before Transformation", name_b="After Transformation", heading="Distribution of Protein Intensities", @@ -63,8 +57,8 @@ def _build_box_hist_plot(df, result_df, graph_type, group_by): ) if graph_type == "Histogram": fig = create_histograms( - dataframe_a=df, - dataframe_b=result_df, + dataframe_a=protein_df, + dataframe_b=output_protein_df, name_a="Before Transformation", name_b="After Transformation", heading="Distribution of Protein Intensities", diff --git a/protzilla/disk_operator.py b/protzilla/disk_operator.py index e41d5bcb..224255c9 100644 --- a/protzilla/disk_operator.py +++ b/protzilla/disk_operator.py @@ -88,7 +88,6 @@ class KEYS: STEP_OUTPUTS = "output" STEP_FORM_INPUTS = "form_inputs" STEP_INPUTS = "inputs" - STEP_PLOT_INPUTS = "plot_inputs" STEP_MESSAGES = "messages" STEP_PLOTS = "plots" STEP_INSTANCE_IDENTIFIER = "instance_identifier" @@ -208,8 +207,6 @@ def _read_step(self, step_data: dict, steps: StepManager) -> Step: instance_identifier=step_data.get(KEYS.STEP_INSTANCE_IDENTIFIER), ) step.inputs = step_data.get(KEYS.STEP_INPUTS, {}) - if step.section == "data_preprocessing": - step.plot_inputs = step_data.get(KEYS.STEP_PLOT_INPUTS, {}) step.messages = Messages(step_data.get(KEYS.STEP_MESSAGES, [])) step.output = self._read_outputs(step_data.get(KEYS.STEP_OUTPUTS, {})) step.plots = self._read_plots(step_data.get(KEYS.STEP_PLOTS, [])) @@ -220,8 +217,6 @@ def _read_step(self, step_data: dict, steps: StepManager) -> Step: def _write_step(self, step: Step, workflow_mode: bool = False) -> dict: with ErrorHandler(): step_data = {} - if step.section == "data_preprocessing": - step_data[KEYS.STEP_PLOT_INPUTS] = sanitize_inputs(step.plot_inputs) step_data[KEYS.STEP_TYPE] = step.__class__.__name__ step_data[KEYS.STEP_INSTANCE_IDENTIFIER] = step.instance_identifier step_data[KEYS.STEP_FORM_INPUTS] = sanitize_inputs(step.form_inputs) diff --git a/protzilla/methods/data_analysis.py b/protzilla/methods/data_analysis.py index 511629f5..2d42e7b1 100644 --- a/protzilla/methods/data_analysis.py +++ b/protzilla/methods/data_analysis.py @@ -41,29 +41,11 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: return inputs -class PlotStep(DataAnalysisStep): - step = "plot" - - def handle_outputs(self, outputs: dict): - super().handle_outputs(outputs) - plots = self.output.output.pop("plots", []) - self.plots = Plots(plots) - - class DifferentialExpressionANOVA(DataAnalysisStep): display_name = "ANOVA" operation = "differential_expression" method_description = "A function that uses ANOVA to test the difference between two or more groups defined in the clinical data. The ANOVA test is conducted on the level of each protein. The p-values are corrected for multiple testing." - input_keys = [ - "intensity_df", - "multiple_testing_correction_method", - "alpha", - "log_base", - "grouping", - "selected_groups", - "metadata_df", - ] output_keys = [ "differentially_expressed_proteins_df", "significant_proteins_df", @@ -73,8 +55,7 @@ class DifferentialExpressionANOVA(DataAnalysisStep): "filtered_proteins", ] - def method(self, inputs: dict) -> dict: - return anova(**inputs) + calc_method = staticmethod(anova) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["log_base"] = steps.get_step_input(TransformationLog, "log_base") @@ -82,26 +63,12 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["metadata_df"] = steps.metadata_df return inputs - def plot(self, inputs): - raise NotImplementedError("Plotting is not implemented yet for this step.") - class DifferentialExpressionTTest(DataAnalysisStep): display_name = "t-Test" operation = "differential_expression" method_description = "A function to conduct a two sample t-test between groups defined in the clinical data. The t-test is conducted on the level of each protein. The p-values are corrected for multiple testing. The fold change is calculated by group2/group1." - input_keys = [ - "ttest_type", - "intensity_df", - "multiple_testing_correction_method", - "alpha", - "log_base", - "grouping", - "group1", - "group2", - "metadata_df", - ] output_keys = [ "differentially_expressed_proteins_df", "significant_proteins_df", @@ -111,8 +78,7 @@ class DifferentialExpressionTTest(DataAnalysisStep): "corrected_alpha", ] - def method(self, inputs: dict) -> dict: - return t_test(**inputs) + calc_method = staticmethod(t_test) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["log_base"] = steps.get_step_input(TransformationLog, "log_base") @@ -120,25 +86,12 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["metadata_df"] = steps.metadata_df return inputs - def plot(self, inputs): - raise NotImplementedError("Plotting is not implemented yet for this step.") - class DifferentialExpressionLinearModel(DataAnalysisStep): display_name = "Linear Model" operation = "differential_expression" method_description = "A function to fit a linear model using ordinary least squares for each protein. The linear model fits the protein intensities on Y axis and the grouping on X for group1 X=-1 and group2 X=1. The p-values are corrected for multiple testing." - input_keys = [ - "intensity_df", - "multiple_testing_correction_method", - "alpha", - "log_base", - "grouping", - "group1", - "group2", - "metadata_df", - ] output_keys = [ "differentially_expressed_proteins_df", "significant_proteins_df", @@ -148,8 +101,7 @@ class DifferentialExpressionLinearModel(DataAnalysisStep): "filtered_proteins", ] - def method(self, inputs: dict) -> dict: - return linear_model(**inputs) + calc_method = staticmethod(linear_model) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["log_base"] = steps.get_step_input(TransformationLog, "log_base") @@ -157,9 +109,6 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["metadata_df"] = steps.metadata_df return inputs - def plot(self, inputs): - raise NotImplementedError("Plotting is not implemented yet for this step.") - class DifferentialExpressionMannWhitneyOnIntensity(DataAnalysisStep): display_name = "Mann-Whitney Test" @@ -167,16 +116,6 @@ class DifferentialExpressionMannWhitneyOnIntensity(DataAnalysisStep): method_description = ("A function to conduct a Mann-Whitney U test between groups defined in the clinical data." "The p-values are corrected for multiple testing.") - input_keys = [ - "protein_df", - "metadata_df", - "grouping", - "group1", - "group2", - "alpha", - "multiple_testing_correction_method", - "p_value_calculation_method", - ] output_keys = [ "differentially_expressed_proteins_df", "significant_proteins_df", @@ -186,8 +125,7 @@ class DifferentialExpressionMannWhitneyOnIntensity(DataAnalysisStep): "corrected_alpha", ] - def method(self, inputs: dict) -> dict: - return mann_whitney_test_on_intensity_data(**inputs) + calc_method = staticmethod(mann_whitney_test_on_intensity_data) def insert_dataframes(self, steps: StepManager, inputs) -> dict: if steps.get_step_output(Step, "protein_df", inputs["protein_df"]) is not None: @@ -203,16 +141,6 @@ class DifferentialExpressionMannWhitneyOnPTM(DataAnalysisStep): method_description = ("A function to conduct a Mann-Whitney U test between groups defined in the clinical data." "The p-values are corrected for multiple testing.") - input_keys = [ - "ptm_df", - "metadata_df", - "grouping", - "group1", - "group2", - "alpha", - "multiple_testing_correction_method", - "p_value_calculation_method", - ] output_keys = [ "differentially_expressed_ptm_df", "significant_ptm_df", @@ -222,8 +150,7 @@ class DifferentialExpressionMannWhitneyOnPTM(DataAnalysisStep): "corrected_alpha", ] - def method(self, inputs: dict) -> dict: - return mann_whitney_test_on_ptm_data(**inputs) + calc_method = staticmethod(mann_whitney_test_on_ptm_data) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["ptm_df"] = steps.get_step_output(Step, "ptm_df", inputs["ptm_df"]) @@ -237,15 +164,6 @@ class DifferentialExpressionKruskalWallisOnIntensity(DataAnalysisStep): method_description = ("A function to conduct a Kruskal-Wallis test between groups defined in the clinical data." "The p-values are corrected for multiple testing.") - input_keys = [ - "protein_df", - "metadata_df", - "grouping", - "selected_groups", - "alpha", - "log_base", - "multiple_testing_correction_method", - ] output_keys = [ "differentially_expressed_proteins_df", "significant_proteins_df", @@ -253,8 +171,7 @@ class DifferentialExpressionKruskalWallisOnIntensity(DataAnalysisStep): "corrected_alpha", ] - def method(self, inputs: dict) -> dict: - return kruskal_wallis_test_on_intensity_data(**inputs) + calc_method = staticmethod(kruskal_wallis_test_on_intensity_data) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["ptm_df"] = steps.get_step_output(Step, "ptm_df", inputs["ptm_df"]) @@ -268,15 +185,6 @@ class DifferentialExpressionKruskalWallisOnIntensity(DataAnalysisStep): method_description = ("A function to conduct a Kruskal-Wallis test between groups defined in the clinical data." "The p-values are corrected for multiple testing.") - input_keys = [ - "protein_df", - "metadata_df", - "grouping", - "selected_groups", - "alpha", - "log_base", - "multiple_testing_correction_method", - ] output_keys = [ "differentially_expressed_proteins_df", "significant_proteins_df", @@ -284,8 +192,7 @@ class DifferentialExpressionKruskalWallisOnIntensity(DataAnalysisStep): "corrected_alpha", ] - def method(self, inputs: dict) -> dict: - return kruskal_wallis_test_on_intensity_data(**inputs) + calc_method = staticmethod(kruskal_wallis_test_on_intensity_data) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["protein_df"] = steps.get_step_output(Step, "protein_df", inputs["protein_df"]) @@ -300,14 +207,6 @@ class DifferentialExpressionKruskalWallisOnPTM(DataAnalysisStep): method_description = ("A function to conduct a Kruskal-Wallis test between groups defined in the clinical data." "The p-values are corrected for multiple testing.") - input_keys = [ - "ptm_df", - "metadata_df", - "grouping", - "selected_groups", - "alpha", - "multiple_testing_correction_method", - ] output_keys = [ "differentially_expressed_ptm_df", "significant_ptm_df", @@ -315,8 +214,7 @@ class DifferentialExpressionKruskalWallisOnPTM(DataAnalysisStep): "corrected_alpha", ] - def method(self, inputs: dict) -> dict: - return kruskal_wallis_test_on_ptm_data(**inputs) + calc_method = staticmethod(kruskal_wallis_test_on_ptm_data) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["ptm_df"] = steps.get_step_output(Step, "ptm_df", inputs["ptm_df"]) @@ -324,26 +222,16 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: return inputs -class PlotVolcano(PlotStep): +class PlotVolcano(DataAnalysisStep): display_name = "Volcano Plot" operation = "plot" method_description = ("Plots the results of a differential expression analysis in a volcano plot. The x-axis shows " "the log2 fold change and the y-axis shows the -log10 of the corrected p-values. The user " "can define a fold change threshold and an alpha level to highlight significant items.") - input_keys = [ - "p_values", - "fc_threshold", - "alpha", - "group1", - "group2", - "item_type", - "items_of_interest", - "log2_fc", - ] + output_keys = [] - def method(self, inputs: dict) -> dict: - return create_volcano_plot(**inputs) + plot_method = staticmethod(create_volcano_plot) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["p_values"] = steps.get_step_output( @@ -368,19 +256,12 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: return inputs -class PlotScatterPlot(PlotStep): +class PlotScatterPlot(DataAnalysisStep): display_name = "Scatter Plot" - operation = "plot" + operation = "plot" method_description = "Creates a scatter plot from data. This requires a dimension reduction method to be run first, as the input dataframe should contain only 2 or 3 columns." - input_keys = [ - "input_df", - "color_df", - ] - output_keys = [] - - def method(self, inputs: dict) -> dict: - return scatter_plot(**inputs) + plot_method = staticmethod(scatter_plot) # TODO: input def insert_dataframes(self, steps: StepManager, inputs) -> dict: @@ -391,20 +272,12 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: return inputs -class PlotClustergram(PlotStep): +class PlotClustergram(DataAnalysisStep): display_name = "Clustergram" operation = "plot" method_description = "Creates a clustergram from data" - input_keys = [ - "input_df", - "sample_group_df", - "flip_axes", - ] - output_keys = ["plots"] - - def method(self, inputs: dict) -> dict: - return clustergram_plot(**inputs) + plot_method = staticmethod(clustergram_plot) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["input_df"] = steps.protein_df @@ -412,18 +285,14 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: return inputs -class PlotProtQuant(PlotStep): +class PlotProtQuant(DataAnalysisStep): display_name = "Protein Quantification Plot" operation = "plot" method_description = ( "Creates a line chart for intensity across samples for protein groups" ) - input_keys = ["input_df", "protein_group", "similarity_measure", "similarity"] - output_keys = [] - - def method(self, inputs: dict) -> dict: - return prot_quant_plot(**inputs) + plot_method = staticmethod(prot_quant_plot) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["input_df"] = steps.get_step_output( @@ -432,40 +301,28 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: return inputs -class PlotPrecisionRecallCurve(PlotStep): +class PlotPrecisionRecallCurve(DataAnalysisStep): display_name = "Precision Recall" operation = "plot" method_description = "The precision-recall curve shows the tradeoff between precision and recall for different threshold" - input_keys = [ - # TODO: Input - "plot_title", - ] - # Todo: output_keys - def method(self, inputs: dict) -> dict: - return evaluate_classification_model(**inputs) + calc_method = staticmethod(evaluate_classification_model) def insert_dataframes(self, steps: StepManager, inputs) -> dict: # TODO: Input return inputs -class PlotROC(PlotStep): +class PlotROC(DataAnalysisStep): display_name = "Receiver Operating Characteristic curve" operation = "plot" method_description = "The ROC curve helps assess the model's ability to discriminate between positive and negative classes and determine an optimal threshold for decision making" - input_keys = [ - # TODO: Input - "plot_title", - ] - # Todo: output_keys - def method(self, inputs: dict) -> dict: - return evaluate_classification_model(**inputs) + calc_method = staticmethod(evaluate_classification_model) def insert_dataframes(self, steps: StepManager, inputs) -> dict: # Todo: Input @@ -477,20 +334,6 @@ class ClusteringKMeans(DataAnalysisStep): operation = "clustering" method_description = "Partitions a number of samples in k clusters using k-means" - input_keys = [ - "input_df", - "labels_column", - "positive_label", - "model_selection", - "model_selection_scoring", - "scoring", - "n_clusters", - "random_state", - "init_centroid_strategy", - "n_init", - "max_iter", - "tolerance" "metadata_df", - ] output_keys = [ "model", "model_evaluation_df", @@ -498,8 +341,7 @@ class ClusteringKMeans(DataAnalysisStep): "cluster_centers_df", ] - def method(self, inputs: dict) -> dict: - return k_means(**inputs) + calc_method = staticmethod(k_means) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["input_df"] = steps.protein_df @@ -512,21 +354,6 @@ class ClusteringExpectationMaximisation(DataAnalysisStep): operation = "clustering" method_description = "A clustering algorithm that seeks to find the maximum likelihood estimates for a mixture of multivariate Gaussian distributions" - input_keys = [ - "input_df", - "labels_column", - "positive_label", - "model_selection", - "model_selection_scoring", - "scoring", - "n_components", - "reg_covar", - "covariance_type", - "init_params", - "max_iter", - "random_state", - "metadata_df", - ] output_keys = [ "model", "model_evaluation_df", @@ -534,8 +361,7 @@ class ClusteringExpectationMaximisation(DataAnalysisStep): "cluster_labels_probabilities_df", ] - def method(self, inputs: dict) -> dict: - return expectation_maximisation(**inputs) + calc_method = staticmethod(expectation_maximisation) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["input_df"] = steps.protein_df @@ -550,26 +376,13 @@ class ClusteringHierarchicalAgglomerative(DataAnalysisStep): "Performs hierarchical clustering utilizing a bottom-up approach" ) - input_keys = [ - "input_df", - "labels_column", - "positive_label", - "model_selection", - "model_selection_scoring", - "scoring", - "n_clusters", - "metric", - "linkage", - "metadata_df", - ] output_keys = [ "model", "model_evaluation_df", "cluster_labels_df", ] - def method(self, inputs: dict) -> dict: - return hierarchical_agglomerative_clustering(**inputs) + calc_method = staticmethod(hierarchical_agglomerative_clustering) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["input_df"] = steps.protein_df @@ -582,27 +395,6 @@ class ClassificationRandomForest(DataAnalysisStep): operation = "classification" method_description = "A random forest is a meta estimator that fits a number of decision tree classifiers on various sub-samples of the dataset and uses averaging to improve the predictive accuracy and control over-fitting." - input_keys = [ - "input_df", - "labels_column", - "positive_label", - "test_size", - "split_stratify", - "validation_strategy", - "train_val_split", - "n_splits", - "shuffle", - "n_repeats", - "random_state_cv", - "p_samples", - "scoring", - "model_selection", - "model_selection_scoring", - "criterion", - "max_depth", - "random_state", - "metadata_df", - ] output_keys = [ "model", "model_evaluation_df", @@ -612,8 +404,7 @@ class ClassificationRandomForest(DataAnalysisStep): "y_test_df", ] - def method(self, inputs: dict) -> dict: - return random_forest(**inputs) + calc_method = staticmethod(random_forest) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["input_df"] = steps.protein_df @@ -626,27 +417,6 @@ class ClassificationSVM(DataAnalysisStep): operation = "classification" method_description = "A support vector machine constructs a hyperplane or set of hyperplanes in a high- or infinite-dimensional space, which can be used for classification." - input_keys = [ - "input_df", - "labels_column", - "positive_label", - "test_size", - "split_stratify", - "validation_strategy", - "train_val_split", - "n_splits", - "shuffle", - "n_repeats", - "random_state_cv", - "p_samples", - "scoring", - "model_selection", - "model_selection_scoring", - "C", - "kernel", - "tolerance" "random_state", - "metadata_df", - ] output_keys = [ "model", "model_evaluation_df", @@ -656,8 +426,7 @@ class ClassificationSVM(DataAnalysisStep): "y_test_df", ] - def method(self, inputs: dict) -> dict: - return svm(**inputs) + calc_method = staticmethod(svm) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["input_df"] = steps.protein_df @@ -690,19 +459,9 @@ class DimensionReductionTSNE(DataAnalysisStep): operation = "dimension_reduction" method_description = "Dimension reduction of a dataframe using t-SNE" - input_keys = [ - "input_df", - "n_components", - "perplexity", - "metric", - "random_state", - "n_iter", - "n_iter_without_progress", - ] output_keys = ["embedded_data"] - def method(self, inputs: dict) -> dict: - return t_sne(**inputs) + calc_method = staticmethod(t_sne) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["input_df"] = steps.protein_df @@ -715,18 +474,9 @@ class DimensionReductionUMAP(DataAnalysisStep): operation = "dimension_reduction" method_description = "Dimension reduction of a dataframe using UMAP" - input_keys = [ - "input_df", - "n_neighbors", - "n_components", - "min_dist", - "metric", - "random_state", - ] output_keys = ["embedded_data"] - def method(self, inputs: dict) -> dict: - return umap(**inputs) + calc_method = staticmethod(umap) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["input_df"] = steps.get_step_output( @@ -740,21 +490,15 @@ class ProteinGraphPeptidesToIsoform(DataAnalysisStep): operation = "protein_graph" method_description = "Create a variation graph (.graphml) for a Protein and map the peptides onto the graph for coverage visualisation. The protein data will be downloaded from https://rest.uniprot.org/uniprotkb/.txt. Only `Variant`-Features are included in the graph. This, currently, only works with Uniport-IDs and while you are online." - input_keys = [ - "protein_id", - "run_name", - "peptide_df", - "k" "allowed_mismatches", - ] output_keys = [ - "graph_path" "protein_id", + "graph_path", + "protein_id", "peptide_matches", "peptide_mismatches", "filtered_blocks", ] - def method(self, inputs: dict) -> dict: - return peptides_to_isoform(**inputs) + calc_method = staticmethod(peptides_to_isoform) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["peptide_df"] = steps.peptide_df @@ -767,17 +511,12 @@ class ProteinGraphVariationGraph(DataAnalysisStep): operation = "protein_graph" method_description = "Create a variation graph (.graphml) for a protein, including variation-features. The protein data will be downloaded from https://rest.uniprot.org/uniprotkb/.txt. This, currently, only works with Uniport-IDs and while you are online." - input_keys = [ - "protein_id", - "run_name", - ] output_keys = [ "graph_path", "filtered_blocks", ] - def method(self, inputs: dict) -> dict: - return variation_graph(**inputs) + calc_method = staticmethod(variation_graph) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["peptide_df"] = steps.peptide_df @@ -785,20 +524,11 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: return inputs -class FLEXIQuantLF(PlotStep): +class FLEXIQuantLF(DataAnalysisStep): display_name = "FLEXIQuant-LF" operation = "modification_quantification" method_description = "FLEXIQuant-LF is an unbiased, label-free computational tool to indirectly detect modified peptides and to quantify the degree of modification based solely on the unmodified peptide species." - input_keys = [ - "peptide_df", - "metadata_df", - "reference_group", - "protein_id", - "num_init", - "mod_cutoff", - "grouping_column", - ] output_keys = [ "raw_scores", "RM_scores", @@ -806,8 +536,7 @@ class FLEXIQuantLF(PlotStep): "removed_peptides", ] - def method(self, inputs: dict) -> dict: - return flexiquant_lf(**inputs) + plot_method = staticmethod(flexiquant_lf) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["peptide_df"] = steps.get_step_output( @@ -822,16 +551,11 @@ class SelectPeptidesForProtein(DataAnalysisStep): operation = "Peptide analysis" method_description = "Filter peptides for the a selected Protein of Interest from a peptide dataframe" - input_keys = [ - "peptide_df", - "protein_ids", - ] output_keys = [ "peptide_df", ] - def method(self, inputs: dict) -> dict: - return select_peptides_of_protein(**inputs) + calc_method = staticmethod(select_peptides_of_protein) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["peptide_df"] = steps.get_step_output( @@ -862,15 +586,11 @@ class PTMsPerSample(DataAnalysisStep): method_description = ("Analyze the post-translational modifications (PTMs) of a single protein of interest. " "This function requires a peptide dataframe with PTM information.") - input_keys = [ - "peptide_df", - ] output_keys = [ "ptm_df", ] - def method(self, inputs: dict) -> dict: - return ptms_per_sample(**inputs) + calc_method = staticmethod(ptms_per_sample) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["peptide_df"] = steps.get_step_output( @@ -885,15 +605,11 @@ class PTMsProteinAndPerSample(DataAnalysisStep): method_description = ("Analyze the post-translational modifications (PTMs) of all Proteins. " "This function requires a peptide dataframe with PTM information.") - input_keys = [ - "peptide_df", - ] output_keys = [ "ptm_df", ] - def method(self, inputs: dict) -> dict: - return ptms_per_protein_and_sample(**inputs) + calc_method = staticmethod(ptms_per_protein_and_sample) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["peptide_df"] = steps.get_step_output( diff --git a/protzilla/methods/data_integration.py b/protzilla/methods/data_integration.py index 6a7896e6..4f080d64 100644 --- a/protzilla/methods/data_integration.py +++ b/protzilla/methods/data_integration.py @@ -18,8 +18,8 @@ def insert_dataframes(self, steps: StepManager, inputs) -> dict: class PlotStep(DataIntegrationStep): operation = "plot" - def handle_outputs(self, outputs: dict): - super().handle_outputs(outputs) + def handle_calc_outputs(self, outputs: dict): + super().handle_calc_outputs(outputs) plots = outputs["plots"] if "plots" in outputs else [] self.plots = Plots(plots) @@ -29,19 +29,9 @@ class EnrichmentAnalysisGOAnalysisWithString(DataIntegrationStep): operation = "enrichment_analysis" method_description = "Online GO analysis using STRING API" - input_keys = [ - "proteins_df", - "differential_expression_col", - "differential_expression_threshold", - "gene_sets_restring", - "organism", - "direction", - "background_path", - ] output_keys = ["enrichment_df"] - def method(self, inputs: dict) -> dict: - return enrichment_analysis.GO_analysis_with_STRING(**inputs) + calc_method = staticmethod(enrichment_analysis.GO_analysis_with_STRING) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["proteins_df"] = steps.get_step_output( @@ -63,25 +53,9 @@ class EnrichmentAnalysisGOAnalysisWithEnrichr(DataIntegrationStep): display_name = "GO analysis with Enrichr" operation = "enrichment_analysis" method_description = "Online GO analysis using Enrichr API" - input_keys = [ - "proteins_df", - "differential_expression_col", - "differential_expression_threshold", - "gene_mapping_df", - "gene_sets_field", - "gene_sets_path", - "gene_sets_enrichr", - "direction", - "organism", - "background_field", - "background_path", - "background_number", - "background_biomart", - ] output_keys = ["enrichment_df"] - def method(self, inputs: dict) -> dict: - return enrichment_analysis.GO_analysis_with_Enrichr(**inputs) + calc_method = staticmethod(enrichment_analysis.GO_analysis_with_Enrichr) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["proteins_df"] = steps.get_step_output( @@ -107,21 +81,11 @@ class EnrichmentAnalysisGOAnalysisOffline(DataIntegrationStep): display_name = "GO analysis offline" operation = "enrichment_analysis" method_description = "Offline GO Analysis using a hypergeometric test" - input_keys = [ - "protein_df", - "differential_expression_col", - "differential_expression_threshold", - "gene_mapping", # TODO adjust this method to use the gene_mapping_df from gene_mapping - "gene_sets_path", - "direction", - "background_field", - "background_path", - "background_number", - ] + output_keys = ["enrichment_df"] - def method(self, inputs: dict) -> dict: - return enrichment_analysis.GO_analysis_offline(**inputs) + calc_method = staticmethod(enrichment_analysis.GO_analysis_offline) + # TODO gene_mapping - adjust this method to use the gene_mapping_df from gene_mapping def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["proteins_df"] = steps.get_step_output( @@ -143,66 +107,30 @@ class EnrichmentAnalysisWithGSEA(DataIntegrationStep): display_name = "GSEA" operation = "enrichment_analysis" method_description = "Perform gene set enrichment analysis" - input_keys = [ - "protein_df", - "gene_mapping_df", - "gene_sets_field", - "gene_sets_path", - "gene_sets_enrichr", - "grouping", - "group1", - "group2", - "min_size", - "max_size", - "number_of_permutations", - "permutation_type", - "ranking_method", - "weighted_score", - ] output_keys = ["enrichment_df", "ranking"] - def method(self, inputs: dict) -> dict: - return enrichment_analysis.gsea(**inputs) + calc_method = staticmethod(enrichment_analysis.gsea) class EnrichmentAnalysisWithPrerankedGSEA(DataIntegrationStep): display_name = "GSEA preranked" operation = "enrichment_analysis" method_description = "Maps proteins to genes and performs GSEA according using provided numerical column for ranking" - input_keys = [ - "protein_df", - "ranking_column", - "ranking_direction", - "gene_mapping_df", - "gene_sets_field", - "gene_sets_path", - "gene_sets_enrichr", - "min_size", - "max_size", - "number_of_permutations", - "permutation_type", - "weighted_score", - "seed", - "threads", - ] output_keys = ["enrichment_df", "ranking"] - def method(self, inputs: dict) -> dict: - return enrichment_analysis.gsea_preranked(**inputs) + calc_method = staticmethod(enrichment_analysis.gsea_preranked) class DatabaseIntegrationByGeneMapping(DataIntegrationStep): display_name = "Gene mapping" operation = "database_integration" method_description = "Map protein groups to genes" - input_keys = ["dataframe", "database_names", "use_biomart"] output_keys = ["gene_mapping_df", "filtered_protein_ids"] - def method(self, inputs: dict) -> dict: - return database_integration.gene_mapping(**inputs) + calc_method = staticmethod(database_integration.gene_mapping) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["dataframe"] = steps.get_step_output( @@ -215,32 +143,20 @@ class DatabaseIntegrationByUniprot(DataIntegrationStep): display_name = "Uniprot" operation = "database_integration" method_description = "Add Uniprot data to a dataframe" - input_keys = ["dataframe", "database_names", "fields"] output_keys = ["results_df"] - def method(self, inputs: dict) -> dict: - return database_integration.add_uniprot_data(**inputs) + calc_method = staticmethod(database_integration.add_uniprot_data) class PlotGOEnrichmentBarPlot(PlotStep): display_name = "Bar plot for GO enrichment analysis" operation = "plot" method_description = "Creates a bar plot from GO enrichment data" - input_keys = [ - "input_df", - "gene_sets", - "value", - "top_terms", - "cutoff", - "title", - "figsize", - ] - # TODO: input figsize optional? + output_keys = ["plots"] - def method(self, inputs: dict) -> dict: - return di_plots.GO_enrichment_bar_plot(**inputs) + calc_method = staticmethod(di_plots.GO_enrichment_bar_plot) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs[ @@ -256,57 +172,27 @@ class PlotGOEnrichmentDotPlot(PlotStep): display_name = "Dot plot for GO enrichment analysis (offline & with Enrichr) " operation = "plot" method_description = "Creates a categorical scatter plot from GO enrichment data" - input_keys = [ - "x_axis_type", - "gene_sets", - "top_terms", - "cutoff", - "title", - "rotate_x_labels", - "show_ring", - "dot_size", - "figsize", - ] + output_keys = ["plots"] - def method(self, inputs: dict) -> dict: - return di_plots.GO_enrichment_dot_plot(**inputs) + calc_method = staticmethod(di_plots.GO_enrichment_dot_plot) class PlotGSEADotPlot(PlotStep): display_name = "Dot plot for (pre-ranked) GSEA" operation = "plot" method_description = "Creates a categorical scatter plot from GSEA data" - input_keys = [ - "cutoff", - "gene_sets", - "dot_color_value", - "x_axis_value", - "title", - "show_ring", - "dot_size", - "remove_library_names", - "figsize", - ] + output_keys = ["plots"] - def method(self, inputs: dict) -> dict: - return di_plots.gsea_dot_plot(**inputs) + calc_method = staticmethod(di_plots.gsea_dot_plot) class PlotGSEAEnrichmentPlot(PlotStep): display_name = "Enrichment plot for (pre-ranked) GSEA" operation = "plot" method_description = "Creates an enrichment plot from (pre-ranked) GSEA data with the enrichment score, ranked_metric, gene rank and hits" - input_keys = [ - "term_dict", - "term_name", - "ranking", - "pos_pheno_label", - "neg_pheno_label", - "figsize", - ] + output_keys = ["plots"] - def method(self, inputs: dict) -> dict: - return di_plots.gsea_enrichment_plot(**inputs) + calc_method = staticmethod(di_plots.gsea_enrichment_plot) diff --git a/protzilla/methods/data_preprocessing.py b/protzilla/methods/data_preprocessing.py index 0565eaf0..38566d49 100644 --- a/protzilla/methods/data_preprocessing.py +++ b/protzilla/methods/data_preprocessing.py @@ -1,8 +1,5 @@ from __future__ import annotations -import logging -import traceback - from protzilla.data_preprocessing import ( filter_proteins, filter_samples, @@ -12,8 +9,7 @@ peptide_filter, transformation, ) -from protzilla.steps import Plots, Step, StepManager -from protzilla.utilities import format_trace +from protzilla.steps import Step, StepManager class DataPreprocessingStep(Step): @@ -32,34 +28,6 @@ def insert_dataframes(self, steps: StepManager, inputs: dict) -> dict: inputs["peptide_df"] = steps.get_step_output(Step, "peptide_df") return inputs - def plot(self, inputs: dict = None): - if inputs is None: - inputs = self.plot_inputs - else: - self.plot_inputs = inputs.copy() - inputs = self.insert_dataframes_for_plot(inputs) - try: - self.plots = Plots(self.plot_method(inputs)) - except Exception as e: - self.messages.append( - dict( - level=logging.ERROR, - msg=( - f"An error occurred while plotting this step: {e.__class__.__name__} {e} " - f"Please check your parameters or report a potential programming issue." - ), - trace=format_trace(traceback.format_exception(e)), - ) - ) - - def insert_dataframes_for_plot(self, inputs: dict) -> dict: - inputs["method_inputs"] = self.inputs - inputs["method_outputs"] = self.output - return inputs - - def plot_method(self, inputs): - raise NotImplementedError("Plot method not implemented for this step") - class FilterProteinsBySamplesMissing(DataPreprocessingStep): display_name = "By samples missing" @@ -68,13 +36,8 @@ class FilterProteinsBySamplesMissing(DataPreprocessingStep): "Filter proteins based on the amount of samples with nan values" ) - input_keys = ["protein_df", "peptide_df", "percentage"] - - def method(self, inputs): - return filter_proteins.by_samples_missing(**inputs) - - def plot_method(self, inputs): - return filter_proteins.by_samples_missing_plot(**inputs) + calc_method = staticmethod(filter_proteins.by_samples_missing) + plot_method = staticmethod(filter_proteins.by_samples_missing_plot) class FilterByProteinsCount(DataPreprocessingStep): @@ -82,13 +45,8 @@ class FilterByProteinsCount(DataPreprocessingStep): operation = "filter_samples" method_description = "Filter by protein count per sample" - input_keys = ["protein_df", "peptide_df", "deviation_threshold"] - - def method(self, inputs): - return filter_samples.by_protein_count(**inputs) - - def plot_method(self, inputs): - return filter_samples.by_protein_count_plot(**inputs) + calc_method = staticmethod(filter_samples.by_protein_count) + plot_method = staticmethod(filter_samples.by_protein_count_plot) class FilterSamplesByProteinsMissing(DataPreprocessingStep): @@ -98,13 +56,8 @@ class FilterSamplesByProteinsMissing(DataPreprocessingStep): "Filter samples based on the amount of proteins with nan values" ) - input_keys = ["protein_df", "peptide_df", "percentage"] - - def method(self, inputs): - return filter_samples.by_proteins_missing(**inputs) - - def plot_method(self, inputs): - return filter_samples.by_proteins_missing_plot(**inputs) + calc_method = staticmethod(filter_samples.by_proteins_missing) + plot_method = staticmethod(filter_samples.by_proteins_missing_plot) class FilterSamplesByProteinIntensitiesSum(DataPreprocessingStep): @@ -112,13 +65,8 @@ class FilterSamplesByProteinIntensitiesSum(DataPreprocessingStep): operation = "filter_samples" method_description = "Filter by sum of protein intensities per sample" - input_keys = ["protein_df", "peptide_df", "deviation_threshold"] - - def method(self, inputs): - return filter_samples.by_protein_intensity_sum(**inputs) - - def plot_method(self, inputs): - return filter_samples.by_protein_intensity_sum_plot(**inputs) + calc_method = staticmethod(filter_samples.by_protein_intensity_sum) + plot_method = staticmethod(filter_samples.by_protein_intensity_sum_plot) class OutlierDetectionByPCA(DataPreprocessingStep): @@ -126,13 +74,8 @@ class OutlierDetectionByPCA(DataPreprocessingStep): operation = "outlier_detection" method_description = "Detect outliers using PCA" - input_keys = ["protein_df", "peptide_df", "number_of_components", "threshold"] - - def method(self, inputs): - return outlier_detection.by_pca(**inputs) - - def plot_method(self, inputs): - return outlier_detection.by_pca_plot(**inputs) + calc_method = staticmethod(outlier_detection.by_pca) + plot_method = staticmethod(outlier_detection.by_pca_plot) class OutlierDetectionByLocalOutlierFactor(DataPreprocessingStep): @@ -140,13 +83,8 @@ class OutlierDetectionByLocalOutlierFactor(DataPreprocessingStep): operation = "outlier_detection" method_description = "Detect outliers using the local outlier factor" - input_keys = ["protein_df", "peptide_df", "number_of_neighbors"] - - def method(self, inputs): - return outlier_detection.by_local_outlier_factor(**inputs) - - def plot_method(self, inputs): - return outlier_detection.by_local_outlier_factor_plot(**inputs) + calc_method = staticmethod(outlier_detection.by_local_outlier_factor) + plot_method = staticmethod(outlier_detection.by_local_outlier_factor_plot) class OutlierDetectionByIsolationForest(DataPreprocessingStep): @@ -154,13 +92,8 @@ class OutlierDetectionByIsolationForest(DataPreprocessingStep): operation = "outlier_detection" method_description = "Detect outliers using Isolation Forest" - input_keys = ["protein_df", "peptide_df", "n_estimators"] - - def method(self, inputs): - return outlier_detection.by_isolation_forest(**inputs) - - def plot_method(self, inputs): - return outlier_detection.by_isolation_forest_plot(**inputs) + calc_method = staticmethod(outlier_detection.by_isolation_forest) + plot_method = staticmethod(outlier_detection.by_isolation_forest_plot) class TransformationLog(DataPreprocessingStep): @@ -168,13 +101,8 @@ class TransformationLog(DataPreprocessingStep): operation = "transformation" method_description = "Transform data by log" - input_keys = [ "protein_df", "peptide_df", "log_base"] - - def method(self, inputs): - return transformation.by_log(**inputs) - - def plot_method(self, inputs): - return transformation.by_log_plot(**inputs) + calc_method = staticmethod(transformation.by_log) + plot_method = staticmethod(transformation.by_log_plot) class NormalisationByZScore(DataPreprocessingStep): @@ -182,13 +110,8 @@ class NormalisationByZScore(DataPreprocessingStep): operation = "normalisation" method_description = "Normalise data by Z-Score" - plot_input_names = ["protein_df"] - - def method(self, inputs): - return normalisation.by_z_score(**inputs) - - def plot_method(self, inputs): - return normalisation.by_z_score_plot(**inputs) + calc_method = staticmethod(normalisation.by_z_score) + plot_method = staticmethod(normalisation.by_z_score_plot) class NormalisationByTotalSum(DataPreprocessingStep): @@ -196,13 +119,8 @@ class NormalisationByTotalSum(DataPreprocessingStep): operation = "normalisation" method_description = "Normalise data by total sum" - plot_input_names = ["protein_df"] - - def method(self, inputs): - return normalisation.by_totalsum(**inputs) - - def plot_method(self, inputs): - return normalisation.by_totalsum_plot(**inputs) + calc_method = staticmethod(normalisation.by_totalsum) + plot_method = staticmethod(normalisation.by_totalsum_plot) class NormalisationByMedian(DataPreprocessingStep): @@ -210,13 +128,8 @@ class NormalisationByMedian(DataPreprocessingStep): operation = "normalisation" method_description = "Normalise data by median" - input_keys = ["protein_df", "percentile"] - - def method(self, inputs): - return normalisation.by_median(**inputs) - - def plot_method(self, inputs): - return normalisation.by_median_plot(**inputs) + calc_method = staticmethod(normalisation.by_median) + plot_method = staticmethod(normalisation.by_median_plot) class NormalisationByReferenceProtein(DataPreprocessingStep): @@ -224,13 +137,8 @@ class NormalisationByReferenceProtein(DataPreprocessingStep): operation = "normalisation" method_description = "Normalise data by reference protein" - input_keys = ["protein_df", "reference_protein"] - - def method(self, inputs): - return normalisation.by_reference_protein(**inputs) - - def plot_method(self, inputs): - return normalisation.by_reference_protein_plot(**inputs) + calc_method = staticmethod(normalisation.by_reference_protein) + plot_method = staticmethod(normalisation.by_reference_protein_plot) class ImputationByMinPerDataset(DataPreprocessingStep): @@ -238,13 +146,8 @@ class ImputationByMinPerDataset(DataPreprocessingStep): operation = "imputation" method_description = "Impute missing values by the minimum per dataset" - input_keys = ["protein_df", "shrinking_value"] - - def method(self, inputs): - return imputation.by_min_per_dataset(**inputs) - - def plot_method(self, inputs): - return imputation.by_min_per_dataset_plot(**inputs) + calc_method = staticmethod(imputation.by_min_per_dataset) + plot_method = staticmethod(imputation.by_min_per_dataset_plot) class ImputationByMinPerProtein(DataPreprocessingStep): @@ -252,13 +155,8 @@ class ImputationByMinPerProtein(DataPreprocessingStep): operation = "imputation" method_description = "Impute missing values by the minimum per protein" - input_keys = ["protein_df", "shrinking_value"] - - def method(self, inputs): - return imputation.by_min_per_protein(**inputs) - - def plot_method(self, inputs): - return imputation.by_min_per_protein_plot(**inputs) + calc_method = staticmethod(imputation.by_min_per_protein) + plot_method = staticmethod(imputation.by_min_per_protein_plot) class ImputationByMinPerSample(DataPreprocessingStep): @@ -266,13 +164,8 @@ class ImputationByMinPerSample(DataPreprocessingStep): operation = "imputation" method_description = "Impute missing values by the minimum per sample" - input_keys = ["protein_df", "shrinking_value"] - - def method(self, inputs): - return imputation.by_min_per_protein(**inputs) - - def plot_method(self, inputs): - return imputation.by_min_per_sample_plot(**inputs) + calc_method = staticmethod(imputation.by_min_per_protein) + plot_method = staticmethod(imputation.by_min_per_sample_plot) class SimpleImputationPerProtein(DataPreprocessingStep): @@ -283,13 +176,8 @@ class SimpleImputationPerProtein(DataPreprocessingStep): "sklearn.SimpleImputer class" ) - input_keys = ["protein_df", "strategy"] - - def method(self, inputs): - return imputation.by_simple_imputer(**inputs) - - def plot_method(self, inputs): - return imputation.by_simple_imputer_plot(**inputs) + calc_method = staticmethod(imputation.by_simple_imputer) + plot_method = staticmethod(imputation.by_simple_imputer_plot) class ImputationByKNN(DataPreprocessingStep): @@ -301,13 +189,8 @@ class ImputationByKNN(DataPreprocessingStep): "the features that neither is missing are close." ) - input_keys = ["protein_df", "number_of_neighbours"] - - def method(self, inputs): - return imputation.by_knn(**inputs) - - def plot_method(self, inputs): - return imputation.by_knn_plot(**inputs) + calc_method = staticmethod(imputation.by_knn) + plot_method = staticmethod(imputation.by_knn_plot) class ImputationByNormalDistributionSampling(DataPreprocessingStep): @@ -315,25 +198,15 @@ class ImputationByNormalDistributionSampling(DataPreprocessingStep): operation = "imputation" method_description = "Imputation methods include normal distribution sampling per protein or per dataset" - input_keys = ["protein_df", "strategy", "down_shift", "scaling_factor"] - - def method(self, inputs): - return imputation.by_normal_distribution_sampling(**inputs) - - def plot_method(self, inputs): - return imputation.by_normal_distribution_sampling_plot(**inputs) + calc_method = staticmethod(imputation.by_normal_distribution_sampling) + plot_method = staticmethod(imputation.by_normal_distribution_sampling_plot) class FilterPeptidesByPEPThreshold(DataPreprocessingStep): display_name = "PEP threshold" operation = "filter_peptides" method_description = "Filter by PEP-threshold" - - input_keys = ["protein_df", "peptide_df", "threshold"] output_keys = ["protein_df", "peptide_df", "filtered_peptides"] - def method(self, inputs): - return peptide_filter.by_pep_value(**inputs) - - def plot_method(self, inputs): - return peptide_filter.by_pep_value_plot(**inputs) + calc_method = staticmethod(peptide_filter.by_pep_value) + plot_method = staticmethod(peptide_filter.by_pep_value_plot) diff --git a/protzilla/methods/importing.py b/protzilla/methods/importing.py index 6a2f6835..5efb24b3 100644 --- a/protzilla/methods/importing.py +++ b/protzilla/methods/importing.py @@ -17,7 +17,7 @@ class ImportingStep(Step): section = "importing" - def method(self, inputs): + def calc_method(self): raise NotImplementedError("This method must be implemented in a subclass.") def insert_dataframes(self, steps: StepManager, inputs) -> dict: @@ -29,11 +29,9 @@ class MaxQuantImport(ImportingStep): operation = "Protein Data Import" method_description = "Import the protein groups file form output of MaxQuant" - input_keys = ["file_path", "map_to_uniprot", "intensity_name", "aggregation_method"] output_keys = ["protein_df"] - def method(self, inputs): - return max_quant_import(**inputs) + calc_method = staticmethod(max_quant_import) class DiannImport(ImportingStep): @@ -41,11 +39,9 @@ class DiannImport(ImportingStep): operation = "Protein Data Import" method_description = "DIA-NN data import" - input_keys = ["file_path", "map_to_uniprot", "aggregation_method"] output_keys = ["protein_df"] - def method(self, inputs): - return diann_import(**inputs) + calc_method = staticmethod(diann_import) class MsFraggerImport(ImportingStep): @@ -53,11 +49,9 @@ class MsFraggerImport(ImportingStep): operation = "Protein Data Import" method_description = "Import the combined_protein.tsv file form output of MS Fragger" - input_keys = ["file_path", "intensity_name", "map_to_uniprot", "aggregation_method"] output_keys = ["protein_df"] - def method(self, inputs): - return ms_fragger_import(**inputs) + calc_mehtod = staticmethod(ms_fragger_import) class MetadataImport(ImportingStep): @@ -65,11 +59,9 @@ class MetadataImport(ImportingStep): operation = "metadataimport" method_description = "Import metadata" - input_keys = ["file_path", "feature_orientation", "protein_df"] output_keys = ["metadata_df"] - def method(self, inputs): - return metadata_import_method(**inputs) + calc_method = staticmethod(metadata_import_method) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["protein_df"] = steps.get_step_output(ImportingStep, "protein_df") @@ -81,11 +73,9 @@ class MetadataImportMethodDiann(ImportingStep): operation = "metadataimport" method_description = "Import metadata for run relationships of DIA-NN" - input_keys = ["file_path", "groupby_sample", "protein_df"] output_keys = ["metadata_df", "protein_df"] - def method(self, inputs): - return metadata_import_method_diann(**inputs) + calc_method = staticmethod(metadata_import_method_diann) def insert_dataframes(self, steps: StepManager, inputs) -> dict: inputs["protein_df"] = steps.get_step_output(DiannImport, "protein_df") @@ -99,16 +89,9 @@ class MetadataColumnAssignment(ImportingStep): "Assign columns to metadata categories, repeatable for each category" ) - input_keys = [ - "metadata_required_column", - "metadata_unknown_column", - "protein_df", - "metadata_df", - ] output_keys = ["metadata_df", "protein_df"] - def method(self, inputs): - return metadata_column_assignment(**inputs) + calc_method = staticmethod(metadata_column_assignment) def insert_dataframes(self, steps: StepManager, inputs: dict) -> dict: inputs["protein_df"] = steps.get_step_output(ImportingStep, "protein_df") @@ -123,11 +106,9 @@ class PeptideImport(ImportingStep): operation = "peptide_import" method_description = "Import peptide data" - input_keys = ["file_path", "intensity_name", "map_to_uniprot"] output_keys = ["peptide_df"] - def method(self, inputs): - return peptide_import(**inputs) + calc_method = staticmethod(peptide_import) class EvidenceImport(ImportingStep): @@ -135,8 +116,6 @@ class EvidenceImport(ImportingStep): operation = "peptide_import" method_description = "Import an evidence file" - input_keys = ["file_path", "intensity_name", "map_to_uniprot"] output_keys = ["peptide_df"] - def method(self, inputs): - return evidence_import(**inputs) \ No newline at end of file + calc_method = staticmethod(evidence_import) \ No newline at end of file diff --git a/protzilla/runner.py b/protzilla/runner.py index 5c78c57c..18ddbac6 100644 --- a/protzilla/runner.py +++ b/protzilla/runner.py @@ -81,8 +81,7 @@ def compute_workflow(self): if step.section == "importing": self._insert_commandline_inputs(step) self._perform_current_step(step.form_inputs) - if self.all_plots and step.section == "data_preprocessing": - step.plot() + if step.plots and not step.plots.empty: self._save_plots_html(step) diff --git a/protzilla/steps.py b/protzilla/steps.py index 16795259..9448d26e 100644 --- a/protzilla/steps.py +++ b/protzilla/steps.py @@ -29,16 +29,15 @@ class Step: display_name: str = None operation: str = None method_description: str = None - input_keys: list[str] = [] output_keys: list[str] = [] calculation_status: Literal["complete", "outdated", "incomplete", "failed"] = "incomplete" def __init__(self, instance_identifier: str | None = None): self.form_inputs: dict = {} self.inputs: dict = {} - self.messages: Messages = Messages([]) self.output: Output = Output() self.plots: Plots = Plots() + self.messages: Messages = Messages([]) self.instance_identifier = instance_identifier if self.instance_identifier is None: @@ -60,7 +59,6 @@ def __eq__(self, other): def updateInputs(self, inputs: dict) -> None: if inputs: self.inputs = inputs.copy() - self.form_inputs = self.inputs.copy() def calculate(self, steps: StepManager, inputs: dict) -> bool: """ @@ -73,24 +71,30 @@ def calculate(self, steps: StepManager, inputs: dict) -> bool: stepIndex = steps.all_steps.index(self) previousStep = steps.all_steps[stepIndex-1] - try: - if (previousStep.calculation_status == "outdated" ): - if not previousStep.calculate(steps,inputs): - return False + if (previousStep.calculation_status == "outdated" ): + if not previousStep.calculate(steps,inputs): + return False + + if (steps.current_step_index == stepIndex): + self.updateInputs(inputs) + self.messages.clear() + - if (steps.current_step_index == stepIndex): - self.updateInputs(inputs) - self.messages.clear() + try: self.insert_dataframes(steps, self.inputs) - self.validate_inputs() - output_dict = self.method(self.inputs) - self.handle_outputs(output_dict) - self.handle_messages(output_dict,steps,stepIndex) - self.validate_outputs() + if self.calc_method: + calc_output = self.calc_method(**self.calculation_input) + self.handle_calc_outputs(calc_output) + self.validate_outputs() + self.calculation_status = "complete" if (steps.failed_step_index == stepIndex): - steps.failed_step_index = -1 - return True + steps.failed_step_index = -1 + + if self.plot_method: + plot_output = self.plot_method(**self.plot_input) + self.handle_plot_outputs(plot_output) + except NotImplementedError as e: self.messages.append( dict( @@ -116,28 +120,30 @@ def calculate(self, steps: StepManager, inputs: dict) -> bool: ) ) except Exception as e: - self.messages.append( - dict( - level=logging.ERROR, - msg=( - f"An error occurred while calculating this step: {e.__class__.__name__} {e} " - f"Please check your parameters or report a potential programming issue." - ), - trace=format_trace(traceback.format_exception(e)), - ) + self.messages.append( + dict( + level=logging.ERROR, + msg=( + f"An error occurred while calculating this step: {e.__class__.__name__} {e} " + f"Please check your parameters or report a potential programming issue." + ), + trace=format_trace(traceback.format_exception(e)), ) - return False + ) + + if self.calculation_status != "complete": + self.calculation_status = "failed" + steps.failed_step_index = stepIndex - def method(self, **kwargs) -> dict: - raise NotImplementedError("This method must be implemented in a subclass.") + return self.calculation_status == "complete" def insert_dataframes(self, steps: StepManager, inputs: dict) -> dict: return inputs - def handle_outputs(self, outputs: dict) -> None: + def handle_calc_outputs(self, outputs: dict) -> None: """ Handles the dictionary from the calculation method and creates an Output object from it. - Responsible for checking if the output is a dictionary and if it is empty, and setting the output attribute of the instance. + Responsible for checking that the output is a dictonary and not empty, and setting the output attribute of the instance. :param outputs: A dictionary received after the calculation :return: None @@ -149,7 +155,29 @@ def handle_outputs(self, outputs: dict) -> None: raise ValueError("Output of calculation is empty.") self.output = Output(outputs) - def handle_messages(self, outputs: dict, steps: StepManager, stepIndex: int) -> None: + self.handle_messages(outputs) + + def handle_plot_outputs(self, outputs: dict|list) -> None: + """ + Handles the dictionary from the plot method and creates a Plots object from it. + Responsible for clearing and setting the plots attribute of the class. + :param outputs: A dictionary or a list received after the plot method + :return: None + """ + + if not isinstance(outputs, dict) and not isinstance(outputs, list): + raise TypeError("Output of plot method is not a dictionary or a list.") + + if isinstance(outputs, dict): + plots = outputs.pop("plots", []) + self.output.output.update(outputs) + self.handle_messages(outputs) + else: + plots = outputs + + self.plots = Plots(plots) + + def handle_messages(self, outputs: dict) -> None: """ Handles the messages from the calculation method and creates a Messages object from it. Responsible for clearing and setting the messages attribute of the class. @@ -158,58 +186,68 @@ def handle_messages(self, outputs: dict, steps: StepManager, stepIndex: int) -> """ messages = outputs.get("messages", []) self.messages.extend(messages) - for message in messages: - if message["level"] == logging.ERROR: - self.calculation_status = "failed" - steps.failed_step_index = stepIndex - raise Exception("Calculation failed") - - def plot(self, inputs: dict = None) -> None: - raise NotImplementedError( - f"Plotting is not implemented for this step ({self.display_name}). Only preprocessing methods can have additional plots." - ) - def validate_inputs(self, required_keys: list[str] = None) -> bool: - """ - Validates the inputs of the step. If required_keys is not specified, the input_keys of the method class are used. - Will delete unnecessary keys from the inputs dictionary to avoid passing unwanted parameters to the method. - :param required_keys: The keys that are required in the inputs dictionary (optional) - :return: True if the inputs are valid, False otherwise - :raises ValueError: If a required key is missing in the inputs - """ - if required_keys is None: - required_keys = self.input_keys + calc_method = None + plot_method = None # if the plot method uses the output of the calculation method, it should be prefixed with "output_" + + @property + def calculation_input(self) -> dict: + input_parameters = inspect.signature(self.calc_method).parameters + required_keys = [ + key + for key, param in input_parameters.items() + if param.default == inspect.Parameter.empty + ] for key in required_keys: if key not in self.inputs: - raise ValueError(f"Missing input {key} in inputs") - - # Deleting all unnecessary keys as to avoid "too many parameters" error - for key in self.inputs.copy().keys(): - if key not in required_keys: - logging.info( - f"Removing unnecessary key {key} from inputs. If this is not wanted, add the key to input_keys of the method class." + raise ValueError( + f"Missing required input '{key}' for the calulation method" ) - self.inputs.pop(key) - return True + return { + key: self.inputs[key] + for key in input_parameters.keys() + if key in self.inputs + } - def validate_outputs( - self, required_keys: list[str] = None, soft_check: bool = False - ) -> bool: - """ - Validates the outputs of the step. If required_keys is not specified, the output_keys of the method class are used. + @property + def plot_input(self) -> dict: + # if the plot method uses the output of the calculation method, it should be prefixed with "output_" + prefixed_output = { + "output_" + key: value for key, value in self.output.output.items() + } + plot_input = self.inputs | prefixed_output + + input_parameters = inspect.signature(self.plot_method).parameters + required_keys = [ + key + for key, param in input_parameters.items() + if param.default == inspect.Parameter.empty + ] + for key in required_keys: + if key not in plot_input: + raise ValueError(f"Missing required input '{key}' for the plot method") - :param required_keys: The keys that are required in the outputs dictionary (optional) + return { + key: plot_input[key] for key in input_parameters.keys() if key in plot_input + } + + def validate_outputs(self, soft_check: bool = False) -> bool: + """ + Validates the outputs of the step. Uses the output_keys attribute to check if all required keys are present in the output dictionary. :param soft_check: Whether to raise errors or just return False if the output is invalid :return: True if the outputs are valid, False otherwise :raises ValueError: If a required key is missing in the outputs """ - inspect.signature(self.method).parameters - if required_keys is None: - required_keys = self.output_keys - for key in required_keys: + + print("Val0.0") + for key in self.output_keys: + print("Val0.5") if key not in self.output or self.output[key] is None: + print("Val0.7") if not soft_check: + + print("val1.0") raise ValueError( f"Output validation failed: missing output {key} in outputs." ) @@ -219,7 +257,8 @@ def validate_outputs( class Output: - def __init__(self, output: dict = None): + + def __init__(self, output: dict = {}): if output is None: output = {} @@ -257,7 +296,7 @@ def __getitem__(self, key): return self.messages[key] def __repr__(self): - return f"Messages: {[message['message'] for message in self.messages]}" + return f"Messages: {[message['msg'] for message in self.messages]}" def append(self, param): self.messages.append(param) @@ -533,15 +572,13 @@ def current_location(self) -> tuple[str, str, str]: def protein_df(self) -> pd.DataFrame: from protzilla.steps import Step - df = self.get_step_output(Step, "protein_df") - return df + return self.get_step_output(Step, "protein_df") @property def metadata_df(self) -> pd.DataFrame | None: from protzilla.methods.importing import ImportingStep return self.get_step_output(ImportingStep, "metadata_df") - logging.warning("No metadata_df found in steps") @property def preprocessed_output(self) -> Output: diff --git a/tests/protzilla/data_analysis/test_classification.py b/tests/protzilla/data_analysis/test_classification.py index dfb8238e..38cbea7c 100644 --- a/tests/protzilla/data_analysis/test_classification.py +++ b/tests/protzilla/data_analysis/test_classification.py @@ -103,7 +103,7 @@ def random_forest_out( meta_df, "Group", n_estimators=3, - test_validate_split=0.20, + train_val_split=0.20, model_selection=model_selection, validation_strategy=validation_strategy, random_state=42, diff --git a/tests/protzilla/data_integration/test_enrichment_analysis.py b/tests/protzilla/data_integration/test_enrichment_analysis.py index d028998d..1ebaf15f 100644 --- a/tests/protzilla/data_integration/test_enrichment_analysis.py +++ b/tests/protzilla/data_integration/test_enrichment_analysis.py @@ -430,7 +430,7 @@ def test_GO_analysis_with_no_gene_sets_input(): assert "messages" in current_out assert "No gene sets provided" in current_out["messages"][0]["msg"] - +@pytest.mark.skip(reason="The api dosn't work") @patch("protzilla.data_integration.database_query.uniprot_groups_to_genes") def test_GO_analysis_with_Enrichr(mock_uniprot_groups_to_gene, data_folder_tests): if biomart_availability == False: @@ -510,7 +510,7 @@ def test_GO_analysis_with_Enrichr(mock_uniprot_groups_to_gene, data_folder_tests assert "No background provided" in current_out["messages"][0]["msg"] assert "Some proteins could not be mapped" in current_out["messages"][1]["msg"] - +@pytest.mark.skip(reason="The api dosn't work") def test_GO_analysis_Enrichr_wrong_background_file(data_folder_tests): if biomart_availability == False: pytest.skip("BioMart servers are not available. Skipping related tests.") @@ -715,9 +715,8 @@ def test_GO_analysis_offline_no_protein_sets(): proteins_df=proteins_df, gene_sets_path="", differential_expression_col="fold_change", - direction="up", - background=None, gene_mapping_df=pd.DataFrame(columns=["Protein ID", "Gene"]), + direction="up", ) assert "messages" in current_out @@ -736,7 +735,6 @@ def test_GO_analysis_offline_invalid_protein_set_file(): gene_sets_path="an_invalid_filetype.png", differential_expression_col="fold_change", direction="up", - background="", gene_mapping_df=pd.DataFrame(columns=["Protein ID", "Gene"]), ) diff --git a/tests/protzilla/data_preprocessing/test_filter_proteins.py b/tests/protzilla/data_preprocessing/test_filter_proteins.py index e1d2cb0e..e38500af 100644 --- a/tests/protzilla/data_preprocessing/test_filter_proteins.py +++ b/tests/protzilla/data_preprocessing/test_filter_proteins.py @@ -80,7 +80,7 @@ def test_filter_proteins_by_missing_samples( filter_proteins_by_samples_missing_df, peptide_df=None, percentage=1.0 ) - fig = by_samples_missing_plot(filter_proteins_df, method_output, "Pie chart")[0] + fig = by_samples_missing_plot(method_output["remaining_proteins"], method_output["filtered_proteins"], "Pie chart")[0] if show_figures: fig.show() assert method_output["filtered_proteins"] == [ diff --git a/tests/protzilla/data_preprocessing/test_filter_samples.py b/tests/protzilla/data_preprocessing/test_filter_samples.py index 07a35a0a..d4e8853b 100644 --- a/tests/protzilla/data_preprocessing/test_filter_samples.py +++ b/tests/protzilla/data_preprocessing/test_filter_samples.py @@ -74,7 +74,7 @@ def test_by_proteins_missing(filter_samples_df, show_figures, peptides_df): list_samples_excluded_2 = method_output2["filtered_samples"] list_samples_excluded_3 = method_output3["filtered_samples"] - fig = by_proteins_missing_plot(method_input1, method_output1, "Pie chart")[0] + fig = by_proteins_missing_plot(method_output1["protein_df"],method_output1["filtered_samples"], "Pie chart")[0] if show_figures: fig.show() @@ -127,11 +127,11 @@ def test_filter_samples_by_protein_count(filter_samples_df, show_figures, peptid list_samples_excluded_1 = method_output1["filtered_samples"] list_samples_excluded_2 = method_output2["filtered_samples"] - fig = by_protein_count_plot(method_input1, method_output1, "Pie chart")[0] + fig = by_protein_count_plot(method_output1["protein_df"], method_output1["filtered_samples"], "Pie chart")[0] if show_figures: fig.show() - fig = by_protein_count_plot(method_input1, method_output1, "Bar chart")[0] + fig = by_protein_count_plot(method_output1["protein_df"], method_output1["filtered_samples"], "Bar chart")[0] if show_figures: fig.show() @@ -179,7 +179,7 @@ def test_filter_samples_by_protein_intensity_sum( list_samples_excluded_1 = method_output1["filtered_samples"] list_samples_excluded_2 = method_output2["filtered_samples"] - fig = by_protein_intensity_sum_plot(method_input1, method_output1, "Pie chart")[0] + fig = by_protein_intensity_sum_plot(method_output1["protein_df"], method_output1["filtered_samples"], "Pie chart")[0] if show_figures: fig.show() diff --git a/tests/protzilla/data_preprocessing/test_imputation.py b/tests/protzilla/data_preprocessing/test_imputation.py index 93953095..a78c645c 100644 --- a/tests/protzilla/data_preprocessing/test_imputation.py +++ b/tests/protzilla/data_preprocessing/test_imputation.py @@ -159,8 +159,8 @@ def test_imputation_min_value_per_df( method_outputs = by_min_per_dataset(**method_inputs) fig1, fig2 = by_min_per_dataset_plot( - method_inputs, - method_outputs, + input_imputation_df, + method_outputs["protein_df"], "Boxplot", "Bar chart", "Sample", @@ -194,8 +194,8 @@ def test_imputation_min_value_per_sample( method_outputs = by_min_per_sample(**method_inputs) fig1, fig2 = by_min_per_sample_plot( - method_inputs, - method_outputs, + input_imputation_df, + method_outputs["protein_df"], "Boxplot", "Bar chart", "Sample", @@ -229,8 +229,8 @@ def test_imputation_min_value_per_protein( method_outputs = by_min_per_protein(**method_inputs) fig1, fig2 = by_min_per_protein_plot( - method_inputs, - method_outputs, + input_imputation_df, + method_outputs["protein_df"], "Boxplot", "Bar chart", "Sample", @@ -264,8 +264,8 @@ def test_imputation_mean_per_protein( method_outputs = by_simple_imputer(**method_inputs) fig1, fig2 = by_simple_imputer_plot( - method_inputs, - method_outputs, + input_imputation_df, + method_outputs["protein_df"], "Boxplot", "Bar chart", "Sample", @@ -297,8 +297,8 @@ def test_imputation_knn(show_figures, input_imputation_df, assertion_df_knn): method_outputs = by_knn(**method_inputs) fig1, fig2 = by_knn_plot( - method_inputs, - method_outputs, + input_imputation_df, + method_outputs["protein_df"], "Boxplot", "Bar chart", "Sample", @@ -339,8 +339,8 @@ def test_imputation_normal_distribution_sampling(show_figures, input_imputation_ ) fig1, fig2 = by_normal_distribution_sampling_plot( - method_inputs_perProtein, - method_outputs_perProtein, + input_imputation_df, + method_outputs_perProtein["protein_df"], "Boxplot", "Bar chart", "Sample", diff --git a/tests/protzilla/data_preprocessing/test_normalisation.py b/tests/protzilla/data_preprocessing/test_normalisation.py index 8e5c2a45..29aa5d6e 100644 --- a/tests/protzilla/data_preprocessing/test_normalisation.py +++ b/tests/protzilla/data_preprocessing/test_normalisation.py @@ -302,10 +302,9 @@ def expected_df_by_ref_protein_normalisation(): def test_normalisation_by_z_score( normalisation_df, expected_df_by_z_score_normalisation, show_figures ): - method_input = {"protein_df": normalisation_df} - method_outputs = by_z_score(**method_input) + method_outputs = by_z_score(normalisation_df) - fig = by_z_score_plot(method_input, method_outputs, "Boxplot", "Sample", "log10")[0] + fig = by_z_score_plot(normalisation_df, method_outputs["protein_df"], "Boxplot", "Sample", "log10")[0] if show_figures: fig.show() @@ -320,10 +319,9 @@ def test_normalisation_by_z_score( def test_normalisation_by_median( normalisation_df, expected_df_by_median_normalisation, show_figures ): - method_inputs = {"protein_df": normalisation_df} - method_outputs = by_median(**method_inputs) + method_outputs = by_median(normalisation_df) - fig = by_median_plot(method_inputs, method_outputs, "Boxplot", "Sample", "log10")[0] + fig = by_median_plot(normalisation_df, method_outputs["protein_df"], "Boxplot", "Sample", "log10")[0] if show_figures: fig.show() @@ -346,10 +344,9 @@ def test_normalisation_by_median_invalid_percentile(normalisation_df): def test_totalsum_normalisation( normalisation_df, expected_df_by_totalsum_normalisation, show_figures ): - method_inputs = {"protein_df": normalisation_df} - method_outputs = by_totalsum(**method_inputs) + method_outputs = by_totalsum(normalisation_df) - fig = by_totalsum_plot(method_inputs, method_outputs, "Boxplot", "Sample", "log10")[0] + fig = by_totalsum_plot(normalisation_df, method_outputs["protein_df"], "Boxplot", "Sample", "log10")[0] if show_figures: fig.show() @@ -376,7 +373,7 @@ def test_ref_protein_normalisation( } method_outputs = by_reference_protein(**method_input) - fig = by_reference_protein_plot(method_input, method_outputs, "Boxplot", "Sample", "log10")[ + fig = by_reference_protein_plot(normalisation_by_ref_protein_df, method_outputs["protein_df"], "Boxplot", "Sample", "log10")[ 0 ] if show_figures: diff --git a/tests/protzilla/data_preprocessing/test_outlier_detection.py b/tests/protzilla/data_preprocessing/test_outlier_detection.py index e94f84b2..4e4ececb 100644 --- a/tests/protzilla/data_preprocessing/test_outlier_detection.py +++ b/tests/protzilla/data_preprocessing/test_outlier_detection.py @@ -75,7 +75,7 @@ def test_outlier_detection_with_isolation_forest( "n_jobs": -1, } method_outputs = by_isolation_forest(**method_inputs) - fig = by_isolation_forest_plot(method_inputs, method_outputs)[0] + fig = by_isolation_forest_plot(method_outputs["anomaly_df"])[0] if show_figures: fig.show() @@ -110,7 +110,7 @@ def test_outlier_detection_by_local_outlier_factor( "n_jobs": -1, } method_outputs = by_local_outlier_factor(**method_inputs) - fig = by_local_outlier_factor_plot(method_inputs, method_outputs)[0] + fig = by_local_outlier_factor_plot(method_outputs["anomaly_df"])[0] if show_figures: fig.show() assert_peptide_filtering_matches_protein_filtering( @@ -144,7 +144,7 @@ def test_outlier_detection_with_pca(show_figures, outlier_detection_df, peptides "number_of_components": 3, } method_outputs = by_pca(**method_inputs) - fig = by_pca_plot(method_inputs, method_outputs)[0] + fig = by_pca_plot(method_outputs["pca_df"], method_outputs["number_of_components"], method_outputs["explained_variance_ratio"])[0] if show_figures: fig.show() diff --git a/tests/protzilla/data_preprocessing/test_transformation.py b/tests/protzilla/data_preprocessing/test_transformation.py index 09827157..55351117 100644 --- a/tests/protzilla/data_preprocessing/test_transformation.py +++ b/tests/protzilla/data_preprocessing/test_transformation.py @@ -185,7 +185,7 @@ def test_log2_transformation( } method_outputs = by_log(**method_inputs) - fig = by_log_plot(method_inputs, method_outputs, "Boxplot", "Protein ID")[0] + fig = by_log_plot(log2_transformation_df, method_outputs["protein_df"], "Boxplot", "Protein ID")[0] if show_figures: fig.show() @@ -218,8 +218,8 @@ def test_log10_transformation( method_output = by_log(**method_inputs) fig = by_log_plot( - method_inputs, - method_output, + log10_transformation_df, + method_output["protein_df"], "Boxplot", "Protein ID", )[0] diff --git a/tests/protzilla/test_run.py b/tests/protzilla/test_run.py index 1e469c58..75894417 100644 --- a/tests/protzilla/test_run.py +++ b/tests/protzilla/test_run.py @@ -60,16 +60,15 @@ def test_step_plot(self, run_imported): step = ImputationByMinPerProtein() run_imported.step_add(step) run_imported.step_next() - run_imported.step_calculate(inputs={"shrinking_value": 0.5}) - assert run_imported.current_step == step - run_imported.step_plot( - inputs={ + run_imported.step_calculate( + inputs={"shrinking_value": 0.5, "graph_type": "Boxplot", "graph_type_quantities": "Pie chart", "group_by": "None", "visual_transformation": "linear", } ) + assert run_imported.current_step == step print(run_imported.current_step.plots) assert not run_imported.current_step.plots.empty diff --git a/tests/protzilla/test_runner.py b/tests/protzilla/test_runner.py index a3260cad..50171434 100644 --- a/tests/protzilla/test_runner.py +++ b/tests/protzilla/test_runner.py @@ -152,8 +152,6 @@ def test_runner_calculates(monkeypatch, tests_folder_name, ms_data_path, metadat mock_plot = mock_perform_plot(runner) monkeypatch.setattr(runner, "_perform_current_step", mock_method) - for step in runner.run.steps.data_preprocessing: - monkeypatch.setattr(step, "plot", mock_plot) runner.compute_workflow() @@ -188,30 +186,6 @@ def test_runner_calculates_logging(caplog, tests_folder_name, ms_data_path): assert "FileNotFoundError" in caplog.text -def test_runner_plots(monkeypatch, tests_folder_name, ms_data_path, metadata_path): - plot_args = [ - "only_import_and_filter_proteins", - ms_data_path, - f"--run_name={tests_folder_name}/test_runner_{random_string()}", - f"--meta_data_path={metadata_path}", - "--all_plots", - ] - kwargs = args_parser().parse_args(plot_args).__dict__ - runner = Runner(**kwargs) - - mock_method = mock_perform_method(runner) - mock_plot = mock_perform_plot(runner) - - monkeypatch.setattr(runner, "_perform_current_step", mock_method) - for step in runner.run.steps.data_preprocessing: - monkeypatch.setattr(step, "plot", mock_plot) - - runner.compute_workflow() - - assert mock_plot.call_count == 1 - assert mock_plot.inputs == [{"graph_type": "Bar chart"}] - - def test_serialize_graphs(): pre_graphs = [ # this is what the "graphs" section of a step should look like {"graph_type": "Bar chart", "group_by": "Sample"}, diff --git a/ui/db.sqlite3 b/ui/db.sqlite3 index 86540294..eaf5ca4a 100644 Binary files a/ui/db.sqlite3 and b/ui/db.sqlite3 differ diff --git a/ui/runs/form_mapping.py b/ui/runs/form_mapping.py index 7221bba6..171b8df8 100644 --- a/ui/runs/form_mapping.py +++ b/ui/runs/form_mapping.py @@ -81,29 +81,11 @@ data_integration.PlotGSEAEnrichmentPlot: data_integration_forms.PlotGSEAEnrichmentPlotForm, } -_forward_mapping_plots = { - data_preprocessing.FilterProteinsBySamplesMissing: data_preprocessing_forms.FilterProteinsBySamplesMissingPlotForm, - data_preprocessing.FilterByProteinsCount: data_preprocessing_forms.FilterByProteinsCountPlotForm, - data_preprocessing.FilterSamplesByProteinsMissing: data_preprocessing_forms.FilterSamplesByProteinsMissingPlotForm, - data_preprocessing.FilterSamplesByProteinIntensitiesSum: data_preprocessing_forms.FilterSamplesByProteinIntensitiesSumPlotForm, - data_preprocessing.TransformationLog: data_preprocessing_forms.TransformationLogPlotForm, - data_preprocessing.NormalisationByZScore: data_preprocessing_forms.NormalisationByZscorePlotForm, - data_preprocessing.NormalisationByTotalSum: data_preprocessing_forms.NormalisationByTotalSumPlotForm, - data_preprocessing.NormalisationByMedian: data_preprocessing_forms.NormalisationByMedianPlotForm, - data_preprocessing.NormalisationByReferenceProtein: data_preprocessing_forms.NormalisationByReferenceProteinPlotForm, - data_preprocessing.ImputationByMinPerDataset: data_preprocessing_forms.ImputationByMinPerDatasetPlotForm, - data_preprocessing.ImputationByMinPerProtein: data_preprocessing_forms.ImputationByMinPerProteinPlotForm, - data_preprocessing.ImputationByMinPerSample: data_preprocessing_forms.ImputationByMinPerSamplePlotForm, - data_preprocessing.SimpleImputationPerProtein: data_preprocessing_forms.SimpleImputationPerProteinPlotForm, - data_preprocessing.ImputationByKNN: data_preprocessing_forms.ImputationByKNNPlotForm, - data_preprocessing.ImputationByNormalDistributionSampling: data_preprocessing_forms.ImputationByNormalDistributionSamplingPlotForm, - data_preprocessing.FilterPeptidesByPEPThreshold: data_preprocessing_forms.FilterPeptidesByPEPThresholdPlotForm, -} - _reverse_mapping = {v: k for k, v in _forward_mapping.items()} +# all methods of all steps saved as: dict[section][operation][name] : class def generate_hierarchical_dict() -> dict[str, dict[str, dict[str, type[Step]]]]: # Initialize an empty dictionary hierarchical_dict = {} @@ -134,10 +116,6 @@ def _get_form_class_by_step(step: Step) -> type[MethodForm]: raise ValueError(f"No form has been provided for {type(step).__name__} step.") -def _get_plot_form_class_by_step(step: Step) -> type[MethodForm]: - return _forward_mapping_plots.get(type(step)) - - def _get_step_class_by_form(form: MethodForm) -> type[Step]: step_class = _reverse_mapping.get(type(form)) if step_class: @@ -150,11 +128,6 @@ def get_empty_form_by_method(step: Step, run: Run) -> MethodForm: return _get_form_class_by_step(step)(run=run) -def get_empty_plot_form_by_method(step: Step, run: Run) -> MethodForm: - plot_form_class = _get_plot_form_class_by_step(step) - return plot_form_class(run=run) if plot_form_class else None - - def get_filled_form_by_method( step: Step, run: Run, in_history: bool = False ) -> MethodForm: diff --git a/ui/runs/forms/data_preprocessing.py b/ui/runs/forms/data_preprocessing.py index 08590b79..bd09f49d 100644 --- a/ui/runs/forms/data_preprocessing.py +++ b/ui/runs/forms/data_preprocessing.py @@ -63,9 +63,6 @@ class FilterProteinsBySamplesMissingForm(MethodForm): step_size=0.1, initial=0.5, ) - - -class FilterProteinsBySamplesMissingPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BarAndPieChart, label="Graph type", @@ -79,9 +76,6 @@ class FilterByProteinsCountForm(MethodForm): min_value=0, initial=2, ) - - -class FilterByProteinsCountPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BarAndPieChart, label="Graph type", @@ -97,9 +91,6 @@ class FilterSamplesByProteinsMissingForm(MethodForm): step_size=0.1, initial=0.5, ) - - -class FilterSamplesByProteinsMissingPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BarAndPieChart, label="Graph type", @@ -113,9 +104,6 @@ class FilterSamplesByProteinIntensitiesSumForm(MethodForm): min_value=0, initial=2, ) - - -class FilterSamplesByProteinIntensitiesSumPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BarAndPieChart, label="Graph type", @@ -162,9 +150,6 @@ class TransformationLogForm(MethodForm): label="Log transformation base:", initial=LogTransformationBaseType.log2, ) - - -class TransformationLogPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -176,10 +161,6 @@ class TransformationLogPlotForm(MethodForm): class NormalisationByZScoreForm(MethodForm): - pass - - -class NormalisationByZscorePlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -196,10 +177,6 @@ class NormalisationByZscorePlotForm(MethodForm): class NormalisationByTotalSumForm(MethodForm): - pass - - -class NormalisationByTotalSumPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -223,9 +200,6 @@ class NormalisationByMedianForm(MethodForm): step_size=0.1, initial=0.5, ) - - -class NormalisationByMedianPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -254,9 +228,6 @@ class NormalisationByReferenceProteinForms(MethodForm): "protein in each sample. Samples where this value is zero will be " "removed and returned separately." ) - - -class NormalisationByReferenceProteinPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -283,9 +254,6 @@ class ImputationByMinPerDatasetForm(MethodForm): step_size=0.1, initial=0.5, ) - - -class ImputationByMinPerDatasetPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -317,9 +285,6 @@ class ImputationByMinPerProteinForm(MethodForm): step_size=0.1, initial=0.5, ) - - -class ImputationByMinPerProteinPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -348,9 +313,6 @@ class ImputationByMinPerSampleForms(MethodForm): step_size=0.1, initial=0.5, ) - - -class ImputationByMinPerSamplePlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -377,9 +339,6 @@ class SimpleImputationPerProteinForm(MethodForm): label="Strategy", initial=SimpleImputerStrategyType.mean, ) - - -class SimpleImputationPerProteinPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -407,9 +366,6 @@ class ImputationByKNNForms(MethodForm): step_size=1, initial=5, ) - - -class ImputationByKNNPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -442,9 +398,6 @@ class ImputationByNormalDistributionSamplingForm(MethodForm): scaling_factor = CustomFloatField( label="Scaling factor", min_value=0, max_value=1, step_size=0.1, initial=0.5 ) - - -class ImputationByNormalDistributionSamplingPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BoxAndHistogramGraph, label="Graph type", @@ -470,9 +423,6 @@ class FilterPeptidesByPEPThresholdForm(MethodForm): label="Threshold value for PEP", min_value=0, initial=0 ) peptide_df = CustomChoiceField(choices=EmptyEnum, label="peptide_df") - - -class FilterPeptidesByPEPThresholdPlotForm(MethodForm): graph_type = CustomChoiceField( choices=BarAndPieChart, label="Graph type", diff --git a/ui/runs/static/runs/runs.js b/ui/runs/static/runs/runs.js index bd1f0f82..bc8d1220 100644 --- a/ui/runs/static/runs/runs.js +++ b/ui/runs/static/runs/runs.js @@ -55,22 +55,6 @@ $(document).ready(function () { }); }); - // Plot button spinner - $('#plot_form').on('submit', function() { - $('#plot_parameters_submit').html(` - - Plotting... - `); - $('#plot_parameters_submit').prop('disabled', true); - }); - $("#calculateForm").find("#plot_parameters_submit").click(function() { - $(this).html(` - - Plotting... - `); - $(this).prop('disabled', true); - }); - // save current state of accordion in sessionStorage function saveAccordionState() { const panels = []; diff --git a/ui/runs/templates/runs/details.html b/ui/runs/templates/runs/details.html index bb4d85ca..f6b1ddec 100644 --- a/ui/runs/templates/runs/details.html +++ b/ui/runs/templates/runs/details.html @@ -14,8 +14,6 @@ const staticUrl = "{% static 'img/' %}"; - {# TODO 129 Better buttons for analysis and importing #} - {% endblock %} {% block title %} @@ -120,105 +118,41 @@

{{ step.section_heading }} {% endif %} {# show current step #} - {% if not end_of_run %} -
-

{{ display_name }}

-
{{ description }}
- {# if there are plot parameters, display method and plot parameters next to each other #} - {% if plot_form %} -
-
- {% csrf_token %} -
- {{ method_dropdown }} -
-
- {% for field in method_form %} -
- {% if field.field.label %}{{ field.label_tag }}{% endif %} - {{ field }} -
- {% endfor %} -
-
+
+

{{ display_name }}

+
{{ description }}
+ + {# calculateForm #} +
+
+ + {% csrf_token %} +
+ {{ method_dropdown }}
-
- - {% csrf_token %} -
-
- {{ plot_form }} -
-
- - +
+ {% for field in method_form %} +
+ {% if field.field.label %}{{ field.label_tag }}{% endif %} + {{ field }} +
+ {% endfor %}
- {% else %} -
- {% if step != "plot" %} -
- {% csrf_token %} -
- {{ method_dropdown }} -
-
- {% for field in method_form %} -
- {% if field.field.label %}{{ field.label_tag }}{% endif %} - {{ field }} -
- {% endfor %} -
-
- {% if section == "data_preprocessing" %} - - Plot - - {% endif %} - {% if show_table %} - View table - {% endif %} - {% if show_protein_graph %} - View Protein Graph - {% endif %} -
-
- {% else %} -
- {% csrf_token %} - {{ method_dropdown }} -
- {% for field in method_form %} -
- {% if field.field.label %}{{ field.label_tag }}{% endif %} - {{ field }} -
- {% endfor %} -
-
- -
-
+
+ {% if show_table %} + View table + {% endif %} + {% if show_protein_graph %} + View Protein Graph {% endif %}
- {% endif %} +
- {% else %} -

You are at the end of the run. Go back to add more steps of the same section, or add steps of - the next sections on the right.

- {% endif %} +
{# show current plots #} {% if current_plots %} diff --git a/ui/runs/urls.py b/ui/runs/urls.py index f595cfa7..83d92293 100644 --- a/ui/runs/urls.py +++ b/ui/runs/urls.py @@ -9,7 +9,6 @@ path("continue", views.continue_, name="continue"), path("delete", views.delete_, name="delete"), path("detail/", views.detail, name="detail"), - path("/plot", views.plot, name="plot"), path("/tables/", views.tables, name="tables_nokey"), path("/tables//", views.tables, name="tables"), path( diff --git a/ui/runs/views.py b/ui/runs/views.py index 9d96c01b..e21b7b2d 100644 --- a/ui/runs/views.py +++ b/ui/runs/views.py @@ -17,11 +17,11 @@ ) from django.shortcuts import render from django.urls import reverse -from django.conf import settings -from protzilla.run import Run, get_available_run_names -from protzilla.run_v2 import delete_run_folder +from protzilla.constants.paths import WORKFLOWS_PATH +from protzilla.run import Run, get_available_run_names from protzilla.run_helper import log_messages +from protzilla.run_v2 import delete_run_folder from protzilla.stepfactory import StepFactory from protzilla.steps import Step from protzilla.utilities.utilities import ( @@ -29,10 +29,10 @@ format_trace, get_memory_usage, name_to_title, - parameters_from_post, + clean_uniprot_id, + unique_justseen, ) from protzilla.workflow import get_available_workflow_names -from protzilla.constants.paths import WORKFLOWS_PATH from ui.runs.fields import ( make_displayed_history, make_method_dropdown, @@ -42,7 +42,6 @@ from ui.runs.views_helper import display_message, display_messages from .form_mapping import ( - get_empty_plot_form_by_method, get_filled_form_by_method, get_filled_form_by_request, ) @@ -66,6 +65,7 @@ def detail(request: HttpRequest, run_name: str): :return: the rendered details page :rtype: HttpResponse """ + # get current run instance if run_name not in active_runs: active_runs[run_name] = Run(run_name) run: Run = active_runs[run_name] @@ -73,21 +73,16 @@ def detail(request: HttpRequest, run_name: str): request.session['last_view'] = "runs:detail" request.session['run_name'] = run_name - # section, step, method = run.current_run_location() - # end_of_run = not step - if request.POST: method_form = get_filled_form_by_request( request, run ) # TODO maybe not do this as it is done after the calculation if method_form.is_valid(): method_form.submit(run) - plot_form = get_empty_plot_form_by_method(run.current_step, run) # in case the fill_form now would change it method_form.fill_form(run) else: method_form = get_filled_form_by_method(run.current_step, run) - plot_form = get_empty_plot_form_by_method(run.current_step, run) description = run.current_step.method_description @@ -164,7 +159,6 @@ def detail(request: HttpRequest, run_name: str): description=description, method_form=method_form, is_form_dynamic=method_form.is_dynamic, - plot_form=plot_form, current_step_index=run.steps.current_step_index, ), ) @@ -242,6 +236,7 @@ def continue_(request: HttpRequest): return HttpResponseRedirect(reverse("runs:detail", args=(run_name,))) + def delete_(request: HttpRequest): """ Deletes an existing run. The user is redirected to the index page. @@ -249,15 +244,15 @@ def delete_(request: HttpRequest): :param request: the request object :type request: HttpRequest - + :return: the rendered details page of the run :rtype: HttpResponse """ run_name = request.POST["run_name"] if run_name in active_runs: del active_runs[run_name] - - try: + + try: delete_run_folder(run_name) except Exception as e: display_message( @@ -316,35 +311,6 @@ def back(request, run_name): return HttpResponseRedirect(reverse("runs:detail", args=(run_name,))) -def plot(request, run_name): - """ - Creates a plot from the current step/method of the run. - This is only called by the plot button in the data preprocessing section aka when a plot is - simultaneously a step on its own. - Django messages are used to display additional information, warnings and errors to the user. - - :param request: the request object - :type request: HttpRequest - :param run_name: the name of the run - :type run_name: str - - :return: the rendered detail page of the run, now with the plot - :rtype: HttpResponse - """ - if run_name not in active_runs: - active_runs[run_name] = Run(run_name) - run = active_runs[run_name] - parameters = parameters_from_post(request.POST) - - if run.current_step.display_name == "plot": - del parameters["chosen_method"] - run.step_calculate(parameters) - else: - run.current_step.plot(parameters) - - return HttpResponseRedirect(reverse("runs:detail", args=(run_name,))) - - def tables(request, run_name, index, key=None): if run_name not in active_runs: active_runs[run_name] = Run(run_name)