-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
508 lines (425 loc) · 22 KB
/
app.py
File metadata and controls
508 lines (425 loc) · 22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
import streamlit as st
import subprocess
import os
import pandas as pd
from pathlib import Path
import psutil
import time
st.set_page_config(page_title="DP-Clinical-ICL Generator", layout="wide")
def check_system_requirements():
"""Check if system meets the minimum requirements"""
requirements = {
"RAM": {"required": 16, "actual": round(psutil.virtual_memory().total/1024**3), "unit": "GB"},
"Disk": {"required": 10, "actual": round(psutil.disk_usage('/').free/1024**3), "unit": "GB"}
}
# Check GPU using nvidia-smi
try:
gpu_info = subprocess.check_output(['nvidia-smi', '--query-gpu=memory.total', '--format=csv,noheader,nounits'],
universal_newlines=True)
# Split the output into lines and get the maximum GPU memory
gpu_memories = [int(x.strip()) for x in gpu_info.strip().split('\n')]
total_gpu_memory = sum(gpu_memories) / 1024 # Convert MB to GB
num_gpus = len(gpu_memories)
requirements["GPU"] = {"required": True, "actual": True, "unit": None}
requirements["GPU Memory"] = {"required": 14, "actual": round(total_gpu_memory, 1), "unit": "GB"}
requirements["Number of GPUs"] = {"required": 1, "actual": num_gpus, "unit": None}
except (subprocess.CalledProcessError, FileNotFoundError):
requirements["GPU"] = {"required": True, "actual": False, "unit": None}
requirements["GPU Memory"] = {"required": 14, "actual": 0, "unit": "GB"}
requirements["Number of GPUs"] = {"required": 1, "actual": 0, "unit": None}
return requirements
def download_mimic_data(username, password):
"""Download MIMIC-IV dataset files"""
try:
os.makedirs("data", exist_ok=True)
os.chdir("data")
# Download commands with credentials
cmd1 = f"wget -r -N -c -np --user {username} --password {password} https://physionet.org/files/mimic-iv-note/2.2/"
cmd2 = f"wget -r -N -c -np --user {username} --password {password} https://physionet.org/files/mimiciv/2.2/"
subprocess.run(cmd1.split(), check=True)
subprocess.run(cmd2.split(), check=True)
os.chdir("..")
return True
except subprocess.CalledProcessError:
return False
def extract_data():
"""Run data extraction script with progress information"""
try:
# Create containers for progress
status_container = st.empty()
progress_container = st.container()
status_container.info("Starting data extraction process...")
# Check if the script exists
if not os.path.exists("extract_data_amc.py"):
status_container.error("❌ extract_data_amc.py not found in the current directory!")
return False
# Check if input files exist
required_files = [
"data/physionet.org/files/mimic-iv-note/2.2/note/discharge.csv.gz",
"data/physionet.org/files/mimiciv/2.2/hosp/procedures_icd.csv.gz",
"data/physionet.org/files/mimiciv/2.2/hosp/diagnoses_icd.csv.gz",
"data/physionet.org/files/mimiciv/2.2/hosp/d_icd_procedures.csv.gz",
"data/physionet.org/files/mimiciv/2.2/hosp/d_icd_diagnoses.csv.gz"
]
for file in required_files:
if not os.path.exists(file):
status_container.error(f"❌ Required file not found: {file}")
return False
status_container.info("📂 All required files found. Starting extraction...")
# Use Popen to capture output in real-time
process = subprocess.Popen(
["python", "-u", "extract_data_amc.py"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True,
env={**os.environ, 'PYTHONUNBUFFERED': '1', 'PYTHONWARNINGS': 'ignore'}
)
# Show output in real-time
while True:
output = process.stdout.readline()
error = process.stderr.readline()
if output == '' and error == '' and process.poll() is not None:
break
if output:
with progress_container:
if "warning" not in output.lower(): # Skip warning messages
st.info(output.strip())
# Update the status message
if "Loading" in output:
status_container.info("📂 Loading data files...")
elif "Merging" in output:
status_container.info("🔄 Merging datasets...")
elif "Formatting" in output:
status_container.info("✏️ Formatting ICD codes...")
elif "Filtering" in output:
status_container.info("🔍 Filtering records...")
elif "Saving" in output:
status_container.info("💾 Saving processed data...")
if error and "warning" not in error.lower(): # Only show non-warning errors
with progress_container:
st.error(error.strip())
# Get the return code and any error output
return_code = process.poll()
error_output = process.stderr.read()
if return_code == 0:
if os.path.exists("data/mimiciv_icd10.feather"):
status_container.success("✅ Data extraction completed successfully!")
return True
else:
status_container.error("❌ Extraction process completed but output file not found!")
return False
else:
if "warning" not in error_output.lower(): # Only show non-warning errors
status_container.error(f"❌ Extraction failed with error code {return_code}: {error_output}")
return False
except Exception as e:
st.error(f"❌ Error during extraction: {str(e)}")
return False
def generate_data(model_name, num_shots, dataset_size, temperature, prompt_index, custom_dataset=None, prompt=None, nonprivate=False, epsilons=None):
"""Run data generation script"""
try:
# Create data and generated directories if they don't exist
os.makedirs("data/generated", exist_ok=True)
cmd = [
"python", "DP_ICL_gen.py",
"--model_name", model_name,
"--num_shots", str(num_shots),
"--generated_dataset_size", str(dataset_size),
"--temperature", str(temperature),
]
# Handle custom prompt vs prompt index
if prompt:
# Validate custom prompt format
if not prompt.strip().endswith("ICD10-CODES="):
error_message = f"""
⚠️ Custom Prompt Format Error
Your prompt must end with exactly 'ICD10-CODES=' (without quotes). This is required because:
1. The script needs to know where to insert the ICD-10 codes
2. The format must be exact (no extra spaces after the '=')
3. The codes will be inserted immediately after the '='
Your prompt ends with: '{prompt.strip()[-20:] if len(prompt.strip()) > 20 else prompt.strip()}'
To fix this:
1. Check for any trailing spaces or newlines
2. Make sure 'ICD10-CODES=' is the last part of your prompt
3. Verify there are no extra characters after the '='
Example of correct prompt ending:
"... rest of your prompt text here ICD10-CODES="
"""
st.error(error_message)
return False
cmd.extend(["--prompt", prompt])
cmd.extend(["--prompt_index", "0"]) # Use 0 as default when custom prompt is provided
else:
cmd.extend(["--prompt_index", str(prompt_index)])
if custom_dataset:
cmd.extend(["--custom_dataset_path", custom_dataset])
if nonprivate:
cmd.append("--nonprivate")
if epsilons:
cmd.extend(["--epsilons"] + [str(eps) for eps in epsilons])
subprocess.run(cmd, check=True)
return True
except subprocess.CalledProcessError as e:
st.error(f"Generation failed with error: {str(e)}")
return False
# Streamlit UI
st.title("DP-Clinical-ICL Generator")
# Sidebar for navigation
page = st.sidebar.selectbox(
"Step",
["System Check", "Dataset Download", "Data Extraction", "Data Generation"]
)
if page == "System Check":
st.header("System Requirements Check")
requirements = check_system_requirements()
for resource, details in requirements.items():
if details["actual"] >= details["required"]:
message = f"{resource}: {details['actual']}"
if details["unit"]:
message += f" {details['unit']}"
message += f" (Required: {details['required']}"
if details["unit"]:
message += f" {details['unit']}"
message += ")"
st.success(message)
else:
message = f"{resource}: {details['actual']}"
if details["unit"]:
message += f" {details['unit']}"
message += f" (Required: {details['required']}"
if details["unit"]:
message += f" {details['unit']}"
message += ")"
st.error(message)
elif page == "Dataset Download":
st.header("MIMIC-IV Dataset Download")
username = st.text_input("PhysioNet Username")
password = st.text_input("PhysioNet Password", type="password")
if st.button("Download Dataset"):
if username and password:
with st.spinner("Downloading dataset..."):
if download_mimic_data(username, password):
st.success("Dataset downloaded successfully!")
else:
st.error("Dataset download failed. Please check your credentials.")
else:
st.warning("Please enter your PhysioNet credentials.")
elif page == "Data Extraction":
st.header("Data Extraction")
st.warning("Note: Data extraction may take a few minutes to start and several more minutes to complete. This is normal as it needs to process large files.")
if st.button("Extract Data"):
with st.spinner("Extracting data... This may take a few minutes."):
if extract_data():
st.success("Data extraction completed!")
else:
st.error("Data extraction failed. Please check the logs.")
elif page == "Data Generation":
st.header("Data Generation")
# Initialize session state for generated files and timestamp if not exists
if 'generated_files' not in st.session_state:
st.session_state.generated_files = []
if 'last_generation_time' not in st.session_state:
st.session_state.last_generation_time = None
col1, col2 = st.columns(2)
with col1:
model_name = st.text_input("Model Name", value="llama3.2", help="Enter the name of any model available in Ollama (e.g., llama2, mistral, mixtral, etc.)")
num_shots = st.number_input("Number of Shots", min_value=1, value=5)
dataset_size = st.number_input("Dataset Size", min_value=5, value=100)
with col2:
temperature = st.slider(
"Temperature",
min_value=0.1,
max_value=1.0,
value=0.7,
help="Controls the randomness in the model's output. Lower values (0.1-0.3) make the text more focused and deterministic, while higher values (0.7-1.0) make it more creative and diverse. For medical text generation, values between 0.5-0.7 often provide a good balance between accuracy and variation."
)
use_custom_prompt = st.checkbox("Use Custom Prompt", value=False)
custom_dataset = st.file_uploader("Custom Dataset (optional)", type=["feather"])
# Add time estimate warning
estimated_time = dataset_size * num_shots * 0.6 # roughly 0.6 minutes per sample with 5 shots on a 3090
hours = int(estimated_time // 60)
minutes = int(estimated_time % 60)
st.warning(f"""
⚠️ **Generation Time Estimate**
With your current settings ({dataset_size} samples, {num_shots}-shot), expect approximately:
- {hours} hours and {minutes} minutes on an NVIDIA RTX 3090
- Longer times on less powerful GPUs
- Much longer (hours or days) without a GPU
This is because:
- Each sample takes about 1-2 minutes to generate
- The number of shots affects generation time
- The total time scales with the number of samples
- GPU power significantly impacts speed
You can reduce the generation time by:
1. Reducing the number of samples
2. Using fewer shots
3. Running on a powerful GPU
""")
# Privacy settings
privacy_option = st.radio(
"Privacy Setting",
["Non-private", "Default epsilons [1,3,8]", "Custom epsilons"],
help="Choose privacy level for generation"
)
if privacy_option == "Custom epsilons":
custom_epsilons = st.text_input(
"Custom Epsilon Values",
value="1,3,8",
help="Enter comma-separated epsilon values (e.g., 1,3,8)"
)
epsilons = [float(eps.strip()) for eps in custom_epsilons.split(",") if eps.strip()]
elif privacy_option == "Default epsilons [1,3,8]":
epsilons = [1, 3, 8]
else: # Non-private
epsilons = None
# Custom prompt or prompt index based on checkbox
if use_custom_prompt:
custom_prompt = st.text_area(
"Custom Prompt",
value="""[EXAMPLE PROMPT]
Please generate a realistic, concise, and professional clinical discharge summary for a patient based on the following ICD-10 codes. Do not include the ICD-10 codes themselves in the report; instead, reference the medical conditions they represent. Before composing the summary, internally develop a logical and medically accurate patient case, including the timeline of symptom onset, diagnosis, interventions, and outcomes. Do not include this internal planning in the final summary.
The discharge summary should:
Use clinical language with standard medical abbreviations (e.g., CHF for congestive heart failure, N/V for nausea and vomiting).
Be succinct, focusing on essential clinical information without unnecessary explanations.
Reflect a coherent and medically plausible sequence of events with appropriate timing.
Represent a wide range of cases, including both common and rare conditions when specified.
Mimic the style and tone of actual clinical documentation used among healthcare professionals.
Format:
Patient Identification:
Name: [Use initials only, e.g., J.D.]
Age/Gender: [e.g., 45-year-old male]
Admission Date: [Realistic date]
Discharge Date: [Realistic date]
Admitting Diagnosis: [Primary reason for admission]
Discharge Diagnoses:
Primary: [State main condition]
Secondary: [List comorbidities or complications]
Hospital Course:
[Summarize key diagnostic findings, treatments, and patient response]
Discharge Instructions:
Medications: [List with dosages]
Follow-Up: [Appointments, referrals]
Activity: [Restrictions or recommendations]
Diet: [Instructions if applicable]
Warnings: [Symptoms that require immediate attention]
Additional Requirements:
Exclude any patient-identifiable information beyond initials.
Do not include the internal case planning or timeline in the summary.
Ensure medical accuracy and plausibility in terms of disease progression and treatment.
Use appropriate medical terminology relevant to the conditions.
ICD10-CODES= """,
height=600,
help="Enter your custom prompt. Make sure it ends with 'ICD10-CODES= ' for proper code insertion."
)
prompt_index = None # Not used when custom prompt is provided
else:
st.subheader("Choose a Prompt Template")
prompts = {
0: "Basic prompt - Generates a simple, straightforward discharge summary:\n\n" +
"Generate a clinical discharge summary of a patient who had the conditions and procedures described by the following codes ICD10-CODES= ",
1: "Detailed prompt - Focuses on clinical accuracy and standard formatting:\n\n" +
"""Please generate a realistic and concise clinical discharge summary for a patient based on the following ICD-10 codes. Do not include the ICD-10 codes themselves in the report; instead, refer to the medical conditions they represent. Before writing the summary, internally create a logical and medically accurate timeline of the patient's diagnosis, treatment, and progress. Use standard medical abbreviations where appropriate to mirror real clinical documentation. Focus on essential clinical information, avoiding unnecessary explanations or verbosity. Ensure that the report accurately reflects the management of both common and rare diseases as applicable.
Requirements:
- Use standard medical abbreviations (e.g., BP for blood pressure, HR for heart rate)
- Keep the summary concise and focused on relevant clinical details
- Ensure the sequence of events and timing make medical sense
- Cover a wide range of use cases, including rare diseases when specified
- Do not include any ICD-10 codes in the text of the report
Format:
Admission Date: [Date]
Discharge Date: [Date]
Discharge Summary:
Reason for Admission: [Brief]
History of Present Illness: [Concise]
Hospital Course: [Key events]
Discharge Plan: [Instructions]
ICD10-CODES= """,
2: "Professional prompt - Emphasizes medical documentation standards:\n\n" +
"""Please generate a realistic, concise, and professional clinical discharge summary for a patient based on the following ICD-10 codes. Do not include the ICD-10 codes themselves in the report; instead, reference the medical conditions they represent. Before composing the summary, internally develop a logical and medically accurate patient case, including the timeline of symptom onset, diagnosis, interventions, and outcomes.
The discharge summary should:
- Use clinical language with standard medical abbreviations
- Be succinct and focused on essential information
- Reflect a coherent sequence of events
- Mimic actual clinical documentation style
Format:
Patient Identification: [Initials only]
Age/Gender: [e.g., 45M]
Dates: [Admission/Discharge]
Diagnoses: [Primary/Secondary]
Hospital Course: [Key events]
Discharge Plan: [Complete instructions]
ICD10-CODES= """,
3: "Custom prompt template - Use this as a starting point for your own prompt:\n\n" +
"""[ADD YOUR CUSTOM PROMPT HERE]
Remember to:
- Include clear formatting instructions
- Specify medical terminology preferences
- Define documentation standards
- End with ICD10-CODES= """
}
prompt_index = st.selectbox(
"Prompt Index",
options=list(prompts.keys()),
format_func=lambda x: f"Prompt {x}",
help="Choose a predefined prompt template"
)
# Show the selected prompt
st.text_area(
"Selected Prompt Preview",
value=prompts[prompt_index],
height=400,
disabled=True
)
custom_prompt = None
# Add a note about available models
st.info("You can check available models at https://ollama.com/search. Make sure to pull your desired model first using 'ollama pull MODEL_NAME'")
if st.button("Generate Data"):
with st.spinner("Generating data..."):
# Record the start time of generation
generation_start_time = time.time()
if generate_data(
model_name, num_shots, dataset_size, temperature,
prompt_index, custom_dataset, custom_prompt,
nonprivate=(privacy_option == "Non-private"),
epsilons=epsilons
):
st.success("Data generation completed!")
# Wait a brief moment to ensure files are written
time.sleep(1)
# Update the list of generated files - only get files from this generation
st.session_state.last_generation_time = generation_start_time
# Find all generated files in the output directory and its subdirectories
current_files = []
for root, dirs, files in os.walk("data/generated"):
for file in files:
if file.endswith('.csv') and 'generated_dataset' in file:
file_path = Path(os.path.join(root, file))
if os.path.getctime(file_path) >= generation_start_time:
current_files.append(file_path)
st.session_state.generated_files = current_files
else:
st.error("Data generation failed. Please check the logs.")
# Display generated files section (always show if files exist)
if st.session_state.generated_files and st.session_state.last_generation_time:
st.subheader("Generated Files")
for file in st.session_state.generated_files:
if file.exists() and os.path.getctime(file) >= st.session_state.last_generation_time: # Only show files from last generation
try:
df = pd.read_csv(file)
st.write(f"File: {file.name}")
st.write(f"Location: {file}")
st.dataframe(df.head())
# Download button for each file
with open(file, "rb") as f:
st.download_button(
label=f"Download {file.name}",
data=f,
file_name=file.name,
mime="text/csv"
)
st.markdown("---") # Add a separator between files
except Exception as e:
st.error(f"Error reading file {file.name}: {str(e)}")
else:
# Show warning if no files exist
st.warning("No generated dataset files found. Use the form above to generate new files.")