diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..53c4e0e --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,129 @@ +name: Build XAInyPredictor + +on: + push: + workflow_dispatch: + +env: + APP_NAME: XAInyPredictor + APP_VERSION: 1.0.0-beta.1 + +jobs: + build: + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [windows-latest, ubuntu-latest, macos-latest] + + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Dependencias sistema Linux + if: runner.os == 'Linux' + run: | + sudo apt-get update + sudo apt-get install -y libgl1 + + - name: Instalar dependencias Python + run: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + python -m pip install -e . + python -m pip install pyinstaller + + # ========================= + # Windows + # ========================= + + - name: Build Windows app folder + if: runner.os == 'Windows' + run: | + pyinstaller --onedir --noconsole ` + --name "${{ env.APP_NAME }}" ` + --icon "installer/windows/assets/XAInyPredictor.ico" ` + --paths src ` + --collect-all shiny ` + --add-data "src/XAInyPredictor/data;XAInyPredictor/data" ` + --add-data "src/XAInyPredictor/shinyapp/www;XAInyPredictor/shinyapp/www" ` + "src/XAInyPredictor/app.py" + + - name: Install Inno Setup + if: runner.os == 'Windows' + run: | + choco install innosetup --yes --no-progress + + - name: Build Windows installer + if: runner.os == 'Windows' + run: | + & "C:\Program Files (x86)\Inno Setup 6\ISCC.exe" "installer\windows\XAInyPredictor.iss" + + - name: Upload Windows installer + if: runner.os == 'Windows' + uses: actions/upload-artifact@v4 + with: + name: XAInyPredictor-Windows-Installer + path: dist-installer/*.exe + + # ========================= + # Linux + # ========================= + + - name: Build Linux app folder + if: runner.os == 'Linux' + run: | + pyinstaller --onedir \ + --name "${APP_NAME}" \ + --paths src \ + --collect-all shiny \ + --add-data "src/XAInyPredictor/data:XAInyPredictor/data" \ + --add-data "src/XAInyPredictor/shinyapp/www:XAInyPredictor/shinyapp/www" \ + "src/XAInyPredictor/app.py" + + - name: Package Linux + if: runner.os == 'Linux' + run: | + cd dist + tar -czf "${APP_NAME}-${APP_VERSION}-Linux-x86_64.tar.gz" "${APP_NAME}" + + - name: Upload Linux package + if: runner.os == 'Linux' + uses: actions/upload-artifact@v4 + with: + name: XAInyPredictor-Linux-x86_64 + path: dist/XAInyPredictor-*-Linux-x86_64.tar.gz + + # ========================= + # macOS + # ========================= + + - name: Build macOS app + if: runner.os == 'macOS' + run: | + pyinstaller --onedir --windowed \ + --name "${APP_NAME}" \ + --paths src \ + --collect-all shiny \ + --add-data "src/XAInyPredictor/data:XAInyPredictor/data" \ + --add-data "src/XAInyPredictor/shinyapp/www:XAInyPredictor/shinyapp/www" \ + "src/XAInyPredictor/app.py" + + - name: Package macOS app + if: runner.os == 'macOS' + run: | + cd dist + ditto -c -k --keepParent "${APP_NAME}.app" "${APP_NAME}-${APP_VERSION}-macOS.zip" + + - name: Upload macOS package + if: runner.os == 'macOS' + uses: actions/upload-artifact@v4 + with: + name: XAInyPredictor-macOS + path: dist/XAInyPredictor-*-macOS.zip diff --git a/AGENTS.md b/AGENTS.md index 3083f67..3b306fa 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,7 +14,7 @@ shiny run src/XAInyPredictor/app.py --port 8001 # Alternative - **Generate docs**: `doxygen ./docs/Doxyfile` ## Testing -No tests exist. `test/tests.py` is a broken template leftover (imports `mymodule.template` which never existed). No test framework is properly configured. +No tests exist. No test framework is properly configured. ## Type Checking None configured (no mypy, pyright, etc.). @@ -51,4 +51,4 @@ Each use case is a subfolder under `data/` with 4 required files. Add a new use Pre-commit uses black, reorder-python-imports, hadolint, commitizen. Config in `.pre-commit-config.yaml`. ## Commits -Do not commit changes without explicit user approval. \ No newline at end of file +Do not commit changes without explicit user approval. diff --git a/README.md b/README.md index 30851d3..1134455 100644 --- a/README.md +++ b/README.md @@ -1,240 +1,171 @@ # XAInyPredictor -XAInyPredictor is a Python-based application built using Shiny for running the Radioiodine (RAI) classification model trained with a dataset of 294 records and 18 features and visually exploring the features that contribute most to the classificaiton. +XAInyPredictor is a Shiny for Python research prototype for interpretable stratification support. It lets users select a configured use case, build a working input set, confirm that set for analysis, inspect feature context, review model-based stratification outputs, compare profiles, and export results for review. -[![Commitizen friendly](https://img.shields.io/badge/commitizen-friendly-brightgreen.svg)](http://commitizen.github.io/cz-cli/) +The app supports both patient-level cohorts and candidate-level sets. Labels such as patient, cohort, candidate, candidate set, and reference candidates are configured per use case in `config.yml`. -## Table of Contents +> Research prototype notice: outputs support exploratory clinical or translational research workflows. Workflow utility, model interpretation, and clinical relevance should be validated with appropriate domain collaborators before operational use. -- [XAInyPredictor](#xainypredictor) - - [Table of Contents](#table-of-contents) - - [Getting Started](#getting-started) - - [What is XAInyPredictor?](#what-is-xainypredictor) - - [Installation](#installation) - - [Initiate the app](#initiate-the-app) - - [Usage](#usage) - - [1. Data Input](#1-data-input) - - [2. Data Exploration](#2-data-exploration) - - [3. Prediction](#3-prediction) - - [Development](#development) - - [Architecture](#architecture) - - [Testing](#testing) - - [Building](#building) - - [Documentation](#documentation) +## Run The App -## Getting Started +From the repository root: -### What is XAInyPredictor? - -XAInyPredictor is a Python-based application built using Shiny that allows to run the **Radioiodine (RAI) classification model** on patient data and explore the results through intuitive data tables and plots. - -### What is the RAI classification model? - -The RAI classification model is a machine learning algorithm based on Kolmogorov-Arnold Networks (KAN) that is used to identify thyroid cancer patients that are more prone to develop Radioiodine-refractory disease (RAI-R). - -RAI-R typically develops in a subset of patients with differentiated thyroid cancer —primarily papillary and follicular thyroid carcinomas—who lose the ability to uptake or respond to radioactive iodine (I-131). This condition is associated with poor prognosis and presents significant therapeutic challenges. - -The RAI classification model was trained in a dataset of 295 patients and 16 features. - -### Installation - -Install XAInyPredictor using `pip`. First, make sure that pip is updated: - -```bash -python -m pip install --upgrade pip +```powershell +xainypredictor ``` -Then, there are two options to install the app: - -* From a cloned repository: +The default URL is: -```bash -git clone https://github.com/REPO4EU/XAInyPredictor.git -pip install path/to/XAInyPredictor +```text +http://127.0.0.1:8001 ``` -* From a `.zip` file. In this case, first download the zip file from the [GitHub website](https://github.com/REPO4EU/XAInyPredictor/archive/refs/heads/main.zip) and then execute the following command: +Alternative command: -```bash -pip install XAInyPredictor_version.zip +```powershell +shiny run src\XAInyPredictor\app.py --port 8001 ``` -### Initiate the app +If port `8001` is already in use, choose another port: -The basic command to start the application is the following: - -```bash -xainypredictor +```powershell +shiny run src\XAInyPredictor\app.py --port 8002 ``` -This will automatically open a browser web page [http://127.0.0.1:8001](http://127.0.0.1:8001) to access the XAInyPredictor interface. +## Supported Use Cases -Alternatively, you can use this command: +Use cases live under `src/XAInyPredictor/data/`. -```bash -shiny run src/XAInyPredictor/app.py -``` +- `rai`: RAI-R Predictor for thyroid cancer patient stratification. +- `mock`: Diabetes Risk Predictor demo use case. +- `neoag`: Neoantigen Candidate Prioritizer for peptide-HLA candidate prioritization. -Upon running the application, open a web browser and navigate to [http://127.0.0.1:8000](http://127.0.0.1:8000) to access the XAInyPredictor interface. +The startup modal selects the initial use case. The top navigation selector can switch use case at runtime; switching clears the current working set and shows a loading state while the selected model context is prepared. -If you get the following error: +## Workflow -```bash -ERROR: [Errno 48] error while attempting to bind on address ('127.0.0.1', 8000): [errno 48] address already in use -``` +1. Start the app and choose a use case. +2. In **Data Input**, review the selected use case summary. +3. Build the current working set using one or more sources: + - **Manual Entry** + - **Upload File** + - **Example Cohort** or **Example Candidate Set** +4. Review the table at the top of **Data Input**. The table remains the working set even when switching between input-source tabs. +5. Use **Delete Selected** or **Reset** if needed. +6. Click **Confirm** to lock the current working set for downstream analysis. +7. Continue to **Cohort/Candidate Context** and **Patient/Candidate Stratification**. -It means that the port 8000 is already being used. To solve this, specify a different port, like this: +Steps 2 and 3 are blocked until the current input set is confirmed. If the working set changes, it must be confirmed again so downstream plots, tables, and exports use the latest data. -```bash -shiny run src/XAInyPredictor/app.py --port 8787 -``` +## Data Input -Another alternative is to use Docker. First build the docker image: +Uploaded CSV, TSV, or Excel files must contain the feature columns defined by the active use case. The safest option is to download the CSV template and data dictionary from **Data Input** after selecting the use case. -```bash -docker build -t xainypredictor-docker . +General format: + +```csv +ID,Feature 1,Feature 2,Feature 3 +1,value,value,value ``` -And then, run the docker image: +Notes: -```bash -docker run -p 8000:8000 xainypredictor-docker -``` +- `ID` is optional. If it is missing, the app creates sequential IDs. +- Numeric fields accept comma or dot decimal separators in manual entry. +- Numeric fields must respect configured min/max ranges. +- Categorical fields must use the allowed values from the active use case configuration. +- Text fields can be used for identifiers such as peptide sequences or HLA alleles. +- If required columns are missing, categories are invalid, or values are out of range, the app shows validation feedback before the set can be confirmed. -The app should be available in [http://localhost:8000](http://localhost:8000). +## Analysis Views +The labels below adapt to each use case. -## Usage +- **Cohort/Candidate Context**: feature distributions, selected record highlighting, reference population selection, and feature summary statistics. +- **Patient/Candidate Stratification**: selected record score, decision threshold, assigned class/group, stratification table, and profile comparison controls. +- **Cohort/Candidate Set**: set-level summary and score distribution. +- **Model**: model card, intended use, validation status, limitations, and global feature importance. +- **Reference Patients/Candidates**: closest reference-neighbor context when a reference set is configured. -### 1. Data Input +For the neoantigen use case, closest reference candidates come from the model development reference dataset included with the package. The app uses that dataset internally for scaling, thresholding, and similarity calculations, but the reference view reports anonymized neighbor ranks, distances, scores, and classes rather than row-level reference features. -The first step for using the app is to provide input data. The app gives you three options in the left panel: -* **Manual Entry**: This option provides you with a form in which patient data can be entered manually. Every time a patient is introduced, it appears in the table below the form (see **Current Patient Cohort**). Patient data can also be removed by clicking at one or multiple rows of the table (by using *Control + Click*) and clicking the Delete Selected button. -* **Upload File**: This gives you the possibility to upload a file containing all patient data that you want to analyze. The data can be introduced using a tab-separated file (TSV), comma-separated file (CSV) or Excel file that contains the required columns. Please, use the [Excel template](https://github.com/REPO4EU/XAInyPredictor/blob/development/src/XAInyPredictor/shinyapp/www/data_template.xlsx) as reference. -* **Example Cohort**: You can also use the app with the example cohort, which is the data that has been used to train the model. +## Downloads -The features used to train the model are the following: +The stratification view includes a **Download outputs** menu: -* **Age**: Age of the participant in years. -* **Gender**: The gender of the participant. Possible values: M (Male) or F (Female). -* **BMI**: The Body Mass Index of the participant. -* **Tumor size (cm)**: The size of the tumor in centimeters. -* **Tumor stage**: The stage of the tumor, based on standard staging criteria. Possible values: T1a, T1b, T2, T3a, T3b, T4a, T4b, Tx. -* **Node stage**: The stage of the lymph node involvement. Possible values: N0, N1a, N1b, Nx. -* **Metastases**: Indicates the presence of metastases. Possible values: Mx or M1. -* **ATA risk**: The risk classification according to the American Thyroid Association. Possible values: Low, Intermediate, High. -* **Histology**: The type of tumor histology observed in the participant. Possible values: OTC (oncocytic thyroid carcinoma), PTC (papillary thyroid carcinoma), FTC (follicular thyroid carcinoma), PDTC (poorly differentiated thyroid carcinoma). -* **ETE**: Whether extrathyroidal extension was observed. Possible values: YES or NO. -* **Multifocality**: Indicates whether the tumor was multifocal. Possible values: YES or NO. -* **Vascular invasion**: Whether there was evidence of vascular invasion. Possible values: YES or NO. -* **Resection**: The status of the surgical resection. Possible values: R0, R1, R2. -* **Goal of RAI**: The goal of the radioiodine treatment. Possible values: THERAPEUTIC, ’ADIUVANT, ABLATION. +- Report package +- Stratification results +- Closest reference records +- Cohort/candidate set summary +## Use Case Configuration -You can use this [template](https://github.com/REPO4EU/XAInyPredictor/blob/development/src/XAInyPredictor/shinyapp/www/data_template.xlsx) as example: +Each use case folder requires: -```tsv -ID Age Gender BMI Tumor size (cm) Tumor stage Node stage Metastases ATA risk Histology ETE Multifocality Vascular invasion Resection Goal of RAI -RAI001 50 M 27 5 T1a N0 M1 Low PTC NO NO NO R0 ABLATION -RAI002 23 F 19 2 T1b N1a Mx Intermediate OTC NO YES NO R1 THERAPEUTIC +```text +config.yml +model.pkl +feature_order.txt +example_data.csv ``` +Use cases may also include additional files such as `reference_data.csv`. -### 2. Data Exploration +`config.yml` controls: -The second step (**Data Exploration**) allows you to explore the distribution (center panel) and statistics (right panel) of each features in your input data. The three options in the left panel allow you to control the visualization: -* **Feature to plot**: To select the feature to visualize. A numeric feature (e.g., Age, BMI...) will produce a hisogram plot, and a categorical feature (e.g., Tumor stage, Node stage...) will produce a bar plot. -* **Reference population**: This allows you to select the population that you want to show in the distribution plot. The options are **Input data** (it will show the distribution of the input data provided by the user) and **Model** (it will show the distribution of the data used to train the model). -* **Select patient**: This allows you to compare the feature value from a specific participant with the reference population selected. The value of the patient will be highlighted in red in the plot. +- model name, description, and metadata; +- feature definitions, labels, defaults, ranges, allowed values, and roles; +- target column and positive/negative classes; +- entity labels used across the UI; +- tab titles, help text, validation copy, and download labels; +- model-specific options such as false-negative-rate defaults and reference-set limits. +When a new use case folder is added under `src/XAInyPredictor/data/`, restart or refresh the app to make it available in the startup modal and navigation selector. The GitHub Actions build packages the full `data/` directory, so committed use cases are included in generated desktop artifacts. -### 3. Prediction +## Development -The third step (**Prediction**) shows the prediction results and some plots to help interpret the prediction. The results are shown in the center, divided in two panels: -* **Prediction Results Table**: The table shows the predictions of the model, in two columns: - * **Probability**: Indicates the probability of a participant to be RAI-Refractory (resistant). The higher the probability, the more likely it is. - * **RAI-R class**: Classifies the patient in Refractory or Not Refractory. YES = Refractory (i.e., resistant to the treatment); NO = Not Refractory (non-resistant to the treatment). +Install development dependencies: -* **Results Plot**: There are two types of plots (depending on the view selected in the left panel): - * **Feature Analysis (Radar)**: The radar plot compares the values of the features from a selected patient with three different average values. Comparing these values help to evaluate if a patient has been correctly classified or not. The average values are: - * Average all patients: The average values from all patients in the model. - * Avg. RAI-R negative: The average values from all RAI-R negative patients in the model (green). - * Avg. RAI-R positive: The average values from all RAI-R positive patients in the model (red). - * **Distance Analysis (Curves)**: The curves plot displays the distribution of values from a specific feature across the patients in the model, ordered from lowest to highest (blue dots). It highlights in red the patient selected, putting in context the feature values of the patient in comparison with the values from the rest of patients. +```powershell +pip install -e ".[dev]" +``` -The left panel gives you different options to modify the visualization of the results: -* **Select patient**: This allows to select the participant from which the results will be visualized. This prediction results will be highlighte in yellow in the Prediction Results Table, and the values of the features will be shown in red in the Radar and Curve plots. -* **Min. % false negatives**: This threshold controls the False Negative Ratio of the model. - * 0% False Negative Ratio: We refuse to miss any true RAI-Refractory patients. The model will classify a patient as Refractory even if the probability is low (High Sensitivity). - * Higher False Negative Ratio: We accept missing some Refractory patients to ensure those we treat are definitely Refractory (High Specificity). -* **Select View**: Allows you to change the results plot (either Radar or Curve plot). -* **Select features to view**: Allows you to select the features to be visualized in the Results Plot. -* **Radar Plot Elements:**: Allows you to control the different elements shown in the Radar Plot. +Run pre-commit: +```powershell +pre-commit run --all-files +``` -## Development +Build the package: -### Architecture +```powershell +python -m build +``` -The organization of the XAInyPredictor repository is summarized here: +Generate docs: -``` -XAInyPredictor/ -├── build -├── CHANGELOG.md -├── Dockerfile <- Definition of the Docker image -├── docs <- Documentation generated using Doxygen -├── pyproject.toml <- Definition of the python project -├── README.md <- The top-level README for developers using this package. -├── requirements.txt <- Python packages required -├── setup.md -├── sonar-project.properties -├── src <- Source code -│ ├── XAInyPredictor -│ │ ├── __init__.py -│ │ ├── __pycache__ -│ │ ├── app.py <- Main script that defines the app -| | |-- data -| | | |-- Radioiodine\ database_October_RUMC.xlsx <- Dataset used to train the model -| | | |-- current_best_formula.txt <- Formula obtained from training the model -| | | |-- formula.pkl <- Pickle file with formula obtained from training the model -| | | |-- mock.csv <- Mock raw data -| | | |-- mock_processed.csv <- Mock processed data -| | | |-- processed_data.csv <- Dataset used to train the model normalized -| | | `-- raw_data.csv <- Dataset used to train the model in csv -│ │ ├── data_processing_test.ipynb <- Python notebook to test the data processing pipeline step by step (could be removed) -│ │ ├── main.py <- Script to launch the app -│ │ ├── modules -| | | |-- data_processing.py <- Script the data processing functions (to clean and normalize the data) -| | | `-- rai.py <- Script containing the functions of the RAI model -| | `-- shinyapp <- Scripts for specific aspects of the app -| | |-- __init__.py -| | |-- data_exploration.py <- Data Exploration tab -| | |-- data_input.py <- Data Input tab -| | |-- prediction.py <- Prediction tab -| | `-- www <- Scripts related with the layout of the app -│ └── XAInyPredictor.egg-info -└── test <- Unit tests +```powershell +doxygen .\docs\Doxyfile ``` -### Testing +## GitHub Actions Build -No tests have been implemented for the moment. +The repository includes a **Build XAInyPredictor** workflow in `.github/workflows/build.yml`. -### Building +It runs on every push and can also be started manually from GitHub: -Build XAInyPredictor with: +1. Open the repository on GitHub. +2. Go to **Actions**. +3. Select **Build XAInyPredictor**. +4. Click **Run workflow**. +5. When the run finishes, download the generated artifacts from the workflow run summary. -```bash -python3 -m build -``` - -## Documentation +The workflow builds: -Documentation is generated in HTML during CI/CD, but it can be manually generated with: +- a Windows installer; +- a Linux tarball; +- a macOS zip package. -```bash -doxygen ./docs/Doxyfile -``` +## Testing Status -The generated html files will be available at `docs/html` folder. +No automated test suite is currently configured. diff --git a/installer/windows/XAInyPredictor.iss b/installer/windows/XAInyPredictor.iss new file mode 100644 index 0000000..1c2a9ff --- /dev/null +++ b/installer/windows/XAInyPredictor.iss @@ -0,0 +1,42 @@ +#define MyAppName "XAInyPredictor" +#define MyAppVersion "1.0.0" +#define MyAppPublisher "REPO4EU Consortium" +#define MyAppURL "https://repo4.eu/" +#define MyAppExeName "XAInyPredictor.exe" + +[Setup] +AppId={{6F8D6C10-6D5C-4D7E-8B92-5E2561F35E21} +AppName={#MyAppName} +AppVersion={#MyAppVersion} +AppPublisher={#MyAppPublisher} +AppPublisherURL={#MyAppURL} +AppSupportURL={#MyAppURL} +AppUpdatesURL={#MyAppURL} +DefaultDirName={localappdata}\Programs\{#MyAppName} +DefaultGroupName={#MyAppName} +OutputDir=..\..\dist-installer +OutputBaseFilename=XAInyPredictor-Setup-{#MyAppVersion}-Windows +Compression=lzma +SolidCompression=yes +WizardStyle=modern +PrivilegesRequired=lowest +SetupIconFile=assets\XAInyPredictor.ico +UninstallDisplayIcon={app}\XAInyPredictor.ico +WizardSmallImageFile=assets\wizard-small.png +WizardImageFile=assets\wizard-large.png +CloseApplications=yes +RestartApplications=no + +[Files] +Source: "..\..\dist\XAInyPredictor\*"; DestDir: "{app}"; Flags: ignoreversion recursesubdirs createallsubdirs +Source: "assets\XAInyPredictor.ico"; DestDir: "{app}"; Flags: ignoreversion + +[Icons] +Name: "{group}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; IconFilename: "{app}\XAInyPredictor.ico" +Name: "{autodesktop}\{#MyAppName}"; Filename: "{app}\{#MyAppExeName}"; IconFilename: "{app}\XAInyPredictor.ico"; Tasks: desktopicon + +[Tasks] +Name: "desktopicon"; Description: "Create a desktop shortcut"; GroupDescription: "Additional icons:"; Flags: unchecked + +[Run] +Filename: "{app}\{#MyAppExeName}"; Description: "Launch {#MyAppName}"; Flags: nowait postinstall skipifsilent diff --git a/installer/windows/assets/XAInyPredictor.ico b/installer/windows/assets/XAInyPredictor.ico new file mode 100644 index 0000000..d4c51dc Binary files /dev/null and b/installer/windows/assets/XAInyPredictor.ico differ diff --git a/installer/windows/assets/XAInyPredictor.ico.source.png b/installer/windows/assets/XAInyPredictor.ico.source.png new file mode 100644 index 0000000..f4327b6 Binary files /dev/null and b/installer/windows/assets/XAInyPredictor.ico.source.png differ diff --git a/installer/windows/assets/XAInyPredictor_full_logo.ico b/installer/windows/assets/XAInyPredictor_full_logo.ico new file mode 100644 index 0000000..53163ca Binary files /dev/null and b/installer/windows/assets/XAInyPredictor_full_logo.ico differ diff --git a/installer/windows/assets/wizard-large.png b/installer/windows/assets/wizard-large.png new file mode 100644 index 0000000..b6f9a75 Binary files /dev/null and b/installer/windows/assets/wizard-large.png differ diff --git a/installer/windows/assets/wizard-small.png b/installer/windows/assets/wizard-small.png new file mode 100644 index 0000000..5b831a1 Binary files /dev/null and b/installer/windows/assets/wizard-small.png differ diff --git a/pyproject.toml b/pyproject.toml index f937ec0..d9022fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,8 +23,7 @@ where = ["src"] "XAInyPredictor" = [ "data/**/*", "shinyapp/www/*", - "shinyapp/www/static/img/*", - "shinyapp/www/pwa/*" + "shinyapp/www/static/img/*" ] # set the entrypoint to have executable packages diff --git a/src/XAInyPredictor/app.py b/src/XAInyPredictor/app.py index eb17687..d9315b3 100644 --- a/src/XAInyPredictor/app.py +++ b/src/XAInyPredictor/app.py @@ -1,18 +1,100 @@ """Module providing a shiny UI.""" -import pandas as pd +import os from pathlib import Path +import sys + +import time +import socket +import matplotlib +import threading +import webbrowser +import pandas as pd + +from sklearn.preprocessing import StandardScaler from shiny import App, reactive, Session, ui +matplotlib.use("Agg") + from XAInyPredictor.shinyapp import data_input, data_exploration, prediction from XAInyPredictor.modules.model_registry import discover_use_cases, load_use_case, UseCaseNotFoundError from XAInyPredictor.modules.data_processing import clean_data, process_raw_data, process_raw_form_data +from XAInyPredictor.modules.neoag_processing import prepare_neoantigen_features from XAInyPredictor.modules.xai import delta_xai, read_delta_xai_formula, split_data_with_known_target, threshold_for_target_fnr -WWW_DIR = Path(__file__).parent / "shinyapp" / "www" +def _package_resource_dir(*parts: str) -> Path: + candidates = [] + + if getattr(sys, "frozen", False): + bundle_root = Path(getattr(sys, "_MEIPASS", Path(sys.executable).parent)) + candidates.extend( + [ + bundle_root / "XAInyPredictor" / Path(*parts), + Path(sys.executable).parent / "_internal" / "XAInyPredictor" / Path(*parts), + ] + ) + + candidates.append(Path(__file__).resolve().parent / Path(*parts)) + + for candidate in candidates: + if candidate.exists(): + return candidate + + return candidates[0] + + +WWW_DIR = _package_resource_dir("shinyapp", "www") -USE_CASES = discover_use_cases() _USE_CASE_CACHE = {} +_ACTIVE_SESSION_COUNT = 0 +_SESSION_SHUTDOWN_LOCK = threading.Lock() +_SESSION_SHUTDOWN_TIMER = None + + +def _exit_on_last_session_end_enabled() -> bool: + return os.environ.get("XAINYPREDICTOR_EXIT_ON_LAST_SESSION_END") == "1" + + +def _shutdown_process_if_idle() -> None: + with _SESSION_SHUTDOWN_LOCK: + if _ACTIVE_SESSION_COUNT > 0: + return + os._exit(0) + + +def _register_session_lifecycle(session: Session) -> None: + if not _exit_on_last_session_end_enabled(): + return + + global _ACTIVE_SESSION_COUNT, _SESSION_SHUTDOWN_TIMER + + with _SESSION_SHUTDOWN_LOCK: + _ACTIVE_SESSION_COUNT += 1 + if _SESSION_SHUTDOWN_TIMER is not None: + _SESSION_SHUTDOWN_TIMER.cancel() + _SESSION_SHUTDOWN_TIMER = None + + def _on_session_ended() -> None: + global _ACTIVE_SESSION_COUNT, _SESSION_SHUTDOWN_TIMER + + with _SESSION_SHUTDOWN_LOCK: + _ACTIVE_SESSION_COUNT = max(0, _ACTIVE_SESSION_COUNT - 1) + if _ACTIVE_SESSION_COUNT > 0: + return + + _SESSION_SHUTDOWN_TIMER = threading.Timer(3.0, _shutdown_process_if_idle) + _SESSION_SHUTDOWN_TIMER.daemon = True + _SESSION_SHUTDOWN_TIMER.start() + + session.on_ended(_on_session_ended) + +def _get_use_cases() -> dict: + return discover_use_cases() + + +def _use_case_choices() -> dict: + return {name: info["display_name"] for name, info in _get_use_cases().items()} + def _load_model_data_cached(use_case: str): if use_case in _USE_CASE_CACHE: @@ -23,35 +105,45 @@ def _load_model_data_cached(use_case: str): def build_startup_modal(): - use_case_choices = {name: info["display_name"] for name, info in USE_CASES.items()} + use_case_choices = _use_case_choices() return ui.modal( ui.tags.div( ui.tags.h3("Select Use Case", style="color: #007bff; font-weight: bold; text-align: center;"), - ui.tags.p("Choose which prediction model to use:", style="text-align: center;"), + ui.tags.p("Interpretable stratification support for clinical research workflows.", style="text-align: center; color: #5f6f7f;"), + ui.tags.p("Choose which stratification support model to use:", style="text-align: center;"), ui.br(), ui.input_select( id="startup_use_case", label="Available Use Cases:", choices=use_case_choices, - selected="rai", + selected=DEFAULT_USE_CASE if DEFAULT_USE_CASE in use_case_choices else None, ), ), title="Welcome to XAInyPredictor", easy_close=False, footer=ui.TagList( - ui.input_action_button("confirm_startup_use_case", "Start", class_="btn-primary", onclick="Shiny.setInputValue('startup_confirmed', 'yes');") + ui.input_action_button( + "confirm_startup_use_case", + "Start", + class_="btn-primary", + ) ) ) def build_page_header(config: dict, current_use_case: str): - use_case_choices = {name: info["display_name"] for name, info in USE_CASES.items()} + use_case_choices = _use_case_choices() return ui.tags.div( ui.tags.div( - ui.tags.h3( - config.get("titles", {}).get("app_title", "XAInyPredictor"), - style="margin: 0; color: #007bff; font-weight: 800; letter-spacing: -1px;" + ui.tags.div( + ui.tags.h3( + config.get("titles", {}).get("app_title", "XAInyPredictor"), + ), + ui.tags.span( + config.get("titles", {}).get("app_subtitle", "Interpretable Stratification Support"), + class_="navigation-subtitle", + ), ), id="app-title", class_="navigation-title", @@ -68,7 +160,7 @@ def build_page_header(config: dict, current_use_case: str): ui.tags.div( ui.input_action_button( id="tab_data_exploration", - label=config.get("titles", {}).get("tab_data_exploration", "2. Data Exploration"), + label=config.get("titles", {}).get("tab_data_exploration", "2. Cohort Context"), class_="navbar-button", ), id="div-navbar-map", @@ -76,7 +168,7 @@ def build_page_header(config: dict, current_use_case: str): ui.tags.div( ui.input_action_button( id="tab_prediction", - label=config.get("titles", {}).get("tab_prediction", "3. Prediction"), + label=config.get("titles", {}).get("tab_prediction", "3. Patient Stratification"), class_="navbar-button", ), id="div-navbar-plot", @@ -118,9 +210,9 @@ def build_page_header(config: dict, current_use_case: str): def build_ui(config: dict, current_use_case: str): page_dependencies = ui.tags.head( - ui.tags.link(rel="stylesheet", type="text/css", href="layout.css"), - ui.tags.link(rel="stylesheet", type="text/css", href="style.css"), - ui.tags.script(src="index.js"), + ui.tags.link(rel="stylesheet", type="text/css", href="layout.css?v=20260608a"), + ui.tags.link(rel="stylesheet", type="text/css", href="style.css?v=20260608a"), + ui.tags.script(src="index.js?v=20260608a"), ui.tags.meta(name="description", content=config.get("description", "XAI Predictor")), ui.tags.meta(name="theme-color", content="#000000"), ui.tags.meta(name="viewport", content="width=device-width, initial-scale=1"), @@ -147,15 +239,61 @@ def build_ui(config: dict, current_use_case: str): id="run-analysis-container", class_="page-main" ), + ui.tags.div( + ui.tags.div( + ui.tags.div(class_="loading-spinner"), + ui.tags.div("Loading use case...", id="loading-overlay-title", class_="loading-title"), + ui.tags.div("Preparing model outputs and reference context.", id="loading-overlay-subtitle", class_="loading-subtitle"), + class_="loading-panel", + ), + id="use-case-loading-overlay", + class_="loading-overlay", + **{"aria-live": "polite"}, + ), class_="page-layout" ), + ui.tags.div( + config.get("titles", {}).get( + "footer_text", + "Research prototype: clinical workflow utility and interpretation should be validated with clinical collaborators.", + ), + class_="app-validation-footer", + ), title=config.get("titles", {}).get("app_title", "XAInyPredictor"), ) +async def update_navigation_labels(session: Session, config: dict): + titles = config.get("titles", {}) + labels = config.get("labels", {}) + entity = labels.get("entity", {}) + ui.update_action_button("tab_data_input", label=titles.get("tab_data_input", "1. Data Input")) + ui.update_action_button("tab_data_exploration", label=titles.get("tab_data_exploration", "2. Cohort Context")) + ui.update_action_button("tab_prediction", label=titles.get("tab_prediction", "3. Patient Stratification")) + await session.send_custom_message( + "setStratificationTabLabels", + { + "patient": entity.get("singular_title", "Patient"), + "cohort": entity.get("set_title", "Cohort"), + "reference": entity.get("reference_title", "Reference Patients"), + }, + ) + await session.send_custom_message( + "setDataInputMethodLabels", + { + "form": labels.get("manual_entry", "Manual Entry"), + "file": labels.get("upload_file", "Upload File"), + "example": labels.get("example_cohort", "Example Cohort"), + }, + ) + + def load_model_data(use_case_name: str): use_case_data = load_use_case(use_case_name) config = use_case_data["config"] + if config.get("use_case_type") == "neoantigen": + return load_neoag_model_data(use_case_data) + encoding_dict = config.get("encoding", {}) example_raw_df = use_case_data["example_data"].copy() feature_order = use_case_data["feature_order"] @@ -216,6 +354,116 @@ def load_model_data(use_case_name: str): } +def _with_id(df: pd.DataFrame) -> pd.DataFrame: + df = df.copy().reset_index(drop=True) + if "ID" not in df.columns: + df.insert(0, "ID", range(1, len(df) + 1)) + return df + + +def _scale_features(features: pd.DataFrame, scaler: StandardScaler, feature_order: list[str]) -> pd.DataFrame: + return pd.DataFrame( + scaler.transform(features[feature_order]), + columns=feature_order, + index=features.index, + ) + + +def _limit_neoag_reference_df(reference_df: pd.DataFrame, target_col: str, max_candidates: int | None) -> pd.DataFrame: + if not max_candidates or len(reference_df) <= max_candidates: + return reference_df + + if target_col not in reference_df.columns: + return reference_df.sample(n=max_candidates, random_state=42).reset_index(drop=True) + + sampled_parts = [] + class_counts = reference_df[target_col].value_counts() + for class_value, class_count in class_counts.items(): + class_fraction = class_count / len(reference_df) + class_sample_size = max(1, round(max_candidates * class_fraction)) + class_df = reference_df[reference_df[target_col] == class_value] + sampled_parts.append( + class_df.sample(n=min(class_sample_size, len(class_df)), random_state=42) + ) + + sampled_df = pd.concat(sampled_parts, ignore_index=True) + if len(sampled_df) > max_candidates: + sampled_df = sampled_df.sample(n=max_candidates, random_state=42) + + return sampled_df.sample(frac=1, random_state=42).reset_index(drop=True) + + +def load_neoag_model_data(use_case_data: dict): + config = use_case_data["config"] + use_case_path = use_case_data["path"] + feature_order = use_case_data["feature_order"] + model_file = use_case_path / "model.pkl" + formula_file = use_case_path / "feature_order.txt" + target_col = config.get("target_column", "Qualitative_Measure") + + reference_raw_df = pd.read_csv(use_case_path / "reference_data.csv") + reference_raw_df = _limit_neoag_reference_df( + reference_raw_df, + target_col=target_col, + max_candidates=config.get("max_reference_candidates"), + ) + example_raw_df = use_case_data["example_data"].copy() + + reference_prepared = prepare_neoantigen_features(reference_raw_df, feature_order, target=target_col) + example_prepared = prepare_neoantigen_features(example_raw_df, feature_order, target=target_col) + + scaler = StandardScaler() + train_scaled = pd.DataFrame( + scaler.fit_transform(reference_prepared.features[feature_order]), + columns=feature_order, + index=reference_prepared.features.index, + ) + test_scaled = _scale_features(example_prepared.features, scaler, feature_order) + + delta_formula, _, _, feats_in_formula = read_delta_xai_formula( + str(model_file), + str(formula_file), + simplify_formula=False, + ) + combined_scaled = pd.concat([train_scaled, test_scaled], ignore_index=True) + d_combined = delta_xai(delta_formula, combined_scaled, feature_order) + d_train = d_combined.iloc[:len(train_scaled)].reset_index(drop=True) + d_test = d_combined.iloc[len(train_scaled):].reset_index(drop=True) + + y_train = reference_prepared.target if reference_prepared.target is not None else pd.Series([0] * len(reference_raw_df)) + y_test = example_prepared.target if example_prepared.target is not None else pd.Series([0] * len(example_raw_df)) + default_threshold, _ = threshold_for_target_fnr( + y_train.to_numpy(), + d_train["pred_prob"].to_numpy(), + target_fnr=float(config.get("false_negative_rate", 0.10)), + ) + + x_train = _with_id(reference_prepared.features) + x_test = _with_id(example_prepared.features) + x_train_raw = _with_id(reference_raw_df.drop(columns=[target_col], errors="ignore")) + x_test_raw = _with_id(example_raw_df.drop(columns=[target_col], errors="ignore")) + + return { + "config": config, + "delta_formula": delta_formula, + "scaler": scaler, + "example_raw_df": _with_id(example_raw_df.drop(columns=[target_col], errors="ignore")), + "example_proc_df": _with_id(example_prepared.features), + "X_TRAIN_RAW": x_train_raw, + "X_TEST_RAW": x_test_raw, + "X_TRAIN": x_train, + "X_TEST": x_test, + "Y_TRAIN": y_train, + "Y_TEST": y_test, + "D_TRAIN": d_train, + "D_TEST": d_test, + "FEATURE_ORDER_CLEAN": feature_order, + "FEATURE_ORDER_DISPLAY": [feat.replace("_", " ") for feat in feature_order], + "FEATS_IN_FORMULA": [feat.replace("_", " ") for feat in feats_in_formula], + "DEFAULT_THRESHOLD": default_threshold, + } + + DEFAULT_USE_CASE = "rai" MODEL_DATA = load_model_data(DEFAULT_USE_CASE) DEFAULT_CONFIG = MODEL_DATA["config"] @@ -224,6 +472,8 @@ def load_model_data(use_case_name: str): def server(input, output, session: Session): + _register_session_lifecycle(session) + current_use_case = reactive.Value(DEFAULT_USE_CASE) model_data = reactive.Value(MODEL_DATA) config = reactive.Value(DEFAULT_CONFIG) @@ -231,6 +481,7 @@ def server(input, output, session: Session): patient_selected = reactive.Value(None) prob_threshold = reactive.Value(MODEL_DATA["DEFAULT_THRESHOLD"]) data_available = reactive.Value(False) + pending_use_case = reactive.Value(None) delta_test_reactive = reactive.Value(MODEL_DATA["D_TEST"]) x_test_reactive = reactive.Value(MODEL_DATA["X_TEST"]) @@ -239,12 +490,21 @@ def server(input, output, session: Session): def _show_startup_modal(): ui.modal_show(build_startup_modal()) + @reactive.Effect + def _refresh_use_case_selector(): + choices = _use_case_choices() + ui.update_select( + "use_case_selector", + choices=choices, + selected=current_use_case.get() if current_use_case.get() in choices else None, + ) + startup_initialized = reactive.Value(False) @reactive.Effect @reactive.event(input.confirm_startup_use_case) async def _on_startup_confirm(): - if input.startup_confirmed() == "yes" and not startup_initialized.get(): + if not startup_initialized.get(): selected_use_case = input.startup_use_case() startup_initialized.set(True) @@ -256,18 +516,21 @@ async def _on_startup_confirm(): except (UseCaseNotFoundError, Exception) as e: ui.notification_show(f"Error loading use case: {e}", type="error") startup_initialized.set(False) + await session.send_custom_message("setUseCaseLoading", {"visible": False}) return current_use_case.set(selected_use_case) model_data.set(new_model_data) config.set(new_model_data["config"]) + await update_navigation_labels(session, new_model_data["config"]) delta_test_reactive.set(new_model_data["D_TEST"]) x_test_reactive.set(new_model_data["X_TEST"]) prob_threshold.set(new_model_data["DEFAULT_THRESHOLD"]) - ui.update_select("use_case_selector", selected=selected_use_case) + ui.update_select("use_case_selector", choices=_use_case_choices(), selected=selected_use_case) ui.modal_remove() + await session.send_custom_message("setUseCaseLoading", {"visible": False}) @reactive.Effect @reactive.event(input.use_case_selector) @@ -276,82 +539,139 @@ def _switch_use_case(): return new_use_case = input.use_case_selector() if new_use_case and new_use_case != current_use_case.get(): + choices = _use_case_choices() + if new_use_case not in choices: + ui.notification_show(f"Use case '{new_use_case}' is no longer available.", type="error") + ui.update_select("use_case_selector", choices=choices, selected=current_use_case.get()) + return + pending_use_case.set(new_use_case) + ui.update_select("use_case_selector", selected=current_use_case.get()) ui.modal_show( ui.modal( ui.tags.p("Switching use case will clear current data. Continue?"), title="Confirm Use Case Change", footer=ui.TagList( - ui.modal_button("Cancel", onclick="Shiny.setInputValue('use_case_confirm', 'cancel');"), - ui.input_action_button("confirm_switch", "Confirm", class_="btn-primary", onclick="Shiny.setInputValue('use_case_confirm', 'confirm');") + ui.input_action_button("cancel_switch", "Cancel", class_="btn-default"), + ui.input_action_button("confirm_switch", "Confirm", class_="btn-primary") ) ) ) + @reactive.Effect + @reactive.event(input.cancel_switch) + def _cancel_switch(): + pending_use_case.set(None) + ui.update_select("use_case_selector", selected=current_use_case.get()) + ui.modal_remove() + @reactive.Effect @reactive.event(input.confirm_switch) async def _confirm_switch(): - if input.use_case_confirm() == "cancel": + new_use_case = pending_use_case.get() + if not new_use_case: + ui.update_select("use_case_selector", selected=current_use_case.get()) + ui.modal_remove() + await session.send_custom_message("setUseCaseLoading", {"visible": False}) return - new_use_case = input.use_case_selector() try: new_model_data = _load_model_data_cached(new_use_case) current_use_case.set(new_use_case) model_data.set(new_model_data) config.set(new_model_data["config"]) + await update_navigation_labels(session, new_model_data["config"]) delta_test_reactive.set(new_model_data["D_TEST"]) x_test_reactive.set(new_model_data["X_TEST"]) prob_threshold.set(new_model_data["DEFAULT_THRESHOLD"]) data_available.set(False) patient_selected.set(None) + pending_use_case.set(None) + ui.update_select("use_case_selector", selected=new_use_case) ui.notification_show(f"Switched to {new_model_data['config'].get('name', new_use_case)}", type="message") await session.send_custom_message("toggleActiveTab", {"activeTab": "data_input"}) ui.modal_remove() + await session.send_custom_message("setUseCaseLoading", {"visible": False}) except (UseCaseNotFoundError, Exception) as e: ui.notification_show(f"Error loading use case: {e}", type="error") - - @reactive.Effect - @reactive.event(input.btn_cancel_switch) - def _cancel_switch(): - pass + pending_use_case.set(None) + ui.update_select("use_case_selector", selected=current_use_case.get()) + await session.send_custom_message("setUseCaseLoading", {"visible": False}) input_results = data_input.server("data_input", model_data, DEFAULT_CONFIG, config) analysis_data = input_results["data"] is_custom_data = input_results["is_custom"] + is_input_confirmed = input_results["is_confirmed"] + + def _confirmation_required_message(): + labels = config.get().get("labels", {}) + entity = labels.get("entity", {}) + set_lower = entity.get("set_lower", "cohort") + return f"Confirm the {set_lower} in Data Input first." @reactive.Effect - def _update_analysis_context(): + async def _update_analysis_context(): current_df = analysis_data.get() is_custom = is_custom_data.get() + input_confirmed = is_input_confirmed.get() md = model_data.get() cfg = config.get() - if current_df is None: + if not input_confirmed or current_df is None or current_df.empty: data_available.set(False) + delta_test_reactive.set(pd.DataFrame()) + x_test_reactive.set(pd.DataFrame()) + patient_selected.set(None) + await session.send_custom_message("setUseCaseLoading", {"visible": False}) return - data_available.set(True) - - if is_custom: - allowed_columns = [f["name"] for f in cfg.get("features", [])] - current_proc_df = process_raw_form_data( - raw_df=current_df, - example_raw_df=md["example_raw_df"], - encoding_config=cfg.get("encoding", {}), - allowed_columns=allowed_columns - ) - current_proc_df.columns = [col.replace(' ', '_') for col in current_proc_df.columns] - d_test_curr = delta_xai(md.get("delta_formula"), current_proc_df, md["FEATURE_ORDER_CLEAN"]) - delta_test_reactive.set(d_test_curr) - x_test_reactive.set(current_proc_df) - else: - delta_test_reactive.set(md["D_TEST"]) - x_test_reactive.set(md["X_TEST"]) + try: + if is_custom: + if cfg.get("use_case_type") == "neoantigen": + raw_for_model = current_df.drop(columns=["ID"], errors="ignore") + prepared = prepare_neoantigen_features( + raw_for_model, + md["FEATURE_ORDER_CLEAN"], + target=cfg.get("target_column", "Qualitative_Measure"), + ) + current_proc_df = _with_id(prepared.features) + scaled_features = _scale_features( + prepared.features, + md["scaler"], + md["FEATURE_ORDER_CLEAN"], + ) + d_test_curr = delta_xai(md.get("delta_formula"), scaled_features, md["FEATURE_ORDER_CLEAN"]) + else: + allowed_columns = [f["name"] for f in cfg.get("features", [])] + current_proc_df = process_raw_form_data( + raw_df=current_df, + example_raw_df=md["example_raw_df"], + encoding_config=cfg.get("encoding", {}), + allowed_columns=allowed_columns + ) + current_proc_df.columns = [col.replace(' ', '_') for col in current_proc_df.columns] + d_test_curr = delta_xai(md.get("delta_formula"), current_proc_df, md["FEATURE_ORDER_CLEAN"]) + delta_test_reactive.set(d_test_curr.reset_index(drop=True)) + x_test_reactive.set(current_proc_df.reset_index(drop=True)) + else: + delta_test_reactive.set(md["D_TEST"].reset_index(drop=True)) + x_test_reactive.set(md["X_TEST"].reset_index(drop=True)) + + first_id = int(current_df["ID"].iloc[0]) if "ID" in current_df.columns and not current_df.empty else None + patient_selected.set(first_id) + data_available.set(True) + await session.send_custom_message("setUseCaseLoading", {"visible": False}) + except Exception as e: + data_available.set(False) + delta_test_reactive.set(pd.DataFrame()) + x_test_reactive.set(pd.DataFrame()) + patient_selected.set(None) + ui.notification_show(f"Error preparing analysis context: {e}", type="error") + await session.send_custom_message("setUseCaseLoading", {"visible": False}) data_exploration.server( "data_exploration", @@ -374,6 +694,13 @@ def _update_analysis_context(): config ) + @reactive.Effect + async def _sync_step_lock_state(): + await session.send_custom_message( + "setAnalysisStepsLocked", + {"locked": not bool(is_input_confirmed.get())}, + ) + @reactive.Effect @reactive.event(input.tab_data_input) async def _(): @@ -384,8 +711,8 @@ async def _(): @reactive.Effect @reactive.event(input.tab_data_exploration) async def _(): - if not data_available(): - ui.notification_show("Please enter or upload data first.", type="warning") + if not is_input_confirmed.get() or not data_available(): + ui.notification_show(_confirmation_required_message(), type="warning") return await session.send_custom_message( "toggleActiveTab", {"activeTab": "data_exploration"} @@ -394,31 +721,49 @@ async def _(): @reactive.Effect @reactive.event(input.tab_prediction) async def _(): - if not data_available(): - ui.notification_show("Please enter or upload data first.", type="warning") + if not is_input_confirmed.get() or not data_available(): + ui.notification_show(_confirmation_required_message(), type="warning") + await session.send_custom_message("setUseCaseLoading", {"visible": False}) return await session.send_custom_message( "toggleActiveTab", {"activeTab": "prediction"} ) + await session.send_custom_message("setUseCaseLoading", {"visible": False}) @reactive.Effect @reactive.event(input.info_icon) def _show_help_modal(): cfg = config.get() labels = cfg.get("labels", {}) + text = labels.get("text", {}) titles = cfg.get("titles", {}) + features = cfg.get("features", []) + threshold = float(prob_threshold.get()) if prob_threshold.get() is not None else 0 + positive_label = labels.get("positive_class_label", cfg.get("positive_class", "Positive")) + negative_label = labels.get("negative_class_label", cfg.get("negative_class", "Negative")) m = ui.modal( ui.div( ui.tags.h4(f"Welcome to {titles.get('app_title', 'XAInyPredictor')}", style="color: #007bff; font-weight: bold; margin-top: 0;"), ui.p(cfg.get("description", ""), style="font-style: italic; color: #666;"), + ui.tags.div( + ui.tags.b("Research prototype context"), + ui.p( + text.get( + "info_context", + "This app supports patient stratification research by comparing patient-level inputs with a reference cohort and assigning a model-based patient group. Its workflow utility and interpretation should be validated with clinical collaborators.", + ), + style="margin: 6px 0 0;", + ), + style="background: #f4f9ff; border-left: 4px solid #007bff; padding: 10px 12px; margin: 10px 0;" + ), ui.hr(), ui.row( ui.column(2, ui.h1("1", style="background: #e9ecef; border-radius: 50%; width: 40px; height: 40px; text-align: center; line-height: 40px; font-size: 20px; color: #495057; margin: 0 auto;")), ui.column(10, - ui.h5("Input Patient Data", style="font-weight: bold; margin-top: 5px;"), - ui.p(f"Go to the ", ui.tags.b(titles.get('tab_data_input', 'Data Input')), " tab. You can upload a file or manually enter patient details.") + ui.h5(text.get("info_step_input_title", "Input Patient Data"), style="font-weight: bold; margin-top: 5px;"), + ui.p(text.get("info_step_input_prefix", "Go to the "), ui.tags.b(titles.get('tab_data_input', 'Data Input')), text.get("info_step_input_suffix", " tab. You can upload a file or manually enter patient details.")) ) ), ui.br(), @@ -426,8 +771,8 @@ def _show_help_modal(): ui.row( ui.column(2, ui.h1("2", style="background: #e9ecef; border-radius: 50%; width: 40px; height: 40px; text-align: center; line-height: 40px; font-size: 20px; color: #495057; margin: 0 auto;")), ui.column(10, - ui.h5("Explore Features", style="font-weight: bold; margin-top: 5px;"), - ui.p(f"In the ", ui.tags.b(titles.get('tab_data_exploration', 'Data Exploration')), " tab, compare patient features against the reference population.") + ui.h5(text.get("info_step_context_title", "Explore Features"), style="font-weight: bold; margin-top: 5px;"), + ui.p(text.get("info_step_context_prefix", "In the "), ui.tags.b(titles.get('tab_data_exploration', 'Cohort Context')), text.get("info_step_context_suffix", " tab, compare patient features against the reference population.")) ) ), ui.br(), @@ -435,24 +780,65 @@ def _show_help_modal(): ui.row( ui.column(2, ui.h1("3", style="background: #e9ecef; border-radius: 50%; width: 40px; height: 40px; text-align: center; line-height: 40px; font-size: 20px; color: #495057; margin: 0 auto;")), ui.column(10, - ui.h5("Run Prediction", style="font-weight: bold; margin-top: 5px;"), - ui.p(f"Click ", ui.tags.b(titles.get('tab_prediction', 'Prediction')), f" to see the {labels.get('probability_column', 'Probability')}. Use the charts to identify which features are driving the prediction."), + ui.h5(text.get("info_step_output_title", "Review Stratification"), style="font-weight: bold; margin-top: 5px;"), + ui.p( + text.get("info_step_output_prefix", "Click "), + ui.tags.b(titles.get('tab_prediction', 'Patient Stratification')), + text.get("info_step_output_middle", f" to see the {labels.get('probability_column', 'Stratification Score')} and assigned patient group. "), + f"The current decision threshold is {threshold:.3f}. ", + text.get("info_step_output_suffix", "Use the charts to understand the patient's profile against the reference cohort."), + ), ui.tags.div( - ui.tags.span(f"{labels.get('negative_class', 'Negative')}", style="background: #d1e7dd; color: #0f5132; padding: 2px 6px; border-radius: 4px; font-size: 0.8em;"), - " = Low Risk | ", - ui.tags.span(f"{labels.get('positive_class', 'Positive')}", style="background: #f8d7da; color: #842029; padding: 2px 6px; border-radius: 4px; font-size: 0.8em;"), - " = High Risk", + ui.tags.span(f"{negative_label}", style="background: #d1e7dd; color: #0f5132; padding: 2px 6px; border-radius: 4px; font-size: 0.8em;"), + " = below threshold | ", + ui.tags.span(f"{positive_label}", style="background: #f8d7da; color: #842029; padding: 2px 6px; border-radius: 4px; font-size: 0.8em;"), + " = at or above threshold", style="margin-top: 5px; font-size: 0.9em;" ) ) ), + ui.br(), + ui.tags.div( + ui.tags.b("Model summary"), + ui.tags.ul( + ui.tags.li(f"Use case: {cfg.get('name', 'Patient stratification model')}"), + ui.tags.li(f"Target: {cfg.get('target_column', 'target')}"), + ui.tags.li(f"Input variables: {len(features)}"), + ui.tags.li(f"{text.get('groups_label', 'Patient groups')}: {negative_label} / {positive_label}"), + ), + style="background: #fff8e6; border: 1px solid #f1d28c; border-radius: 6px; padding: 10px 12px;" + ), ), - title="How to use this App", + title="Research Prototype Info", easy_close=True, - footer=ui.modal_button("Got it!"), + footer=ui.modal_button("Close"), size="l" ) ui.modal_show(m) -app = App(app_ui, server, static_assets=WWW_DIR) \ No newline at end of file +# 7. Launch App +# Executable launch commands +app = App(app_ui, server, static_assets=WWW_DIR) +def find_free_port(): + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] +PORT = find_free_port() +def open_browser(): + time.sleep(2) + webbrowser.open(f"http://127.0.0.1:{PORT}", new=1) +def ensure_stdio_streams(): + if sys.stdin is None: + sys.stdin = open(os.devnull, "r") + if sys.stdout is None: + sys.stdout = open(os.devnull, "w") + if sys.stderr is None: + sys.stderr = open(os.devnull, "w") + + +if __name__ == "__main__": + os.environ["XAINYPREDICTOR_EXIT_ON_LAST_SESSION_END"] = "1" + ensure_stdio_streams() + threading.Thread(target=open_browser, daemon=True).start() + app.run(host="127.0.0.1", port=PORT, log_level="critical") diff --git a/src/XAInyPredictor/data/mock/config.yml b/src/XAInyPredictor/data/mock/config.yml index cd390ce..244f90f 100644 --- a/src/XAInyPredictor/data/mock/config.yml +++ b/src/XAInyPredictor/data/mock/config.yml @@ -1,5 +1,12 @@ name: "Diabetes Risk Predictor" description: "Predict risk of Type 2 Diabetes development based on clinical parameters" +model_metadata: + model_version: "Demo prototype" + training_date: "Synthetic demonstration dataset" + cohort_source: "Mock reference cohort for workflow demonstration" + intended_use: "Demonstrate the patient stratification workflow on a simple risk use case" + validation_status: "Clinical usefulness should be validated with domain experts before real-world deployment" + limitations: "Demo use case with synthetic/example data; not intended to represent a validated diabetes model" target_column: "Diabetes-Risk" positive_class: "High" @@ -14,7 +21,7 @@ features: default: 45 - name: "BMI" type: "numeric" - display_name: "BMI (kg/m²)" + display_name: "BMI (kg/m2)" input_type: "numeric" min: 15 max: 60 @@ -90,27 +97,56 @@ encoding: labels: positive_class_label: "High Risk" negative_class_label: "Low Risk" - probability_column: "Risk Probability" - class_column: "Risk Level" + probability_column: "Stratification Score" + class_column: "Patient Group" + entity: + singular: "patient" + plural: "patients" + singular_title: "Patient" + plural_title: "Patients" + set_lower: "cohort" + set_title: "Cohort" + reference_plural: "reference patients" + reference_title: "Reference Patients" form: add_patient_button: "Add Patient to Cohort" delete_selected_button: "Delete Selected" + reset_cohort_button: "Reset Cohort" cohort_title: "Current Patient Cohort" + manual_source: "manual cohort" + uploaded_source: "uploaded file" + example_source: "example cohort" + added_message: "Patient added." + deleted_message: "Deleted {count} patient(s)." + reset_message: "Cohort reset." + text: + analysis_title: "Stratification Analysis" + selected_output_title: "Selected Patient Stratification" + groups_label: "Patient groups" + select_entity_label: "Select patient:" + search_entity_placeholder: "Search by patient ID..." + profile_comparison_title: "Patient Profile Comparison " + set_summary_title: "Cohort Stratification Summary" + set_distribution_title: "Cohort Score Distribution " + closest_reference_title: "Closest Reference Patients " titles: app_title: "XAInyPredictor" + app_subtitle: "Interpretable Stratification Support" tab_data_input: "1. Data Input" - tab_data_exploration: "2. Data Exploration" - tab_prediction: "3. Prediction" + tab_data_exploration: "2. Cohort Context" + tab_prediction: "3. Patient Stratification" data_source: "Data Source" select_method: "Select Method:" manual_entry: "Manual Entry" upload_file: "Upload File" example_cohort: "Example Cohort" - feature_analysis: "Risk Analysis" - distance_analysis: "Feature Analysis" - prediction_results: "Risk Assessment" + feature_analysis: "Patient Profile Comparison" + feature_analysis_short: "Profile Comparison" + distance_analysis: "Feature Context Curves" + distance_analysis_short: "Feature Curves" + prediction_results: "Stratification Results" help_texts: - fnr_threshold: "Adjust the sensitivity of the risk assessment." + fnr_threshold: "Adjust how many high-risk patients the model is allowed to miss when assigning patient groups." fnr_zero: "Classify all high-risk patients regardless of probability (high sensitivity)." fnr_higher: "Focus on high-confidence predictions (high specificity)." reference_population_input: "Show distribution of your patient data." @@ -119,4 +155,4 @@ help_texts: radar_average_all: "Average of all patients in the model." radar_average_negative: "Average of low-risk patients (green)." radar_average_positive: "Average of high-risk patients (red)." - curves_plot: "Show how each feature value relates to risk." \ No newline at end of file + curves_plot: "Show how each feature value relates to risk." diff --git a/src/XAInyPredictor/data/neoag/config.yml b/src/XAInyPredictor/data/neoag/config.yml new file mode 100644 index 0000000..92876d8 --- /dev/null +++ b/src/XAInyPredictor/data/neoag/config.yml @@ -0,0 +1,245 @@ +name: "Neoantigen Candidate Prioritizer" +description: "Prioritize peptide-HLA neoantigen candidates by predicted immunogenicity" +use_case_type: "neoantigen" +false_negative_rate: 0.10 +max_reference_candidates: 40 +default_selected_features: 5 +max_curve_features: 6 +target_column: "Qualitative_Measure" +positive_class: 1 +negative_class: 0 + +model_metadata: + model_version: "KAN symbolic prototype" + training_date: "Neoantigen dataset snapshot" + cohort_source: "Reference peptide-HLA candidate dataset from the neoantigen model development workflow" + intended_use: "Interpretable stratification support for neoantigen candidate prioritization and exploratory translational research" + validation_status: "Clinical utility and workflow fit should be validated with immunology and oncology collaborators" + limitations: "Candidate-level prototype; patient-level stratification requires patient identifiers, patient-specific candidate sets, and clinical validation." + +features: + - name: "Peptide" + label: "Peptide sequence" + type: "text" + role: "identifier" + plot: false + default: "YMDGTMSQV" + description: "Mutated peptide sequence presented by the HLA molecule." + + - name: "hla" + label: "HLA allele" + type: "text" + default: "HLA-B35:01" + description: "Presenting HLA allele." + + - name: "pseudosequence" + label: "HLA pseudosequence" + type: "text" + default: "YYATYRNIFTNTYESNLYIRYDSYTWAVLAYLWY" + description: "HLA binding-groove pseudosequence." + + - name: "Molecular_weight" + label: "Molecular weight" + type: "numeric" + min: 500 + max: 2500 + default: 1065.28 + step: 0.01 + + - name: "GRAVY" + label: "GRAVY" + type: "numeric" + min: -5 + max: 5 + default: 0.29 + step: 0.01 + + - name: "instability" + label: "Instability index" + type: "numeric" + min: 0 + max: 200 + default: 55.97 + step: 0.01 + + - name: "mhcflurry_affinity_percentile" + label: "MHCflurry affinity percentile" + type: "numeric" + min: 0 + max: 100 + default: 0.036 + step: 0.001 + + - name: "pI" + label: "Isoelectric point" + type: "numeric" + min: 0 + max: 14 + default: 5.84 + step: 0.01 + + - name: "mhcflurry_processing_score" + label: "MHCflurry processing score" + type: "numeric" + min: 0 + max: 1 + default: 0.758 + step: 0.001 + + - name: "mhcflurry_presentation_score" + label: "MHCflurry presentation score" + type: "numeric" + min: 0 + max: 1 + default: 0.974 + step: 0.001 + + - name: "Half_life" + label: "Half-life" + type: "numeric" + min: 0 + max: 100 + default: 5.5 + step: 0.1 + + - name: "charge" + label: "Charge" + type: "numeric" + min: -5 + max: 5 + default: 0 + step: 1 + + - name: "score" + label: "Expression score" + type: "numeric" + min: -5 + max: 5 + default: -0.0026 + step: 0.0001 + + - name: "hydro_P2" + label: "Hydrophobicity P2" + type: "numeric" + min: -5 + max: 5 + default: 2.8 + step: 0.1 + + - name: "hydro_P9" + label: "Hydrophobicity P9" + type: "numeric" + min: -5 + max: 5 + default: 1.9 + step: 0.1 + +encoding: {} + +labels: + data_source: "Data Source" + select_method: "Select Method:" + manual_entry: "Manual Entry" + upload_file: "Upload File" + example_cohort: "Example Candidate Set" + positive_class_label: "Immunogenic" + negative_class_label: "Non-immunogenic" + positive_class_color: "#198754" + negative_class_color: "#dc3545" + probability_column: "Immunogenicity Score" + class_column: "Candidate Class" + + entity: + singular: "candidate" + plural: "candidates" + singular_title: "Candidate" + plural_title: "Candidates" + set_lower: "candidate set" + set_title: "Candidate Set" + reference_plural: "reference candidates" + reference_title: "Reference Candidates" + + form: + add_patient_button: "Add Candidate to Set" + delete_selected_button: "Delete Selected" + reset_cohort_button: "Reset Candidate Set" + cohort_title: "Current Candidate Set" + no_data_message: "No candidates uploaded." + manual_source: "manual candidate set" + uploaded_source: "uploaded file" + example_source: "example candidate set" + added_message: "Candidate added." + deleted_message: "Deleted {count} candidate(s)." + reset_message: "Candidate set reset." + + text: + analysis_title: "Neoantigen Stratification Support" + selected_output_title: "Selected Candidate Stratification" + groups_label: "Candidate classes" + select_entity_label: "Select candidate:" + search_entity_placeholder: "Search by candidate ID..." + profile_comparison_title: "Candidate Profile Comparison " + set_summary_title: "Candidate Set Summary" + set_distribution_title: "Candidate Score Distribution " + set_distribution_help: "Candidates are ordered by immunogenicity score. The dashed line marks the current decision threshold and the selected candidate is highlighted." + closest_reference_title: "Closest Reference Candidates " + closest_reference_help_title: "Closest reference candidates:" + closest_reference_help: "Reference candidates are ranked by distance in the model contribution space. Row-level reference features are not exposed." + results_table_help: "The table shows model-based neoantigen candidate prioritization:" + score_help: " Indicates the score used to assign the candidate class." + class_help: " Assigns the candidate to Immunogenic or Non-immunogenic." + empty_results_message: "Add or upload candidates to generate stratification support outputs." + empty_individual_summary: "Add or upload candidates to review individual candidate outputs." + selected_entity_not_found: "Selected candidate not found." + empty_set_summary: "Add or upload candidates to summarize the candidate set." + reference_set_label: "Reference candidates" + reference_source_label: "Reference candidate source" + reference_source_missing: "Reference candidate source not specified." + no_set_scores: "No candidate scores available" + set_scores_plot_title: "Candidate immunogenicity scores" + set_scores_plot_xlabel: "Candidates ordered by score" + current_set_label: "Current candidate set" + reference_comparison_empty: "Reference comparison requires a selected candidate and available reference candidates." + closest_reference_narrative: "The selected {entity} is most similar to {reference_plural} mostly assigned to {top_group} ({top_count}/{total}).{score_text}" + closest_reference_privacy_note: "Reference candidates come from the model reference population, not from the current candidate set. The table reports anonymized neighbor ranks, distances, scores, and classes without exposing row-level reference features." + stratification_interpretation: "This {entity} is assigned to {group} because the immunogenicity score ({score}) is {direction} the decision threshold ({threshold})." + stratification_disclaimer: "Research prototype output: this candidate-level assessment supports translational stratification research and should be interpreted with immunology and oncology collaborators during utility validation." + report_context: "This package supports neoantigen candidate prioritization research. Clinical utility and interpretation should be validated with domain collaborators." + context_selection_hint: "Select a candidate to highlight its value in the distribution." + context_card_title: "Candidate Feature Context" + context_card_tooltip: "Blue bars: Reference candidate distribution. Red dashed line: Selected candidate value." + info_context: "This use case supports neoantigen candidate prioritization by assigning peptide-HLA candidates to immunogenicity classes and contextualizing their model evidence. Patient-level stratification would require aggregation across patient-specific candidate sets." + info_step_input_title: "Input Candidate Data" + info_step_input_suffix: " tab. You can upload a file or manually enter candidate details." + info_step_context_suffix: " tab, compare candidate features against the reference candidate population." + info_step_output_middle: " to see the Immunogenicity Score and assigned candidate class. " + info_step_output_suffix: "Use the charts to understand the candidate's profile against reference candidates." + +titles: + app_title: "XAInyPredictor" + app_subtitle: "Interpretable Stratification Support" + tab_data_input: "1. Data Input" + tab_data_exploration: "2. Candidate Context" + tab_prediction: "3. Stratification Support" + data_source: "Data Source" + select_method: "Select Method:" + manual_entry: "Manual Entry" + upload_file: "Upload File" + example_cohort: "Example Candidate Set" + feature_analysis: "Candidate Profile Comparison" + feature_analysis_short: "Profile Comparison" + distance_analysis: "Feature Context Curves" + distance_analysis_short: "Feature Curves" + prediction_results: "Candidate Stratification Results" + +help_texts: + fnr_threshold: "This setting controls how many immunogenic candidates the model is allowed to miss when assigning candidate classes." + fnr_zero: "Classify candidates as Immunogenic whenever needed to avoid missing true immunogenic candidates." + fnr_higher: "Accept missing some immunogenic candidates to focus on higher-confidence candidate prioritization." + feature_analysis: "Compare selected candidate features against the reference candidate distribution." + distance_analysis: "Inspect how feature values relate to the model score and candidate class across reference candidates." + prediction_results: "Model-based candidate prioritization results for the current candidate set." + profile_comparison: "Compare the selected candidate profile with similar reference candidates and class-level averages." + feature_importance: "Global feature contribution summary derived from the symbolic model outputs." + cohort_score_distribution: "Distribution of model scores in the current candidate set." + closest_patients: "Reference candidates closest to the selected candidate in model contribution space." diff --git a/src/XAInyPredictor/data/neoag/example_data.csv b/src/XAInyPredictor/data/neoag/example_data.csv new file mode 100644 index 0000000..4e22bfb --- /dev/null +++ b/src/XAInyPredictor/data/neoag/example_data.csv @@ -0,0 +1,9 @@ +Peptide,hla,Molecular_weight,GRAVY,instability,mhcflurry_affinity_percentile,pI,mhcflurry_processing_score,mhcflurry_presentation_score,mhcflurry_presentation_percentile,Half_life,charge,score,allele,pseudosequence,hydro_P2,hydro_P9,Qualitative_Measure +IMGVIFLI,HLA-A02:01,1006.3,2.79,8.89,0.077625,5.18,0.220265971496701,0.907581526459468,0.109809782608707,7.2,0,0.2349,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,1.9,4.5,0 +EHSLQVAY,HLA-B44:03,1017.09,-0.08,28.36,0.012625,5.24,0.952577814459801,0.983803438125947,0.0045108695652231,4.4,0,-0.21811,HLA-B44:03,YYTKYREISTNTYENTAYIRYDDYTWAVLAYLSY,-3.2,-1.3,1 +LLQRLHLL,HLA-A02:01,1104.39,1.33,45.11,0.15975,9.73,0.415324414148927,0.93250565362658,0.0733967391304446,100.0,1,-0.05938,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,3.8,3.8,0 +LARWLPPV,HLA-A02:01,1022.24,0.76,150.44,0.0115,9.8,0.298499772325158,0.952881693055639,0.0440760869565366,4.4,1,0.2542,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,1.8,4.2,1 +ERSILSSF,HLA-B44:02,1066.17,-0.31,121.91,0.0445,6.0,0.257026167586446,0.815047220933975,0.254293478260862,0.8,0,-0.26679,HLA-B44:02,YYTKYREISTNTYENTAYIRYDDYTWAVDAYLSY,-4.5,2.8,0 +LMPPDRTAV,HLA-A02:01,1127.36,-0.41,48.63,0.26025,8.75,0.306393058504909,0.868101768065977,0.17184782608696,1.3,1,0.04698,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,1.9,4.2,1 +VEEVSLRK,HLA-A11:01,1030.18,-0.24,66.51,0.1915,6.19,0.241995816119015,0.812400923219697,0.258967391304324,4.4,0,0.0386,HLA-A11:01,YYAMYQENVAQTDVDTLYIIYRDYTWAAQAYRWY,-3.5,-3.9,1 +LFGDIYLAI,HLA-A02:01,1081.26,1.56,9.0,0.210875,4.05,0.147020794451237,0.815747731009217,0.254293478260862,30.0,0,0.23332,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,2.8,4.5,1 diff --git a/src/XAInyPredictor/data/neoag/feature_order.txt b/src/XAInyPredictor/data/neoag/feature_order.txt new file mode 100644 index 0000000..080d3c2 --- /dev/null +++ b/src/XAInyPredictor/data/neoag/feature_order.txt @@ -0,0 +1,14 @@ +KAN Regression Formula (Current Best) +Logit 0: +0.127852266742022*x_1 - 7.66626337514084e-9*x_11 - 0.0129264124643102*x_12 - 0.181616344552348*x_13 + 0.0361174068698439*x_17 + 0.0265596549129137*x_18 + 0.227783100798148*x_20 - 1.39062275695197e-9*(-4.67559957504272*x_10 - 6.9117603302002)**2 - 1.75703167570873e-8*(0.000144128748564608*x_14 + 10.0147029245461)**5 - 1.16805821903654e-7*(-0.00295263396371789*x_15 - 9.99698831773312)**5 + 0.00205329852178693*(-0.204507530245635*x_16 - 7.46691533381111)**2 - 0.00729791773483157*(0.135128486990095*x_19 - 7.60327980291321)**2 - 0.0204238966107368*(0.894582437679474*x_4 - 3.43709696468926)**2 + 2.05431127958035e-10*(-6.15823984146118*x_5 - 8.74216079711914)**2 + 0.00899958703666925*(0.834729358083372*x_6 - 4.90151420969558)**2 - 1.06551702572233e-7*(-0.001744153152582*x_9 - 9.93806966192114)**5 + 0.00821205973625183*(-1.39534490300304*x_2 + 0.397368211672379*x_3 - 6.45577223527574)**2 + 2.18861941903015e-8*(-0.00036172543988807*x_7 + 0.000238585078382507*(-1.61199986934662*x_8 - 3.18815994262695)**2 - 10.0022653995627)**5 + 0.094709151091283 + +Logit 1: +-0.0111273478956243*x_22 + 0.0710576972793644*x_23 + 0.0623118077754002*x_26 + 0.311000747853518*x_27 + 0.117561980493237*x_28 + 0.00469614488037243*x_30 + 0.0935879529113407*x_31 - 0.117330822391121*x_32 - 0.356531882178684*x_33 - 0.0878837491649806*x_36 + 0.0894674790738564*x_40 - 0.00930960010737181*(7.76707477087396 - 1.89227154219443*x_29)**2 + 5.00421997173817e-8*(-0.00113741347586672*x_21 - 10.0038585638509)**5 + 0.00274473871104419*(-2.15227396236957*x_24 - 9.02433428052053)**2 - 5.245092182804e-8*(-0.00215816496955429*x_25 - 9.9974698169172)**5 - 0.00261832657270133*(0.0595521395132764*x_34 - 8.8377975429162)**2 + 7.38819304858879e-9*(-9.99847984313965*x_35 - 0.39999982714653)**2 + 1.83338738679595e-6*(0.0165293791175806*x_39 + 10.0113233322097)**4 + 0.00299422070384026*(2.60283304596429*x_38 - 0.0260650039770089*(-8.52567958831787*x_37 - 1.78815996646881)**2 - 5.79957169014943)**2 + 0.673137153304119 + +Train acc: 0.7061367380560132 +Test acc: 0.7174629324546952 +Train formula-model MAE: 6.023937948227181e-05 +Test formula-model MAE: 5.925585671958872e-05 +Train formula-model corr: 0.9999999562772879 +Test formula-model corr: 0.9999999571007347 +Variables: ['Molecular_weight', 'GRAVY', 'instability', 'mhcflurry_affinity_percentile', 'pI', 'mhcflurry_processing_score', 'mhcflurry_presentation_score', 'Half_life', 'charge', 'score', 'hydro_P2', 'hydro_P9', 'peptide_length', 'anchor_hydro_mean', 'anchor_hydro_diff', 'peptide_hydro_mean', 'peptide_estimated_sidechain_charge', 'peptide_frac_hydrophobic', 'peptide_frac_aromatic', 'peptide_frac_polar', 'peptide_frac_basic', 'peptide_frac_acidic', 'peptide_frac_small', 'peptide_frac_special', 'peptide_frac_aliphatic', 'peptide_frac_amide', 'pseudo_hydro_mean', 'pseudo_estimated_sidechain_charge', 'pseudo_frac_hydrophobic', 'pseudo_frac_aromatic', 'pseudo_frac_polar', 'pseudo_frac_basic', 'pseudo_frac_acidic', 'pseudo_frac_small', 'pseudo_frac_special', 'pseudo_frac_aliphatic', 'pseudo_frac_amide', 'hla_locus_A', 'hla_locus_B', 'hla_locus_C'] diff --git a/src/XAInyPredictor/data/neoag/model.pkl b/src/XAInyPredictor/data/neoag/model.pkl new file mode 100644 index 0000000..bdc2b89 Binary files /dev/null and b/src/XAInyPredictor/data/neoag/model.pkl differ diff --git a/src/XAInyPredictor/data/neoag/reference_data.csv b/src/XAInyPredictor/data/neoag/reference_data.csv new file mode 100644 index 0000000..5266b36 --- /dev/null +++ b/src/XAInyPredictor/data/neoag/reference_data.csv @@ -0,0 +1,41 @@ +Peptide,hla,Molecular_weight,GRAVY,instability,mhcflurry_affinity_percentile,pI,mhcflurry_processing_score,mhcflurry_presentation_score,mhcflurry_presentation_percentile,Half_life,charge,score,allele,pseudosequence,hydro_P2,hydro_P9,Qualitative_Measure +YLGTGHTL,HLA-A24:02,989.13,-0.33,-18.36,0.004625,8.6,0.90272356569767,0.986617704091726,0.0024999999999977,1.3,1,0.15018,HLA-A24:02,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,3.8,3.8,1 +VASAQIHLY,HLA-A01:01,1102.24,0.66,21.34,0.188125,6.4,0.462824888527393,0.769690295934435,0.326820652173893,7.2,0,-0.08237,HLA-A01:01,YFAMYQENMAHTDANTLYIIYRDYTWVARVYRGY,1.8,-1.3,0 +ENAAYQVL,HLA-B44:02,1008.08,-0.1,21.91,0.135125,4.05,0.489908669143915,0.884998432106704,0.144782608695664,7.2,0,-0.00175,HLA-B44:02,YYTKYREISTNTYENTAYIRYDDYTWAVDAYLSY,-3.5,3.8,0 +MTENIVEV,HLA-A02:01,1090.25,-0.1,66.48,0.033125,4.53,0.893273741006851,0.992785879193505,0.0007065217391186,1.0,0,0.32567,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,-0.7,4.2,1 +LHESDPMVI,HLA-A02:01,1187.36,0.46,20.72,0.476875,4.35,0.826960079371929,0.956972867139642,0.0369565217391425,1.1,-1,-0.16277,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,-3.2,4.5,0 +ILHRLHEI,HLA-A02:01,1177.4,0.56,51.69,0.0585,6.92,0.169697714969516,0.900868002027345,0.121195652173924,1.1,1,0.15471,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,3.8,4.5,0 +LAYVSVLY,HLA-A29:02,1040.25,2.02,8.89,0.0305,5.52,0.0539091379614546,0.776582008742521,0.315244565217355,5.5,0,-0.07819,HLA-A29:02,YTAMYLQNVAQTDANTLYIMYRDYTWAVLAYTWY,1.8,-1.3,1 +YIEMLKSI,HLA-A24:02,1127.42,0.79,74.59,0.0355,5.75,0.590252012014389,0.9483693784867,0.049510869565239,30.0,0,-0.31615,HLA-A24:02,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,4.5,4.5,1 +PFDKATIM,HLA-B35:01,1035.26,0.57,55.97,0.046375,5.84,0.778687968850136,0.973771676145048,0.0153804347826422,5.5,0,-0.00233,HLA-B35:01,YYATYRNIFTNTYESNLYIRYDSYTWAVLAYLWY,2.8,1.9,1 +ECSQELQKF,HLA-B44:02,1212.33,-1.03,143.32,0.02075,4.53,0.228030471131205,0.811656283121126,0.260380434782604,7.2,0,-0.43347,HLA-B44:02,YYTKYREISTNTYENTAYIRYDDYTWAVDAYLSY,2.5,2.8,0 +RAMYSVEL,HLA-C07:01,1096.3,-0.26,71.42,0.01025,8.59,0.389982596039772,0.907603641503208,0.109809782608707,1.3,1,-0.22999,HLA-C07:01,YDSGYRENYRQADVSNLYLRYDSYTLAALAYTWY,1.8,3.8,1 +FFDRIRFF,HLA-A24:02,1246.46,0.82,14.22,0.114125,9.57,0.669801857322454,0.945672059266013,0.0538586956521953,100.0,1,0.34808,HLA-A24:02,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,2.8,2.8,0 +LMMTLPSI,HLA-A02:01,1018.33,1.84,45.29,0.097,5.53,0.0637162479106336,0.838095880911685,0.218288043478239,5.5,0,-0.31236,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,1.9,4.5,0 +LLPVLFGV,HLA-A02:06,956.22,2.76,51.69,0.0051249999999999,5.49,0.531753820367157,0.979356304013592,0.0084239130434866,100.0,0,0.1336,HLA-A02:06,YYAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,3.8,4.2,1 +LHDRQHSI,HLA-A02:01,1118.25,-0.73,21.91,0.1885,6.92,0.160899348556995,0.83227758882377,0.226847826086967,5.5,1,-0.09518,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,-3.2,4.5,1 +ELKKTSSF,HLA-B44:02,1053.17,-1.17,20.86,0.0595,8.59,0.700152177363634,0.952866083294727,0.0440760869565366,1.4,1,-0.63034,HLA-B44:02,YYTKYREISTNTYENTAYIRYDDYTWAVDAYLSY,3.8,2.8,0 +ISCLTFGR,HLA-A33:03,1033.2,0.44,129.04,0.04025,8.26,0.0368481220211833,0.754300693261813,0.351358695652166,3.5,1,0.03639,HLA-A33:03,YTAMYRNNVAHIDVDTLYIMYQDYTWAVLAYTWY,-0.8,-4.5,1 +VHQIFGSAY,HLA-B15:01,1152.32,0.6,9.0,0.14825,6.49,0.425033454317599,0.87306837846772,0.164103260869567,30.0,0,0.05268,HLA-B15:01,YYAMYREISTNTYESNLYLRYDSYTWAEWAYLWY,-3.2,-1.3,1 +YVSYHTVY,HLA-A24:02,1160.23,-0.41,5.84,0.968374999999999,5.24,0.892197452485561,0.771193866161023,0.323505434782604,1.0,0,-0.06934,HLA-A24:02,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,4.2,-1.3,0 +MYLPGAEV,HLA-A02:01,966.11,0.46,47.17,0.20825,4.05,0.187197149265557,0.836326724057346,0.221249999999983,1.9,0,0.10026,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,-1.3,4.2,1 +LILFLLFV,HLA-A02:01,1140.46,3.13,30.29,0.016875,5.52,0.0088876856752904,0.871459214579331,0.167038043478243,2.8,0,0.19464,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,4.5,4.2,0 +YIVMPVFI,HLA-A24:02,1094.41,2.63,69.68,0.090375,5.52,0.861093997955322,0.974056660629577,0.0141576086956831,20.0,0,0.00654,HLA-A24:02,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,4.5,4.5,0 +LYDFDYYII,HLA-A02:01,1325.46,0.4,49.82,0.362,4.05,0.330958516336977,0.841088775084575,0.213750000000004,7.2,-1,0.22788,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,-1.3,4.5,0 +MINGVVKL,HLA-A02:01,960.19,1.11,-20.44,0.084625,8.47,0.516151152551174,0.96424617835298,0.0267391304348052,1.9,0,0.01739,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,4.5,3.8,1 +LENVAFNV,HLA-A02:01,992.08,0.61,8.89,0.285375,4.05,0.0927903840783983,0.742745675670397,0.368532608695645,1.9,0,0.19804,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,-3.5,4.2,1 +YVSVMNFI,HLA-A24:02,1071.29,1.8,-16.19,0.04975,5.49,0.732103321701288,0.965953891241422,0.0255434782608858,100.0,0,-0.21523,HLA-A24:02,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,4.2,4.5,0 +VGGVFTSV,HLA-A02:01,851.94,1.37,13.17,0.150625,5.24,0.482934825122356,0.948142629795685,0.049510869565239,1.9,0,0.1316,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,-0.4,4.2,1 +LLVHPQWV,HLA-A02:01,1048.24,0.71,17.66,0.181125,6.74,0.110304668778554,0.810879643516455,0.260380434782604,30.0,0,0.09066,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,3.8,4.2,1 +QIEWRMKK,HLA-A11:01,1219.46,-1.61,32.63,0.155875,9.99,0.16683382546762,0.783683257517365,0.304076086956485,7.2,2,0.13417,HLA-A11:01,YYAMYQENVAQTDVDTLYIIYRDYTWAAQAYRWY,4.5,-3.9,0 +PQQKGLRL,HLA-B07:02,1067.29,-1.52,55.28,0.029125,11.17,0.66546992957592,0.959263332653379,0.0344021739130369,1.3,3,-0.31138,HLA-B07:02,YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY,-3.5,3.8,1 +PTPKGTVM,HLA-B07:02,986.19,-0.81,11.42,0.002125,11.0,0.693685218691826,0.969286824269444,0.0200815217391578,1.0,2,-0.11978,HLA-B07:02,YYSEYRNIYAQTDESNLYLSYDYYTWAERAYEWY,-0.7,1.9,1 +RKLFGTHF,HLA-B27:05,1062.22,-0.41,-21.5,0.091,11.0,0.474968678783625,0.874248038095832,0.162499999999994,30.0,2,0.1164,HLA-B27:05,YHTEYREICAKTDEDTLYLNYHDYTWAVLAYEWY,-3.9,2.8,1 +SGKDIYTY,HLA-A01:01,1093.19,-0.51,-9.98,0.085,5.83,0.255498605780303,0.748021217455179,0.360951086956504,1.1,0,-0.03956,HLA-A01:01,YFAMYQENMAHTDANTLYIIYRDYTWVARVYRGY,-0.4,-1.3,0 +LYDVVTKL,HLA-A02:01,1021.21,0.93,16.33,0.033125,5.88,0.908561088144779,0.993148781082687,0.000489130434758,4.4,0,0.00694,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,-1.3,3.8,1 +MACVGFFL,HLA-A02:01,974.2,2.07,53.61,0.021625,5.24,0.0229798836808186,0.872096257419025,0.165597826086937,1.9,0,0.19775,HLA-A02:01,YFAMYGEKVAHTHVDTLYVRYHYYTWAVLAYTWY,1.8,3.8,1 +SLAPEIRMYW,HLA-B57:01,1322.53,-0.09,72.01,0.027125,6.0,0.605688747018576,0.951891813556119,0.0451630434782828,30.0,0,0.15043,HLA-B57:01,YYAMYGENMASTYENIAYIVYDSYTWAVLAYLWY,3.8,-1.3,1 +YSSPNVSI,HLA-A24:02,965.06,0.46,70.73,0.090375,5.49,0.149138833861798,0.752388037767572,0.354239130434763,100.0,0,-0.29888,HLA-A24:02,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,-0.8,4.5,0 +EIKRGLFF,HLA-B44:02,1096.28,0.09,69.04,0.035125,8.46,0.607210163027048,0.940078816672612,0.0617934782608813,1.9,1,-0.03246,HLA-B44:02,YYTKYREISTNTYENTAYIRYDDYTWAVDAYLSY,4.5,2.8,0 +YPDAPPLRL,HLA-A24:02,1154.36,-0.09,71.58,0.064375,5.84,0.761047102510929,0.966713647931235,0.0240489130435008,5.5,0,0.05646,HLA-A24:02,YSAMYEEKVAHTDENIAYLMFHYYTWAVQAYTGY,-1.6,3.8,1 +VYEGVWKK,HLA-A03:01,1136.34,-1.04,-17.3,0.00325,9.53,0.742044918239117,0.980109695019429,0.0072554347826212,1.3,2,0.23235,HLA-A03:01,YFAMYQENVAQTDVDTLYIIYRDYTWAELAYTWY,-1.3,-3.9,1 diff --git a/src/XAInyPredictor/data/rai/config.yml b/src/XAInyPredictor/data/rai/config.yml index 254a6d2..733bc1c 100644 --- a/src/XAInyPredictor/data/rai/config.yml +++ b/src/XAInyPredictor/data/rai/config.yml @@ -1,5 +1,12 @@ name: "RAI-R Predictor" description: "Predict Radioiodine Refractoriness (RAI-R) in thyroid cancer patients" +model_metadata: + model_version: "Prototype v1" + training_date: "Research dataset snapshot" + cohort_source: "Reference cohort from the RAI-R model development dataset" + intended_use: "Cohort-level patient stratification research and exploratory decision support" + validation_status: "Utility and workflow fit should be validated with clinical collaborators before operational use" + limitations: "Small research cohort and prototype workflow; outputs require clinical review and prospective validation" target_column: "RAI-R" positive_class: "YES" @@ -23,7 +30,7 @@ features: default: "M" - name: "BMI" type: "numeric" - display_name: "BMI (kg/m²)" + display_name: "BMI (kg/m2)" input_type: "numeric" min: 10 max: 60 @@ -193,28 +200,57 @@ encoding: labels: positive_class_label: "Refractory" negative_class_label: "Non-Refractory" - probability_column: "Probability" - class_column: "Classification" + probability_column: "Stratification Score" + class_column: "Patient Group" + entity: + singular: "patient" + plural: "patients" + singular_title: "Patient" + plural_title: "Patients" + set_lower: "cohort" + set_title: "Cohort" + reference_plural: "reference patients" + reference_title: "Reference Patients" form: add_patient_button: "Add Patient to Cohort" delete_selected_button: "Delete Selected" + reset_cohort_button: "Reset Cohort" cohort_title: "Current Patient Cohort" no_data_message: "No data uploaded." + manual_source: "manual cohort" + uploaded_source: "uploaded file" + example_source: "example cohort" + added_message: "Patient added." + deleted_message: "Deleted {count} patient(s)." + reset_message: "Cohort reset." + text: + analysis_title: "Stratification Analysis" + selected_output_title: "Selected Patient Stratification" + groups_label: "Patient groups" + select_entity_label: "Select patient:" + search_entity_placeholder: "Search by patient ID..." + profile_comparison_title: "Patient Profile Comparison " + set_summary_title: "Cohort Stratification Summary" + set_distribution_title: "Cohort Score Distribution " + closest_reference_title: "Closest Reference Patients " titles: app_title: "XAInyPredictor" + app_subtitle: "Interpretable Stratification Support" tab_data_input: "1. Data Input" - tab_data_exploration: "2. Data Exploration" - tab_prediction: "3. Prediction" + tab_data_exploration: "2. Cohort Context" + tab_prediction: "3. Patient Stratification" data_source: "Data Source" select_method: "Select Method:" manual_entry: "Manual Entry" upload_file: "Upload File" example_cohort: "Example Cohort" - feature_analysis: "Feature Analysis" - distance_analysis: "Distance Analysis" - prediction_results: "Prediction Results" + feature_analysis: "Patient Profile Comparison" + feature_analysis_short: "Profile Comparison" + distance_analysis: "Feature Context Curves" + distance_analysis_short: "Feature Curves" + prediction_results: "Stratification Results" help_texts: - fnr_threshold: "This controls the safety margin of the model." + fnr_threshold: "This setting controls how many positive patients the model is allowed to miss when assigning patient groups." fnr_zero: "We refuse to miss any true RAI-Refractory patients. The model will classify a patient as Refractory even if the probability is low (High Sensitivity)." fnr_higher: "We accept missing some Refractory patients to ensure those we treat are definitely Refractory (High Specificity)." reference_population_input: "It will show the distribution of the input data." @@ -223,4 +259,4 @@ help_texts: radar_average_all: "The average values from all patients in the model." radar_average_negative: "The average values from all RAI-R negative patients in the model (green)." radar_average_positive: "The average values from all RAI-R positive patients in the model (red)." - curves_plot: "The curves plot displays the distribution of values from a specific feature across the patients in the model, ordered from lowest to highest." \ No newline at end of file + curves_plot: "The curves plot displays the distribution of values from a specific feature across the patients in the model, ordered from lowest to highest." diff --git a/src/XAInyPredictor/main.py b/src/XAInyPredictor/main.py index 082aa73..78554f4 100644 --- a/src/XAInyPredictor/main.py +++ b/src/XAInyPredictor/main.py @@ -1,7 +1,10 @@ +import os + from shiny import run_app from XAInyPredictor.app import app def main(): + os.environ["XAINYPREDICTOR_EXIT_ON_LAST_SESSION_END"] = "1" run_app( app, host="127.0.0.1", diff --git a/src/XAInyPredictor/modules/data_processing.py b/src/XAInyPredictor/modules/data_processing.py index ca1265a..98d706b 100644 --- a/src/XAInyPredictor/modules/data_processing.py +++ b/src/XAInyPredictor/modules/data_processing.py @@ -1,5 +1,6 @@ import os import math +import logging import matplotlib.pyplot as plt from matplotlib.figure import Figure import numpy as np @@ -10,6 +11,8 @@ from sklearn.impute import IterativeImputer from sklearn.preprocessing import StandardScaler +logger = logging.getLogger(__name__) + def process_raw_data(raw_df: pd.DataFrame, output_file: str | None = None, encoding_config: dict = None, target: str = 'class_target', allowed_columns: list | None = None) -> pd.DataFrame: """ @@ -28,7 +31,7 @@ def process_raw_data(raw_df: pd.DataFrame, output_file: str | None = None, encod :rtype: DataFrame """ if output_file != None and os.path.exists(output_file): - print("Processed dataset already exists! Skipping processing.") + logger.debug("Processed dataset already exists. Skipping processing.") df = pd.read_csv(output_file) return df @@ -51,20 +54,33 @@ def process_raw_data(raw_df: pd.DataFrame, output_file: str | None = None, encod # Process data clean_df = clean_data(raw_df, allowed_columns=allowed_columns) + generated_figures = [] if target: - target_prop_plot = check_target_proportion(clean_df) - non_numeric_summary = check_non_numeric_features(clean_df) + generated_figures.append(check_target_proportion(clean_df)) + check_non_numeric_features(clean_df) clean_nomissing_df = handle_missing_values(clean_df) num_feat_results = plot_distribution_numerical_features(clean_nomissing_df) cat_feat_results = plot_distribution_categorical_features(clean_nomissing_df) - (outliers_iqr, outliers_z) = detect_outliers(clean_nomissing_df) + detect_outliers(clean_nomissing_df) cramers_results = calculate_cramers_v_matrix(clean_nomissing_df) + generated_figures.extend( + fig + for fig in [ + num_feat_results.get('histogram'), + num_feat_results.get('boxplot'), + cat_feat_results.get('plot'), + cramers_results.get('plot'), + ] + if fig is not None + ) norm_df = normalize_and_encode_features(clean_nomissing_df, encoding_dict) + for fig in generated_figures: + plt.close(fig) # Save the processed dataset to a CSV file if output_file != None: norm_df.to_csv(output_file, index=False) - print(f"Processed dataset saved to {output_file}") + logger.debug("Processed dataset saved to %s", output_file) return norm_df @@ -156,8 +172,7 @@ def check_target_proportion( """ # Check the proportions of the target feature target_counts = df[target].value_counts(normalize=True) - print(f"Proportions of the target feature:") - print(target_counts) + logger.debug("Proportions of the target feature:\n%s", target_counts) # Visualize the proportions using a bar plot sns.barplot(y=target_counts.values, hue=target_counts.index, palette='viridis') @@ -190,7 +205,7 @@ def check_non_numeric_features(df: pd.DataFrame) -> dict: # Display the non-numeric values in each column for column, values in non_numeric_summary.items(): - print(f"Non-numeric entries in '{column}':", values) + logger.debug("Non-numeric entries in '%s': %s", column, values) return non_numeric_summary @@ -205,17 +220,16 @@ def handle_missing_values(df: pd.DataFrame) -> pd.DataFrame: :return: Pandas dataframe containing the new data frame with the imputated values. :rtype: pd.DataFrame """ - missing_percent_df = check_percent_missing_values(df) + check_percent_missing_values(df) proc_df = df.copy() proc_df = imputate_numerical_values_with_median(proc_df) proc_df = imputate_categorical_values_with_mode(proc_df) - missing_percent_df2 = check_percent_missing_values(proc_df) + check_percent_missing_values(proc_df) - print("Dataset after handling missing values:") - print(proc_df.head()) + logger.debug("Dataset after handling missing values:\n%s", proc_df.head()) return proc_df @@ -242,8 +256,7 @@ def check_percent_missing_values(df: pd.DataFrame) -> pd.DataFrame: '% Total Values': missing_percent }) - print("Missing Values Summary:") - print(missing_percent_df) + logger.debug("Missing Values Summary:\n%s", missing_percent_df) return missing_percent_df @@ -266,7 +279,7 @@ def imputate_values_for_feature_with_reference(df: pd.DataFrame, :return: Pandas dataframe containing the data frame with the imputated values. :rtype: pd.DataFrame """ - print(f"Imputating values for feature {feature_to_imputate}...") + logger.debug("Imputing values for feature %s...", feature_to_imputate) # Check if feature_to_imputate in reference_features if feature_to_imputate not in reference_features: reference_features.append(feature_to_imputate) @@ -284,7 +297,7 @@ def imputate_values_for_feature_with_reference(df: pd.DataFrame, # Stop if non-numeric values other than nan are found if non_numeric_mask.any(): - print(f"Non-numeric column found: {column}. Imputation of values for {feature_to_imputate} is not possible!") + logger.debug("Non-numeric column found: %s. Imputation of values for %s is not possible.", column, feature_to_imputate) return df # Initialize and apply the IterativeImputer @@ -296,7 +309,7 @@ def imputate_values_for_feature_with_reference(df: pd.DataFrame, # Check if any missing values remain in 'BMI' remaining_missing_values = df[feature_to_imputate].isna().sum() - print(f"Missing values remaining after imputation for {feature_to_imputate}: {remaining_missing_values}") + logger.debug("Missing values remaining after imputation for %s: %s", feature_to_imputate, remaining_missing_values) return df @@ -346,7 +359,7 @@ def plot_distribution_numerical_features(df: pd.DataFrame) -> dict: numerical_columns = df.select_dtypes(include=['number']).columns.tolist() if not numerical_columns: - print("No numerical columns found in the dataframe.") + logger.debug("No numerical columns found in the dataframe.") return { 'histogram': None, 'boxplot': None, @@ -354,12 +367,11 @@ def plot_distribution_numerical_features(df: pd.DataFrame) -> dict: 'distribution_info': None } - print("Numerical columns:", numerical_columns) + logger.debug("Numerical columns: %s", numerical_columns) # Get distribution summary distribution_info = df.describe() - print("Numerical Variables Distribution Summary:") - print(distribution_info) + logger.debug("Numerical Variables Distribution Summary:\n%s", distribution_info) # Set up a grid for numerical columns num_cols = len(numerical_columns) @@ -422,10 +434,10 @@ def plot_distribution_categorical_features(df: pd.DataFrame) -> dict: categorical_columns = df.select_dtypes(exclude=['number']).columns.tolist() if not categorical_columns: - print("No categorical columns found in the dataframe.") + logger.debug("No categorical columns found in the dataframe.") return {'plot': None, 'columns': [], 'distribution_info': None} - print("Categorical columns:", categorical_columns) + logger.debug("Categorical columns: %s", categorical_columns) # Set up a grid for categorical columns num_cols = len(categorical_columns) @@ -526,14 +538,14 @@ def detect_outliers(df: pd.DataFrame) -> tuple: # Detect outliers using IQR outliers_iqr = detect_outliers_iqr(df, numerical_cols) for col, outliers in outliers_iqr.items(): - print(f"{col} has {len(outliers)} outliers based on IQR.") + logger.debug("%s has %s outliers based on IQR.", col, len(outliers)) # Identify outliers with z-score > 3 or < -3 outliers_z = detect_outliers_zscore(df, numerical_cols, 3) # Display outliers using z-score for col in numerical_cols: - print(f"{col} has {len(outliers_z)} outliers based on z-scores.") + logger.debug("%s has %s outliers based on z-scores.", col, len(outliers_z)) return (outliers_iqr, outliers_z) @@ -552,6 +564,8 @@ def cramers_v(x: list | pd.Series, y: list | pd.Series) -> float: chi2 = chi2_contingency(contingency_table)[0] n = contingency_table.sum().sum() r, k = contingency_table.shape + if n == 0 or min(r, k) <= 1: + return np.nan return np.sqrt(chi2 / (n * (min(r, k) - 1))) @@ -581,12 +595,11 @@ def calculate_cramers_v_matrix(df: pd.DataFrame) -> dict: mask = np.triu(np.ones_like(cramers_v_matrix, dtype=bool)) # Visualize the lower triangle of the Cramér's V matrix - sns.heatmap(cramers_v_matrix, mask=mask, annot=True, cmap='coolwarm', square=True, cbar_kws={'label': "Cramér's V"}) + fig, ax = plt.subplots() + sns.heatmap(cramers_v_matrix, mask=mask, annot=True, cmap='coolwarm', square=True, cbar_kws={'label': "Cramér's V"}, ax=ax) plt.title("Cramér's V Matrix for Categorical Variables") plt.xticks(rotation=45, ha='right') plt.yticks(rotation=0) - plt.show() - fig = plt.gcf() # get current figure return {'plot': fig, 'matrix': cramers_v_matrix} @@ -671,6 +684,5 @@ def normalize_and_encode_features( """ df = normalize_numerical_features(df) df = encode_categorical_features(df, encoding_dict) - print("Dataset after standardization of numerical and categorical features:") - print(df.head()) + logger.debug("Dataset after standardization of numerical and categorical features:\n%s", df.head()) return df diff --git a/src/XAInyPredictor/modules/neoag_processing.py b/src/XAInyPredictor/modules/neoag_processing.py new file mode 100644 index 0000000..5d204be --- /dev/null +++ b/src/XAInyPredictor/modules/neoag_processing.py @@ -0,0 +1,178 @@ +import re +from dataclasses import dataclass + +import numpy as np +import pandas as pd + + +AA_HYDRO = { + "A": 1.8, + "C": 2.5, + "D": -3.5, + "E": -3.5, + "F": 2.8, + "G": -0.4, + "H": -3.2, + "I": 4.5, + "K": -3.9, + "L": 3.8, + "M": 1.9, + "N": -3.5, + "P": -1.6, + "Q": -3.5, + "R": -4.5, + "S": -0.8, + "T": -0.7, + "V": 4.2, + "W": -0.9, + "Y": -1.3, +} + +AA_GROUPS = { + "hydrophobic": set("AVILMFWYC"), + "aromatic": set("FWYH"), + "polar": set("STNQYC"), + "basic": set("KRH"), + "acidic": set("DE"), + "small": set("AGCSTP"), + "special": set("GPC"), + "aliphatic": set("AVILM"), + "amide": set("NQ"), +} + +AA_FORMAL_CHARGE = {"K": 1.0, "R": 1.0, "H": 0.1, "D": -1.0, "E": -1.0} + +BASE_NUMERIC_COLUMNS = [ + "Molecular_weight", + "GRAVY", + "instability", + "mhcflurry_affinity_percentile", + "pI", + "mhcflurry_processing_score", + "mhcflurry_presentation_score", + "Half_life", + "charge", + "score", + "hydro_P2", + "hydro_P9", +] + +METADATA_COLUMNS = ["Peptide", "hla", "pseudosequence"] + + +@dataclass +class NeoagPreparedData: + features: pd.DataFrame + metadata: pd.DataFrame + target: pd.Series | None + + +def _safe_fraction(seq: str, residue_set: set[str]) -> float: + length = max(len(seq), 1) + return sum(aa in residue_set for aa in seq) / length + + +def _mean_hydrophobicity(seq: str) -> float: + vals = [AA_HYDRO[aa] for aa in seq if aa in AA_HYDRO] + return float(np.mean(vals)) if vals else 0.0 + + +def _approx_sidechain_charge(seq: str) -> float: + return float(sum(AA_FORMAL_CHARGE.get(aa, 0.0) for aa in seq)) + + +def _sequence_group_features(seq: str, prefix: str) -> dict[str, float]: + features = { + f"{prefix}_hydro_mean": _mean_hydrophobicity(seq), + f"{prefix}_estimated_sidechain_charge": _approx_sidechain_charge(seq), + } + for group_name, residue_set in AA_GROUPS.items(): + features[f"{prefix}_frac_{group_name}"] = _safe_fraction(seq, residue_set) + return features + + +def _extract_hla_tokens(hla: str) -> dict[str, str]: + match = re.match(r"HLA-([A-Z])(\d+):(\d+)", str(hla)) + if match is None: + return {"hla_locus": "other", "hla_group": "unknown", "hla_protein": "unknown"} + locus, group, protein = match.groups() + return {"hla_locus": locus, "hla_group": f"{locus}{group}", "hla_protein": protein} + + +def _clean_raw_df(df: pd.DataFrame, target: str | None) -> pd.DataFrame: + required = set(METADATA_COLUMNS + BASE_NUMERIC_COLUMNS) + if target: + required.add(target) + missing = required - set(df.columns) + if missing: + raise ValueError(f"Missing required neoantigen column(s): {sorted(missing)}") + + clean_df = df.copy() + for col in METADATA_COLUMNS: + clean_df[col] = clean_df[col].where(clean_df[col].notna(), "") + + clean_df["Peptide"] = clean_df["Peptide"].astype(str).str.strip().str.upper() + clean_df["pseudosequence"] = clean_df["pseudosequence"].astype(str).str.strip().str.upper() + clean_df["hla"] = clean_df["hla"].astype(str).str.strip() + + empty_text_cols = [col for col in METADATA_COLUMNS if clean_df[col].eq("").any()] + if empty_text_cols: + raise ValueError(f"Neoantigen text column(s) contain empty values: {empty_text_cols}") + + for col in BASE_NUMERIC_COLUMNS: + clean_df[col] = pd.to_numeric(clean_df[col], errors="coerce") + + if clean_df[BASE_NUMERIC_COLUMNS].isna().any().any(): + bad_cols = clean_df[BASE_NUMERIC_COLUMNS].columns[ + clean_df[BASE_NUMERIC_COLUMNS].isna().any() + ].tolist() + raise ValueError(f"Neoantigen numeric column(s) contain invalid values: {bad_cols}") + + if target and target in clean_df.columns: + clean_df[target] = pd.to_numeric(clean_df[target], errors="coerce").astype("Int64") + + return clean_df + + +def prepare_neoantigen_features( + raw_df: pd.DataFrame, + feature_order: list[str], + target: str | None = "Qualitative_Measure", +) -> NeoagPreparedData: + target_col = target if target in raw_df.columns else None + clean_df = _clean_raw_df(raw_df, target_col) + metadata = clean_df[METADATA_COLUMNS].copy().reset_index(drop=True) + features = clean_df[BASE_NUMERIC_COLUMNS].copy() + + peptide_features = ( + clean_df["Peptide"].apply(lambda seq: _sequence_group_features(seq, "peptide")).apply(pd.Series) + ) + pseudo_features = ( + clean_df["pseudosequence"].apply(lambda seq: _sequence_group_features(seq, "pseudo")).apply(pd.Series) + ) + features["peptide_length"] = clean_df["Peptide"].str.len().astype(float) + features["anchor_hydro_mean"] = (clean_df["hydro_P2"] + clean_df["hydro_P9"]) / 2.0 + features["anchor_hydro_diff"] = clean_df["hydro_P2"] - clean_df["hydro_P9"] + + hla_tokens = clean_df["hla"].apply(_extract_hla_tokens).apply(pd.Series) + hla_locus_dummies = pd.get_dummies(hla_tokens["hla_locus"], prefix="hla_locus") + for col in ["hla_locus_A", "hla_locus_B", "hla_locus_C"]: + if col not in hla_locus_dummies.columns: + hla_locus_dummies[col] = 0 + + features = pd.concat( + [ + features.reset_index(drop=True), + peptide_features.reset_index(drop=True), + pseudo_features.reset_index(drop=True), + hla_locus_dummies[["hla_locus_A", "hla_locus_B", "hla_locus_C"]].reset_index(drop=True), + ], + axis=1, + ) + features = features.reindex(columns=feature_order, fill_value=0.0).astype(float) + + target_series = None + if target_col: + target_series = clean_df[target_col].astype(int).reset_index(drop=True) + + return NeoagPreparedData(features=features, metadata=metadata, target=target_series) diff --git a/src/XAInyPredictor/modules/xai.py b/src/XAInyPredictor/modules/xai.py index dd89123..045cab5 100644 --- a/src/XAInyPredictor/modules/xai.py +++ b/src/XAInyPredictor/modules/xai.py @@ -3,14 +3,18 @@ from matplotlib.projections import register_projection from matplotlib.projections.polar import PolarAxes import matplotlib.pyplot as plt -from matplotlib.spines import Spine -from matplotlib.transforms import Affine2D -import numpy as np -import pandas as pd -import re +from matplotlib.spines import Spine +from matplotlib.transforms import Affine2D +import logging +import numpy as np +import pandas as pd +import re from sklearn.model_selection import train_test_split -import sympy # See https://github.com/knottwill/CoxKAN/blob/main/reprod/results.ipynb - +import sympy # See https://github.com/knottwill/CoxKAN/blob/main/reprod/results.ipynb + +logger = logging.getLogger(__name__) +_DELTA_XAI_FN_CACHE = {} + def split_data_with_known_target(input_df, target='class_target', test_split=0.2): """ @@ -36,7 +40,7 @@ def split_data_with_known_target(input_df, target='class_target', test_split=0.2 return x_train, x_test, y_train, y_test -def read_delta_xai_formula(formula_pkl_file: str, current_best_formula_file: str): +def read_delta_xai_formula(formula_pkl_file: str, current_best_formula_file: str, simplify_formula: bool = True): """ Reads the XAI output pickle file formula.pkl, which contains the formula of the XAI model, and simplifies it using the package sympy. Reads the file @@ -55,12 +59,14 @@ def read_delta_xai_formula(formula_pkl_file: str, current_best_formula_file: str formula = pd.read_pickle(f) # Show the formula saved - print(f"Logit 0 formula: {formula[0]}") - print(f"Logit 1 formula: {formula[1]}") + logger.debug("Logit 0 formula: %s", formula[0]) + logger.debug("Logit 1 formula: %s", formula[1]) - # Simplify formula - delta_formula = sympy.simplify(formula[1] - formula[0]) - print(f"Delta formula simplified: {delta_formula}") + # Simplify formula + delta_formula = formula[1] - formula[0] + if simplify_formula: + delta_formula = sympy.simplify(delta_formula) + logger.debug("Delta formula simplified: %s", delta_formula) # Read best formula file and extract variables with open(current_best_formula_file, 'r') as f: @@ -68,7 +74,7 @@ def read_delta_xai_formula(formula_pkl_file: str, current_best_formula_file: str if line.startswith('Variables:'): feature_order = eval(line.strip().split(': ')[1]) feature_mappings = {str(idx+1) : feat_name for idx, feat_name in enumerate(feature_order)} - print(f"Feature mappings: {'; '.join([f'x_{idx} = {feat_name}' for idx, feat_name in feature_mappings.items()])}") + logger.debug("Feature mappings: %s", '; '.join([f'x_{idx} = {feat_name}' for idx, feat_name in feature_mappings.items()])) # Get features in formula vars_in_formula = set(re.findall(r'x_(\d+)', str(delta_formula))) @@ -77,7 +83,7 @@ def read_delta_xai_formula(formula_pkl_file: str, current_best_formula_file: str return delta_formula, feature_mappings, feature_order, features_in_formula -def delta_xai(d_form, x_in: pd.DataFrame, feature_order: list | None = None): +def delta_xai(d_form, x_in: pd.DataFrame, feature_order: list | None = None): """ Calculates probability associated to the classification of the target class based on the KAAM formula calculated using the default training set. @@ -100,27 +106,33 @@ def delta_xai(d_form, x_in: pd.DataFrame, feature_order: list | None = None): # Reorder DataFrame columns to match feature_order x_in = x_in[feature_order] - x_d = x_in.to_numpy() - delta = np.zeros( - (x_d.shape[0], x_d.shape[1] + 1)) # One input per covariate, one extra output for the constant term - for i in range(x_d.shape[0]): - for j in range(x_d.shape[1]): - func = d_form.copy() - for k in range(x_d.shape[1]): - if k == j: - func = func.subs(f'x_{k + 1}', x_d[i, j]) - else: - func = func.subs(f'x_{k + 1}', 0) - delta[i, j] = float(func) - # Get the constant term - func = d_form.copy() - for k in range(x_d.shape[1]): - func = func.subs(f'x_{k + 1}', 0) - delta[i, -1] = float(func) - delta[i, :-1] -= delta[i, -1] # Subtract the constant term from all the other terms (only account for it once) - delta = pd.DataFrame(delta, index=x_in.index, columns=x_in.columns.tolist() + ['const']) - delta["pred_prob"] = 1. / (1 + np.exp(-delta.values.sum(axis=1))) - return delta + x_d = x_in.to_numpy() + n_features = x_d.shape[1] + delta = np.zeros( + (x_d.shape[0], n_features + 1)) # One input per covariate, one extra output for the constant term + + cache_key = (str(d_form), tuple(x_in.columns)) + if cache_key not in _DELTA_XAI_FN_CACHE: + symbols = [sympy.Symbol(f"x_{idx + 1}") for idx in range(n_features)] + zero_subs = {symbol: 0 for symbol in symbols} + const = float(d_form.subs(zero_subs)) + partial_fns = [] + for symbol in symbols: + partial_subs = {other_symbol: 0 for other_symbol in symbols if other_symbol != symbol} + partial_expr = d_form.subs(partial_subs) + partial_fns.append(sympy.lambdify(symbol, partial_expr, "numpy")) + _DELTA_XAI_FN_CACHE[cache_key] = const, partial_fns + + const, partial_fns = _DELTA_XAI_FN_CACHE[cache_key] + for j, partial_fn in enumerate(partial_fns): + values = partial_fn(x_d[:, j]) + delta[:, j] = np.asarray(values, dtype=float) + + delta[:, -1] = const + delta[:, :-1] -= const # Subtract the constant term from all the other terms (only account for it once) + delta = pd.DataFrame(delta, index=x_in.index, columns=x_in.columns.tolist() + ['const']) + delta["pred_prob"] = 1. / (1 + np.exp(-delta.values.sum(axis=1))) + return delta def threshold_for_target_fnr(y_true: list, y_prob: list, target_fnr: float=0): @@ -183,12 +195,13 @@ def analyze_patient( show_closest_radial=True, show_average_radial=True, show_average_class0_radial=True, - show_average_class1_radial=True, - neg_class_label="Negative", - pos_class_label="Positive", - ): + show_average_class1_radial=True, + neg_class_label="Negative", + pos_class_label="Positive", + entity_label="Patient", + ): - print(f"Analyzing patient {patient_id}") + logger.debug("Analyzing patient %s", patient_id) # Get feature values from patient selected and save them in delta_patient patient_index = df[df['ID'] == int(patient_id)].index.to_list()[0] @@ -237,9 +250,9 @@ def analyze_patient( # avg_proba = 1 / (1 + np.exp(-delta_train_values_noconst.mean(axis=0))) # avg_proba_class0 = 1 / (1 + np.exp(-delta_train_class0_values_noconst.mean(axis=0))) # avg_proba_class1 = 1 / (1 + np.exp(-delta_train_class1_values_noconst.mean(axis=0))) - title = f"Patient {patient_id} (prob = {pred_proba_patient:.2f}, mean = {avg_proba[0]:.2f})" - fig_radar, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(projection='radar')) - # fig.subplots_adjust(top=0.85, bottom=0.05) + title = f"{entity_label} {patient_id} (prob = {pred_proba_patient:.2f}, mean = {avg_proba[0]:.2f})" + fig_radar, ax = plt.subplots(figsize=(7.8, 5.8), subplot_kw=dict(projection='radar')) + fig_radar.subplots_adjust(left=0.32, right=0.96, top=0.86, bottom=0.08) ax.set_rgrids([0.2, 0.4, 0.6, 0.8]) ax.set_title(title, position=(0.5, 1.1), ha='center', weight='bold') ax.set_yticklabels([]) # Remove radial axis numbers @@ -278,40 +291,65 @@ def analyze_patient( ax.fill(theta, pat_proba[feature_idx], alpha=0.1, color='r') ax.set_varlabels([label.replace("_", " ") for label in feature_names]) - plt.legend(loc='best', ncol=2, bbox_to_anchor=(0.8, -0.1)) # Note: this can be uncommented, but may clutter the plot + ax.legend(loc='center left', ncol=1, bbox_to_anchor=(-0.55, 0.5), frameon=True) # Curves plot: show only the ones that do matter!! # Revert again the order to have the right plot order feature_idx_inv = feature_idx[::-1] cols_vars = [delta_test.columns[i] for i in feature_idx_inv] - #print(cols_vars) n_feats = min(len(cols_vars), max_plot_curves) # Number of features to show in the curves plot if n_feats > 0: # There is something to show dtrain_df = pd.DataFrame(delta_train_values, columns=feature_names+['const']) fig_curve, axs = plt.subplots(n_feats + 1, 1, figsize=(6, 2 * (n_feats + 1))) - x_vals = np.arange(delta_train_values.sum(axis=1).min(), delta_train_values.sum(axis=1).max(), 0.01) - theor_proba = 1 / (1 + np.exp(-x_vals)) - axs[0].plot(x_vals, theor_proba, 'b', alpha=0.2) - axs[0].scatter(delta_patient.values.sum(), pred_proba_patient, color='r') + clipped_train_prob = np.clip(delta_train["pred_prob"].to_numpy(dtype=float), 1e-9, 1 - 1e-9) + train_logits = np.log(clipped_train_prob / (1 - clipped_train_prob)) + clipped_patient_prob = np.clip(pred_proba_patient, 1e-9, 1 - 1e-9) + patient_logit = float(np.log(clipped_patient_prob / (1 - clipped_patient_prob))) + x_min = min(float(train_logits.min()), patient_logit) + x_max = max(float(train_logits.max()), patient_logit) + padding = max((x_max - x_min) * 0.08, 0.1) + x_vals = np.linspace(x_min - padding, x_max + padding, 300) + theor_proba = 1 / (1 + np.exp(-x_vals)) + axs[0].plot(x_vals, theor_proba, 'b', alpha=0.2) + axs[0].scatter(patient_logit, pred_proba_patient, color='r') axs[0].set_xlabel('Logit') axs[0].set_ylabel('Probability') axs[0].set_title(f'Patient results') - for idj, feat_name in enumerate(cols_vars): - if idj < n_feats: # Only plot the first n_feats features - j = idj + 1 # The first plot is already used for the theoretical curve - # Keep only unique values of x_test[feat_name] - idxs = np.unique(x_train[feat_name].values, return_index=True)[1] - axs[j].plot(x_train[feat_name].values[idxs], dtrain_df[feat_name].values[idxs], color='b') - axs[j].scatter(x_train[feat_name].values[idxs], dtrain_df[feat_name].values[idxs], - color='b', alpha=0.1) - for jj in range(n_dists): - axs[j].scatter(x_train.iloc[idx_closest[jj]][feat_name], dtrain_df.iloc[idx_closest[jj]][feat_name], - color='lightsalmon', alpha=1) - - axs[j].scatter(patient_info[feat_name], delta_patient[feat_name], color='r') - axs[j].set_ylabel(f"logit {feat_name.replace('_', ' ')}") + for idj, feat_name in enumerate(cols_vars): + if idj < n_feats: # Only plot the first n_feats features + j = idj + 1 # The first plot is already used for the theoretical curve + curve_df = pd.DataFrame( + { + "x": x_train[feat_name].values, + "y": dtrain_df[feat_name].values, + } + ).dropna().sort_values("x") + curve_df = curve_df.drop_duplicates("x", keep="first") + if len(curve_df) > 1: + axs[j].plot(curve_df["x"].values, curve_df["y"].values, color='b') + axs[j].scatter(curve_df["x"].values, curve_df["y"].values, color='b', alpha=0.1) + for jj in range(n_dists): + axs[j].scatter(x_train.iloc[idx_closest[jj]][feat_name], dtrain_df.iloc[idx_closest[jj]][feat_name], + color='lightsalmon', alpha=1) + + patient_x = float(patient_info[feat_name].iloc[0]) + patient_y = float(delta_patient[feat_name].iloc[0]) + axs[j].scatter(patient_x, patient_y, color='r') + x_axis_values = curve_df["x"].tolist() + [patient_x] + y_axis_values = curve_df["y"].tolist() + [patient_y] + if x_axis_values: + x_axis_min = min(x_axis_values) + x_axis_max = max(x_axis_values) + x_padding = max((x_axis_max - x_axis_min) * 0.08, 0.05) + axs[j].set_xlim(x_axis_min - x_padding, x_axis_max + x_padding) + if y_axis_values: + y_axis_min = min(y_axis_values) + y_axis_max = max(y_axis_values) + y_padding = max((y_axis_max - y_axis_min) * 0.08, 0.05) + axs[j].set_ylim(y_axis_min - y_padding, y_axis_max + y_padding) + axs[j].set_ylabel(f"logit {feat_name.replace('_', ' ')}") return fig_radar, fig_curve @@ -332,14 +370,14 @@ def analyze_patient_new( show_average_class1_radial=True ): - print(f"Analyzing patient {patient_id} with Enhanced Visualization...") + logger.debug("Analyzing patient %s with enhanced visualization.", patient_id) # --- 1. DATA PREPARATION --- # Locate patient patient_ids = df['ID'].astype(int).tolist() if int(patient_id) not in patient_ids: - print(f"Error: Patient {patient_id} not found.") + logger.debug("Patient %s not found.", patient_id) return None, None patient_index = df[df['ID'] == int(patient_id)].index[0] diff --git a/src/XAInyPredictor/shinyapp/data_exploration.py b/src/XAInyPredictor/shinyapp/data_exploration.py index 077e23b..e2e1da5 100644 --- a/src/XAInyPredictor/shinyapp/data_exploration.py +++ b/src/XAInyPredictor/shinyapp/data_exploration.py @@ -1,15 +1,16 @@ -from os import path import matplotlib.pyplot as plt +import numpy as np import pandas as pd -import seaborn as sns -from shiny import Inputs, Outputs, Session, module, reactive, render, req, ui +from shiny import Inputs, Outputs, Session, module, reactive, render, ui @module.ui def data_exploration_ui(config=None): if config is None: config = {} - + labels = config.get("labels", {}) + entity = labels.get("entity", {}) + text = labels.get("text", {}) return ui.layout_sidebar( ui.sidebar( ui.output_ui("feature_dropdown"), @@ -35,30 +36,27 @@ def data_exploration_ui(config=None): selected='Input data', ), ui.output_ui("patient_selector_ui"), - ui.p("Select a patient to highlight their value in the distribution.", style="font-size: 0.8em; color: gray;"), + ui.output_ui("context_selection_hint"), + class_="context-sidebar", ), ui.page_fluid( ui.layout_columns( ui.card( - ui.card_header( - ui.div( - "Feature Distribution Analysis", - ui.tooltip( - ui.span(ui.tags.i(class_="glyphicon glyphicon-question-sign"), style="font-size: 0.8em; color: gray; margin-left: 5px;"), - "Blue bars: Reference population distribution. Red dashed line: Selected patient's value.", - ) - ) - ), + ui.card_header(ui.output_ui("context_card_header")), + ui.output_ui("context_plot_context"), ui.output_plot("plot_feature_distribution"), - full_screen=True + full_screen=True, + class_="cohort-context-plot-card", ), ui.card( - ui.card_header("Feature Statistics"), + ui.card_header("Feature summary"), ui.output_ui("feature_stats_ui"), - full_screen=True + full_screen=True, + class_="feature-summary-card", ), col_widths=(9, 3), ), + class_="cohort-context-page", ), ) @@ -68,25 +66,126 @@ def server(input: Inputs, output: Outputs, session: Session, global_input_data, display_to_name = reactive.Value({}) + def _text_labels(): + current_cfg = config_reactive.get() if config_reactive else (config_init or {}) + labels = current_cfg.get("labels", {}) + entity = labels.get("entity", {}) + return { + "singular": entity.get("singular", "patient"), + "singular_title": entity.get("singular_title", "Patient"), + "plural": entity.get("plural", "patients"), + "set_lower": entity.get("set_lower", "cohort"), + "set_title": entity.get("set_title", "Cohort"), + "select_label": labels.get("text", {}).get("select_entity_label", "Select patient:"), + "search_placeholder": labels.get("text", {}).get("search_entity_placeholder", "Search by patient ID..."), + "context_selection_hint": labels.get("text", {}).get( + "context_selection_hint", + f"Highlighted {entity.get('singular', 'patient')} updates the plot.", + ), + "context_card_title": labels.get("text", {}).get( + "context_card_title", + f"Feature distribution in current {entity.get('set_lower', 'cohort')}", + ), + "context_card_tooltip": labels.get("text", {}).get( + "context_card_tooltip", + f"Blue bars: Reference population distribution. Red dashed line: Selected {entity.get('singular', 'patient')}'s value.", + ), + } + + def _feature_label(feature: dict) -> str: + return feature.get("display_name") or feature.get("label") or feature.get("name", "Feature") + + def _is_plot_feature(feature: dict) -> bool: + if feature.get("plot") is False: + return False + if feature.get("role") == "identifier": + return False + return True + + def _normalize_columns(df): + if df is None: + return None + df = df.copy() + df.columns = [col.replace('_', ' ') for col in df.columns] + return df + + def _selected_feature_name(df=None): + display_name = str(input.feature_to_plot()) + col_name = display_to_name.get().get(display_name, display_name) + if df is not None and col_name not in df.columns: + underscored = col_name.replace(' ', '_') + for column in df.columns: + if column.replace(' ', '_') == underscored: + return column + return col_name + + def _selected_feature_label(): + return str(input.feature_to_plot()) + + def _reference_dataframe(): + reference_data = input.reference_data() + if reference_data == 'Input data': + df = global_input_data.get() + else: + md = model_data.get() + df = md.get("X_TRAIN_RAW") if md else None + return _normalize_columns(df) + + def _selected_patient_row(): + df = _normalize_columns(global_input_data.get()) + patient_id = patient_selected_id.get() + if df is None or df.empty or 'ID' not in df.columns or patient_id is None: + return None + row = df[df['ID'] == patient_id] + return None if row.empty else row + + def _format_value(value): + if pd.isna(value): + return "Missing" + if isinstance(value, (int, float, np.integer, np.floating)): + if float(value).is_integer(): + return str(int(value)) + return f"{float(value):.2f}" + return str(value) + @reactive.Effect def _rebuild_feature_list(): current_cfg = config_reactive.get() if config_reactive else (config_init or {}) features = current_cfg.get("features", []) if current_cfg else [] - display_to_name.set({f["display_name"]: f["name"] for f in features}) + features = [feature for feature in features if _is_plot_feature(feature)] + display_to_name.set({_feature_label(f): f["name"] for f in features}) @output @render.ui def feature_dropdown(): current_cfg = config_reactive.get() if config_reactive else (config_init or {}) features = current_cfg.get("features", []) if current_cfg else [] - choices = [(f["display_name"], f["display_name"]) for f in features] + features = [feature for feature in features if _is_plot_feature(feature)] + choices = [(_feature_label(f), _feature_label(f)) for f in features] if not choices: return ui.p("No features configured", style="font-size: 0.85em; color: gray;") return ui.input_select( id="feature_to_plot", label="Feature to plot:", choices=dict(choices), - selected=features[0].get("display_name") if features else None, + selected=_feature_label(features[0]) if features else None, + ) + + @output + @render.ui + def context_selection_hint(): + return ui.p(_text_labels()["context_selection_hint"], class_="sidebar-help-text") + + @output + @render.ui + def context_card_header(): + labels = _text_labels() + return ui.div( + labels["context_card_title"], + ui.tooltip( + ui.span(ui.tags.i(class_="glyphicon glyphicon-question-sign"), class_="card-help-icon"), + labels["context_card_tooltip"], + ), ) @reactive.Effect @@ -128,17 +227,15 @@ def patient_selector_ui(): selected_val = all_ids[0] else: selected_val = int(current_selection) - print(f"Patient selected (inspect): {selected_val}") - return ui.input_selectize( id="local_patient_select", - label="Select patient:", + label=_text_labels()["select_label"], choices=all_ids, selected=selected_val, options={ "create": False, # Don't allow creating new options "allowEmptyOption": False, - "placeholder": "Search by patient ID...", + "placeholder": _text_labels()["search_placeholder"], "maxItems": 1 } ) @@ -170,59 +267,86 @@ def feature_stats_ui(): """ Display statistics for the selected feature """ - reference_data = input.reference_data() - if reference_data == 'Input data': - df = global_input_data.get() - else: - md = model_data.get() - df = md.get("X_TRAIN_RAW") if md else None + df = _reference_dataframe() if df is None or df.empty: - return ui.p("No data available") - - # Ensure column names have the proper format for the UI - df = df.copy() - df.columns = [col.replace('_', ' ') for col in df.columns] + return ui.div("No data available", class_="feature-summary-empty") - display_name = str(input.feature_to_plot()) - col_name = display_to_name.get().get(display_name, display_name) - # Also try with space replacement as fallback - if col_name not in df.columns: - col_name_alt = col_name.replace(' ', '_') - if col_name_alt in [c.replace(' ', '_') for c in df.columns]: - col_name = col_name_alt + feature = _selected_feature_name(df) - feature = col_name if feature not in df.columns: - return ui.p("Feature not found") + return ui.div("Feature not found", class_="feature-summary-empty") - # Get the feature data feature_data = df[feature].dropna() + missing = int(df[feature].isnull().sum()) + + def metric(label, value, class_name=""): + return ui.div( + ui.div(label, class_="feature-metric-label"), + ui.div(value, class_="feature-metric-value"), + class_=f"feature-metric {class_name}".strip(), + ) if pd.api.types.is_numeric_dtype(feature_data): - stats = { - 'Mean': f"{feature_data.mean():.2f}", - 'Median': f"{feature_data.median():.2f}", - 'Std Dev': f"{feature_data.std():.2f}", - 'Min': f"{feature_data.min():.2f}", - 'Max': f"{feature_data.max():.2f}", - 'Count': f"{len(feature_data)}", - 'Missing': f"{df[feature].isnull().sum()}" - } + if feature_data.empty: + return ui.div("No non-missing values", class_="feature-summary-empty") + + return ui.div( + metric("Count", str(len(feature_data)), "feature-metric-primary"), + metric("Missing", str(missing)), + metric("Mean", f"{feature_data.mean():.2f}"), + metric("Median", f"{feature_data.median():.2f}"), + metric("Range", f"{feature_data.min():.2f} - {feature_data.max():.2f}"), + metric("Std dev", f"{feature_data.std():.2f}" if len(feature_data) > 1 else "N/A"), + class_="feature-metric-grid", + ) else: - stats = { - 'Unique Values': f"{feature_data.nunique()}", - 'Most Common': f"{feature_data.mode().iloc[0] if not feature_data.mode().empty else 'N/A'}", - 'Count': f"{len(feature_data)}", - 'Missing': f"{df[feature].isnull().sum()}" - } - - # Create statistics table - stat_items = [] - for key, value in stats.items(): - stat_items.append(ui.p(ui.tags.b(key + ": "), value)) - - return ui.div(*stat_items, style="font-size: 0.9em;") + unique_count = feature_data.nunique() + mode = feature_data.mode().iloc[0] if not feature_data.mode().empty else "N/A" + return ui.div( + metric("Count", str(len(feature_data)), "feature-metric-primary"), + metric("Missing", str(missing)), + metric("Unique", str(unique_count)), + metric("Most common", str(mode), "feature-metric-wide"), + class_="feature-metric-grid", + ) + + @output + @render.ui + def context_plot_context(): + df = _reference_dataframe() + patient_row = _selected_patient_row() + labels = _text_labels() + + if df is None or df.empty: + return None + + feature = _selected_feature_name(df) + feature_label = _selected_feature_label() + patient_id = patient_selected_id.get() + patient_value = None + if patient_row is not None and feature in patient_row.columns: + patient_value = patient_row[feature].values[0] + + chips = [ + ui.span(f"Reference: {input.reference_data()}", class_="context-chip"), + ui.span(f"n={len(df)}", class_="context-chip"), + ] + if patient_id is not None: + chips.append(ui.span(f"Highlighted {labels['singular']}: {patient_id}", class_="context-chip context-chip-highlight")) + if patient_value is not None: + chips.append(ui.span(f"Value: {_format_value(patient_value)}", class_="context-chip context-chip-highlight")) + + notices = [] + if len(df) < 10: + notices.append( + ui.div( + f"Small {labels['set_lower']}: distribution is based on {len(df)} records.", + class_="context-small-sample context-small-sample-info", + ) + ) + + return ui.div(ui.div(*chips, class_="context-chip-row"), *notices, class_="context-plot-context") @output @render.plot @@ -231,24 +355,16 @@ def plot_feature_distribution(): Plot distribution of a feature within the input dataset and highlight the value from a specific patient. """ - df = global_input_data.get() - if df is None or df.empty: + input_df = _normalize_columns(global_input_data.get()) + df = _reference_dataframe() + if input_df is None or input_df.empty or df is None or df.empty: fig, ax = plt.subplots() ax.set_axis_off() ax.text(0.5, 0.5, "No data available", ha="center", va="center") return fig - - # Ensure column names have the proper format for the UI - df = df.copy() - df.columns = [col.replace('_', ' ') for col in df.columns] - display_name = str(input.feature_to_plot()) - col_name = display_to_name.get().get(display_name, display_name) - if col_name not in df.columns: - col_name_alt = col_name.replace(' ', '_') - if col_name_alt in [c.replace(' ', '_') for c in df.columns]: - col_name = col_name_alt - feature = col_name + feature = _selected_feature_name(df) + feature_label = _selected_feature_label() patient_id = patient_selected_id.get() if feature not in df.columns: @@ -258,62 +374,119 @@ def plot_feature_distribution(): return fig # Filter for specific patient - patient_row = df[df['ID'] == patient_id] + patient_row = input_df[input_df['ID'] == patient_id] if patient_row.empty: fig, ax = plt.subplots(figsize=(10, 6)) ax.set_axis_off() - ax.text(0.5, 0.5, f"Patient {patient_id} not found in data!", ha="center", va="center", fontsize=16) + ax.text(0.5, 0.5, f"{_text_labels()['singular_title']} {patient_id} not found in data!", ha="center", va="center", fontsize=16) return fig patient_val = patient_row[feature].values[0] - - # Get model data if necessary - reference_data = input.reference_data() - if reference_data == 'Model': - md = model_data.get() - df = md.get("X_TRAIN_RAW") if md else None - if df is not None: - df = pd.concat([df, patient_row]) + plot_series = df[feature].dropna() # Create figure - fig, ax = plt.subplots() + fig, ax = plt.subplots(figsize=(10, 5.6)) # Handle categorical vs numerical - if pd.api.types.is_numeric_dtype(df[feature]): - # Create histogram for numerical data - ax.hist(df[feature].dropna(), bins=30, color='#4682B4', alpha=0.5, edgecolor='black', linewidth=0.5) - ax.set_title(f"Distribution of {feature}\nPatient {patient_id} value: {patient_val}", - fontsize=14, fontweight='bold', pad=20) - ax.set_xlabel(feature, fontsize=12) - ax.set_ylabel('Frequency', fontsize=12) + if pd.api.types.is_numeric_dtype(plot_series): + numeric_values = pd.to_numeric(plot_series, errors="coerce").dropna() + if len(numeric_values) < 10: + y = [1] * len(numeric_values) + ax.scatter(numeric_values, y, color='#4682B4', alpha=0.75, s=80, edgecolor='black', linewidth=0.4, label="Reference records") + ax.set_yticks([]) + ax.set_ylabel("") + else: + bins = min(30, max(6, int(len(numeric_values) ** 0.5))) + ax.hist(numeric_values, bins=bins, color='#4682B4', alpha=0.5, edgecolor='black', linewidth=0.5) + ax.set_ylabel('Frequency', fontsize=12) + + ax.set_title(f"Distribution of {feature_label}", fontsize=14, fontweight='bold', pad=16) + ax.set_xlabel(feature_label, fontsize=12) # Add vertical line for patient value ax.axvline(patient_val, color='red', linewidth=3, linestyle='--', - label=f'Patient {patient_id} ({patient_val})') + label=f"{_text_labels()['singular_title']} {patient_id} ({_format_value(patient_val)})") # Add statistics text - mean_val = df[feature].mean() + mean_val = numeric_values.mean() ax.axvline(mean_val, color='blue', linewidth=2, linestyle='-', label=f'Mean ({mean_val:.2f})') else: - # Create bar plot for categorical data - value_counts = df[feature].value_counts() - bars = ax.bar(value_counts.index, value_counts.values, - color='#4682B4', alpha=0.5, edgecolor='black', linewidth=0.5) - ax.set_title(f"Distribution of {feature}\nPatient {patient_id} value: {patient_val}", - fontsize=14, fontweight='bold', pad=20) - ax.set_xlabel(feature, fontsize=12) - ax.set_ylabel('Count', fontsize=12) - ax.tick_params(axis='x', rotation=45) - - # Highlight patient's value - if patient_val in value_counts.index: - patient_idx = list(value_counts.index).index(patient_val) - bars[patient_idx].set_color('red') - bars[patient_idx].set_alpha(1.0) - bars[patient_idx].set_edgecolor('black') - bars[patient_idx].set_linewidth(2) + value_counts = plot_series.astype(str).value_counts() + unique_ratio = value_counts.size / max(len(plot_series), 1) + if value_counts.size == 1: + ax.set_axis_off() + ax.text( + 0.5, + 0.58, + f"All records share the same {feature_label} value", + ha="center", + va="center", + fontsize=15, + fontweight="bold", + ) + ax.text( + 0.5, + 0.43, + f"Value: {_format_value(patient_val)}", + ha="center", + va="center", + fontsize=12, + color="#405261", + ) + return fig + + if value_counts.size > 12 or unique_ratio > 0.8: + ax.set_axis_off() + ax.text( + 0.5, + 0.58, + f"{feature_label} has mostly unique values", + ha="center", + va="center", + fontsize=15, + fontweight="bold", + ) + ax.text( + 0.5, + 0.43, + f"Selected {_text_labels()['singular']} value: {_format_value(patient_val)}", + ha="center", + va="center", + fontsize=12, + color="#405261", + ) + return fig + + patient_val_str = str(patient_val) + bar_colors = ['#9fbed6'] * len(value_counts) + edge_colors = ['#6889a5'] * len(value_counts) + line_widths = [0.8] * len(value_counts) + if patient_val_str in value_counts.index: + patient_idx = list(value_counts.index).index(patient_val_str) + bar_colors[patient_idx] = '#bfd7ec' + edge_colors[patient_idx] = '#d62728' + line_widths[patient_idx] = 2.5 + + ax.barh( + value_counts.index, + value_counts.values, + color=bar_colors, + alpha=0.95, + edgecolor=edge_colors, + linewidth=line_widths, + ) + ax.set_title(f"Distribution of {feature_label}", fontsize=14, fontweight='bold', pad=16) + ax.set_xlabel('Count', fontsize=12) + ax.set_ylabel(feature_label, fontsize=12) + ax.invert_yaxis() + + if patient_val_str in value_counts.index: + ax.barh([], [], color='#bfd7ec', edgecolor='#d62728', linewidth=2.5, label=f"Selected value: {_format_value(patient_val)}") ax.grid(True, alpha=0.3) - ax.legend() + handles, _ = ax.get_legend_handles_labels() + if handles: + ax.legend() + fig.tight_layout() return fig diff --git a/src/XAInyPredictor/shinyapp/data_input.py b/src/XAInyPredictor/shinyapp/data_input.py index 0bdbf5b..6718795 100644 --- a/src/XAInyPredictor/shinyapp/data_input.py +++ b/src/XAInyPredictor/shinyapp/data_input.py @@ -3,6 +3,12 @@ from XAInyPredictor.modules.data_processing import clean_data +def _decimal_text(value): + if value in (None, ""): + return "" + return str(value).replace(",", ".") + + def build_form_fields(config: dict) -> tuple: """Build form fields dynamically from config. Returns (columns, labels_dict) where columns is a list of ui.div for layout, @@ -14,15 +20,30 @@ def build_form_fields(config: dict) -> tuple: if not features: return None, {} - columns = [] + sections = {} + section_order = [] labels_dict = {} - current_column = [] - column_count = 0 + + def _section_for_feature(feature: dict) -> str: + name = f"{feature.get('name', '')} {feature.get('display_name', '')} {feature.get('label', '')}".lower() + if any(token in name for token in ["peptide", "hla", "allele", "pseudo"]): + return "Candidate identity" + if any(token in name for token in ["molecular", "gravy", "instability", "isoelectric", "half", "charge"]): + return "Molecular properties" + if any(token in name for token in ["mhcflurry", "score", "affinity", "presentation", "processing"]): + return "Presentation features" + if any(token in name for token in ["age", "gender", "bmi"]): + return "Patient profile" + if any(token in name for token in ["tumor", "histology", "node", "metastases", "extension", "multifocality", "vascular", "resection"]): + return "Tumor characteristics" + if any(token in name for token in ["ata", "rai", "treatment"]): + return "Treatment context" + return "Input variables" for idx, feature in enumerate(features): feat_name = feature.get("name", "") - display_name = feature.get("display_name", feat_name) - input_type = feature.get("input_type", "numeric") + display_name = feature.get("display_name") or feature.get("label") or feat_name + input_type = feature.get("input_type") or feature.get("type", "numeric") values = feature.get("values", []) display_values = feature.get("display_values", {}) default = feature.get("default", "") @@ -37,14 +58,19 @@ def build_form_fields(config: dict) -> tuple: field_ui = None if input_type == "numeric": - field_ui = ui.input_numeric( - input_id, - display_name, - value=default if default else min_val, - min=min_val, - max=max_val, - step=step + default_value = default if default not in (None, "") else min_val + field_ui = ui.div( + ui.input_text( + input_id, + display_name, + value=_decimal_text(default_value), + placeholder=_decimal_text(default_value), + autocomplete="off", + ), + class_="manual-decimal-input", ) + elif input_type == "text": + field_ui = ui.input_text(input_id, display_name, value=str(default or "")) elif input_type == "select" and values: choices = {} for v in values: @@ -56,26 +82,29 @@ def build_form_fields(config: dict) -> tuple: field_ui = ui.div( field_ui, ui.tooltip( - ui.span(ui.tags.i(class_="glyphicon glyphicon-question-sign"), style="cursor: pointer; color: #007bc2; margin-left: 5px;"), + ui.span(ui.tags.i(class_="glyphicon glyphicon-question-sign"), class_="form-help-icon"), help_text, placement="right" ), - style="display: flex; align-items: center;" + class_="form-field-with-help", ) - current_column.append(ui.div(field_ui, style="margin-bottom: 10px;")) - - if (idx + 1) % 4 == 0 or idx == len(features) - 1: - if current_column: - col_ui = ui.div( - *current_column, - style="padding: 10px;" - ) - columns.append(col_ui) - current_column = [] - column_count += 1 + section_name = _section_for_feature(feature) + if section_name not in sections: + sections[section_name] = [] + section_order.append(section_name) + sections[section_name].append(ui.div(field_ui, class_="manual-entry-field")) + + section_cards = [ + ui.div( + ui.div(section_name, class_="manual-entry-section-title"), + ui.div(*sections[section_name], class_="manual-entry-section-fields"), + class_="manual-entry-section", + ) + for section_name in section_order + ] - return columns, labels_dict + return section_cards, labels_dict @module.ui @@ -84,60 +113,169 @@ def data_input_ui(config=None): config = {} labels = config.get("labels", {}) - titles = config.get("titles", {}) - - form_columns, labels_dict = build_form_fields(config) - - return ui.layout_sidebar( - ui.sidebar( - ui.h4(labels.get("data_source", "Data Source")), - ui.input_radio_buttons( - "input_method", - labels.get("select_method", "Select Method:"), - {"form": labels.get("manual_entry", "Manual Entry"), "file": labels.get("upload_file", "Upload File"), "example": labels.get("example_cohort", "Example Cohort")} - ), - ui.panel_conditional( - "input.input_method == 'file'", - ui.input_file( - "input_dataset_file", - "Upload File (TSV, CSV, Excel)", - accept=[".tsv", ".csv", ".xlsx"], - multiple=False, - ), - ui.p("Ensure columns match the template.", style="font-size: 0.8em; color: gray;"), + return ui.page_fluid( + ui.div( + ui.card( + ui.card_header("Selected Use Case"), + ui.output_ui("use_case_summary"), + class_="use-case-summary-card", ), - ), - - ui.panel_conditional( - "input.input_method == 'form'", ui.card( - ui.card_header(ui.tags.b("➕ " + (labels.get("form", {}).get("add_patient_button", "New Patient Entry")))), - ui.output_ui("form_fields"), + ui.card_header(ui.output_ui("cohort_table_header")), + ui.output_ui("current_set_body"), + ui.output_ui("input_validation_ui"), + ui.output_ui("input_validation_success_ui"), ui.card_footer( - ui.input_action_button("btn_add_form", labels.get("form", {}).get("add_patient_button", "Add Patient to Cohort"), class_="btn-primary", width="100%") - ) + ui.div( + ui.output_ui("apply_data_ui"), + ui.div( + ui.output_ui("delete_selected_action"), + ui.output_ui("reset_cohort_action"), + class_="cohort-secondary-actions", + ), + class_="cohort-actions", + ) + ), + class_="current-set-card", + full_screen=True, ), - ui.br() - ), - - ui.card( - ui.card_header(ui.tags.b(labels.get("form", {}).get("cohort_title", "Current Patient Cohort"))), - ui.output_data_frame("out_patient_table"), - ui.panel_conditional( - "input.input_method == 'form'", - ui.input_action_button("btn_delete_selected", labels.get("form", {}).get("delete_selected_button", "Delete Selected"), class_="btn-danger btn-sm", width="100%") + ui.div( + ui.output_ui("input_method_header"), + ui.navset_tab( + ui.nav_panel( + labels.get("manual_entry", "Manual Entry"), + ui.card( + ui.card_header(ui.output_ui("manual_entry_header")), + ui.output_ui("form_fields"), + ui.card_footer(ui.output_ui("manual_entry_action")), + ), + value="form", + ), + ui.nav_panel( + labels.get("upload_file", "Upload File"), + ui.card( + ui.card_header(ui.output_ui("upload_entry_header")), + ui.div( + ui.div( + ui.input_file( + "input_dataset_file", + "Upload File (TSV, CSV, Excel)", + accept=[".tsv", ".csv", ".xlsx"], + multiple=False, + ), + ui.p("CSV, TSV, or Excel file with the current use case columns.", class_="upload-template-hint"), + class_="upload-file-picker", + ), + ui.div( + ui.download_button( + "download_input_template", + "Template", + class_="btn-default btn-sm input-template-download", + ), + ui.download_button( + "download_data_dictionary", + "Data dictionary", + class_="btn-default btn-sm input-template-download", + ), + class_="upload-resource-actions", + ), + class_="upload-file-panel", + ), + ui.div( + ui.output_ui("upload_file_requirements"), + ui.input_action_button("btn_load_file", "Load file into current set", class_="btn-primary"), + class_="upload-file-actions", + ), + ), + value="file", + ), + ui.nav_panel( + labels.get("example_cohort", "Example Cohort"), + ui.card( + ui.card_header(ui.output_ui("example_entry_header")), + ui.div( + ui.div( + ui.div("Configured sample dataset", class_="example-set-title"), + ui.p( + "Load a ready-to-use dataset for the selected use case.", + class_="upload-template-hint", + ), + class_="example-set-copy", + ), + ui.output_ui("example_set_summary"), + class_="example-set-panel", + ), + ui.div( + ui.input_action_button("btn_load_example", "Load example into current set", class_="btn-primary"), + class_="example-set-actions", + ), + ), + value="example", + ), + id="input_method", + selected="form", + ), + class_="data-input-method-tabs", ), - full_screen=True - ) + class_="data-input-page", + ), ) @module.server def server(input: Inputs, output: Outputs, session: Session, model_data, config_init, config_reactive=None): labels_dict = reactive.Value({}) + input_types = reactive.Value({}) feature_cols = reactive.Value([]) form_df = reactive.Value(pd.DataFrame()) id_counter = reactive.Value(1) + validation_errors = reactive.Value([]) + validation_success = reactive.Value(None) + pending_data = reactive.Value(None) + last_valid_data = reactive.Value(None) + pending_is_custom = reactive.Value(False) + pending_signature = reactive.Value(None) + confirmed_signature = reactive.Value(None) + confirmed_ready = reactive.Value(False) + output_data = reactive.Value(None) + output_is_custom = reactive.Value(False) + working_source = reactive.Value("form") + + def _current_text_labels(): + current_cfg = config_reactive.get() if config_reactive else (config_init or {}) + labels = current_cfg.get("labels", {}) + entity = labels.get("entity", {}) + form_labels = labels.get("form", {}) + return { + "singular": entity.get("singular", "patient"), + "plural": entity.get("plural", "patients"), + "set_lower": entity.get("set_lower", "cohort"), + "set_title": entity.get("set_title", "Cohort"), + "manual_source": form_labels.get("manual_source", "manual cohort"), + "uploaded_source": form_labels.get("uploaded_source", "uploaded file"), + "example_source": form_labels.get("example_source", "example cohort"), + "added_message": form_labels.get("added_message", "Patient added."), + "deleted_message": form_labels.get("deleted_message", "Deleted {count} patient(s)."), + "reset_message": form_labels.get("reset_message", "Cohort reset."), + "groups_label": labels.get("text", {}).get("groups_label", "Patient groups"), + "add_button": form_labels.get("add_patient_button", "Add Patient to Cohort"), + "delete_button": form_labels.get("delete_selected_button", "Delete Selected"), + "reset_button": form_labels.get("reset_cohort_button", "Reset Cohort"), + "cohort_title": form_labels.get( + "cohort_title", + f"Current {entity.get('set_title', 'Patient Cohort')}", + ), + "no_data_message": form_labels.get("no_data_message", "No data uploaded."), + } + + def _input_method_choices(): + current_cfg = config_reactive.get() if config_reactive else (config_init or {}) + labels = current_cfg.get("labels", {}) + return { + "form": labels.get("manual_entry", "Manual Entry"), + "file": labels.get("upload_file", "Upload File"), + "example": labels.get("example_cohort", "Example Cohort"), + } @reactive.Effect def _rebuild_fields(): @@ -145,6 +283,7 @@ def _rebuild_fields(): features = current_cfg.get("features", []) if current_cfg else [] new_labels = {} + new_input_types = {} new_cols = [] if features: new_cols = [f["name"] for f in features] @@ -152,19 +291,30 @@ def _rebuild_fields(): feat_name = f.get("name", "") input_id = f"in_{feat_name.replace(' ', '_').replace('(', '').replace(')', '')}" new_labels[input_id] = feat_name + new_input_types[feat_name] = f.get("input_type") or f.get("type", "numeric") labels_dict.set(new_labels) + input_types.set(new_input_types) feature_cols.set(new_cols) form_df.set(pd.DataFrame(columns=['ID'] + new_cols)) + id_counter.set(1) + validation_errors.set([]) + validation_success.set(None) + pending_data.set(None) + last_valid_data.set(None) + pending_is_custom.set(False) + pending_signature.set(None) + confirmed_signature.set(None) + confirmed_ready.set(False) - # Force the radio button back to form on use-case change - ui.update_radio_buttons("input_method", selected="form") + working_source.set("form") - output_data = reactive.Value(None) - output_is_custom = reactive.Value(False) + # Force the tabs back to manual entry on use-case change. + ui.update_navset("input_method", selected="form") def _ensure_id(df): if df is None: return None + df = df.copy().reset_index(drop=True) if 'ID' not in df.columns: df.insert(0, 'ID', range(1, len(df) + 1)) else: @@ -172,27 +322,217 @@ def _ensure_id(df): df = df[cols] return df + def _numeric_error_message(value, feature_name): + current_cfg = config_reactive.get() if config_reactive else config_init + for feature in current_cfg.get("features", []) if current_cfg else []: + if feature.get("name") == feature_name: + display_name = feature.get("display_name") or feature.get("label") or feature_name + return f"{display_name} must be numeric. Invalid value(s): {value}." + return f"{feature_name} must be numeric. Invalid value(s): {value}." + + def _data_signature(df, is_custom): + if df is None or df.empty: + return None + normalized = df.copy().reset_index(drop=True) + normalized = normalized.reindex(sorted(normalized.columns), axis=1) + row_hash = pd.util.hash_pandas_object(normalized.astype(str), index=False).sum() + return f"{bool(is_custom)}:{len(normalized)}:{list(normalized.columns)}:{int(row_hash)}" + + def _format_values(values, max_items=5): + clean_values = [str(v) for v in values if pd.notna(v)] + preview = clean_values[:max_items] + suffix = "" if len(clean_values) <= max_items else f" and {len(clean_values) - max_items} more" + return ", ".join(preview) + suffix + + def _template_values(feature, n_rows=5): + if "default" in feature: + default = feature.get("default") + else: + default = "" + + if feature.get("input_type") == "select": + values = feature.get("values", []) + if not values: + return [default] * n_rows + start_idx = values.index(default) if default in values else 0 + return [values[(start_idx + i) % len(values)] for i in range(n_rows)] + + min_val = feature.get("min") + max_val = feature.get("max") + step = feature.get("step", 1) + + if min_val is None or max_val is None: + return [default] * n_rows + + min_float = float(_decimal_text(min_val)) + max_float = float(_decimal_text(max_val)) + step_float = float(_decimal_text(step)) + base_value = float(_decimal_text(default)) if default != "" else (min_float + max_float) / 2 + midpoint = (min_float + max_float) / 2 + numeric_values = [ + base_value, + min_val, + max_val, + midpoint, + min(max(base_value + step_float, min_float), max_float), + ] + + return [round(value, 2) if isinstance(value, float) and not value.is_integer() else int(value) if isinstance(value, float) else value for value in numeric_values[:n_rows]] + + def _template_dataframe(config): + features = config.get("features", []) if config else [] + n_rows = 5 + rows = [{"ID": idx + 1} for idx in range(n_rows)] + for feature in features: + feat_name = feature.get("name") + if feat_name: + values = _template_values(feature, n_rows) + for idx, row in enumerate(rows): + row[feat_name] = values[idx] + return pd.DataFrame(rows) + + def _data_dictionary_dataframe(config): + rows = [] + features = config.get("features", []) if config else [] + for feature in features: + allowed_values = feature.get("values", []) + rows.append( + { + "Variable": feature.get("name", ""), + "Display Name": feature.get("display_name") or feature.get("label") or feature.get("name", ""), + "Type": feature.get("type", ""), + "Input Type": feature.get("input_type", ""), + "Required": "YES", + "Allowed Values": ", ".join(map(str, allowed_values)), + "Min": feature.get("min", ""), + "Max": feature.get("max", ""), + "Default": feature.get("default", ""), + "Description": feature.get("help_text", ""), + } + ) + return pd.DataFrame(rows) + + def _validation_success_message(df, config, method): + features = config.get("features", []) if config else [] + required_cols = [feature.get("name") for feature in features if feature.get("name")] + matched_cols = [col for col in required_cols if col in df.columns] + text_labels = _current_text_labels() + source = { + "form": text_labels["manual_source"], + "file": text_labels["uploaded_source"], + "example": text_labels["example_source"], + }.get(method, text_labels["set_lower"]) + return f"{len(df)} {text_labels['singular']}(s) loaded from {source}. {len(matched_cols)}/{len(required_cols)} required columns matched. No validation issues found." + + def _validate_input_data(df, config): + if df is None or df.empty or not config: + return [] + + errors = [] + features = config.get("features", []) + required_cols = [feature.get("name") for feature in features if feature.get("name")] + missing_cols = [col for col in required_cols if col not in df.columns] + + if missing_cols: + errors.append(f"Missing required column(s): {', '.join(missing_cols)}.") + + for feature in features: + feat_name = feature.get("name") + if not feat_name or feat_name not in df.columns: + continue + + display_name = feature.get("display_name") or feature.get("label") or feat_name + input_type = feature.get("input_type") or feature.get("type", "numeric") + series = df[feat_name] + non_missing = series[series.notna()] + + if non_missing.empty: + continue + + if input_type == "numeric": + numeric_series = pd.to_numeric(non_missing.astype(str).str.replace(",", ".", regex=False), errors="coerce") + invalid_values = non_missing[numeric_series.isna()].drop_duplicates().tolist() + if invalid_values: + errors.append(f"{display_name} must be numeric. Invalid value(s): {_format_values(invalid_values)}.") + continue + + min_val = feature.get("min") + max_val = feature.get("max") + range_mask = pd.Series(False, index=numeric_series.index) + if min_val is not None: + range_mask = range_mask | (numeric_series < min_val) + if max_val is not None: + range_mask = range_mask | (numeric_series > max_val) + + if range_mask.any(): + bad_values = non_missing[range_mask].drop_duplicates().tolist() + if min_val is not None and max_val is not None: + range_text = f"between {min_val} and {max_val}" + elif min_val is not None: + range_text = f"greater than or equal to {min_val}" + else: + range_text = f"less than or equal to {max_val}" + errors.append(f"{display_name} must be {range_text}. Out-of-range value(s): {_format_values(bad_values)}.") + + elif input_type == "select": + allowed_values = feature.get("values", []) + if not allowed_values: + continue + + invalid_values = sorted(set(non_missing.astype(str)) - set(map(str, allowed_values))) + if invalid_values: + errors.append( + f"{display_name} has unexpected categor{('y' if len(invalid_values) == 1 else 'ies')}: " + f"{_format_values(invalid_values)}. Allowed: {', '.join(map(str, allowed_values))}." + ) + elif input_type == "text": + continue + + return errors + @reactive.Effect @reactive.event(input.btn_add_form) def _add_patient(): current_df = form_df.get() new_id = id_counter() current_labels = labels_dict.get() + current_input_types = input_types.get() + current_cfg = config_reactive.get() if config_reactive else config_init row_data = {'ID': new_id} for input_id, feat_name in current_labels.items(): val = getattr(input, input_id, None)() if val is not None: - if isinstance(val, str): + if current_input_types.get(feat_name) == "numeric": + numeric_value = pd.to_numeric(_decimal_text(val), errors="coerce") + if pd.isna(numeric_value): + error = _numeric_error_message(val, feat_name) + validation_errors.set([error]) + validation_success.set(None) + ui.notification_show("Input validation failed. See details in Data Input.", type="error") + return + row_data[feat_name] = float(numeric_value) + elif isinstance(val, str): row_data[feat_name] = val else: row_data[feat_name] = float(val) new_row = pd.DataFrame([row_data]) + validation_df = pd.concat([current_df, new_row], ignore_index=True) + errors = _validate_input_data(validation_df, current_cfg) + if errors: + validation_errors.set(errors) + validation_success.set(None) + ui.notification_show("Input validation failed. See details in Data Input.", type="error") + return + updated_df = pd.concat([current_df, new_row], ignore_index=True) + validation_errors.set([]) + validation_success.set(None) form_df.set(updated_df) + working_source.set("form") id_counter.set(new_id + 1) - ui.notification_show("Patient added.", type="message") + ui.notification_show(_current_text_labels()["added_message"], type="message") @reactive.Effect @reactive.event(input.btn_delete_selected) @@ -202,82 +542,401 @@ def _delete_patient(): ui.notification_show("No rows selected.", type="warning") return - current_df = form_df.get() - updated_df = current_df.drop(index=list(selected_rows)).reset_index(drop=True) + current_df = _ensure_id(form_df.get()) + selected_positions = sorted({int(row) for row in selected_rows}, reverse=True) + valid_positions = [row for row in selected_positions if 0 <= row < len(current_df)] + if not valid_positions: + ui.notification_show("Selected rows are no longer available.", type="warning") + return + + updated_df = current_df.drop(current_df.index[valid_positions]).reset_index(drop=True) form_df.set(updated_df) - ui.notification_show(f"Deleted {len(selected_rows)} patient(s).", type="message") + working_source.set("form") + ui.notification_show(_current_text_labels()["deleted_message"].format(count=len(valid_positions)), type="message") @reactive.Effect - def _update_output_pipeline(): - method = str(input.input_method()) - raw_df = None - is_custom = False + @reactive.event(input.btn_reset_cohort) + def _reset_cohort(): + current_cols = feature_cols.get() + form_df.set(pd.DataFrame(columns=['ID'] + current_cols)) + id_counter.set(1) + validation_errors.set([]) + validation_success.set(None) + pending_data.set(None) + last_valid_data.set(None) + pending_is_custom.set(False) + pending_signature.set(None) + confirmed_signature.set(None) + confirmed_ready.set(False) + output_data.set(None) + output_is_custom.set(False) + working_source.set("form") + ui.update_navset("input_method", selected="form") + ui.notification_show(_current_text_labels()["reset_message"], type="message") - md = model_data.get() - example_raw_data = md.get("X_TEST_RAW") + @reactive.Effect + @reactive.event(input.btn_load_file) + def _load_file_into_set(): + file_info = input.input_dataset_file() + if not file_info: + ui.notification_show("Select a file before loading it into the current set.", type="warning") + return - if method == "example": - if example_raw_data is not None: - raw_df = example_raw_data.copy() - is_custom = False + file_path = file_info[0]["datapath"] + try: + if file_path.endswith(".csv"): + df = pd.read_csv(file_path) + elif file_path.endswith(".tsv"): + df = pd.read_csv(file_path, sep="\t") + elif file_path.endswith(".xlsx"): + df = pd.read_excel(file_path) + else: + raise ValueError("Unsupported file format!") + form_df.set(_ensure_id(clean_data(df))) + working_source.set("file") + ui.notification_show("File loaded into the current set.", type="message") + except Exception as e: + ui.notification_show(f"Error reading file: {e}", type="error") + validation_errors.set([f"Could not read uploaded file: {e}"]) + validation_success.set(None) - elif method == "form": - raw_df = form_df.get() - is_custom = True + @reactive.Effect + @reactive.event(input.btn_load_example) + def _load_example_into_set(): + md = model_data.get() + example_raw_data = md.get("X_TEST_RAW") if md else None + if example_raw_data is None or example_raw_data.empty: + ui.notification_show("No example data available for this use case.", type="warning") + return - elif method == "file": - file_info = input.input_dataset_file() - is_custom = True - if not file_info: - output_data.set(None) - output_is_custom.set(is_custom) - return + form_df.set(_ensure_id(example_raw_data.copy())) + working_source.set("example") + ui.notification_show(f"Example {_current_text_labels()['set_lower']} loaded.", type="message") - file_path = file_info[0]["datapath"] - try: - if file_path.endswith(".csv"): - df = pd.read_csv(file_path) - elif file_path.endswith(".tsv"): - df = pd.read_csv(file_path, sep="\t") - elif file_path.endswith(".xlsx"): - df = pd.read_excel(file_path) - else: - raise ValueError("Unsupported file format!") - raw_df = clean_data(df) - except Exception as e: - ui.notification_show(f"Error reading file: {e}", type="error") - raw_df = None + @reactive.Effect + def _update_output_pipeline(): + method = working_source.get() + raw_df = form_df.get() + is_custom = method != "example" + current_cfg = config_reactive.get() if config_reactive else config_init if raw_df is not None and not raw_df.empty: clean_df = _ensure_id(raw_df) - output_data.set(clean_df) + errors = _validate_input_data(clean_df, current_cfg) + validation_errors.set(errors) + if errors: + validation_success.set(None) + pending_is_custom.set(is_custom) + pending_signature.set(None) + confirmed_ready.set(False) + output_is_custom.set(False) + ui.notification_show("Input validation failed. See details in Data Input.", type="error") + return + validation_success.set(_validation_success_message(clean_df, current_cfg, method)) + pending_data.set(clean_df) + last_valid_data.set(clean_df.copy()) + pending_is_custom.set(is_custom) + current_signature = _data_signature(clean_df, is_custom) + pending_signature.set(current_signature) + if current_signature != confirmed_signature.get(): + confirmed_ready.set(False) + output_data.set(None) + output_is_custom.set(False) else: + validation_errors.set([]) + validation_success.set(None) + pending_data.set(None) + last_valid_data.set(None) + pending_is_custom.set(is_custom) + pending_signature.set(None) + confirmed_ready.set(False) output_data.set(None) + output_is_custom.set(False) + + @reactive.Effect + @reactive.event(input.btn_use_data) + def _use_pending_data(): + clean_df = pending_data.get() + if clean_df is None or clean_df.empty: + ui.notification_show("No valid data available to use.", type="warning") + return + + output_data.set(clean_df.copy()) + output_is_custom.set(bool(pending_is_custom.get())) + confirmed_signature.set(pending_signature.get()) + confirmed_ready.set(True) + ui.notification_show(f"{_current_text_labels()['set_title']} confirmed for analysis.", type="message") + + @output + @render.ui + def input_validation_ui(): + errors = validation_errors.get() + if not errors: + return None + + return ui.div( + ui.div("Input validation failed", class_="input-validation-title"), + ui.tags.ul(*[ui.tags.li(error) for error in errors]), + class_="input-validation-alert", + ) + + @output + @render.ui + def input_validation_success_ui(): + message = validation_success.get() + if not message: + return None + + return ui.div( + ui.div("Input validation passed", class_="input-validation-title"), + ui.p(message, class_="input-validation-success-copy"), + class_="input-validation-success", + ) + + @output + @render.ui + def apply_data_ui(): + clean_df = pending_data.get() + if clean_df is None or clean_df.empty: + return None + + labels = _current_text_labels() + is_confirmed = confirmed_ready.get() and pending_signature.get() == confirmed_signature.get() + if is_confirmed: + return ui.div( + ui.div( + ui.div(f"{labels['set_title']} confirmed", class_="confirm-data-title"), + ui.p( + f"{len(clean_df)} {labels['plural']} are locked for Candidate Context and Stratification Support.", + class_="apply-data-hint", + ), + class_="confirm-data-copy", + ), + class_="apply-data-panel apply-data-panel-confirmed", + ) - output_is_custom.set(is_custom) + return ui.div( + ui.input_action_button( + "btn_use_data", + f"Confirm {labels['set_lower']}", + class_="btn-primary", + ), + ui.p( + "Candidate Context and Stratification Support will unlock after confirmation.", + class_="apply-data-hint", + ), + class_="apply-data-panel", + ) + + @render.download(filename="input_template.csv") + def download_input_template(): + current_cfg = config_reactive.get() if config_reactive else config_init + yield _template_dataframe(current_cfg or {}).to_csv(index=False) + + @render.download(filename="data_dictionary.csv") + def download_data_dictionary(): + current_cfg = config_reactive.get() if config_reactive else config_init + yield _data_dictionary_dataframe(current_cfg or {}).to_csv(index=False) + + @output + @render.ui + def manual_entry_header(): + return ui.tags.b("+ " + _current_text_labels()["add_button"]) + + @output + @render.ui + def upload_entry_header(): + current_cfg = config_reactive.get() if config_reactive else (config_init or {}) + labels = current_cfg.get("labels", {}) + return ui.tags.b(labels.get("upload_file", "Upload File")) + + @output + @render.ui + def upload_file_requirements(): + current_cfg = config_reactive.get() if config_reactive else (config_init or {}) + features = current_cfg.get("features", []) + return ui.div( + ui.span(str(len(features)), class_="upload-requirement-count"), + ui.span(" required variables"), + class_="upload-requirement-pill", + ) + + @output + @render.ui + def example_entry_header(): + current_cfg = config_reactive.get() if config_reactive else (config_init or {}) + labels = current_cfg.get("labels", {}) + return ui.tags.b(labels.get("example_cohort", "Example Cohort")) + + @output + @render.ui + def example_set_summary(): + md = model_data.get() + example_raw_data = md.get("X_TEST_RAW") if md else None + row_count = 0 if example_raw_data is None else len(example_raw_data) + labels = _current_text_labels() + return ui.div( + ui.span(str(row_count), class_="upload-requirement-count"), + ui.span(f" example {labels['plural']}"), + class_="upload-requirement-pill", + ) + + @output + @render.ui + def manual_entry_action(): + return ui.input_action_button( + "btn_add_form", + _current_text_labels()["add_button"], + class_="btn-primary manual-entry-submit", + ) + + @output + @render.ui + def delete_selected_action(): + df = pending_data.get() + if df is None or df.empty: + df = last_valid_data.get() + if df is None or df.empty: + df = output_data.get() + if df is None or df.empty: + return None + + return ui.input_action_button( + "btn_delete_selected", + _current_text_labels()["delete_button"], + class_="btn-danger btn-sm cohort-action-button", + ) + + @output + @render.ui + def reset_cohort_action(): + return ui.input_action_button( + "btn_reset_cohort", + _current_text_labels()["reset_button"], + class_="btn-default btn-sm cohort-action-button", + ) + + @output + @render.ui + def cohort_table_header(): + df = pending_data.get() + if df is None or df.empty: + df = last_valid_data.get() + if df is None or df.empty: + df = output_data.get() + if df is None or df.empty: + df = form_df.get() + row_count = 0 if df is None else len(df) + is_confirmed = ( + confirmed_ready.get() + and pending_signature.get() is not None + and pending_signature.get() == confirmed_signature.get() + ) + + if row_count == 0: + status_label = "Empty" + status_class = "set-status-empty" + elif is_confirmed: + status_label = "Confirmed" + status_class = "set-status-confirmed" + else: + status_label = "Confirmation required" + status_class = "set-status-pending" + + labels = _current_text_labels() + return ui.div( + ui.div( + ui.tags.b(labels["cohort_title"]), + ui.span(f"{row_count} {labels['plural']}", class_="set-row-count"), + class_="cohort-title-group", + ), + ui.span(status_label, class_=f"set-status-pill {status_class}"), + class_="cohort-table-title", + ) + + @output + @render.ui + def input_method_header(): + labels = _current_text_labels() + return ui.div( + ui.div("Add or load records", class_="data-input-method-title"), + ui.div(f"Working {labels['set_lower']}", class_="data-input-method-subtitle"), + class_="data-input-method-header", + ) + + @output + @render.ui + def use_case_summary(): + current_cfg = config_reactive.get() if config_reactive else config_init + if not current_cfg: + return ui.p("No use case selected.") + + features = current_cfg.get("features", []) + labels = current_cfg.get("labels", {}) + positive_label = labels.get("positive_class_label", current_cfg.get("positive_class", "Positive")) + negative_label = labels.get("negative_class_label", current_cfg.get("negative_class", "Negative")) + + return ui.div( + ui.div( + ui.div("Model objective", class_="use-case-summary-label"), + ui.div(current_cfg.get("description", "Patient stratification model"), class_="use-case-summary-value"), + class_="use-case-summary-item use-case-summary-wide", + ), + ui.div( + ui.div(_current_text_labels()["groups_label"], class_="use-case-summary-label"), + ui.div(f"{negative_label} / {positive_label}", class_="use-case-summary-value"), + class_="use-case-summary-item", + ), + ui.div( + ui.div("Input variables", class_="use-case-summary-label"), + ui.div(str(len(features)), class_="use-case-summary-value"), + class_="use-case-summary-item", + ), + class_="use-case-summary", + ) @output @render.ui def form_fields(): current_cfg = config_reactive.get() if config_reactive else config_init - form_columns, _ = build_form_fields(current_cfg or {}) - if form_columns: - return ui.layout_columns( - *form_columns, col_widths=tuple([4] * len(form_columns)) - ) + form_sections, _ = build_form_fields(current_cfg or {}) + if form_sections: + return ui.div(*form_sections, class_="manual-entry-grid") return ui.p("No features configured.") + @output + @render.ui + def current_set_body(): + display_df = pending_data.get() + if display_df is None or display_df.empty: + display_df = last_valid_data.get() + if display_df is None or display_df.empty: + display_df = output_data.get() + + if display_df is None or display_df.empty: + labels = _current_text_labels() + return ui.div( + ui.div("No records yet", class_="empty-set-title"), + ui.div( + f"Add manually, upload a file, or load the example {labels['set_lower']}.", + class_="empty-set-copy", + ), + class_="empty-set-state", + ) + + return ui.output_data_frame("out_patient_table") + @output @render.data_frame def out_patient_table(): - display_df = output_data.get() + display_df = pending_data.get() + if display_df is None or display_df.empty: + display_df = last_valid_data.get() + if display_df is None or display_df.empty: + display_df = output_data.get() if display_df is None or display_df.empty: - return render.DataGrid( - pd.DataFrame({"Message": ["No data uploaded."]}), - width="100%", - selection_mode="none" - ) + return render.DataGrid(pd.DataFrame(), width="100%", height="1px", selection_mode="none") display_df = display_df.copy() display_df.columns = [c.replace('_', ' ') for c in display_df.columns] @@ -286,7 +945,8 @@ def out_patient_table(): return render.DataGrid( display_df, width="100%", + height="260px", selection_mode="rows" ) - return {"data": output_data, "is_custom": output_is_custom} \ No newline at end of file + return {"data": output_data, "is_custom": output_is_custom, "is_confirmed": confirmed_ready} diff --git a/src/XAInyPredictor/shinyapp/prediction.py b/src/XAInyPredictor/shinyapp/prediction.py index 9ad5548..7978420 100644 --- a/src/XAInyPredictor/shinyapp/prediction.py +++ b/src/XAInyPredictor/shinyapp/prediction.py @@ -1,3 +1,7 @@ +from datetime import datetime +from io import BytesIO +from zipfile import ZIP_DEFLATED, ZipFile + import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -15,11 +19,15 @@ def prediction_ui(config=None): titles = config.get("titles", {}) labels = config.get("labels", {}) help_texts = config.get("help_texts", {}) + entity = labels.get("entity", {}) + text = labels.get("text", {}) pos_class_label = labels.get("positive_class_label", "Positive") neg_class_label = labels.get("negative_class_label", "Negative") - prob_col = labels.get("probability_column", "Probability") - class_col = labels.get("class_column", "Class") + singular_title = entity.get("singular_title", "Patient") + plural = entity.get("plural", "patients") + set_title = entity.get("set_title", "Cohort") + reference_title = entity.get("reference_title", "Reference Patients") return ui.layout_sidebar( ui.sidebar( @@ -28,7 +36,7 @@ def prediction_ui(config=None): ui.input_numeric( id="fnr_threshold", label= ui.div( - "Min. % false negatives: ", + "Allowed false-negative rate: ", ui.popover( ui.span(ui.tags.i(class_="glyphicon glyphicon-info-sign"), "", style="color: #007bc2; cursor: pointer; font-size: 0.9em;"), ui.tags.div( @@ -46,80 +54,455 @@ def prediction_ui(config=None): value=0, min=0, max=100, step=1 ), ), - ui.hr(), - ui.h5("Visualization Settings"), - ui.input_select( - id="view_mode", - label="Select View:", - choices={ - "radar": titles.get("feature_analysis", "Feature Analysis (Radar)"), - "curve": titles.get("distance_analysis", "Distance Analysis (Curve)") - }, - selected="radar", - ), - ui.output_ui("features_to_plot_ui"), - ui.panel_conditional( - "input.view_mode == 'radar'", - ui.input_checkbox_group( - id="radar_plot_elements", - label="Radar Plot Elements:", - choices={ - "closest": "Closest patients", - "average": "Average all patients", - "average_0": f"Avg. {neg_class_label}", - "average_1": f"Avg. {pos_class_label}", - }, - selected=["closest", "average", "average_0", "average_1"], - ), - ), ), ui.page_fluid( - ui.layout_columns( - ui.card( - ui.card_header( + ui.div( + ui.div(ui.output_ui("analysis_title_header"), class_="stratification-tabset-title"), + ui.navset_tab( + ui.nav_panel( + singular_title, + ui.div( + ui.tags.details( + ui.tags.summary(text.get("download_menu", "Download outputs")), + ui.div( + ui.download_button("download_report_package", text.get("download_report_package", "Report package"), class_="btn-primary btn-sm report-package-download"), + ui.download_button("download_stratification_results", text.get("download_results", "Stratification results"), class_="btn-default btn-sm"), + ui.download_button("download_closest_patients", text.get("download_closest", "Closest reference candidates"), class_="btn-default btn-sm"), + ui.download_button("download_cohort_summary", text.get("download_summary", "Cohort summary"), class_="btn-default btn-sm"), + class_="stratification-download-menu", + ), + class_="stratification-download-details", + ), + class_="stratification-downloads", + ), + ui.layout_columns( + ui.card( + ui.card_header(ui.output_ui("selected_output_header")), + ui.output_ui("stratification_summary"), + ui.output_ui("stratification_interpretation"), + class_="selected-patient-stratification-card", + ), + col_widths=12, + ), + ui.layout_columns( + ui.card( + ui.card_header(ui.output_ui("results_table_header")), + ui.output_data_frame("results_table_output"), + height="240px", + ), + col_widths=12, + ), + ui.card( + ui.card_header("Profile controls"), ui.div( - titles.get("prediction_results", "Prediction results "), - ui.popover( - ui.span(ui.tags.i(class_="glyphicon glyphicon-info-sign"), "", style="color: #007bc2; cursor: pointer; font-size: 0.9em;"), - ui.tags.div( - ui.tags.b("Prediction results:"), - ui.tags.p("The table shows the predictions of the model:"), - ui.tags.ul( - ui.tags.li(ui.tags.b(f"{prob_col}:"), f" Indicates the probability. The higher, the more likely."), - ui.tags.li(ui.tags.b(f"{class_col}:"), f" Classifies as {pos_class_label} or {neg_class_label}."), + ui.div( + ui.div( + ui.input_select( + id="view_mode", + label="Select View:", + choices={ + "radar": titles.get("feature_analysis_short", "Profile Comparison"), + "curve": titles.get("distance_analysis_short", "Feature Curves") + }, + selected="radar", ), - style="width: 250px;" + class_="patient-visual-setting patient-visual-setting-view", ), - placement="right" - ) - ) + ui.div( + ui.input_action_button("btn_select_default_features", "Top features", class_="btn-default btn-sm feature-control-button"), + ui.input_action_button("btn_clear_features", "Clear", class_="btn-default btn-sm feature-control-button"), + class_="feature-control-row", + ), + class_="patient-visual-controls-top", + ), + ui.div( + ui.div( + ui.output_ui("features_to_plot_ui"), + class_="feature-selector-inline", + ), + class_="patient-visual-setting patient-visual-setting-features", + ), + ui.div( + ui.panel_conditional( + "input.view_mode == 'radar'", + ui.div( + ui.input_checkbox_group( + id="radar_plot_elements", + label="Radar Plot Elements:", + choices={ + "closest": text.get("radar_closest", f"Closest {plural}"), + "average": text.get("radar_average", f"Average all {plural}"), + "average_0": f"Avg. {neg_class_label}", + "average_1": f"Avg. {pos_class_label}", + }, + selected=["closest", "average", "average_0", "average_1"], + ), + class_="patient-visual-setting patient-visual-setting-radar", + ), + ), + class_="patient-visual-radar-wrap", + ), + class_="patient-visual-settings-grid", + ), + class_="patient-visual-settings-card", ), - ui.output_data_frame("results_table_output"), - height="300px", + ui.output_ui("dynamic_plot_container"), + value="patient", ), - col_widths=12 - ), - ui.output_ui("dynamic_plot_container") + ui.nav_panel( + set_title, + ui.layout_columns( + ui.card( + ui.card_header(ui.output_ui("set_summary_header")), + ui.output_ui("cohort_stratification_summary"), + height="180px", + ), + col_widths=12, + ), + ui.layout_columns( + ui.card( + ui.card_header(ui.output_ui("set_distribution_header")), + ui.output_plot("cohort_score_distribution_plot", height="420px", width="100%"), + full_screen=True, + ), + col_widths=12, + ), + value="cohort", + ), + ui.nav_panel( + "Model", + ui.layout_columns( + ui.card( + ui.card_header("Model Card"), + ui.output_ui("model_card"), + class_="model-card", + ), + col_widths=12, + ), + ui.layout_columns( + ui.card( + ui.card_header( + ui.div( + "Global Feature Importance ", + ui.popover( + ui.span(ui.tags.i(class_="glyphicon glyphicon-info-sign"), "", style="color: #007bc2; cursor: pointer; font-size: 0.9em;"), + ui.tags.div( + ui.tags.b("Global feature importance:"), + ui.tags.p("Importance is estimated from the variance of each feature contribution in the model explanation matrix."), + style="width: 250px;" + ), + placement="right" + ) + ) + ), + ui.output_plot("global_feature_importance_plot", height="460px", width="100%"), + full_screen=True, + ), + col_widths=12, + ), + value="model", + ), + ui.nav_panel( + reference_title, + ui.layout_columns( + ui.card( + ui.card_header(ui.output_ui("closest_reference_header")), + ui.output_ui("closest_reference_patients_narrative"), + ui.output_ui("closest_reference_patients_table"), + height="420px", + full_screen=True, + ), + col_widths=12, + ), + value="reference", + ), + id="stratification_tabs", + selected="patient", + ), + class_="stratification-tabset", + ) ), ) @module.server def server(input: Inputs, output: Outputs, session: Session, global_input_data, patient_selected_id, model_data, delta_test_reactive, x_test_reactive, prob_threshold, config_init, config_reactive=None): + selection_revision = reactive.Value(0) @reactive.Calc def current_labels(): current_cfg = config_reactive.get() if config_reactive else (config_init or {}) labels = current_cfg.get("labels", {}) + entity = labels.get("entity", {}) + positive_color = labels.get("positive_class_color", "#dc3545") + negative_color = labels.get("negative_class_color", "#198754") return { "positive_class": current_cfg.get("positive_class", "YES"), "negative_class": current_cfg.get("negative_class", "NO"), "positive_class_label": labels.get("positive_class_label", "Positive"), "negative_class_label": labels.get("negative_class_label", "Negative"), - "probability_column": labels.get("probability_column", "Probability"), - "class_column": labels.get("class_column", "Class") + "positive_class_color": positive_color, + "negative_class_color": negative_color, + "positive_class_bg": labels.get("positive_class_bg", f"{positive_color}22"), + "negative_class_bg": labels.get("negative_class_bg", f"{negative_color}22"), + "probability_column": labels.get("probability_column", "Stratification Score"), + "class_column": labels.get("class_column", "Patient Group"), + "singular": entity.get("singular", "patient"), + "plural": entity.get("plural", "patients"), + "singular_title": entity.get("singular_title", "Patient"), + "plural_title": entity.get("plural_title", "Patients"), + "set_lower": entity.get("set_lower", "cohort"), + "set_title": entity.get("set_title", "Cohort"), + "reference_plural": entity.get("reference_plural", "reference patients"), + "reference_title": entity.get("reference_title", "Reference Patients"), + "text": labels.get("text", {}), } + @reactive.Calc + def current_config(): + return config_reactive.get() if config_reactive else (config_init or {}) + + @reactive.Calc + def stratification_results_df(): + selection_revision.get() + df = global_input_data.get() + delta_test = delta_test_reactive.get() + prob_thr = prob_threshold.get() + lbls = current_labels() + + if df is None or delta_test is None or df.empty or delta_test.empty: + return None + if "ID" not in df.columns or "pred_prob" not in delta_test.columns: + return None + if len(df) != len(delta_test): + return None + + threshold = float(prob_thr) if prob_thr is not None else 0 + res_df = pd.concat( + [ + df[["ID"]].reset_index(drop=True), + delta_test[["pred_prob"]].reset_index(drop=True), + ], + axis=1, + ) + res_df["ID"] = res_df["ID"].astype(int) + res_df = res_df.rename(columns={"pred_prob": lbls["probability_column"]}) + res_df[lbls["class_column"]] = np.where( + res_df[lbls["probability_column"]] >= threshold, + lbls["positive_class_label"], + lbls["negative_class_label"], + ) + return res_df.sort_values(by=["ID"]).reset_index(drop=True) + + def _feature_display_name(feature_name: str) -> str: + return str(feature_name).replace("_", " ") + + def _contribution_columns(delta_df): + if delta_df is None: + return [] + return [col for col in delta_df.columns if col not in ["const", "pred_prob"]] + + def _selected_entity_id(): + sel_id = patient_selected_id.get() + if sel_id is None or sel_id == "" or sel_id == "None": + return None + try: + return int(sel_id) + except (TypeError, ValueError): + return None + + @reactive.Calc + def closest_reference_patients_df(): + selection_revision.get() + raw_df = global_input_data.get() + delta_test = delta_test_reactive.get() + md = model_data.get() + sel_id = _selected_entity_id() + lbls = current_labels() + threshold = float(prob_threshold.get()) if prob_threshold.get() is not None else 0 + + if ( + raw_df is None + or delta_test is None + or md is None + or sel_id is None + or raw_df.empty + or delta_test.empty + ): + return None + if len(raw_df) != len(delta_test): + return None + + raw_df = raw_df.reset_index(drop=True) + delta_test = delta_test.reset_index(drop=True) + + delta_train = md.get("D_TRAIN") + if delta_train is None or delta_train.empty: + return None + + patient_rows = raw_df.index[raw_df["ID"].astype(int) == sel_id].tolist() + if not patient_rows: + return None + patient_pos = patient_rows[0] + + feature_cols = [ + col + for col in _contribution_columns(delta_train) + if col in delta_test.columns + ] + if not feature_cols or patient_pos >= len(delta_test): + return None + + train_matrix = delta_train[feature_cols].to_numpy(dtype=float) + patient_vector = delta_test.iloc[patient_pos][feature_cols].to_numpy(dtype=float) + distances = np.linalg.norm(train_matrix - patient_vector, axis=1) + closest_positions = np.argsort(distances)[:5] + + rows = [] + for pos in closest_positions: + score = float(delta_train.iloc[int(pos)]["pred_prob"]) + group = lbls["positive_class_label"] if score >= threshold else lbls["negative_class_label"] + row = { + "Reference Rank": len(rows) + 1, + "Distance": round(float(distances[int(pos)]), 3), + lbls["probability_column"]: round(score, 3), + lbls["class_column"]: group, + } + rows.append(row) + + return pd.DataFrame(rows) + + @reactive.Calc + def cohort_summary_df(): + res_df = stratification_results_df() + lbls = current_labels() + + if res_df is None or res_df.empty: + return None + + prob_col = lbls["probability_column"] + class_col = lbls["class_column"] + total = len(res_df) + rows = [ + {"Metric": lbls["plural_title"], "Value": total}, + {"Metric": "Mean score", "Value": round(float(res_df[prob_col].mean()), 3)}, + {"Metric": "Decision threshold", "Value": round(float(prob_threshold.get() or 0), 3)}, + ] + for group, count in res_df[class_col].value_counts().items(): + rows.append({"Metric": f"{group} count", "Value": int(count)}) + rows.append({"Metric": f"{group} percentage", "Value": round(float(count) / total * 100, 1)}) + return pd.DataFrame(rows) + + def _export_metadata(): + cfg = current_config() + lbls = current_labels() + input_df = global_input_data.get() + item_count = 0 if input_df is None else len(input_df) + threshold = float(prob_threshold.get()) if prob_threshold.get() is not None else 0 + + return { + "Use Case": cfg.get("name", "Unknown use case"), + "Model Objective": cfg.get("description", "Patient stratification model"), + "Target": cfg.get("target_column", "target"), + "Decision Threshold": round(threshold, 3), + "Negative Group": lbls["negative_class_label"], + "Positive Group": lbls["positive_class_label"], + f"{lbls['singular_title']} Count": item_count, + "Exported At": datetime.now().isoformat(timespec="seconds"), + "Output Type": "Research prototype; clinical workflow utility should be validated with collaborators", + } + + def _attach_export_metadata(df): + metadata = _export_metadata() + export_df = df.copy() if df is not None else pd.DataFrame() + for key, value in reversed(list(metadata.items())): + export_df.insert(0, key, value) + return export_df + + def _cohort_summary_export_df(): + summary_df = cohort_summary_df() + if summary_df is None: + summary_df = pd.DataFrame() + metadata_df = pd.DataFrame( + [{"Metric": key, "Value": value} for key, value in _export_metadata().items()] + ) + if not summary_df.empty: + return pd.concat([metadata_df, pd.DataFrame([{"Metric": "", "Value": ""}]), summary_df], ignore_index=True) + return metadata_df + + def _metadata_export_df(): + return pd.DataFrame( + [{"Field": key, "Value": value} for key, value in _export_metadata().items()] + ) + + def _report_readme_text(): + metadata = _export_metadata() + lbls = current_labels() + text = lbls["text"] + closest_reference_file = f"closest_reference_{_slug_for_filename(lbls['plural'])}.csv" + set_summary_file = f"{_slug_for_filename(lbls['set_lower'])}_stratification_summary.csv" + return ( + "XAInyPredictor report package\n" + "============================\n\n" + f"Use case: {metadata.get('Use Case', 'Unknown use case')}\n" + f"Exported at: {metadata.get('Exported At', '')}\n\n" + "Files included:\n" + f"- metadata.csv: use case, threshold, {lbls['class_column'].lower()}, export timestamp, and prototype context.\n" + f"- stratification_results.csv: {lbls['singular']}-level stratification score and assigned group.\n" + f"- {set_summary_file}: {lbls['set_lower']}-level counts, score summary, and metadata.\n" + f"- {closest_reference_file}: anonymized closest-reference ranks, distances, scores, and classes for the selected {lbls['singular']} when available.\n\n" + "Prototype context:\n" + + text.get( + "report_context", + "This package supports patient stratification research. Clinical workflow utility and interpretation should be validated with clinical collaborators.", + ) + "\n" + ) + + def _slug_for_filename(value): + slug = "".join(ch.lower() if ch.isalnum() else "_" for ch in str(value)) + return "_".join(part for part in slug.split("_") if part) + + def _closest_reference_filename(): + lbls = current_labels() + return f"closest_reference_{_slug_for_filename(lbls['plural'])}.csv" + + def _set_summary_filename(): + lbls = current_labels() + return f"{_slug_for_filename(lbls['set_lower'])}_stratification_summary.csv" + + def _run_patient_analysis(data, feats, opts=None): + if not data: + return None, None + + opts = opts or [] + lbls = current_labels() + try: + result = analyze_patient( + patient_id=data["patient_id"], + df=data["df"], + delta_train=data["delta_train"], + delta_test=data["delta_test"], + x_train=data["x_train"], + y_train=data["y_train"], + x_test=data["x_test"], + features_to_plot=feats, + n_dists=3, + show_closest_radial="closest" in opts, + show_average_radial="average" in opts, + show_average_class0_radial="average_0" in opts, + show_average_class1_radial="average_1" in opts, + neg_class_label=lbls["negative_class_label"], + pos_class_label=lbls["positive_class_label"], + entity_label=lbls["singular_title"], + ) + except Exception: + return None, None + + if not result: + return None, None + + return result + @reactive.Effect def _update_sidebar_labels(): lbls = current_labels() @@ -129,15 +512,121 @@ def _update_sidebar_labels(): ui.update_checkbox_group( "radar_plot_elements", choices={ - "closest": "Closest patients", - "average": "Average all patients", + "closest": lbls["text"].get("radar_closest", f"Closest {lbls['plural']}"), + "average": lbls["text"].get("radar_average", f"Average all {lbls['plural']}"), "average_0": f"Avg. {neg}", "average_1": f"Avg. {pos}", }, selected=list(input.radar_plot_elements()) if input.radar_plot_elements() else ["closest", "average", "average_0", "average_1"] ) - @output + @reactive.Effect + def _sync_selection_with_active_data(): + df = global_input_data.get() + if df is None or df.empty or "ID" not in df.columns: + patient_selected_id.set(None) + return + + all_ids = sorted(df["ID"].astype(int).tolist()) + current_selection = _selected_entity_id() + if current_selection is None or int(current_selection) not in all_ids: + patient_selected_id.set(all_ids[0]) + ui.update_selectize("local_patient_select", selected=all_ids[0]) + + @output(suspend_when_hidden=False) + @render.ui + def analysis_title_header(): + lbls = current_labels() + return lbls["text"].get("analysis_title", "Stratification Analysis") + + @output(suspend_when_hidden=False) + @render.ui + def selected_output_header(): + lbls = current_labels() + return lbls["text"].get( + "selected_output_title", + f"Selected {lbls['singular_title']} Stratification", + ) + + @output(suspend_when_hidden=False) + @render.ui + def results_table_header(): + cfg = current_config() + titles = cfg.get("titles", {}) + lbls = current_labels() + text = lbls["text"] + prob_col = lbls["probability_column"] + class_col = lbls["class_column"] + return ui.div( + titles.get("prediction_results", "Prediction results "), + ui.popover( + ui.span(ui.tags.i(class_="glyphicon glyphicon-info-sign"), "", style="color: #007bc2; cursor: pointer; font-size: 0.9em;"), + ui.tags.div( + ui.tags.b(text.get("results_table_help_title", "Stratification results:")), + ui.tags.p(text.get("results_table_help", f"The table shows model-based {lbls['singular']} stratification:")), + ui.tags.ul( + ui.tags.li(ui.tags.b(f"{prob_col}:"), text.get("score_help", f" Indicates the score used to assign the {lbls['singular']} class.")), + ui.tags.li(ui.tags.b(f"{class_col}:"), text.get("class_help", f" Assigns the {lbls['singular']} to {lbls['positive_class_label']} or {lbls['negative_class_label']}.")), + ), + style="width: 250px;", + ), + placement="right", + ), + ) + + @output(suspend_when_hidden=False) + @render.ui + def set_summary_header(): + lbls = current_labels() + return lbls["text"].get("set_summary_title", f"{lbls['set_title']} Stratification Summary") + + @output(suspend_when_hidden=False) + @render.ui + def set_distribution_header(): + lbls = current_labels() + text = lbls["text"] + return ui.div( + text.get("set_distribution_title", f"{lbls['set_title']} Score Distribution "), + ui.popover( + ui.span(ui.tags.i(class_="glyphicon glyphicon-info-sign"), "", style="color: #007bc2; cursor: pointer; font-size: 0.9em;"), + ui.tags.div( + ui.tags.b(text.get("set_distribution_help_title", f"{lbls['set_title']} score distribution:")), + ui.tags.p( + text.get( + "set_distribution_help", + f"{lbls['plural_title']} are ordered by stratification score. The dashed line marks the current decision threshold and the selected {lbls['singular']} is highlighted.", + ) + ), + style="width: 250px;", + ), + placement="right", + ), + ) + + @output(suspend_when_hidden=False) + @render.ui + def closest_reference_header(): + lbls = current_labels() + text = lbls["text"] + return ui.div( + text.get("closest_reference_title", f"Closest {lbls['reference_title']} "), + ui.popover( + ui.span(ui.tags.i(class_="glyphicon glyphicon-info-sign"), "", style="color: #007bc2; cursor: pointer; font-size: 0.9em;"), + ui.tags.div( + ui.tags.b(text.get("closest_reference_help_title", f"Closest {lbls['reference_plural']}:")), + ui.tags.p( + text.get( + "closest_reference_help", + f"{lbls['reference_title']} are ranked by distance in the model contribution space, using the same feature effects that drive the explanation plots.", + ) + ), + style="width: 260px;", + ), + placement="right", + ), + ) + + @output(suspend_when_hidden=False) @render.ui def patient_selector_ui(): """ @@ -145,37 +634,37 @@ def patient_selector_ui(): global_input_data variable and according to this, it creates a dropdown to select patients. """ + lbls = current_labels() df = global_input_data.get() if df is None or df.empty or 'ID' not in df.columns: - return ui.p("No valid data loaded.") + return ui.p(lbls["text"].get("empty_results_message", f"Add or upload {lbls['plural']} in Data Input to generate stratification outputs.")) # Get list of patient IDs all_ids = sorted(df['ID'].astype(int).tolist()) # Check current selection from global state - current_selection = patient_selected_id.get() + current_selection = _selected_entity_id() # Default to first if no current selection - if current_selection == None or int(current_selection) not in all_ids: + if current_selection is None or int(current_selection) not in all_ids: selected_val = all_ids[0] else: selected_val = int(current_selection) - print(f"Patient selected (run analysis): {selected_val}") return ui.input_selectize( id="local_patient_select", - label="Select patient:", + label=lbls["text"].get("select_entity_label", f"Select {lbls['singular']}:"), choices=all_ids, selected=selected_val, options={ "create": False, # Don't allow creating new options "allowEmptyOption": False, - "placeholder": "Search by patient ID...", + "placeholder": lbls["text"].get("search_entity_placeholder", f"Search by {lbls['singular']} ID..."), "maxItems": 1 } ) - @output + @output(suspend_when_hidden=False) @render.ui def features_to_plot_ui(): """ @@ -184,21 +673,48 @@ def features_to_plot_ui(): features that appear in the formula (discarding the ones that do not appear). """ md = model_data.get() + cfg = current_config() feature_names = md.get("FEATURE_ORDER_DISPLAY", []) if md else [] features_to_plot = md.get("FEATS_IN_FORMULA", []) if md else [] + max_default_features = cfg.get("default_selected_features") + if max_default_features: + features_to_plot = features_to_plot[:int(max_default_features)] return ui.input_selectize( id="features_to_plot", label="Select features to view:", choices=feature_names, selected=features_to_plot, multiple=True, + options={ + "plugins": ["remove_button"], + "maxOptions": 100, + "placeholder": "Choose features", + }, ) - @output + @reactive.Effect + @reactive.event(input.btn_select_default_features) + def _select_default_features(): + md = model_data.get() + cfg = current_config() + if not md: + return + features_to_plot = md.get("FEATS_IN_FORMULA", []) + max_default_features = cfg.get("default_selected_features", 5) + features_to_plot = features_to_plot[:int(max_default_features)] + ui.update_selectize("features_to_plot", selected=features_to_plot) + + @reactive.Effect + @reactive.event(input.btn_clear_features) + def _clear_features(): + ui.update_selectize("features_to_plot", selected=[]) + + @output(suspend_when_hidden=False) @render.ui def dynamic_plot_container(): mode = input.view_mode() selected_features = input.features_to_plot() + cfg = current_config() lbls = current_labels() neg = lbls["negative_class_label"] @@ -207,24 +723,25 @@ def dynamic_plot_container(): # Calculate dynamic height for Curve plot # Base height + (pixels per feature * number of features) n_feats = len(selected_features) if selected_features else 0 - curve_height_px = 300 + (n_feats * 250) + max_curve_features = int(cfg.get("max_curve_features", 8)) + curve_height_px = 300 + (min(n_feats, max_curve_features) * 210) # Define UI cards radar_card = ui.card( ui.card_header( ui.div( - "Feature Analysis (Radar) ", + lbls["text"].get("profile_comparison_title", "Patient Profile Comparison "), ui.popover( ui.span(ui.tags.i(class_="glyphicon glyphicon-info-sign"), "", style="color: #007bc2; cursor: pointer; font-size: 0.9em;"), ui.tags.div( ui.tags.b("Understanding the radar plot:"), - ui.tags.p("The radar plot compares the values of the features from a selected patient (red) with three different distributions:"), + ui.tags.p(lbls["text"].get("radar_help_intro", f"The radar plot compares the values of the features from a selected {lbls['singular']} (red) with three different distributions:")), ui.tags.ul( - ui.tags.li(ui.tags.b("Average all patients:"), " The average values from all patients in the model (blue)."), - ui.tags.li(ui.tags.b(f"Avg. {neg}:"), f" The average values from all {neg} patients in the model (green)."), - ui.tags.li(ui.tags.b(f"Avg. {pos}:"), f" The average values from all {pos} patients in the model (yellow)."), + ui.tags.li(ui.tags.b(f"Average all {lbls['plural']}:"), f" The average values from all {lbls['plural']} in the model (blue)."), + ui.tags.li(ui.tags.b(f"Avg. {neg}:"), f" The average values from all {neg} {lbls['plural']} in the model (green)."), + ui.tags.li(ui.tags.b(f"Avg. {pos}:"), f" The average values from all {pos} {lbls['plural']} in the model (yellow)."), ), - ui.tags.p("Comparing these values help to evaluate if a patient has been correctly classified or not."), + ui.tags.p(lbls["text"].get("radar_help_outro", f"Comparing these values helps contextualize why the selected {lbls['singular']} falls into a given stratification group.")), style="width: 250px;" ), placement="right" @@ -239,12 +756,12 @@ def dynamic_plot_container(): curve_card = ui.card( ui.card_header( ui.div( - "Distance Analysis (Curves) ", + "Feature Context Curves ", ui.popover( ui.span(ui.tags.i(class_="glyphicon glyphicon-info-sign"), "", style="color: #007bc2; cursor: pointer; font-size: 0.9em;"), ui.tags.div( ui.tags.b("Understanding the curves plot:"), - ui.tags.p("The curves plot displays the distribution of values from a specific feature across the patients in the model, ordered from lowest to highest (blue dots). It highlights in red the patient selected, put in context the feature values of the patient in comparison with the values from the rest of patients."), + ui.tags.p(lbls["text"].get("curves_help", f"The curves plot displays the distribution of values from a specific feature across the {lbls['plural']} in the model, ordered from lowest to highest (blue dots). It highlights in red the selected {lbls['singular']}, putting its feature values in context.")), style="width: 250px;" ), placement="right" @@ -279,6 +796,8 @@ def _sync_selection(): try: new_value = int(val) patient_selected_id.set(new_value) + with reactive.isolate(): + selection_revision.set(selection_revision.get() + 1) except ValueError: # If conversion fails, don't update and optionally show notification # ui.notification_show("Invalid patient ID selected", type="error") @@ -295,7 +814,14 @@ def _calculate_probability_threshold(): md = model_data.get() y_test = md.get("Y_TEST") if md else None - if y_test is None or delta_test is None: + if ( + y_test is None + or delta_test is None + or delta_test.empty + or "pred_prob" not in delta_test.columns + ): + return + if len(y_test) != len(delta_test): return try: @@ -318,12 +844,12 @@ def _calculate_probability_threshold(): ) prob_threshold.set(threshold) - print(f"Updated Threshold: {threshold:.3f} (Target FNR: {target_fnr})") # --- Calculation --- @reactive.Calc def get_patient_data_context(): + selection_revision.get() """ Gathers all the dataframes and IDs needed for analysis. Returns a dictionary of data or None if invalid. @@ -331,7 +857,7 @@ def get_patient_data_context(): raw_df = global_input_data.get() delta_test = delta_test_reactive.get() x_test = x_test_reactive.get() - patient_id = patient_selected_id.get() + patient_id = _selected_entity_id() md = model_data.get() if md is None: @@ -344,6 +870,12 @@ def get_patient_data_context(): # Validation checks if any((x is None) or (isinstance(x, pd.DataFrame) and x.empty) for x in [raw_df, delta_train, delta_test, x_train, y_train, patient_id]): return None + if len(raw_df) != len(delta_test) or len(raw_df) != len(x_test): + return None + + raw_df = raw_df.reset_index(drop=True) + delta_test = delta_test.reset_index(drop=True) + x_test = x_test.reset_index(drop=True) all_ids = sorted(raw_df['ID'].astype(int).tolist()) if patient_id == None or int(patient_id) not in all_ids: @@ -370,73 +902,496 @@ def get_patient_data_context(): # --- Outputs --- - @output + @output(suspend_when_hidden=False) + @render.ui + def stratification_summary(): + try: + sel_id = _selected_entity_id() + prob_thr = prob_threshold.get() + lbls = current_labels() + res_df = stratification_results_df() + + prob_col = lbls["probability_column"] + class_col = lbls["class_column"] + + if res_df is None or res_df.empty or sel_id is None: + return ui.div(lbls["text"].get("empty_individual_summary", f"Add or upload {lbls['plural']} to review individual stratification outputs."), class_="stratification-summary-empty") + + patient_row = res_df[res_df["ID"].astype(int) == sel_id] + if patient_row.empty: + return ui.div(lbls["text"].get("selected_entity_not_found", f"Selected {lbls['singular']} not found."), class_="stratification-summary-empty") + + score = float(patient_row[prob_col].iloc[0]) + threshold = float(prob_thr) if prob_thr is not None else 0 + group = str(patient_row[class_col].iloc[0]) + group_color = lbls["positive_class_color"] if group == lbls["positive_class_label"] else lbls["negative_class_color"] + + return ui.div( + ui.div( + ui.div(lbls["singular_title"], class_="stratification-summary-label"), + ui.div(f"{int(sel_id)}", class_="stratification-summary-value"), + class_="stratification-summary-item", + ), + ui.div( + ui.div(prob_col, class_="stratification-summary-label"), + ui.div(f"{score:.3f}", class_="stratification-summary-value"), + class_="stratification-summary-item", + ), + ui.div( + ui.div("Decision Threshold", class_="stratification-summary-label"), + ui.div(f"{threshold:.3f}", class_="stratification-summary-value"), + class_="stratification-summary-item", + ), + ui.div( + ui.div(class_col, class_="stratification-summary-label"), + ui.div(group, class_="stratification-summary-value", style=f"color: {group_color};"), + class_="stratification-summary-item", + ), + class_="stratification-summary", + ) + except Exception as exc: + return ui.div(f"Unable to update selected stratification: {exc}", class_="stratification-summary-empty") + + @output(suspend_when_hidden=False) + @render.ui + def stratification_interpretation(): + try: + sel_id = _selected_entity_id() + prob_thr = prob_threshold.get() + lbls = current_labels() + res_df = stratification_results_df() + + if res_df is None or res_df.empty or sel_id is None: + return None + + prob_col = lbls["probability_column"] + class_col = lbls["class_column"] + patient_row = res_df[res_df["ID"].astype(int) == sel_id] + if patient_row.empty: + return None + + score = float(patient_row[prob_col].iloc[0]) + threshold = float(prob_thr) if prob_thr is not None else 0 + group = str(patient_row[class_col].iloc[0]) + direction = "above or equal to" if score >= threshold else "below" + positive_label = lbls["positive_class_label"] + negative_label = lbls["negative_class_label"] + interpretation = lbls["text"].get( + "stratification_interpretation", + f"This {lbls['singular']} is assigned to {group} because the stratification score ({score:.3f}) is {direction} the decision threshold ({threshold:.3f}).", + ) + + return ui.div( + ui.p( + interpretation.format( + entity=lbls["singular"], + entity_title=lbls["singular_title"], + group=group, + score=f"{score:.3f}", + threshold=f"{threshold:.3f}", + direction=direction, + ), + class_="stratification-interpretation-main", + ), + ui.div( + ui.div( + ui.tags.b(f"Below threshold: "), + f"assigned to {negative_label}.", + class_="stratification-rule stratification-rule-negative", + style=f"background: {lbls['negative_class_bg']}; color: {lbls['negative_class_color']};", + ), + ui.div( + ui.tags.b(f"At or above threshold: "), + f"assigned to {positive_label}.", + class_="stratification-rule stratification-rule-positive", + style=f"background: {lbls['positive_class_bg']}; color: {lbls['positive_class_color']};", + ), + class_="stratification-rules", + ), + ui.p( + lbls["text"].get( + "stratification_disclaimer", + "Research prototype output: this stratification supports cohort-level clinical decision support research and should be interpreted with clinical collaborators during utility validation.", + ), + class_="stratification-disclaimer", + ), + class_="stratification-interpretation", + ) + except Exception as exc: + return ui.div(f"Unable to update selected stratification explanation: {exc}", class_="stratification-summary-empty") + + @output(suspend_when_hidden=False) + @render.ui + def cohort_stratification_summary(): + res_df = stratification_results_df() + lbls = current_labels() + + if res_df is None or res_df.empty: + return ui.div(lbls["text"].get("empty_set_summary", f"Add or upload {lbls['plural']} to summarize {lbls['set_lower']}-level stratification."), class_="stratification-summary-empty") + + prob_col = lbls["probability_column"] + class_col = lbls["class_column"] + group_counts = res_df[class_col].value_counts() + total = len(res_df) + positive_label = lbls["positive_class_label"] + negative_label = lbls["negative_class_label"] + + positive_count = int(group_counts.get(positive_label, 0)) + negative_count = int(group_counts.get(negative_label, 0)) + mean_score = float(res_df[prob_col].mean()) + positive_pct = (positive_count / total * 100) if total else 0 + + return ui.div( + ui.div( + ui.div(lbls["plural_title"], class_="stratification-summary-label"), + ui.div(str(total), class_="stratification-summary-value"), + class_="stratification-summary-item", + ), + ui.div( + ui.div("Mean score", class_="stratification-summary-label"), + ui.div(f"{mean_score:.3f}", class_="stratification-summary-value"), + class_="stratification-summary-item", + ), + ui.div( + ui.div(negative_label, class_="stratification-summary-label"), + ui.div(str(negative_count), class_="stratification-summary-value", style=f"color: {lbls['negative_class_color']};"), + class_="stratification-summary-item", + ), + ui.div( + ui.div(positive_label, class_="stratification-summary-label"), + ui.div(f"{positive_count} ({positive_pct:.0f}%)", class_="stratification-summary-value", style=f"color: {lbls['positive_class_color']};"), + class_="stratification-summary-item", + ), + class_="stratification-summary", + ) + + @output(suspend_when_hidden=False) + @render.ui + def model_card(): + cfg = current_config() + md = model_data.get() + lbls = current_labels() + features = cfg.get("features", []) + metadata = cfg.get("model_metadata", {}) + threshold = float(prob_threshold.get()) if prob_threshold.get() is not None else 0 + reference_n = 0 + if md is not None and md.get("X_TRAIN_RAW") is not None: + reference_n = len(md.get("X_TRAIN_RAW")) + + feature_names = [ + feature.get("display_name", feature.get("name", "Feature")) + for feature in features + ] + feature_preview = ", ".join(feature_names[:6]) + if len(feature_names) > 6: + feature_preview += f", and {len(feature_names) - 6} more" + + return ui.div( + ui.div( + ui.div( + ui.div("Use case", class_="model-card-label"), + ui.div(cfg.get("name", "Patient stratification model"), class_="model-card-value"), + class_="model-card-item", + ), + ui.div( + ui.div("Target", class_="model-card-label"), + ui.div(cfg.get("target_column", "Target"), class_="model-card-value"), + class_="model-card-item", + ), + ui.div( + ui.div("Decision threshold", class_="model-card-label"), + ui.div(f"{threshold:.3f}", class_="model-card-value"), + class_="model-card-item", + ), + ui.div( + ui.div(lbls["text"].get("reference_set_label", "Reference cohort"), class_="model-card-label"), + ui.div(str(reference_n), class_="model-card-value"), + class_="model-card-item", + ), + ui.div( + ui.div("Input variables", class_="model-card-label"), + ui.div(str(len(features)), class_="model-card-value"), + class_="model-card-item", + ), + ui.div( + ui.div(lbls["text"].get("groups_label", "Patient groups"), class_="model-card-label"), + ui.div(f"{lbls['negative_class_label']} / {lbls['positive_class_label']}", class_="model-card-value"), + class_="model-card-item model-card-groups", + ), + ui.div( + ui.div("Model version", class_="model-card-label"), + ui.div(metadata.get("model_version", "Prototype"), class_="model-card-value"), + class_="model-card-item", + ), + ui.div( + ui.div("Training snapshot", class_="model-card-label"), + ui.div(metadata.get("training_date", "Not specified"), class_="model-card-value"), + class_="model-card-item", + ), + class_="model-card-grid", + ), + ui.div( + ui.div("Model objective", class_="model-card-label"), + ui.p(cfg.get("description", "Patient stratification for cohort-level clinical decision support research."), class_="model-card-copy"), + class_="model-card-note", + ), + ui.div( + ui.div(lbls["text"].get("reference_source_label", "Reference cohort source"), class_="model-card-label"), + ui.p(metadata.get("cohort_source", lbls["text"].get("reference_source_missing", "Reference cohort source not specified.")), class_="model-card-copy"), + class_="model-card-note", + ), + ui.div( + ui.div("Intended use", class_="model-card-label"), + ui.p(metadata.get("intended_use", "Patient stratification research and exploratory decision support."), class_="model-card-copy"), + class_="model-card-note", + ), + ui.div( + ui.div("Variables used", class_="model-card-label"), + ui.p(feature_preview or "No input variables configured.", class_="model-card-copy"), + class_="model-card-note", + ), + ui.div( + ui.div("Validation status", class_="model-card-label"), + ui.p(metadata.get("validation_status", "Clinical utility should be validated with domain experts."), class_="model-card-copy"), + class_="model-card-note model-card-validation", + ), + ui.div( + ui.div("Limitations", class_="model-card-label"), + ui.p(metadata.get("limitations", "Research prototype for patient stratification; outputs require expert review and validation."), class_="model-card-copy"), + class_="model-card-note model-card-warning", + ), + ) + + @output(suspend_when_hidden=False) + @render.plot + def cohort_score_distribution_plot(): + if input.stratification_tabs() != "cohort": + return _empty_plot("Open the set tab to view this plot") + + res_df = stratification_results_df() + sel_id = _selected_entity_id() + threshold = float(prob_threshold.get()) if prob_threshold.get() is not None else 0 + lbls = current_labels() + + if res_df is None or res_df.empty: + return _empty_plot(lbls["text"].get("no_set_scores", f"No {lbls['set_lower']} scores available")) + + prob_col = lbls["probability_column"] + class_col = lbls["class_column"] + sorted_df = res_df.sort_values(prob_col, ascending=False).reset_index(drop=True) + colors = np.where( + sorted_df[class_col] == lbls["positive_class_label"], + lbls["positive_class_color"], + lbls["negative_class_color"], + ) + + fig, ax = plt.subplots(figsize=(9, 4)) + ax.bar(np.arange(len(sorted_df)), sorted_df[prob_col], color=colors, alpha=0.82, width=0.85) + ax.axhline(threshold, color="#1f2d3d", linestyle="--", linewidth=1.5, label=f"Threshold {threshold:.3f}") + + if sel_id is not None and sel_id in sorted_df["ID"].astype(int).values: + selected_idx = int(sorted_df.index[sorted_df["ID"].astype(int) == sel_id][0]) + selected_score = float(sorted_df.loc[selected_idx, prob_col]) + ax.scatter(selected_idx, selected_score, color="#ffc107", edgecolor="#1f2d3d", s=110, zorder=5, label=f"{lbls['singular_title']} {int(sel_id)}") + + ax.set_title(lbls["text"].get("set_scores_plot_title", "Cohort stratification scores"), fontweight="bold") + ax.set_xlabel(lbls["text"].get("set_scores_plot_xlabel", f"{lbls['plural_title']} ordered by score")) + ax.set_ylabel(prob_col) + ax.set_ylim(0, max(1.0, float(sorted_df[prob_col].max()) * 1.08)) + ax.grid(axis="y", alpha=0.25) + ax.legend(loc="best") + fig.tight_layout() + return fig + + @output(suspend_when_hidden=False) + @render.plot + def global_feature_importance_plot(): + if input.stratification_tabs() != "model": + return _empty_plot("Open the model tab to view this plot") + + md = model_data.get() + delta_test = delta_test_reactive.get() + lbls = current_labels() + + if md is None: + return _empty_plot("No model data available") + + delta_train = md.get("D_TRAIN") + if delta_train is None or delta_train.empty: + return _empty_plot("No reference explanation matrix available") + + train_cols = _contribution_columns(delta_train) + if not train_cols: + return _empty_plot("No feature contributions available") + + train_importance = delta_train[train_cols].var(axis=0).sort_values(ascending=False) + top_features = train_importance.head(10).index.tolist() + + fig, ax = plt.subplots(figsize=(9, 4.8)) + y_pos = np.arange(len(top_features)) + ax.barh( + y_pos, + train_importance.loc[top_features].to_numpy(), + color="#007bff", + alpha=0.78, + label="Reference", + ) + + if delta_test is not None and not delta_test.empty: + current_cols = [col for col in top_features if col in delta_test.columns] + if current_cols: + current_importance = delta_test[current_cols].var(axis=0).reindex(top_features).fillna(0) + ax.scatter( + current_importance.to_numpy(), + y_pos, + color="#fd7e14", + s=42, + label=lbls["text"].get("current_set_label", "Current cohort"), + zorder=4, + ) + + ax.set_yticks(y_pos) + ax.set_yticklabels([_feature_display_name(col) for col in top_features]) + ax.invert_yaxis() + ax.set_xlabel("Importance") + ax.set_title("Global feature importance", fontweight="bold") + ax.grid(axis="x", alpha=0.25) + ax.legend(loc="best") + fig.tight_layout() + return fig + + @output(suspend_when_hidden=False) + @render.ui + def closest_reference_patients_table(): + closest_df = closest_reference_patients_df() + if closest_df is None or closest_df.empty: + return None + + return ui.tags.div( + ui.HTML(closest_df.to_html(index=False, classes="reference-candidates-table", border=0)), + class_="reference-candidates-table-wrap", + ) + + @output(suspend_when_hidden=False) + @render.ui + def closest_reference_patients_narrative(): + closest_df = closest_reference_patients_df() + lbls = current_labels() + + if closest_df is None or closest_df.empty: + return ui.div(lbls["text"].get("reference_comparison_empty", f"Reference comparison requires a selected {lbls['singular']} and an available reference {lbls['set_lower']}."), class_="closest-patients-narrative") + + class_col = lbls["class_column"] + if class_col not in closest_df.columns: + return None + + counts = closest_df[class_col].value_counts() + top_group = str(counts.index[0]) + top_count = int(counts.iloc[0]) + total = len(closest_df) + score_col = lbls["probability_column"] + mean_score = float(closest_df[score_col].mean()) if score_col in closest_df.columns else None + + score_text = f" Their mean stratification score is {mean_score:.3f}." if mean_score is not None else "" + return ui.div( + ui.div( + lbls["text"].get( + "closest_reference_narrative", + f"The selected {lbls['singular']} is most similar to {lbls['reference_plural']} mostly assigned to {top_group} ({top_count}/{total}).{score_text}", + ).format( + entity=lbls["singular"], + reference_plural=lbls["reference_plural"], + top_group=top_group, + top_count=top_count, + total=total, + score_text=score_text, + ), + class_="closest-patients-narrative", + ), + ui.div( + lbls["text"].get( + "closest_reference_privacy_note", + "Reference records come from the model reference population, not from the current input set. If the reference population contains private or sensitive records, expose only anonymized summaries or distances rather than row-level reference data.", + ), + class_="reference-privacy-note", + ), + ) + + @render.download(filename="stratification_results.csv") + def download_stratification_results(): + res_df = stratification_results_df() + if res_df is None: + res_df = pd.DataFrame() + yield _attach_export_metadata(res_df).to_csv(index=False) + + @render.download(filename="xainypredictor_report_package.zip") + def download_report_package(): + res_df = stratification_results_df() + closest_df = closest_reference_patients_df() + if res_df is None: + res_df = pd.DataFrame() + if closest_df is None: + closest_df = pd.DataFrame() + + buffer = BytesIO() + with ZipFile(buffer, mode="w", compression=ZIP_DEFLATED) as report_zip: + report_zip.writestr("README.txt", _report_readme_text()) + report_zip.writestr("metadata.csv", _metadata_export_df().to_csv(index=False)) + report_zip.writestr("stratification_results.csv", _attach_export_metadata(res_df).to_csv(index=False)) + report_zip.writestr(_closest_reference_filename(), _attach_export_metadata(closest_df).to_csv(index=False)) + report_zip.writestr(_set_summary_filename(), _cohort_summary_export_df().to_csv(index=False)) + + buffer.seek(0) + yield buffer.getvalue() + + @render.download(filename=_closest_reference_filename) + def download_closest_patients(): + closest_df = closest_reference_patients_df() + if closest_df is None: + closest_df = pd.DataFrame() + yield _attach_export_metadata(closest_df).to_csv(index=False) + + @render.download(filename=_set_summary_filename) + def download_cohort_summary(): + yield _cohort_summary_export_df().to_csv(index=False) + + @output(suspend_when_hidden=False) @render.plot def radar_plot(): + if input.stratification_tabs() != "patient" or input.view_mode() != "radar": + return _empty_plot("Open profile comparison to view this plot") + data = get_patient_data_context() if not data: return _empty_plot("Radar data unavailable") - opts = input.radar_plot_elements() - feats = input.features_to_plot() + opts = list(input.radar_plot_elements() or []) + feats = list(input.features_to_plot() or []) if len(feats) < 3: return _empty_plot("Please, select at least 3 features to view") - # Fetch current labels reactively - lbls = current_labels() - - # We call the plotting function INSIDE render.plot - # This ensures it redraws correctly on resize. - fig_radar, _ = analyze_patient( - patient_id=data["patient_id"], - df=data["df"], - delta_train=data["delta_train"], - delta_test=data["delta_test"], - x_train=data["x_train"], - y_train=data["y_train"], - x_test=data["x_test"], - features_to_plot=feats, - n_dists=3, - show_closest_radial="closest" in opts, - show_average_radial="average" in opts, - show_average_class0_radial="average_0" in opts, - show_average_class1_radial="average_1" in opts, - neg_class_label=lbls["negative_class_label"], - pos_class_label=lbls["positive_class_label"], - ) + fig_radar, _ = _run_patient_analysis(data, feats, opts) return fig_radar if fig_radar else _empty_plot("Radar data unavailable") - @output + @output(suspend_when_hidden=False) @render.plot def curve_plot(): + if input.stratification_tabs() != "patient" or input.view_mode() != "curve": + return _empty_plot("Open feature curves to view this plot") + data = get_patient_data_context() if not data: return _empty_plot("Curve data unavailable") - feats = input.features_to_plot() + cfg = current_config() + feats = list(input.features_to_plot() or []) + max_curve_features = int(cfg.get("max_curve_features", 8)) + if len(feats) > max_curve_features: + feats = feats[:max_curve_features] if len(feats) < 3: return _empty_plot("Please, select at least 3 features to view") - # Fetch current labels reactively - lbls = current_labels() - - # Retrieve the second figure (Curve) from the function - # Note: Ideally you'd split analyze_patient into two functions to save performance, - # but calling it again here is the safest way to ensure the plot exists for this context. - _, fig_curve = analyze_patient( - patient_id=data["patient_id"], - df=data["df"], - delta_train=data["delta_train"], - delta_test=data["delta_test"], - x_train=data["x_train"], - y_train=data["y_train"], - x_test=data["x_test"], - features_to_plot=feats, - n_dists=3, - neg_class_label=lbls["negative_class_label"], - pos_class_label=lbls["positive_class_label"], - # We don't care about radar options here - ) + _, fig_curve = _run_patient_analysis(data, feats) return fig_curve if fig_curve else _empty_plot("Curve data unavailable") def _empty_plot(msg): @@ -445,13 +1400,11 @@ def _empty_plot(msg): ax.set_axis_off() return fig - @output + @output(suspend_when_hidden=False) @render.data_frame def results_table_output(): - df = global_input_data.get() - delta_test = delta_test_reactive.get() - sel_id = patient_selected_id.get() - prob_thr = prob_threshold.get() + sel_id = _selected_entity_id() + res_df = stratification_results_df() # Fetch current labels reactively lbls = current_labels() @@ -460,21 +1413,12 @@ def results_table_output(): prob_col = lbls["probability_column"] class_col = lbls["class_column"] - if df is None or delta_test is None or df.empty or delta_test.empty: + if res_df is None or res_df.empty: return render.DataGrid( - pd.DataFrame({"Message": ["No data available."]}), + pd.DataFrame({"Message": [lbls["text"].get("empty_results_message", f"Add or upload {lbls['plural']} to generate stratification results.")]}), width="100%" ) - if prob_thr is None: - prob_thr = 0 - - res_df = df.copy() - if "pred_prob" in delta_test.columns: - res_df = pd.concat([res_df, delta_test[["pred_prob"]]], axis=1) - res_df = res_df.rename(columns={"pred_prob": prob_col}).sort_values(by=['ID']) - res_df[class_col] = np.where(res_df[prob_col] >= prob_thr, pos_class_label, neg_class_label) - columns_to_show = "ID", prob_col, class_col final_cols = [col for col in columns_to_show if col in res_df.columns] res_df = res_df[final_cols].reset_index(drop=True) @@ -482,13 +1426,13 @@ def results_table_output(): res_df[prob_col] = res_df[prob_col].round(3) table_styles = None - if sel_id: - if (res_df[res_df['ID'] == sel_id].empty): + if sel_id is not None: + if (res_df[res_df['ID'].astype(int) == sel_id].empty): return render.DataGrid( pd.DataFrame({"Message": ["Selected ID not found."]}), width="100%" ) - sel_id_row = int(res_df[res_df['ID'] == sel_id].index[0]) + sel_id_row = int(res_df[res_df['ID'].astype(int) == sel_id].index[0]) pos_rows = res_df[res_df[class_col] == pos_class_label].index.to_list() neg_rows = res_df[res_df[class_col] == neg_class_label].index.to_list() table_styles = [ @@ -504,7 +1448,7 @@ def results_table_output(): "rows": pos_rows, "cols": list(range(len(res_df.columns))), "style": { - "color": "#dc3545", + "color": lbls["positive_class_color"], "font-weight": "bold", }, }, @@ -512,7 +1456,7 @@ def results_table_output(): "rows": neg_rows, "cols": list(range(len(res_df.columns))), "style": { - "color": "#198754", + "color": lbls["negative_class_color"], "font-weight": "bold", }, } diff --git a/src/XAInyPredictor/shinyapp/www/data_template.tsv b/src/XAInyPredictor/shinyapp/www/data_template.tsv deleted file mode 100644 index e1cde05..0000000 --- a/src/XAInyPredictor/shinyapp/www/data_template.tsv +++ /dev/null @@ -1,2 +0,0 @@ -ID Age Gender BMI Tumor size (cm) Tumor stage Node stage Metastases ATA risk Histology ETE Multifocality Vascular invasion Resection Goal of RAI - "[M; F] (i.e., M = Male, F = Female)" [T1a; T1b; T2; T3a; T3b; T4a; T4b] [N0; N1a; N1b; Nx] [M1; Mx] [Low; Intermediate; High] [OTC; PTC; FTC; PDTC] [YES; NO] [YES; NO] [YES; NO] [R0; R1; R2] [ABLATION; ADIUVANT; THERAPEUTIC] \ No newline at end of file diff --git a/src/XAInyPredictor/shinyapp/www/data_template.xlsx b/src/XAInyPredictor/shinyapp/www/data_template.xlsx deleted file mode 100644 index 2973d51..0000000 Binary files a/src/XAInyPredictor/shinyapp/www/data_template.xlsx and /dev/null differ diff --git a/src/XAInyPredictor/shinyapp/www/index.js b/src/XAInyPredictor/shinyapp/www/index.js index ca8a274..48b47c7 100644 --- a/src/XAInyPredictor/shinyapp/www/index.js +++ b/src/XAInyPredictor/shinyapp/www/index.js @@ -1,6 +1,134 @@ $(() => { let currentTab = "data_input"; // default + const setUseCaseLoading = (visible, title, subtitle) => { + if (title) $("#loading-overlay-title").text(title); + if (subtitle) $("#loading-overlay-subtitle").text(subtitle); + $("#use-case-loading-overlay").toggleClass("is-visible", Boolean(visible)); + }; + + const showStartupLoading = () => { + const selectedUseCase = $("#startup_use_case option:selected").text() || "selected use case"; + $("body").addClass("startup-loading"); + setUseCaseLoading( + true, + "Loading selected use case...", + `Preparing ${selectedUseCase}, model outputs, and reference context.` + ); + }; + + const normalizeManualDecimalInputs = (root = document) => { + $(root) + .find(".manual-decimal-input input") + .addBack(".manual-decimal-input input") + .each(function () { + const normalized = String(this.value || "").replace(/,/g, "."); + if (this.value !== normalized) { + this.value = normalized; + $(this).trigger("change"); + } + this.setAttribute("inputmode", "decimal"); + }); + }; + + normalizeManualDecimalInputs(); + + $(document).on("input change blur", ".manual-decimal-input input", function () { + const normalized = String(this.value || "").replace(/,/g, "."); + if (this.value !== normalized) { + this.value = normalized; + $(this).trigger("change"); + } + }); + + new MutationObserver((mutations) => { + mutations.forEach((mutation) => { + mutation.addedNodes.forEach((node) => { + if (node.nodeType === 1) normalizeManualDecimalInputs(node); + }); + }); + }).observe(document.body, { childList: true, subtree: true }); + + $(document).on("click", "#confirm_switch", () => { + setUseCaseLoading(true, "Loading use case...", "Preparing model outputs and reference context."); + }); + + $(document).on("click", "#confirm_startup_use_case", function () { + window.setTimeout(() => { + showStartupLoading(); + $(this).prop("disabled", true).text("Loading..."); + }, 0); + }); + + $(document).on("click", "#tab_prediction", () => { + setUseCaseLoading(true, "Loading stratification support...", "Synchronizing inputs, scores, and reference context."); + }); + + $(document).on("click", "#data_input-btn_use_data", () => { + setUseCaseLoading(true, "Confirming data...", "Preparing the confirmed set for downstream analysis."); + }); + + Shiny.addCustomMessageHandler("setUseCaseLoading", (payload) => { + setUseCaseLoading(payload.visible, payload.title, payload.subtitle); + if (!payload.visible) { + $("body").removeClass("startup-loading"); + $("#confirm_startup_use_case").prop("disabled", false).text("Start"); + } + }); + + Shiny.addCustomMessageHandler("setStratificationTabLabels", (payload) => { + const labels = { + patient: payload.patient, + cohort: payload.cohort, + reference: payload.reference, + }; + + Object.entries(labels).forEach(([value, label]) => { + if (!label) return; + const currentText = { + patient: ["Patient", "Candidate"], + cohort: ["Cohort", "Candidate Set"], + reference: ["Reference Patients", "Reference Candidates"], + }[value] || []; + + $(".stratification-tabset a, .stratification-tabset button").each(function () { + const $el = $(this); + const dataValue = $el.attr("data-value") || $el.data("value"); + const text = $el.text().trim(); + if (dataValue === value || currentText.includes(text)) { + $el.text(label); + } + }); + }); + }); + + Shiny.addCustomMessageHandler("setAnalysisStepsLocked", (payload) => { + const locked = Boolean(payload.locked); + $("#tab_data_exploration, #tab_prediction") + .toggleClass("workflow-step-locked", locked) + .attr("aria-disabled", locked ? "true" : "false"); + }); + + Shiny.addCustomMessageHandler("setDataInputMethodLabels", (payload) => { + const fallbackText = { + form: ["Manual Entry"], + file: ["Upload File"], + example: ["Example Cohort", "Example Candidate Set"], + }; + + Object.entries(payload || {}).forEach(([value, label]) => { + if (!label) return; + $(".data-input-method-tabs a, .data-input-method-tabs button").each(function () { + const $el = $(this); + const dataValue = $el.attr("data-value") || $el.data("value"); + const text = $el.text().trim(); + if (dataValue === value || (fallbackText[value] || []).includes(text)) { + $el.text(label); + } + }); + }); + }); + Shiny.addCustomMessageHandler("toggleActiveTab", (payload) => { const activeTab = payload.activeTab; diff --git a/src/XAInyPredictor/shinyapp/www/layout.css b/src/XAInyPredictor/shinyapp/www/layout.css index 9f8432a..775a9e2 100644 --- a/src/XAInyPredictor/shinyapp/www/layout.css +++ b/src/XAInyPredictor/shinyapp/www/layout.css @@ -82,7 +82,7 @@ html, body, .container.fluid { .navbar-top { display: grid; /* Columns: Title, flexible space, Menu, flexible space, Logo, Info */ - grid-template-columns: auto 1fr auto 1fr auto 40px; + grid-template-columns: minmax(220px, auto) minmax(8px, 1fr) minmax(0, auto) minmax(8px, 1fr) auto 40px; grid-template-rows: 1fr; align-items: center; /* Vertically center everything */ @@ -91,8 +91,8 @@ html, body, .container.fluid { height: 100%; width: 100%; - padding: 0 30px; /* Added horizontal padding for a cleaner look */ - gap: 20px; + padding: 0 24px; + gap: 12px; } .navigation-title { @@ -106,7 +106,8 @@ html, body, .container.fluid { display: flex; justify-content: center; align-items: center; - gap: 5px; + gap: 6px; + min-width: 0; } .navigation-logo-right { @@ -125,6 +126,10 @@ html, body, .container.fluid { /* Mobile responsive override */ @media only screen and (max-width: 800px) { + .page-layout { + grid-template-rows: auto 1fr; + } + .navbar-top { grid-template-columns: 1fr auto; grid-template-rows: auto auto; @@ -133,4 +138,9 @@ html, body, .container.fluid { "menu logo"; padding: 10px; } + + .navigation-menu { + flex-wrap: wrap; + justify-content: flex-start; + } } diff --git a/src/XAInyPredictor/shinyapp/www/pwa/icon.png b/src/XAInyPredictor/shinyapp/www/pwa/icon.png deleted file mode 100644 index 97bafae..0000000 --- a/src/XAInyPredictor/shinyapp/www/pwa/icon.png +++ /dev/null @@ -1 +0,0 @@ -{"payload":{"allShortcutsEnabled":false,"fileTree":{"respiratory_disease_pyshiny/www/pwa":{"items":[{"name":"icon.png","path":"respiratory_disease_pyshiny/www/pwa/icon.png","contentType":"file"},{"name":"manifest.json","path":"respiratory_disease_pyshiny/www/pwa/manifest.json","contentType":"file"},{"name":"offline.html","path":"respiratory_disease_pyshiny/www/pwa/offline.html","contentType":"file"}],"totalCount":3},"respiratory_disease_pyshiny/www":{"items":[{"name":"pwa","path":"respiratory_disease_pyshiny/www/pwa","contentType":"directory"},{"name":"static","path":"respiratory_disease_pyshiny/www/static","contentType":"directory"},{"name":"index.js","path":"respiratory_disease_pyshiny/www/index.js","contentType":"file"},{"name":"layout.css","path":"respiratory_disease_pyshiny/www/layout.css","contentType":"file"},{"name":"pwa-service-worker.js","path":"respiratory_disease_pyshiny/www/pwa-service-worker.js","contentType":"file"},{"name":"style.css","path":"respiratory_disease_pyshiny/www/style.css","contentType":"file"}],"totalCount":6},"respiratory_disease_pyshiny":{"items":[{"name":"data","path":"respiratory_disease_pyshiny/data","contentType":"directory"},{"name":"modules","path":"respiratory_disease_pyshiny/modules","contentType":"directory"},{"name":"utils","path":"respiratory_disease_pyshiny/utils","contentType":"directory"},{"name":"www","path":"respiratory_disease_pyshiny/www","contentType":"directory"},{"name":"README.md","path":"respiratory_disease_pyshiny/README.md","contentType":"file"},{"name":"app.py","path":"respiratory_disease_pyshiny/app.py","contentType":"file"},{"name":"requirements.txt","path":"respiratory_disease_pyshiny/requirements.txt","contentType":"file"}],"totalCount":7},"":{"items":[{"name":"app-slug","path":"app-slug","contentType":"directory"},{"name":"covid19-dashboard","path":"covid19-dashboard","contentType":"directory"},{"name":"covid19-timeline","path":"covid19-timeline","contentType":"directory"},{"name":"cran-explorer","path":"cran-explorer","contentType":"directory"},{"name":"freedom-press-index","path":"freedom-press-index","contentType":"directory"},{"name":"genome-browser","path":"genome-browser","contentType":"directory"},{"name":"hangman-ru","path":"hangman-ru","contentType":"directory"},{"name":"hangman","path":"hangman","contentType":"directory"},{"name":"hotshot-dashboard","path":"hotshot-dashboard","contentType":"directory"},{"name":"lake-profile-dashboard","path":"lake-profile-dashboard","contentType":"directory"},{"name":"life-of-pi","path":"life-of-pi","contentType":"directory"},{"name":"nyc-metro-vis","path":"nyc-metro-vis","contentType":"directory"},{"name":"nz-trade-dash","path":"nz-trade-dash","contentType":"directory"},{"name":"one-source-indy","path":"one-source-indy","contentType":"directory"},{"name":"respiratory_disease_pyshiny","path":"respiratory_disease_pyshiny","contentType":"directory"},{"name":"shiny-decisions","path":"shiny-decisions","contentType":"directory"},{"name":"z_rmd-gallery","path":"z_rmd-gallery","contentType":"directory"},{"name":".gitignore","path":".gitignore","contentType":"file"},{"name":"README.md","path":"README.md","contentType":"file"},{"name":"shiny-gallery.Rproj","path":"shiny-gallery.Rproj","contentType":"file"}],"totalCount":20}},"fileTreeProcessingTime":7.154719,"foldersToFetch":[],"reducedMotionEnabled":null,"repo":{"id":235330724,"defaultBranch":"master","name":"shiny-gallery","ownerLogin":"rstudio","currentUserCanPush":false,"isFork":false,"isEmpty":false,"createdAt":"2020-01-21T11:51:28.000Z","ownerAvatar":"https://avatars.githubusercontent.com/u/513560?v=4","public":true,"private":false,"isOrgOwned":true},"symbolsExpanded":false,"treeExpanded":true,"refInfo":{"name":"master","listCacheKey":"v0:1579607491.0","canEdit":false,"refType":"branch","currentOid":"afd71b32bf28895672813acbba356361fdaa0989"},"path":"respiratory_disease_pyshiny/www/pwa/icon.png","currentUser":null,"blob":{"rawLines":null,"stylingDirectives":null,"csv":null,"csvError":null,"dependabotInfo":{"showConfigurationBanner":false,"configFilePath":null,"networkDependabotPath":"/rstudio/shiny-gallery/network/updates","dismissConfigurationNoticePath":"/settings/dismiss-notice/dependabot_configuration_notice","configurationNoticeDismissed":null,"repoAlertsPath":"/rstudio/shiny-gallery/security/dependabot","repoSecurityAndAnalysisPath":"/rstudio/shiny-gallery/settings/security_analysis","repoOwnerIsOrg":true,"currentUserCanAdminRepo":false},"displayName":"icon.png","displayUrl":"https://github.com/rstudio/shiny-gallery/blob/master/respiratory_disease_pyshiny/www/pwa/icon.png?raw=true","headerInfo":{"blobSize":"19.4 KB","deleteInfo":{"deleteTooltip":"You must be signed in to make or propose changes"},"editInfo":{"editTooltip":"You must be signed in to make or propose changes"},"ghDesktopPath":"https://desktop.github.com","gitLfsPath":null,"onBranch":true,"shortPath":"eae8caf","siteNavLoginPath":"/login?return_to=https%3A%2F%2Fgithub.com%2Frstudio%2Fshiny-gallery%2Fblob%2Fmaster%2Frespiratory_disease_pyshiny%2Fwww%2Fpwa%2Ficon.png","isCSV":false,"isRichtext":false,"toc":null,"lineInfo":{"truncatedLoc":null,"truncatedSloc":null},"mode":"file"},"image":true,"isCodeownersFile":null,"isPlain":false,"isValidLegacyIssueTemplate":false,"issueTemplateHelpUrl":"https://docs.github.com/articles/about-issue-and-pull-request-templates","issueTemplate":null,"discussionTemplate":null,"language":null,"languageID":null,"large":false,"loggedIn":false,"newDiscussionPath":"/rstudio/shiny-gallery/discussions/new","newIssuePath":"/rstudio/shiny-gallery/issues/new","planSupportInfo":{"repoIsFork":null,"repoOwnedByCurrentUser":null,"requestFullPath":"/rstudio/shiny-gallery/blob/master/respiratory_disease_pyshiny/www/pwa/icon.png","showFreeOrgGatedFeatureMessage":null,"showPlanSupportBanner":null,"upgradeDataAttributes":null,"upgradePath":null},"publishBannersInfo":{"dismissActionNoticePath":"/settings/dismiss-notice/publish_action_from_dockerfile","dismissStackNoticePath":"/settings/dismiss-notice/publish_stack_from_file","releasePath":"/rstudio/shiny-gallery/releases/new?marketplace=true","showPublishActionBanner":false,"showPublishStackBanner":false},"rawBlobUrl":"https://github.com/rstudio/shiny-gallery/raw/master/respiratory_disease_pyshiny/www/pwa/icon.png","renderImageOrRaw":true,"richText":null,"renderedFileInfo":null,"shortPath":null,"symbolsEnabled":true,"tabSize":8,"topBannersInfo":{"overridingGlobalFundingFile":false,"globalPreferredFundingPath":null,"repoOwner":"rstudio","repoName":"shiny-gallery","showInvalidCitationWarning":false,"citationHelpUrl":"https://docs.github.com/en/github/creating-cloning-and-archiving-repositories/creating-a-repository-on-github/about-citation-files","showDependabotConfigurationBanner":false,"actionsOnboardingTip":null},"truncated":false,"viewable":false,"workflowRedirectUrl":null,"symbols":null},"copilotInfo":null,"copilotAccessAllowed":false,"csrf_tokens":{"/rstudio/shiny-gallery/branches":{"post":"dvqwDASdSh8so2o_kVFUm9EtXPV-umI2wgyulKoABDpmkGllLgS1C9S89BI-G7zQix5Y-n2IEe98msQh2AFNMA"},"/repos/preferences":{"post":"_Iy0ll05MZr2VFNzRwOOT3MPgEYwGBxRPnZD7p12n1uPIjSUTbItIUA1-vmQjoVu4jzeJVKH3neHcqzrhVsq_A"}}},"title":"shiny-gallery/respiratory_disease_pyshiny/www/pwa/icon.png at master · rstudio/shiny-gallery"} \ No newline at end of file diff --git a/src/XAInyPredictor/shinyapp/www/style.css b/src/XAInyPredictor/shinyapp/www/style.css index 9cdc458..2a6b9d3 100644 --- a/src/XAInyPredictor/shinyapp/www/style.css +++ b/src/XAInyPredictor/shinyapp/www/style.css @@ -4,25 +4,38 @@ } .navbar-top { - background-color: #f0f8ff; + background: linear-gradient(180deg, #f5fbff 0%, #eef6fd 100%); + border-bottom: 1px solid #dbe8f3; + box-shadow: 0 2px 10px rgba(31, 45, 61, 0.08); } .navbar-button { - background-color: #f0f8ff; - height: 50px; - float: right; - border-radius: 4px; + background-color: transparent; + color: #2d3b48; + height: 44px; + float: none; + border-radius: 7px; margin: 0 2px; border: 1px solid transparent; + font-size: 0.88rem; + font-weight: 600; + line-height: 1.15; + max-width: 190px; + min-width: 0; + overflow: hidden; + padding: 0 12px; + text-overflow: ellipsis; transition: all 0.2s ease; + white-space: normal; } .navbar-button.active-tab { - background-color: #007bff !important; + background: #007bff !important; color: white !important; font-weight: bold; /* Emphasize active state */ - border-bottom: 3px solid #0056b3; /* Optional underline or bottom border */ - box-shadow: 0 2px 4px rgba(0,123,255,0.3); + border: 1px solid #006add; + border-bottom: 3px solid #0056b3; + box-shadow: 0 6px 14px rgba(0, 123, 255, 0.24); } .active-tab { @@ -31,10 +44,10 @@ } .navbar-info { - background-color: #f0f8ff; - height: 50px; + background-color: transparent; + height: 44px; border: none; - float: right; + float: none; transition: 0.3s; } @@ -47,23 +60,59 @@ } .navbar-button:hover { - background-color: #dde5ed; - color: #000; + background-color: #e3eff8; + color: #1f2d3d; +} + +.navbar-button.workflow-step-locked { + background: #f3f6f9; + border-color: #d9e1e8; + box-shadow: none; + color: #8a98a8; + cursor: not-allowed; + opacity: 0.72; +} + +.navbar-button.workflow-step-locked:hover { + background: #f3f6f9; + color: #8a98a8; } .navbar-button:focus { outline: none; - background-color: #9ccdf7; + background-color: #e3eff8; + color: #1f2d3d; +} + +.navbar-button.active-tab:focus { + background: #007bff !important; + color: white !important; } .navigation-title h3 { font-family: 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; text-transform: uppercase; - font-size: 2rem; + color: #007bff; + font-size: 1.75rem; + font-weight: 850; + letter-spacing: 0; + line-height: 1; + margin: 0; +} + +.navigation-subtitle { + color: #5f6f7f; + display: block; + font-size: 0.78rem; + font-weight: 600; + line-height: 1.2; + margin-top: 5px; + text-transform: none; } .navigation-logo-right img { transition: transform 0.2s; + filter: drop-shadow(0 2px 4px rgba(31, 45, 61, 0.12)); } .navigation-logo-right img:hover { @@ -71,31 +120,212 @@ } #div-navbar-map { - background-color: #f0f8ff; - height: 50px; + background-color: transparent; + height: 44px; + min-width: 0; transition: 0.2s; float: left; } #div-navbar-map:hover { - background-color: #dde5ed; + background-color: transparent; } #div-navbar-plot { - background-color: #f0f8ff; - height: 50px; + background-color: transparent; + height: 44px; + min-width: 0; transition: 0.3s; - float: right; + float: left; } #div-navbar-plot:hover { - background-color: #dde5ed; + background-color: transparent; } #div-navbar-plot:focus { - background-color: #f0f8ff; + background-color: transparent; +} + +#div-navbar-tabs { + background: rgba(255, 255, 255, 0.72); + border: 1px solid #dbe8f3; + border-radius: 9px; + box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.8); + padding: 5px; + min-width: 0; } +#div-navbar-tabs .form-group { + margin-bottom: 0; +} + +#div-navbar-tabs select, +#div-navbar-tabs .selectize-control { + min-width: 180px; + width: clamp(180px, 18vw, 260px); +} + +#div-navbar-tabs .selectize-input { + border-color: #b9c7d4; + border-radius: 6px; + box-shadow: none; + min-height: 38px; + padding: 8px 34px 8px 12px; +} + +.loading-overlay { + align-items: center; + background: rgba(245, 251, 255, 0.78); + backdrop-filter: blur(2px); + display: flex; + inset: 80px 0 0 0; + justify-content: center; + opacity: 0; + pointer-events: none; + position: fixed; + transition: opacity 0.18s ease; + z-index: 2000; +} + +.loading-overlay.is-visible { + opacity: 1; + pointer-events: all; +} + +body.startup-loading #use-case-loading-overlay { + background: rgba(245, 251, 255, 0.86); + inset: 0; + z-index: 10000; +} + +.loading-panel { + align-items: center; + background: #ffffff; + border: 1px solid #d4e6f5; + border-radius: 8px; + box-shadow: 0 12px 36px rgba(31, 45, 61, 0.18); + display: flex; + flex-direction: column; + gap: 8px; + max-width: 360px; + padding: 24px 28px; + text-align: center; +} + +.loading-spinner { + animation: loading-spin 0.9s linear infinite; + border: 4px solid #d8ebfb; + border-top-color: #007bff; + border-radius: 50%; + height: 42px; + width: 42px; +} + +.loading-title { + color: #1f2d3d; + font-size: 1rem; + font-weight: 700; +} + +.loading-subtitle { + color: #5f6f7f; + font-size: 0.86rem; + line-height: 1.35; +} + +.reference-privacy-note { + background: #fff8e6; + border: 1px solid #f3d78b; + border-left: 4px solid #f0b429; + border-radius: 6px; + color: #5a4300; + font-size: 0.86rem; + line-height: 1.4; + margin: 0 0 14px; + padding: 9px 12px; +} + +.reference-candidates-table-wrap { + overflow-x: auto; + width: 100%; +} + +.reference-candidates-table { + border-collapse: collapse; + font-size: 0.9rem; + margin-top: 12px; + width: 100%; +} + +.reference-candidates-table th, +.reference-candidates-table td { + border: 1px solid #d8e0e8; + padding: 8px 10px; + text-align: left; +} + +.reference-candidates-table th { + background: #f8fafc; + color: #1f2d3d; + font-weight: 700; +} + +.apply-data-panel { + align-items: center; + background: #fff8e6; + border: 1px solid #f3d78b; + border-radius: 7px; + display: flex; + gap: 14px; + justify-content: space-between; + margin: 12px 0; + padding: 12px 14px; +} + +.cohort-actions .apply-data-panel { + flex: 1 1 360px; + margin: 0; +} + +.apply-data-panel-confirmed { + background: #f3fbf6; + border-color: #b7dfc5; +} + +.confirm-data-title { + color: #0f5132; + font-weight: 750; + margin-bottom: 3px; +} + +.apply-data-panel .btn { + background: #806000; + border-color: #806000; + color: #ffffff; + min-width: 220px; +} + +.apply-data-panel .btn:hover, +.apply-data-panel .btn:focus { + background: #6e5200; + border-color: #6e5200; + color: #ffffff; +} + +.apply-data-hint { + color: #806000; + font-size: 0.86rem; + margin: 0; +} + +@keyframes loading-spin { + to { + transform: rotate(360deg); + } +} + + .bootstrap-switch .bootstrap-switch-handle-on.bootstrap-switch-primary { background: #9ccdf7; color: #000; @@ -136,3 +366,1314 @@ z-index: 1; visibility: visible; } + +.bslib-sidebar-layout > .sidebar, +.sidebar { + border-right: 1px solid #e5eaf0; + padding: 18px 16px; +} + +.sidebar h4, +.sidebar h5 { + color: #1f2d3d; + font-weight: 600; + line-height: 1.25; + margin: 0 0 18px; +} + +.sidebar h4 { + font-size: 1.35rem; +} + +.sidebar h5 { + font-size: 1.15rem; + margin-top: 12px; +} + +.sidebar .form-group, +.sidebar .shiny-input-container, +.sidebar .control-label { + max-width: 100%; +} + +.sidebar label, +.sidebar .control-label { + color: #2d3b48; + font-size: 0.92rem; + font-weight: 500; + line-height: 1.25; + margin-bottom: 6px; +} + +.sidebar .radio-group label.control-label, +.sidebar .checkbox-group label.control-label, +.sidebar .shiny-input-radiogroup > label, +.sidebar .shiny-input-checkboxgroup > label { + display: block; + margin-bottom: 10px; +} + +.sidebar .shiny-options-group { + margin-top: 10px; +} + +.sidebar [id$="radar_plot_elements"] .shiny-options-group, +.sidebar [id$="input_method"] .shiny-options-group { + margin-top: 14px !important; + padding-top: 2px; +} + +.sidebar .shiny-options-group .radio:first-child, +.sidebar .shiny-options-group .checkbox:first-child, +.sidebar .shiny-options-group .form-check:first-child { + margin-top: 0; +} + +.sidebar .radio, +.sidebar .checkbox, +.sidebar .form-check { + align-items: flex-start; + display: flex; + gap: 8px; + line-height: 1.25; + margin: 6px 0; +} + +.sidebar input[type="radio"], +.sidebar input[type="checkbox"] { + flex: 0 0 auto; + margin: 2px 0 0; +} + +.sidebar .radio label, +.sidebar .checkbox label, +.sidebar .form-check label, +.sidebar .form-check-label { + font-size: 0.95rem; + line-height: 1.25; + margin: 0; + overflow-wrap: anywhere; +} + +.sidebar select, +.sidebar .selectize-control { + width: 100% !important; +} + +.sidebar .input-group { + align-items: stretch; + display: flex; + flex-direction: column; + max-width: 100%; + width: 100%; +} + +.sidebar .input-group .form-control { + border-radius: 4px; + display: block; + font-size: 0.86rem; + height: auto; + line-height: 1.3; + margin-top: 6px; + min-width: 0; + overflow: hidden; + padding: 8px 10px; + text-overflow: ellipsis; + width: 100% !important; +} + +.sidebar .input-group-btn { + display: block; + width: 100%; +} + +.sidebar .input-group-btn .btn, +.sidebar .btn-file { + border-radius: 4px !important; + display: block; + white-space: nowrap; + width: 100%; +} + +.sidebar .input-template-download { + margin-top: 10px; + width: 100%; +} + +.sidebar .upload-template-hint { + color: #5f6f7f; + font-size: 0.82rem; + line-height: 1.35; + margin: 10px 0 0; +} + +.sidebar .selectize-input { + border-color: #aeb8c2; + border-radius: 4px; + box-shadow: none; + font-size: 0.9rem; + line-height: 1.3; + min-height: 36px; + padding: 7px 9px; +} + +.sidebar .selectize-control.single .selectize-input { + overflow: hidden; + padding-right: 28px; + text-overflow: ellipsis; + white-space: nowrap; +} + +.sidebar .selectize-control.multi .selectize-input { + max-height: 320px; + min-height: 118px; + overflow-y: auto; + white-space: normal; +} + +.sidebar .selectize-control.multi .selectize-input > div { + background: #f4f7fa; + border: 1px solid #e2e8ef; + border-radius: 4px; + color: #1f2d3d; + margin: 2px 4px 2px 0; + max-width: 100%; + overflow: hidden; + padding: 3px 7px; + text-overflow: ellipsis; +} + +.sidebar hr { + border-top: 1px solid #d8dee6; + margin: 22px 0; +} + +.feature-control-row { + display: flex; + gap: 8px; + margin: -8px 0 14px; +} + +.feature-control-button { + flex: 1 1 0; +} + +.patient-visual-settings-card { + border-color: #dbe8f3; + box-shadow: 0 4px 14px rgba(31, 45, 61, 0.06); + margin: 10px 0 18px; + overflow: visible; +} + +.patient-visual-settings-card .card-header { + background: #fbfdff; + border-bottom-color: #dbe8f3; + font-weight: 750; + padding: 8px 14px; +} + +.patient-visual-settings-card .card-body { + overflow: visible; + padding: 10px 14px; +} + +.patient-visual-settings-grid { + align-items: start; + display: grid; + column-gap: 28px; + grid-template-areas: + "top radar" + "features radar"; + grid-template-columns: minmax(520px, 1fr) minmax(360px, max-content); + row-gap: 12px; + overflow: visible; +} + +.patient-visual-setting { + min-width: 0; +} + +.patient-visual-controls-top { + align-items: end; + display: grid; + gap: 12px; + grid-template-columns: minmax(220px, 280px) auto; + grid-area: top; + justify-content: start; +} + +.patient-visual-setting .form-group, +.patient-visual-setting .shiny-input-container { + margin-bottom: 0; + width: 100%; +} + +.patient-visual-setting-features .selectize-control.multi .selectize-input { + max-height: none; + min-height: 42px; + overflow-y: visible; + white-space: normal; +} + +.patient-visual-setting-features .selectize-control { + position: relative; + z-index: 40; +} + +.patient-visual-setting-features .selectize-dropdown { + z-index: 4000; +} + +.patient-visual-setting-features .selectize-control.multi .selectize-input > div { + background: #f4f7fa; + border: 1px solid #e2e8ef; + border-radius: 4px; + color: #1f2d3d; + margin: 2px 4px 2px 0; + max-width: 100%; + overflow: hidden; + padding: 3px 7px; + text-overflow: ellipsis; +} + +.patient-visual-settings-card .feature-control-row { + align-items: center; + justify-content: flex-start; + margin: 0 0 2px; +} + +.patient-visual-settings-card .feature-control-button { + flex: 0 0 auto; + min-width: 92px; + padding: 6px 10px; +} + +.feature-selector-inline { + align-items: flex-start; + display: block; + max-width: 920px; +} + +.patient-visual-radar-wrap { + align-self: center; + grid-area: radar; + min-width: 0; +} + +.patient-visual-setting-features { + grid-area: features; +} + +.patient-visual-setting-radar { + padding-top: 18px; +} + +.patient-visual-setting-radar > .form-group > label, +.patient-visual-setting-radar > .shiny-input-container > label, +.patient-visual-setting-radar label.control-label { + display: block; + margin: 0 0 14px; +} + +.patient-visual-setting-radar .form-check { + margin: 0; +} + +.patient-visual-setting-radar .shiny-options-group { + display: flex; + flex-wrap: wrap; + gap: 10px 8px; + margin-top: 0; +} + +.patient-visual-setting-radar .checkbox, +.patient-visual-setting-radar .form-check { + background: #f4f7fa; + border: 1px solid #dbe8f3; + border-radius: 999px; + color: #1f2d3d; + display: inline-flex; + font-size: 0.8rem; + line-height: 1; + padding: 6px 9px; + width: fit-content; +} + +.patient-visual-setting-radar input[type="checkbox"] { + margin: 0 6px 0 0; +} + +.patient-visual-setting-radar label { + margin: 0; + white-space: nowrap; +} + +.sidebar-help-text { + color: gray; + font-size: 0.8em; + line-height: 1.35; +} + +.card-help-icon { + color: gray; + cursor: pointer; + font-size: 0.8em; + margin-left: 5px; +} + +.context-sidebar .sidebar-help-text { + color: #667789; + font-size: 0.8rem; + line-height: 1.35; + margin-top: 4px; +} + +.cohort-context-page { + padding: 18px 22px 44px; +} + +.cohort-context-plot-card, +.feature-summary-card { + border-color: #dbe8f3; + box-shadow: 0 8px 22px rgba(31, 45, 61, 0.08); +} + +.cohort-context-plot-card .card-header, +.feature-summary-card .card-header { + background: #fbfdff; + border-bottom-color: #dbe8f3; + font-weight: 750; + padding: 11px 14px; +} + +.context-plot-context { + border-bottom: 1px solid #e7eef5; + padding: 10px 12px; +} + +.context-chip-row { + display: flex; + flex-wrap: wrap; + gap: 8px; +} + +.context-chip { + background: #f3f6f9; + border: 1px solid #d9e1e8; + border-radius: 999px; + color: #405261; + font-size: 0.8rem; + font-weight: 650; + line-height: 1; + padding: 7px 10px; +} + +.context-chip-highlight { + background: #eef7ff; + border-color: #b9dcff; + color: #174a73; +} + +.context-small-sample { + background: #f5f9fd; + border: 1px solid #cfe0ef; + border-radius: 6px; + color: #405261; + font-size: 0.82rem; + line-height: 1.35; + margin-top: 10px; + padding: 8px 10px; +} + +.feature-metric-grid { + display: grid; + gap: 10px; + grid-template-columns: repeat(2, minmax(0, 1fr)); +} + +.feature-metric { + background: #f8fbff; + border: 1px solid #dbe8f3; + border-radius: 7px; + min-width: 0; + padding: 11px 12px; +} + +.feature-metric-primary { + background: #eef7ff; + border-color: #b9dcff; +} + +.feature-metric-wide { + grid-column: 1 / -1; +} + +.feature-metric-label { + color: #667789; + font-size: 0.76rem; + font-weight: 700; + margin-bottom: 5px; + text-transform: uppercase; +} + +.feature-metric-value { + color: #1f2d3d; + font-size: 1rem; + font-weight: 750; + line-height: 1.25; + overflow-wrap: anywhere; +} + +.feature-summary-empty { + background: #f8fbff; + border: 1px dashed #bfd5e8; + border-radius: 7px; + color: #667789; + font-size: 0.9rem; + padding: 16px; + text-align: center; +} + +.stratification-summary { + display: grid; + grid-template-columns: repeat(4, minmax(0, 1fr)); + gap: 12px; + align-items: stretch; +} + +.selected-patient-stratification-card, +.selected-patient-stratification-card .card-body { + height: auto !important; + max-height: none !important; + overflow: visible !important; +} + +.stratification-tabset { + padding-top: 6px; +} + +.stratification-tabset-title { + color: #1f2d3d; + font-size: 1.1rem; + font-weight: 750; + margin: 0 0 10px; +} + +.stratification-tabset .nav-tabs { + border-bottom: 1px solid #dbe8f3; + gap: 6px; + margin-bottom: 16px; +} + +.stratification-tabset .nav-tabs .nav-link { + border: 1px solid transparent; + border-radius: 7px 7px 0 0; + color: #405261; + font-weight: 650; + padding: 10px 16px; +} + +.stratification-tabset .nav-tabs .nav-link.active { + background: #f8fbff; + border-color: #dbe8f3 #dbe8f3 #f8fbff; + color: #007bff; +} + +.stratification-tabset .tab-content { + padding-top: 2px; +} + +.stratification-downloads { + align-items: center; + background: transparent; + border: none; + border-radius: 0; + display: flex; + flex-wrap: wrap; + gap: 8px; + justify-content: flex-start; + margin: 12px 0 10px; + padding: 0; +} + +.stratification-downloads .btn { + border-radius: 5px; + font-weight: 650; +} + +.stratification-download-details { + position: relative; +} + +.stratification-download-details summary { + background: #007bff; + border: 1px solid #006add; + border-radius: 7px; + box-shadow: 0 8px 18px rgba(0, 123, 255, 0.22); + color: #fff; + cursor: pointer; + font-size: 0.92rem; + font-weight: 700; + list-style: none; + padding: 9px 16px; + transition: background-color 0.16s ease, box-shadow 0.16s ease, transform 0.16s ease; +} + +.stratification-download-details summary::after { + border-left: 4px solid transparent; + border-right: 4px solid transparent; + border-top: 5px solid currentColor; + content: ""; + display: inline-block; + margin-left: 10px; + transform: translateY(-1px); +} + +.stratification-download-details summary:hover { + background: #006fe6; + box-shadow: 0 10px 22px rgba(0, 123, 255, 0.28); + transform: translateY(-1px); +} + +.stratification-download-details summary::-webkit-details-marker { + display: none; +} + +.stratification-download-menu { + background: #ffffff; + border: 1px solid #dbe8f3; + border-radius: 7px; + box-shadow: 0 10px 24px rgba(31, 45, 61, 0.14); + display: grid; + gap: 8px; + min-width: 250px; + padding: 10px; + position: absolute; + left: 0; + top: calc(100% + 6px); + z-index: 30; +} + +.stratification-download-menu .btn { + text-align: left; + width: 100%; +} + +.closest-patients-narrative { + border-left: 4px solid #007bff; + background: #f4f9ff; + color: #2d3b48; + font-size: 0.95rem; + line-height: 1.4; + margin: 4px 0 12px; + padding: 10px 12px; +} + +.model-card .card-body { + padding-bottom: 14px; +} + +.model-card-grid { + display: grid; + grid-template-columns: repeat(4, minmax(0, 1fr)); + gap: 10px; + margin-bottom: 12px; +} + +.model-card-item, +.model-card-note { + background: #f8fbff; + border: 1px solid #d7e5f2; + border-radius: 6px; + padding: 10px 12px; +} + +.model-card-item { + min-height: 70px; +} + +.model-card-label { + color: #5f6f7f; + font-size: 0.83rem; + font-weight: 700; + margin-bottom: 5px; +} + +.model-card-value { + color: #1f2d3d; + font-size: 0.98rem; + font-weight: 750; + line-height: 1.25; + overflow-wrap: anywhere; +} + +.model-card-groups .model-card-value { + font-size: 0.94rem; +} + +.model-card-copy { + color: #2d3b48; + font-size: 0.94rem; + line-height: 1.38; + margin: 0; +} + +.model-card-note { + margin-top: 8px; +} + +.model-card-warning { + background: #f8fbff; + border-color: #d7e5f2; + border-left: 4px solid #f0b429; +} + +.model-card-validation { + background: #f4f9ff; + border-color: #cfe2ff; + border-left: 4px solid #007bff; +} + +.stratification-summary-item { + border: 1px solid #d7e5f2; + border-radius: 6px; + background: #f8fbff; + padding: 12px 14px; +} + +.stratification-summary-label { + color: #5f6f7f; + font-size: 0.85rem; + font-weight: 600; + margin-bottom: 6px; +} + +.stratification-summary-value { + color: #1f2d3d; + font-size: 1.25rem; + font-weight: 700; + overflow-wrap: anywhere; +} + +.stratification-positive { + color: #dc3545; +} + +.stratification-negative { + color: #198754; +} + +.stratification-summary-empty { + color: #5f6f7f; + padding: 12px 0; +} + +.stratification-interpretation { + border-left: 4px solid #007bff; + color: #2d3b48; + font-size: 0.95rem; + line-height: 1.4; + margin-top: 12px; + padding: 10px 12px; + background: #f4f9ff; +} + +.stratification-interpretation-main { + margin: 0 0 10px; +} + +.stratification-rules { + display: flex; + flex-wrap: wrap; + gap: 8px; + margin-bottom: 10px; +} + +.stratification-rule { + border-radius: 5px; + font-size: 0.88rem; + padding: 5px 8px; +} + +.stratification-rule-negative { + background: #d1e7dd; + color: #0f5132; +} + +.stratification-rule-positive { + background: #f8d7da; + color: #842029; +} + +.stratification-disclaimer { + color: #5f6f7f; + font-size: 0.84rem; + line-height: 1.35; + margin: 0; +} + +.use-case-summary { + display: grid; + grid-template-columns: minmax(0, 2fr) repeat(2, minmax(150px, 1fr)); + gap: 12px; +} + +.data-input-page { + margin: 0 auto; + max-width: 1680px; + padding: 18px 22px 44px; +} + +.use-case-summary-card { + margin-bottom: 0; +} + +.current-set-card { + border-color: #dbe8f3; + box-shadow: 0 8px 22px rgba(31, 45, 61, 0.08); + margin-top: 14px; +} + +.current-set-card .card-header { + background: #fbfdff; + border-bottom-color: #dbe8f3; + padding: 11px 14px; +} + +.current-set-card .card-body { + padding: 12px 14px 10px; +} + +.empty-set-state { + align-items: center; + background: #f8fbff; + border: 1px dashed #bfd5e8; + border-radius: 7px; + display: flex; + flex-direction: column; + justify-content: center; + min-height: 94px; + padding: 18px; + text-align: center; +} + +.empty-set-title { + color: #1f2d3d; + font-size: 1rem; + font-weight: 750; + margin-bottom: 4px; +} + +.empty-set-copy { + color: #5f6f7f; + font-size: 0.86rem; + line-height: 1.35; +} + +.cohort-table-title { + align-items: center; + display: flex; + gap: 12px; + justify-content: space-between; +} + +.cohort-title-group { + align-items: baseline; + display: flex; + flex-wrap: wrap; + gap: 10px; + min-width: 0; +} + +.set-row-count { + color: #5f6f7f; + font-size: 0.82rem; + font-weight: 600; +} + +.set-status-pill { + border: 1px solid transparent; + border-radius: 999px; + font-size: 0.76rem; + font-weight: 750; + line-height: 1; + padding: 6px 10px; + white-space: nowrap; +} + +.set-status-empty { + background: #f3f6f9; + border-color: #d9e1e8; + color: #667789; +} + +.set-status-pending { + background: #fff8e6; + border-color: #f3d78b; + color: #806000; +} + +.set-status-confirmed { + background: #edf8f1; + border-color: #b7dfc5; + color: #0f5132; +} + +.data-input-method-tabs { + background: #ffffff; + border: 1px solid #dbe8f3; + border-radius: 7px; + box-shadow: 0 6px 18px rgba(31, 45, 61, 0.06); + margin-top: 16px; + overflow: hidden; +} + +.data-input-method-header { + align-items: baseline; + background: #fbfdff; + border-bottom: 1px solid #dbe8f3; + display: flex; + flex-wrap: wrap; + gap: 8px 14px; + justify-content: flex-start; + padding: 10px 14px; +} + +.data-input-method-title { + color: #1f2d3d; + font-size: 1rem; + font-weight: 750; +} + +.data-input-method-subtitle { + color: #5f6f7f; + font-size: 0.82rem; + font-weight: 600; + line-height: 1.35; +} + +.data-input-method-tabs > .nav { + background: #f8fbff; + border-bottom: 1px solid #dbe8f3; + padding: 0 12px; +} + +.data-input-method-tabs .nav-link { + border-radius: 6px 6px 0 0; + font-weight: 650; + padding: 10px 14px; +} + +.data-input-method-tabs .tab-content { + padding: 12px; +} + +.data-input-method-tabs .card { + border-color: #dbe8f3; + box-shadow: none; + margin-bottom: 0; +} + +.data-input-method-tabs .card-header { + background: #ffffff; + border-bottom-color: #e7eef5; + padding: 10px 14px; +} + +.data-input-method-tabs .card-body { + padding: 12px 16px; +} + +.data-input-method-tabs .card-footer { + background: #fbfdff; + border-top-color: #e7eef5; + padding: 12px 16px; + text-align: right; +} + +.data-input-method-tabs .row, +.data-input-method-tabs .layout-column-wrap { + row-gap: 4px; +} + +.data-input-method-tabs .upload-template-hint { + color: #5f6f7f; + font-size: 0.86rem; + line-height: 1.35; + margin: 8px 0 0; +} + +.data-input-method-tabs .input-template-download { + min-width: 132px; +} + +.manual-entry-grid { + display: grid; + gap: 14px; + grid-template-columns: repeat(3, minmax(240px, 1fr)); +} + +.manual-entry-section { + background: #ffffff; + border: 1px solid #e1ebf4; + border-radius: 7px; + min-width: 0; + padding: 14px; +} + +.manual-entry-section-title { + color: #1f2d3d; + font-size: 0.9rem; + font-weight: 750; + margin-bottom: 12px; +} + +.manual-entry-section-fields { + display: grid; + gap: 12px; +} + +.manual-entry-field { + min-width: 0; +} + +.manual-entry-field .form-group, +.manual-entry-field .shiny-input-container { + margin-bottom: 0; + width: 100%; +} + +.manual-entry-field input, +.manual-entry-field select, +.manual-entry-field .selectize-control { + max-width: 340px; + width: 100% !important; +} + +.form-field-with-help { + align-items: center; + display: flex; + gap: 8px; +} + +.form-field-with-help > .form-group, +.form-field-with-help > .shiny-input-container { + flex: 1 1 auto; +} + +.form-help-icon { + color: #007bc2; + cursor: pointer; + display: inline-flex; + transform: translateY(12px); +} + +.manual-entry-submit { + min-width: 220px; +} + +.upload-file-panel { + align-items: start; + background: #f8fbff; + border: 1px solid #dbe8f3; + border-radius: 7px; + display: grid; + gap: 16px; + grid-template-columns: minmax(260px, 1fr) auto; + padding: 16px; +} + +.upload-file-picker .form-group, +.upload-file-picker .shiny-input-container { + margin-bottom: 0; +} + +.upload-file-picker .input-group { + max-width: 420px; +} + +.upload-resource-actions { + align-items: center; + display: flex; + flex-wrap: wrap; + gap: 8px; + justify-content: flex-end; +} + +.upload-file-actions { + align-items: center; + display: flex; + gap: 12px; + justify-content: flex-end; + margin-top: 12px; +} + +.upload-file-actions .btn-primary { + min-width: 220px; +} + +.upload-requirement-pill { + background: #eef7ff; + border: 1px solid #b9dcff; + border-radius: 999px; + color: #334b63; + font-size: 0.82rem; + font-weight: 650; + padding: 7px 11px; + white-space: nowrap; +} + +.upload-requirement-count { + color: #006add; + font-weight: 800; +} + +.example-set-panel { + align-items: center; + background: #f8fbff; + border: 1px solid #dbe8f3; + border-radius: 7px; + display: flex; + gap: 16px; + justify-content: space-between; + padding: 16px; +} + +.example-set-copy { + min-width: 0; +} + +.example-set-title { + color: #1f2d3d; + font-size: 0.94rem; + font-weight: 750; +} + +.example-set-actions { + display: flex; + justify-content: flex-end; + margin-top: 12px; +} + +.example-set-actions .btn-primary { + min-width: 240px; +} + +.use-case-summary-item { + border: 1px solid #d7e5f2; + border-radius: 6px; + background: #f8fbff; + padding: 12px 14px; +} + +.use-case-summary-label { + color: #5f6f7f; + font-size: 0.85rem; + font-weight: 600; + margin-bottom: 6px; +} + +.use-case-summary-value { + color: #1f2d3d; + font-size: 0.98rem; + font-weight: 650; + line-height: 1.35; +} + +.input-validation-alert { + background: #fff5f5; + border: 1px solid #f1b6bd; + border-left: 4px solid #dc3545; + border-radius: 6px; + color: #842029; + margin: 0 0 12px; + padding: 12px 14px; +} + +.input-validation-title { + font-weight: 750; + margin-bottom: 6px; +} + +.input-validation-alert ul { + margin: 0; + padding-left: 18px; +} + +.input-validation-alert li { + line-height: 1.35; + margin: 4px 0; +} + +.input-validation-success { + background: #f3fbf6; + border: 1px solid #b7dfc5; + border-left: 4px solid #198754; + border-radius: 6px; + color: #0f5132; + margin: 0 0 12px; + padding: 12px 14px; +} + +.input-validation-success-copy { + line-height: 1.35; + margin: 0; +} + +.cohort-actions { + align-items: center; + display: flex; + flex-wrap: wrap; + gap: 10px; + justify-content: space-between; +} + +.cohort-actions .shiny-panel-conditional { + display: contents; +} + +.cohort-secondary-actions { + align-items: center; + display: flex; + flex: 0 0 auto; + flex-wrap: wrap; + gap: 8px; + justify-content: flex-end; +} + +.cohort-action-button { + min-width: 150px; +} + +.current-set-card .card-footer { + background: #fbfdff; + border-top-color: #dbe8f3; + padding: 10px 14px; +} + +.current-set-card .dataTables_wrapper, +.current-set-card .shiny-bound-output { + margin: 0; +} + +.current-set-card table { + margin-bottom: 0; +} + +.cohort-actions .apply-data-panel { + flex: 1 1 520px; +} + +.cohort-actions .apply-data-panel .btn { + box-shadow: 0 5px 12px rgba(0, 123, 255, 0.2); + font-weight: 750; +} + +.app-validation-footer { + background: #f8fbff; + border-top: 1px solid #dbe8f3; + bottom: 0; + color: #5f6f7f; + font-size: 0.78rem; + left: 0; + line-height: 1.3; + padding: 6px 14px; + position: fixed; + right: 0; + text-align: center; + z-index: 20; +} + +@media (max-width: 900px) { + .stratification-summary { + grid-template-columns: repeat(2, minmax(0, 1fr)); + } + + .model-card-grid { + grid-template-columns: repeat(2, minmax(0, 1fr)); + } + + .use-case-summary { + grid-template-columns: 1fr; + } + + .data-input-page { + padding: 12px 12px 44px; + } + + .manual-entry-grid, + .upload-file-panel { + grid-template-columns: 1fr; + } + + .manual-entry-field input, + .manual-entry-field select, + .manual-entry-field .selectize-control { + max-width: none; + } + + .upload-file-actions, + .upload-resource-actions, + .example-set-panel { + align-items: stretch; + flex-direction: column; + } + + .upload-file-actions .btn-primary, + .data-input-method-tabs .input-template-download, + .manual-entry-submit, + .example-set-actions .btn-primary { + width: 100%; + } + + .cohort-actions { + flex-direction: column; + align-items: stretch; + } + + .cohort-secondary-actions { + width: 100%; + } + + .cohort-action-button { + width: 100%; + } + + .cohort-table-title { + align-items: flex-start; + flex-direction: column; + } + + .data-input-method-header { + align-items: flex-start; + flex-direction: column; + } + + .patient-visual-settings-grid { + grid-template-areas: + "top" + "features" + "radar"; + grid-template-columns: 1fr; + } + + .patient-visual-controls-top { + grid-template-columns: 1fr; + } + + .patient-visual-settings-card .feature-control-row { + flex-direction: column; + } + + .patient-visual-settings-card .feature-control-button { + width: 100%; + } + + .cohort-context-page { + padding: 12px 12px 44px; + } + + .feature-metric-grid { + grid-template-columns: 1fr; + } +} diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/test/tests.py b/test/tests.py deleted file mode 100644 index a091b8f..0000000 --- a/test/tests.py +++ /dev/null @@ -1,19 +0,0 @@ -import unittest -from typing import ClassVar - -from mymodule.template import TemplateClass - - -class TestMyModule(unittest.TestCase): - test: ClassVar[TemplateClass] - - @classmethod - def setUpClass(self): - self.test = TemplateClass("./test/data/testInput.tsv", threshold=10) - - def test_square(self): - self.assertEqual(self.test._square(10), 100) - - -if __name__ == "__main__": - unittest.main()