Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/XAInyPredictor/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def build_page_header(config: dict, current_use_case: str):
choices=use_case_choices,
selected=current_use_case,
),
style="display: flex; align-items: center; margin-left: 10px;"
class_="navbar-use-case-selector",
),
id="div-navbar-tabs",
class_="navigation-menu",
Expand Down Expand Up @@ -210,9 +210,9 @@ def build_page_header(config: dict, current_use_case: str):

def build_ui(config: dict, current_use_case: str):
page_dependencies = ui.tags.head(
ui.tags.link(rel="stylesheet", type="text/css", href="layout.css?v=20260608a"),
ui.tags.link(rel="stylesheet", type="text/css", href="style.css?v=20260608a"),
ui.tags.script(src="index.js?v=20260608a"),
ui.tags.link(rel="stylesheet", type="text/css", href="layout.css?v=20260615a"),
ui.tags.link(rel="stylesheet", type="text/css", href="style.css?v=20260615a"),
ui.tags.script(src="index.js?v=20260615a"),
ui.tags.meta(name="description", content=config.get("description", "XAI Predictor")),
ui.tags.meta(name="theme-color", content="#000000"),
ui.tags.meta(name="viewport", content="width=device-width, initial-scale=1"),
Expand Down Expand Up @@ -283,6 +283,7 @@ async def update_navigation_labels(session: Session, config: dict):
{
"form": labels.get("manual_entry", "Manual Entry"),
"file": labels.get("upload_file", "Upload File"),
"documents": labels.get("prepare_csv_documents", "Prepare CSV from Documents"),
"example": labels.get("example_cohort", "Example Cohort"),
},
)
Expand Down Expand Up @@ -530,6 +531,7 @@ async def _on_startup_confirm():
ui.update_select("use_case_selector", choices=_use_case_choices(), selected=selected_use_case)

ui.modal_remove()
await session.send_custom_message("resetDataInputUploads", {})
await session.send_custom_message("setUseCaseLoading", {"visible": False})

@reactive.Effect
Expand Down Expand Up @@ -592,6 +594,7 @@ async def _confirm_switch():
ui.notification_show(f"Switched to {new_model_data['config'].get('name', new_use_case)}", type="message")

await session.send_custom_message("toggleActiveTab", {"activeTab": "data_input"})
await session.send_custom_message("resetDataInputUploads", {})

ui.modal_remove()
await session.send_custom_message("setUseCaseLoading", {"visible": False})
Expand Down
209 changes: 1 addition & 208 deletions src/XAInyPredictor/modules/xai.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,214 +354,7 @@ def analyze_patient(
return fig_radar, fig_curve


def analyze_patient_new(
patient_id,
df,
delta_train,
delta_test,
x_train,
y_train,
features_to_plot=None,
n_dists=3, # Reduced default for clarity
max_plot_curves=10,
show_closest_radial=True,
show_average_radial=True,
show_average_class0_radial=True,
show_average_class1_radial=True
):

logger.debug("Analyzing patient %s with enhanced visualization.", patient_id)

# --- 1. DATA PREPARATION ---

# Locate patient
patient_ids = df['ID'].astype(int).tolist()
if int(patient_id) not in patient_ids:
logger.debug("Patient %s not found.", patient_id)
return None, None

patient_index = df[df['ID'] == int(patient_id)].index[0]

# Filter Features
if features_to_plot and len(features_to_plot) > 0:
clean_feats = [feat.replace(' ', '_') for feat in features_to_plot]
feature_names = [x for x in clean_feats if x in delta_test.columns and x not in ["const", "pred_prob"]]
else:
feature_names = [x for x in delta_test.columns if x not in ["const", "pred_prob"]]

# Extract data subsets
d_train = delta_train[feature_names].values
d_test_patient = delta_test.loc[patient_index, feature_names].values.flatten()
patient_prob = delta_test.iloc[patient_index]['pred_prob']

# Probability
pred_prob = delta_test.loc[patient_index, "pred_prob"]

# Neighbor finding (Euclidean distance in SHAP/Delta space)
dists = np.linalg.norm(d_train - d_test_patient, axis=1)
idx_closest = np.argsort(dists)[:n_dists]

# --- 2. RADAR PLOT ---

# Setup Data for Radar
# We use MinMax Scaling on the DELTA (contribution) values.
# Min = Lowest contribution observed in training (Low Risk)
# Max = Highest contribution observed in training (High Risk)

mins = d_train.min(axis=0)
maxs = d_train.max(axis=0)
ranges = maxs - mins
ranges[ranges == 0] = 1e-9 # Avoid division by zero

def normalize(v):
return (v - mins) / ranges

# Prepare vectors to plot
pat_norm = normalize(d_test_patient)

# Averages
class0_mask = (y_train == 0).values
class1_mask = (y_train == 1).values

avg_norm = normalize(d_train.mean(axis=0))
avg_c0_norm = normalize(d_train[class0_mask].mean(axis=0))
avg_c1_norm = normalize(d_train[class1_mask].mean(axis=0))

# Plotting
N = len(feature_names)
theta = radar_factory(N, frame='polygon')

fig_radar, ax = plt.subplots(figsize=(9, 9), subplot_kw=dict(projection='radar'))

# Grid lines and labels
ax.set_rgrids([0.2, 0.4, 0.6, 0.8], labels=[], angle=0, color="grey", alpha=0.3)
ax.set_varlabels([f.replace("_", " ") for f in feature_names])
ax.tick_params(pad=15) # Move labels out slightly

# 1. Plot Reference Populations
if show_average_class0_radial:
ax.plot(theta, avg_c0_norm, color='#2ca02c', linewidth=2, linestyle='--', label='Avg. negative')

if show_average_class1_radial:
ax.plot(theta, avg_c1_norm, color='#d62728', linewidth=2, linestyle='--', label='Avg. positive')

if show_average_radial:
ax.plot(theta, avg_norm, color='grey', linewidth=2, label='Population Average')

# 2. Plot Closest Neighbors (lighter opacity)
if show_closest_radial:
for i, idx in enumerate(idx_closest):
neighbor_vals = normalize(d_train[idx])
ax.plot(theta, neighbor_vals, color='#ff7f0e', alpha=0.3, label='Similar Patients' if i == 0 else "")

# 3. Plot Selected Patient (Thick, Filled)
ax.plot(theta, pat_norm, color='#1f77b4', linewidth=2, label='Selected Patient')
ax.fill(theta, pat_norm, color='#1f77b4', alpha=0.1)

# Styling
title = f"Patient {patient_id} (prob = {patient_prob:.2f})"
ax.set_title(title, position=(0.5, 1.1), ha='center', weight='bold')

# Improved Legend
legend = ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.1), fontsize='small', frameon=False)


