[96a5a0]: / patient_matching / testing.py

Download this file

198 lines (168 with data), 7.5 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
# patient_matching/testing.py
import json
import logging
import os
from typing import Dict
from patient_matching.aggregated_trial_truth import (
AggregatedTruthTable,
TruthValue,
aggregate_identified_trials,
)
from patient_matching.trial_overall_evaluator import evaluate_all_trials
from patient_matching.truth_evaluator import update_truth_table_with_user_response
from patient_matching.user_answer_parser import (
ConversationHistory,
parse_user_response,
)
from src.models.identified_criteria import IdentifiedTrial
from src.models.logical_criteria import LogicalTrial
from src.repositories.trial_repository import (
export_pydantic_to_json,
load_pydantic_from_json,
load_pydantic_models_from_folder,
)
from src.utils.config import DEFAULT_OUTPUT_DIR, setup_logging
from src.utils.helpers import get_non_empty_input
setup_logging(log_to_file=True, log_level=logging.INFO)
def get_trial_limit() -> int:
"""
Prompt the user for how many trials to load.
"""
while True:
try:
limit_str = input(
"\nHow many trials would you like to load? (just press enter for all): "
).strip()
if limit_str == "":
return 0
if limit_str.lower() in ["quit", "exit"]:
logging.info("Exiting program.")
exit()
limit = int(limit_str)
if limit >= 0:
return limit
print("Please enter a non-negative number.")
except ValueError:
print("Please enter a valid number.")
def has_unknowns(status_map: Dict[str, TruthValue]) -> bool:
"""
Check if there are any trials with an UNKNOWN truth value in the status map.
Args:
status_map (Dict[str, TruthValue]): A dictionary mapping trial IDs to their truth values.
Returns:
bool: True if there are UNKNOWN truth values, False otherwise.
"""
return any(value == TruthValue.UNKNOWN for value in status_map.values())
def main() -> None: # sourcery skip: dict-assign-update-to-union
logging.basicConfig(level=logging.INFO)
# 1. Get number of trials to load
limit = get_trial_limit()
if limit == 0:
logging.info("Loading all trials...")
else:
logging.info("Loading first %d trials...", limit)
# 2. Load identified and logical trials
identified_folder = os.path.join(DEFAULT_OUTPUT_DIR, "allTrials", "identified")
logical_folder = os.path.join(DEFAULT_OUTPUT_DIR, "allTrials", "logical")
identified_trials = load_pydantic_models_from_folder(
identified_folder,
IdentifiedTrial,
file_extension="_identified.json",
limit=limit if limit > 0 else None,
)
logging.info("Loaded %d identified trials", len(identified_trials))
# Load all logical trials into memory
logical_trials: Dict[str, LogicalTrial] = {}
for trial in identified_trials:
trial_id = trial.info.nct_id
model = load_pydantic_from_json(
logical_folder, f"{trial_id}_logical.json", LogicalTrial
)
if model is None:
logging.warning("Logical model for trial %s not found", trial_id)
else:
logical_trials[trial_id] = model
# 3. Create initial aggregated truth table using identified trials
aggregator: AggregatedTruthTable = aggregate_identified_trials(identified_trials)
# Initial status map: evaluate all trials once
status_list = evaluate_all_trials(list(logical_trials.values()), aggregator)
status_map: Dict[str, TruthValue] = {
res.trial_id: res.overall_truth for res in status_list
}
matching_folder = os.path.join(DEFAULT_OUTPUT_DIR, "matching")
os.makedirs(matching_folder, exist_ok=True)
# 4. Save initial aggregator
export_pydantic_to_json(aggregator, "aggregated_truth.json", matching_folder)
logging.info(
"Initial aggregator saved to %s",
os.path.join(matching_folder, "aggregated_truth.json"),
)
# 5. Initialize conversation history
conversation_history = ConversationHistory()
# 6. Enter user input loop
logging.info("Entering user response loop. Type 'quit' or 'exit' to stop.")
while has_unknowns(status_map):
# Get user input
question = get_non_empty_input("\nEnter your question (or 'quit' to exit): ")
user_input = get_non_empty_input("\nEnter your response:")
try:
# Parse user response
parsed_response = parse_user_response(user_input, question)
conversation_history.add_response(question, parsed_response.parsed_answers)
export_pydantic_to_json(
conversation_history,
"conversation_history.json",
matching_folder,
)
# Update truth table and track modified trials
if modified_trials := {
trial_id
for criterion in parsed_response.parsed_answers
for trial_id in update_truth_table_with_user_response(
criterion, aggregator
)
}:
logging.info("Updated truth values for %d trials", len(modified_trials))
logging.info("Modified trials: %s", modified_trials)
else:
logging.info("No trials were affected by this response")
# NOTE: we don't need to re-evaluate trials that are already False since they won't change
# however once we change to an analog value for truth evaluation, we may need to re-evaluate all trials
if to_eval_ids := [
tid
for tid in modified_trials
if status_map.get(tid, TruthValue.UNKNOWN) != TruthValue.FALSE
]:
to_eval_models = [logical_trials[tid] for tid in to_eval_ids]
new_results = evaluate_all_trials(to_eval_models, aggregator)
else:
new_results = []
status_map.update({res.trial_id: res.overall_truth for res in new_results})
# Print affected trials summary, sorted by truth value then NCTID
print(f"\n{len(new_results)} trials affected.\n")
# Sort by truth value (as string for stable ordering), then by NCTID
sorted_results = sorted(
new_results, key=lambda res: (str(res.overall_truth), res.trial_id)
)
for res in sorted_results:
print(f"{res.trial_id}: {res.overall_truth}")
# Save updated aggregator
export_pydantic_to_json(
aggregator, "aggregated_truth.json", matching_folder
)
# Save trial truth map
try:
truth_map_path = os.path.join(matching_folder, "trial_truth_map.json")
with open(truth_map_path, "w", encoding="utf-8") as f:
json.dump(
{tid: val.value for tid, val in status_map.items()}, f, indent=4
)
logging.info("Trial truth map saved to %s", truth_map_path)
except Exception as e:
logging.error("Failed to save trial truth map: %s", e)
except Exception as e:
logging.error("Error processing user input: %s", e)
continue
logging.info("Exiting user response loop.")
if __name__ == "__main__":
main()