# --- 3. CURVES PLOT ---

# Select top features (sorted by absolute contribution for this patient)
# This ensures we see the most relevant features first
abs_contribution = np.abs(d_test_patient)
top_indices = np.argsort(abs_contribution)[::-1][:max_plot_curves]
top_features = [feature_names[i] for i in top_indices]

n_curves = len(top_features) + 1

# Use constrained_layout for automatic nice spacing
fig_curve, axs = plt.subplots(n_curves, 1, figsize=(8, 3 * n_curves), constrained_layout=True)

if n_curves == 1: axs = [axs] # Handle single plot case

# -- A. Overall Probability Gauge (Top Plot) --
ax_prob = axs[0]

# Create a theoretical sigmoid background
x_sigmoid = np.linspace(-6, 6, 100)
y_sigmoid = 1 / (1 + np.exp(-x_sigmoid))

# Plot population density of predictions
all_probs = delta_test['pred_prob'].values
ax_prob.hist(delta_train['pred_prob'], bins=30, density=False, alpha=0.15, color='grey', label='Population Dist.')

# Plot patient marker
ax_prob_twin = ax_prob.twinx() # Use twin axis for probability curve vs histogram count
ax_prob_twin.plot([], []) # Dummy to align colors
ax_prob_twin.set_ylim(0, 1.1)
ax_prob_twin.set_yticks([0, 0.5, 1])
ax_prob_twin.set_ylabel("Probability")

# Current patient line
ax_prob.axvline(pred_prob, color='#1f77b4', linewidth=3, linestyle='-', label=f'Patient: {pred_prob:.2f}')

ax_prob.set_title("Overall Risk Prediction", weight='bold')
ax_prob.set_xlabel("Predicted Probability")
ax_prob.set_yticks([]) # Hide histogram counts
ax_prob.legend(loc='upper left')


# -- B. Feature Contribution Curves --

for i, feat in enumerate(top_features):
ax = axs[i+1]

# Get data
raw_vals = x_train[feat].values
shap_vals = delta_train[feat].values

# 1. Background Density (Histogram)
# Shows where most patients lie for this feature
ax_hist = ax.twinx()
ax_hist.hist(raw_vals, bins=20, color='grey', alpha=0.1, density=True)
ax_hist.set_yticks([]) # Hide density labels

# 2. Relationship Curve (Smoothed or scatter)
# We sort to draw a line
sort_idx = np.argsort(raw_vals)
ax.plot(raw_vals[sort_idx], shap_vals[sort_idx], color='black', alpha=0.4, linewidth=1, label='Risk Trend')

# 3. Reference Points (Class Averages)
c0_mean_x = x_train.loc[y_train==0, feat].mean()
c0_mean_y = delta_train.loc[y_train==0, feat].mean()
c1_mean_x = x_train.loc[y_train==1, feat].mean()
c1_mean_y = delta_train.loc[y_train==1, feat].mean()

ax.scatter(c0_mean_x, c0_mean_y, color='#2ca02c', s=100, marker='D', label='Avg Low Risk', zorder=3)
ax.scatter(c1_mean_x, c1_mean_y, color='#d62728', s=100, marker='D', label='Avg High Risk', zorder=3)

# 4. Patient & Neighbors
pat_x = df.loc[patient_index, feat]
pat_y = delta_test.loc[patient_index, feat]

# Neighbors
for n_idx in idx_closest:
n_x = x_train.iloc[n_idx][feat]
n_y = delta_train.iloc[n_idx][feat]
ax.scatter(n_x, n_y, color='#ff7f0e', alpha=0.6, s=30)

# Patient (Large Dot)
ax.scatter(pat_x, pat_y, color='#1f77b4', s=150, edgecolors='white', linewidth=2, label='Patient', zorder=5)

# Labels
ax.set_title(f"Feature: {feat.replace('_', ' ')}")
ax.set_xlabel(f"Value ({feat})")
ax.set_ylabel("Risk Contribution")
ax.grid(True, alpha=0.2)

if i == 0: # Only legend on first feature plot to save space
ax.legend(loc='best', fontsize='small')

return fig_radar, fig_curve

def radar_factory(num_vars, frame='circle'):
def radar_factory(num_vars, frame='circle'):
"""
Create a radar chart with `num_vars` Axes.

Expand Down
Loading
Loading